task 1.4 without plots
This commit is contained in:
parent
57eefddad8
commit
3b94b39368
1 changed files with 39 additions and 5 deletions
44
A2.py
44
A2.py
|
|
@ -1,10 +1,11 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import warnings
|
||||
from sklearn.model_selection import train_test_split
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
#reproducibility
|
||||
np.random.seed(1)
|
||||
np.random.seed(2)
|
||||
|
||||
#__________________________________________________________________________________
|
||||
#Task 1
|
||||
|
|
@ -150,13 +151,46 @@ for i, D in enumerate(D_i):
|
|||
plt.ylabel('y')
|
||||
if i >= 9:
|
||||
plt.xlabel('x')
|
||||
|
||||
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
#__________________________________________________________________________________
|
||||
#1.4 Model Selection
|
||||
|
||||
# Split the data into training and validation sets
|
||||
x_train, x_val, y_train, y_val = train_test_split(x, y_noisy, test_size=0.2, random_state=42)
|
||||
|
||||
# range of basis functions to test
|
||||
D_values = list(range(0, 46)) # 0 to 45
|
||||
|
||||
# Initialize arrays to store errors
|
||||
train_sse = []
|
||||
val_sse = []
|
||||
|
||||
|
||||
# For each number of basis functions
|
||||
for D in D_values:
|
||||
# Create and fit the model
|
||||
model = GaussianRegression(sigma=1.0)
|
||||
model.fit(x_train, y_train, D)
|
||||
|
||||
# predict on training then validation
|
||||
yh_train = model.predict(x_train)
|
||||
yh_train = yh_train.flatten() if yh_train.ndim > 1 else yh_train
|
||||
|
||||
yh_val = model.predict(x_val)
|
||||
yh_val = yh_val.flatten() if yh_val.ndim > 1 else yh_val
|
||||
|
||||
# compute SSE
|
||||
sse_train = np.sum((y_train - yh_train)**2)
|
||||
sse_val = np.sum((y_val - yh_val)**2)
|
||||
|
||||
train_sse.append(sse_train)
|
||||
val_sse.append(sse_val)
|
||||
|
||||
print(f"D={D:2d}: Train SSE = {sse_train:8.2f}, Val SSE = {sse_val:8.2f}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue