diff --git a/A2.py b/A2.py index e1459e2..3af0426 100644 --- a/A2.py +++ b/A2.py @@ -93,26 +93,25 @@ class GaussianRegression: def __init__(self, sigma=1.0): self.sigma = sigma - self.weights = None + self.w = None self.D = None - def fit(self, x_train, y_train, D): + def fit(self, x, y, D): # Store D for later use in predict self.D = D - # create features for training and fit using least squares - X_train_basis = gaussian_features(x_train, D, self.sigma) - self.weights = np.linalg.lstsq(X_train_basis, y_train, rcond=None)[0] + X = gaussian_features(x, D, self.sigma) + self.w = np.linalg.lstsq(X, y, rcond=None)[0] return self - def predict(self, x_predict): + def predict(self, x): # create features for prediction and predict - X_predict_basis = gaussian_features(x_predict, self.D, self.sigma) - y_pred = X_predict_basis @ self.weights + X = gaussian_features(x, self.D, self.sigma) + yh = X @ self.w - return y_pred + return yh