Update A2.py
This commit is contained in:
parent
3b94b39368
commit
65fe618f93
1 changed files with 3 additions and 3 deletions
6
A2.py
6
A2.py
|
|
@ -16,7 +16,7 @@ def generate_data(n_samples=100, noise_std=1.0):
|
||||||
x = np.linspace(0, 10, n_samples)
|
x = np.linspace(0, 10, n_samples)
|
||||||
|
|
||||||
#y values without noise
|
#y values without noise
|
||||||
y_clean = (np.log(x + 1e-10) + 1) * np.cos(x) + np.sin(2*x)
|
y_clean = np.log(x + 1) * np.cos(x) + np.sin(2*x)
|
||||||
|
|
||||||
#noise
|
#noise
|
||||||
noise = np.random.normal(0, noise_std, n_samples)
|
noise = np.random.normal(0, noise_std, n_samples)
|
||||||
|
|
@ -117,7 +117,7 @@ class GaussianRegression:
|
||||||
|
|
||||||
|
|
||||||
def true_function(x):
|
def true_function(x):
|
||||||
return (np.log(x + 1e-10) + 1) * np.cos(x) + np.sin(2*x)
|
return np.log(x + 1) * np.cos(x) + np.sin(2*x)
|
||||||
|
|
||||||
# fit models with different numbers of basis functions and plot
|
# fit models with different numbers of basis functions and plot
|
||||||
D_i = [0, 2, 5, 10, 13, 15, 17, 20, 25, 30, 35, 45]
|
D_i = [0, 2, 5, 10, 13, 15, 17, 20, 25, 30, 35, 45]
|
||||||
|
|
@ -178,7 +178,7 @@ for D in D_values:
|
||||||
# predict on training then validation
|
# predict on training then validation
|
||||||
yh_train = model.predict(x_train)
|
yh_train = model.predict(x_train)
|
||||||
yh_train = yh_train.flatten() if yh_train.ndim > 1 else yh_train
|
yh_train = yh_train.flatten() if yh_train.ndim > 1 else yh_train
|
||||||
|
|
||||||
yh_val = model.predict(x_val)
|
yh_val = model.predict(x_val)
|
||||||
yh_val = yh_val.flatten() if yh_val.ndim > 1 else yh_val
|
yh_val = yh_val.flatten() if yh_val.ndim > 1 else yh_val
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue