Update A2.py
This commit is contained in:
parent
f13d13947e
commit
57eefddad8
1 changed files with 8 additions and 9 deletions
17
A2.py
17
A2.py
|
|
@ -93,26 +93,25 @@ class GaussianRegression:
|
||||||
|
|
||||||
def __init__(self, sigma=1.0):
|
def __init__(self, sigma=1.0):
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
self.weights = None
|
self.w = None
|
||||||
self.D = None
|
self.D = None
|
||||||
|
|
||||||
def fit(self, x_train, y_train, D):
|
def fit(self, x, y, D):
|
||||||
# Store D for later use in predict
|
# Store D for later use in predict
|
||||||
self.D = D
|
self.D = D
|
||||||
|
|
||||||
# create features for training and fit using least squares
|
# create features for training and fit using least squares
|
||||||
X_train_basis = gaussian_features(x_train, D, self.sigma)
|
X = gaussian_features(x, D, self.sigma)
|
||||||
self.weights = np.linalg.lstsq(X_train_basis, y_train, rcond=None)[0]
|
self.w = np.linalg.lstsq(X, y, rcond=None)[0]
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def predict(self, x_predict):
|
def predict(self, x):
|
||||||
# create features for prediction and predict
|
# create features for prediction and predict
|
||||||
X_predict_basis = gaussian_features(x_predict, self.D, self.sigma)
|
X = gaussian_features(x, self.D, self.sigma)
|
||||||
y_pred = X_predict_basis @ self.weights
|
yh = X @ self.w
|
||||||
|
|
||||||
return y_pred
|
return yh
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue