From 57eefddad82ef13f6ddd9aff11d29f8030c62762 Mon Sep 17 00:00:00 2001 From: ShaaniBel Date: Sat, 18 Oct 2025 12:22:05 -0400 Subject: [PATCH] Update A2.py --- A2.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/A2.py b/A2.py index e1459e2..3af0426 100644 --- a/A2.py +++ b/A2.py @@ -93,26 +93,25 @@ class GaussianRegression: def __init__(self, sigma=1.0): self.sigma = sigma - self.weights = None + self.w = None self.D = None - def fit(self, x_train, y_train, D): + def fit(self, x, y, D): # Store D for later use in predict self.D = D - # create features for training and fit using least squares - X_train_basis = gaussian_features(x_train, D, self.sigma) - self.weights = np.linalg.lstsq(X_train_basis, y_train, rcond=None)[0] + X = gaussian_features(x, D, self.sigma) + self.w = np.linalg.lstsq(X, y, rcond=None)[0] return self - def predict(self, x_predict): + def predict(self, x): # create features for prediction and predict - X_predict_basis = gaussian_features(x_predict, self.D, self.sigma) - y_pred = X_predict_basis @ self.weights + X = gaussian_features(x, self.D, self.sigma) + yh = X @ self.w - return y_pred + return yh