Update A2.py

This commit is contained in:
ShaaniBel 2025-10-18 12:22:05 -04:00
parent f13d13947e
commit 57eefddad8

17
A2.py
View file

@ -93,26 +93,25 @@ class GaussianRegression:
def __init__(self, sigma=1.0):
self.sigma = sigma
self.weights = None
self.w = None
self.D = None
def fit(self, x_train, y_train, D):
def fit(self, x, y, D):
# Store D for later use in predict
self.D = D
# create features for training and fit using least squares
X_train_basis = gaussian_features(x_train, D, self.sigma)
self.weights = np.linalg.lstsq(X_train_basis, y_train, rcond=None)[0]
X = gaussian_features(x, D, self.sigma)
self.w = np.linalg.lstsq(X, y, rcond=None)[0]
return self
def predict(self, x_predict):
def predict(self, x):
# create features for prediction and predict
X_predict_basis = gaussian_features(x_predict, self.D, self.sigma)
y_pred = X_predict_basis @ self.weights
X = gaussian_features(x, self.D, self.sigma)
yh = X @ self.w
return y_pred
return yh