From 3b94b39368cbb488dacaa94169184e8f1debdfe0 Mon Sep 17 00:00:00 2001 From: ShaaniBel Date: Sat, 18 Oct 2025 13:12:12 -0400 Subject: [PATCH] task 1.4 without plots --- A2.py | 44 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/A2.py b/A2.py index 3af0426..ec2c1f1 100644 --- a/A2.py +++ b/A2.py @@ -1,10 +1,11 @@ import numpy as np import matplotlib.pyplot as plt import warnings +from sklearn.model_selection import train_test_split warnings.filterwarnings('ignore') #reproducibility -np.random.seed(1) +np.random.seed(2) #__________________________________________________________________________________ #Task 1 @@ -150,13 +151,46 @@ for i, D in enumerate(D_i): plt.ylabel('y') if i >= 9: plt.xlabel('x') - - plt.tight_layout() plt.show() - - #__________________________________________________________________________________ #1.4 Model Selection + +# Split the data into training and validation sets +x_train, x_val, y_train, y_val = train_test_split(x, y_noisy, test_size=0.2, random_state=42) + +# range of basis functions to test +D_values = list(range(0, 46)) # 0 to 45 + +# Initialize arrays to store errors +train_sse = [] +val_sse = [] + + +# For each number of basis functions +for D in D_values: + # Create and fit the model + model = GaussianRegression(sigma=1.0) + model.fit(x_train, y_train, D) + + # predict on training then validation + yh_train = model.predict(x_train) + yh_train = yh_train.flatten() if yh_train.ndim > 1 else yh_train + + yh_val = model.predict(x_val) + yh_val = yh_val.flatten() if yh_val.ndim > 1 else yh_val + + # compute SSE + sse_train = np.sum((y_train - yh_train)**2) + sse_val = np.sum((y_val - yh_val)**2) + + train_sse.append(sse_train) + val_sse.append(sse_val) + + print(f"D={D:2d}: Train SSE = {sse_train:8.2f}, Val SSE = {sse_val:8.2f}") + + + +