fixed an error.
This commit is contained in:
parent
455b48c89b
commit
6508fcbbab
2 changed files with 9 additions and 9 deletions
|
|
@ -74,19 +74,19 @@ class LogisticRegression:
|
|||
n_samples = self.x.shape[0]
|
||||
batch_size = self.batch_size or n_samples
|
||||
|
||||
# number of batches per iteration
|
||||
n_batches = int(np.ceil(n_samples / batch_size))
|
||||
|
||||
for epoch in range(1, self.n_iter + 1):
|
||||
shuffled_idx = np.random.permutation(n_samples) # random permutation of the indices
|
||||
x_shuffled = self.x[shuffled_idx]
|
||||
y_shuffled = self.y[shuffled_idx]
|
||||
|
||||
# process execution for each mini‑batch
|
||||
for b in range(0, n_samples, batch_size):
|
||||
for b in range(n_batches):
|
||||
start = b * batch_size
|
||||
end = start + batch_size
|
||||
end = min(start + batch_size, n_samples)
|
||||
idx = shuffled_idx[start:end]
|
||||
|
||||
x_batch = x_shuffled[idx]
|
||||
y_batch = y_shuffled[idx]
|
||||
x_batch = self.x[idx]
|
||||
y_batch = self.y[idx]
|
||||
|
||||
|
||||
z = x_batch.dot(self.w)
|
||||
p = self.sigmoid(z)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue