task 1.4 without plots

This commit is contained in:
ShaaniBel 2025-10-18 13:12:12 -04:00
parent 57eefddad8
commit 3b94b39368

44
A2.py
View file

@ -1,10 +1,11 @@
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import warnings import warnings
from sklearn.model_selection import train_test_split
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
#reproducibility #reproducibility
np.random.seed(1) np.random.seed(2)
#__________________________________________________________________________________ #__________________________________________________________________________________
#Task 1 #Task 1
@ -151,12 +152,45 @@ for i, D in enumerate(D_i):
if i >= 9: if i >= 9:
plt.xlabel('x') plt.xlabel('x')
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()
#__________________________________________________________________________________ #__________________________________________________________________________________
#1.4 Model Selection #1.4 Model Selection
# Split the data into training and validation sets
x_train, x_val, y_train, y_val = train_test_split(x, y_noisy, test_size=0.2, random_state=42)
# range of basis functions to test
D_values = list(range(0, 46)) # 0 to 45
# Initialize arrays to store errors
train_sse = []
val_sse = []
# For each number of basis functions
for D in D_values:
# Create and fit the model
model = GaussianRegression(sigma=1.0)
model.fit(x_train, y_train, D)
# predict on training then validation
yh_train = model.predict(x_train)
yh_train = yh_train.flatten() if yh_train.ndim > 1 else yh_train
yh_val = model.predict(x_val)
yh_val = yh_val.flatten() if yh_val.ndim > 1 else yh_val
# compute SSE
sse_train = np.sum((y_train - yh_train)**2)
sse_val = np.sum((y_val - yh_val)**2)
train_sse.append(sse_train)
val_sse.append(sse_val)
print(f"D={D:2d}: Train SSE = {sse_train:8.2f}, Val SSE = {sse_val:8.2f}")