Untitled
unknown
python
5 months ago
4.1 kB
4
Indexable
def train(self, X_train, y_train, X_val, y_val, epochs=10, learning_rate=0.001, batch_size = 32, loss_function='mse'): self.train_losses = [] # Initialize a list to store training losses self.val_losses = [] # Initialize a list to store validation losses for epoch in range(epochs): total_loss = 0 ### START CODE HERE ### num_batches = (X_train.shape[0] + batch_size - 1) // batch_size # Calculate the number of batches ### END CODE HERE ### with tqdm(total=num_batches, desc=f"Epoch {epoch + 1}/{epochs}", unit="batch") as pbar: for batch_idx in range(num_batches): ### START CODE HERE ### # Get the batch data X_batch = X_train[batch_idx * batch_size: (batch_idx + 1) * batch_size] y_batch = y_train[batch_idx * batch_size: (batch_idx + 1) * batch_size] ### END CODE HERE ### ### START CODE HERE ### # 1. Forward to get the prediction. # 2. Calculate the loss according to your parameters ('cce' or 'mse'). # 3. Calculate dA. # 4. backward with the calculated dA. # 5. update the parameters. y_pred = self.forward(X_batch) # Compute loss if loss_function == 'cce': loss = compute_CCE_loss(y_pred, y_batch) elif loss_function == 'mse': loss = compute_MSE_loss(y_pred, y_batch) else: raise ValueError("Unsupported loss function") total_loss += loss dA = self.backward(y_pred, y_batch) self.update(learning_rate) ### END CODE HERE ### # Update the progress bar and loss every 5 iterations if (batch_idx + 1) % 5 == 0: pbar.set_postfix(loss=total_loss / (batch_idx + 1)) pbar.update(1) # Increment the progress bar by 1 unit # Handle the remaining examples that do not fit into a full batch if len(X_train) % batch_size != 0: # Get the remaining data ### START CODE HERE ### X_batch = X_train[batch_idx * batch_size:] y_batch = y_train[batch_idx * batch_size:] ### END CODE HERE ### ### START CODE HERE ### # Same as above in batch y_pred = self.forward(X_batch) # Compute loss if loss_function == 'cce': loss = compute_CCE_loss(y_pred, y_batch) elif loss_function == 'mse': loss = compute_MSE_loss(y_pred, y_batch) else: raise ValueError("Unsupported loss function") total_loss += loss dA = self.backward(y_pred, y_batch) self.update(learning_rate) ### END CODE HERE ### ### START CODE HERE ### avg_train_loss = total_loss / num_batches # Calculate the average loss over batches. ### END CODE HERE ### self.train_losses.append(avg_train_loss) print(f'Epoch {epoch + 1}/{epochs}, Training Loss: {avg_train_loss}') # Validation part ### START CODE HERE ### # 1. Get the prediction # 2. compute the loss ('mse', 'cce'). y_pred = self.forward(X_val) if loss_function == 'cce': val_loss = compute_CCE_loss(y_pred, y_val) elif loss_function == 'mse': val_loss = compute_MSE_loss(y_pred, y_val) ### END CODE HERE ### self.val_losses.append(val_loss) print(f'Epoch {epoch + 1}/{epochs}, Validation Loss: {val_loss}')
Editor is loading...
Leave a Comment