diff --git a/A2.py b/A2.py index 33b6d0a..a9b5e08 100644 --- a/A2.py +++ b/A2.py @@ -86,3 +86,79 @@ plt.tight_layout() plt.show() #__________________________________________________________________________________ +#1.3 Model fitting +#for now I used the whole data but idk we that's what they asked for that part +class GaussianRegression: + """Linear Regression with Gaussian Basis Functions""" + + def __init__(self, sigma=1.0): + self.sigma = sigma + self.weights = None + self.mus = None + self.D = None + + def fit(self, x_train, y_train, D): + # Store D for later use in predict + self.D = D + + # create features for training and fit using least squares + X_train_basis = gaussian_features(x_train, D, self.sigma) + self.weights = np.linalg.lstsq(X_train_basis, y_train, rcond=None)[0] + + return self + + + def predict(self, x_predict): + # create features for prediction and predict + X_predict_basis = gaussian_features(x_predict, self.D, self.sigma) + y_pred = X_predict_basis @ self.weights + + return y_pred + + + +def true_function(x): + return (np.log(x + 1e-10) + 1) * np.cos(x) + np.sin(2*x) + +# fit models with different numbers of basis functions and plot +D_i = [0, 2, 5, 10, 13, 15, 17, 20, 25, 30, 35, 45] +x_plot = np.linspace(0, 10, 300) + +plt.figure(figsize=(18, 12)) + +for i, D in enumerate(D_i): + plt.subplot(4, 3, i+1) + + # Create new model for each D value, fit and get predictions + model = GaussianRegression(sigma=1.0) + model.fit(x, y_noisy, D) + y_hat = model.predict(x_plot) + + # Ensure y_hat is 1D and has same length as x_plot + y_hat = y_hat.flatten() if y_hat.ndim > 1 else y_hat + + # Plot + plt.plot(x_plot, true_function(x_plot), 'b-', label='True Function', linewidth=2, alpha=0.7) + plt.plot(x, y_noisy, 'ro', label='Noisy Data', alpha=0.4, markersize=3) + plt.plot(x_plot, y_hat, 'g-', label=f'Fitted (D={D})', linewidth=2) + + plt.ylim(-6, 6) + plt.title(f'D = {D}') + plt.grid(True, alpha=0.3) + plt.legend(fontsize=8) + + # x and y labels + if i % 3 == 0: + plt.ylabel('y') + if i >= 9: + plt.xlabel('x') + + + +plt.tight_layout() +plt.show() + + + +#__________________________________________________________________________________ +#1.4 Model Selection