Untitled
unknown
python
a year ago
4.1 kB
7
Indexable
def train(self, X_train, y_train, X_val, y_val, epochs=10, learning_rate=0.001, batch_size = 32, loss_function='mse'):
self.train_losses = [] # Initialize a list to store training losses
self.val_losses = [] # Initialize a list to store validation losses
for epoch in range(epochs):
total_loss = 0
### START CODE HERE ###
num_batches = (X_train.shape[0] + batch_size - 1) // batch_size # Calculate the number of batches
### END CODE HERE ###
with tqdm(total=num_batches, desc=f"Epoch {epoch + 1}/{epochs}", unit="batch") as pbar:
for batch_idx in range(num_batches):
### START CODE HERE ###
# Get the batch data
X_batch = X_train[batch_idx * batch_size: (batch_idx + 1) * batch_size]
y_batch = y_train[batch_idx * batch_size: (batch_idx + 1) * batch_size]
### END CODE HERE ###
### START CODE HERE ###
# 1. Forward to get the prediction.
# 2. Calculate the loss according to your parameters ('cce' or 'mse').
# 3. Calculate dA.
# 4. backward with the calculated dA.
# 5. update the parameters.
y_pred = self.forward(X_batch)
# Compute loss
if loss_function == 'cce':
loss = compute_CCE_loss(y_pred, y_batch)
elif loss_function == 'mse':
loss = compute_MSE_loss(y_pred, y_batch)
else:
raise ValueError("Unsupported loss function")
total_loss += loss
dA = self.backward(y_pred, y_batch)
self.update(learning_rate)
### END CODE HERE ###
# Update the progress bar and loss every 5 iterations
if (batch_idx + 1) % 5 == 0:
pbar.set_postfix(loss=total_loss / (batch_idx + 1))
pbar.update(1) # Increment the progress bar by 1 unit
# Handle the remaining examples that do not fit into a full batch
if len(X_train) % batch_size != 0:
# Get the remaining data
### START CODE HERE ###
X_batch = X_train[batch_idx * batch_size:]
y_batch = y_train[batch_idx * batch_size:]
### END CODE HERE ###
### START CODE HERE ###
# Same as above in batch
y_pred = self.forward(X_batch)
# Compute loss
if loss_function == 'cce':
loss = compute_CCE_loss(y_pred, y_batch)
elif loss_function == 'mse':
loss = compute_MSE_loss(y_pred, y_batch)
else:
raise ValueError("Unsupported loss function")
total_loss += loss
dA = self.backward(y_pred, y_batch)
self.update(learning_rate)
### END CODE HERE ###
### START CODE HERE ###
avg_train_loss = total_loss / num_batches # Calculate the average loss over batches.
### END CODE HERE ###
self.train_losses.append(avg_train_loss)
print(f'Epoch {epoch + 1}/{epochs}, Training Loss: {avg_train_loss}')
# Validation part
### START CODE HERE ###
# 1. Get the prediction
# 2. compute the loss ('mse', 'cce').
y_pred = self.forward(X_val)
if loss_function == 'cce':
val_loss = compute_CCE_loss(y_pred, y_val)
elif loss_function == 'mse':
val_loss = compute_MSE_loss(y_pred, y_val)
### END CODE HERE ###
self.val_losses.append(val_loss)
print(f'Epoch {epoch + 1}/{epochs}, Validation Loss: {val_loss}')Editor is loading...
Leave a Comment