Update A2.py
This commit is contained in:
parent
f13d13947e
commit
57eefddad8
1 changed files with 8 additions and 9 deletions
17
A2.py
17
A2.py
|
|
@ -93,26 +93,25 @@ class GaussianRegression:
|
|||
|
||||
def __init__(self, sigma=1.0):
|
||||
self.sigma = sigma
|
||||
self.weights = None
|
||||
self.w = None
|
||||
self.D = None
|
||||
|
||||
def fit(self, x_train, y_train, D):
|
||||
def fit(self, x, y, D):
|
||||
# Store D for later use in predict
|
||||
self.D = D
|
||||
|
||||
# create features for training and fit using least squares
|
||||
X_train_basis = gaussian_features(x_train, D, self.sigma)
|
||||
self.weights = np.linalg.lstsq(X_train_basis, y_train, rcond=None)[0]
|
||||
X = gaussian_features(x, D, self.sigma)
|
||||
self.w = np.linalg.lstsq(X, y, rcond=None)[0]
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def predict(self, x_predict):
|
||||
def predict(self, x):
|
||||
# create features for prediction and predict
|
||||
X_predict_basis = gaussian_features(x_predict, self.D, self.sigma)
|
||||
y_pred = X_predict_basis @ self.weights
|
||||
X = gaussian_features(x, self.D, self.sigma)
|
||||
yh = X @ self.w
|
||||
|
||||
return y_pred
|
||||
return yh
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue