Untitled

 avatar
unknown
plain_text
2 years ago
2.1 kB
7
Indexable
def entrena(self, X, y, Xv=None, yv=None, n_epochs=100, salida_epoch=False, early_stopping=False, paciencia=3):
        self.classes = np.unique(y)
        self.weights = np.zeros(X.shape[1])

        if Xv is None:
            Xv = X
            yv = y

        best_loss = float('inf')
        epochs_without_improvement = 0

        for epoch in range(n_epochs):
            if self.rate_decay:
                rate_n = self.rate / (1 + epoch)
            else:
                rate_n = self.rate

            indices = np.arange(len(X))
            np.random.shuffle(indices)

            for i in range(0, len(X), self.batch_tam):
                indices_batch = indices[i:i + self.batch_tam]
                X_batch = X[indices_batch]
                y_batch = y[indices_batch]

                y_pred = self.sigmoide(np.dot(X_batch, self.weights))
                gradient = np.dot(X_batch.T, y_pred - y_batch) / len(X_batch)
                self.weights -= rate_n * gradient

            if salida_epoch:
                y_pred_train = self.clasifica_prob(X)
                loss_train = self.entropia_cruzada(y, y_pred_train)
                acc_train = np.mean(self.clasifica(X) == y)

                y_pred_val = self.clasifica_prob(Xv)
                loss_val = self.entropia_cruzada(yv, y_pred_val)
                acc_val = np.mean(self.clasifica(Xv) == yv)

                print(f"Epoch {epoch + 1}:")
                print(f"  en entrenamiento EC: {loss_train:.4f}, rendimiento: {acc_train:.4f}")
                #print(f"  Rendimiento (entrenamiento): {acc_train:.4f}")
                print(f"  en entrenamiento EC: {loss_val:.4f}, rendimiento: {acc_val:.4f}")
                #print(f"  Rendimiento (validación): {acc_val:.4f}")

                if early_stopping and loss_val < best_loss:
                    best_loss = loss_val
                    epochs_without_improvement = 0
                elif early_stopping:
                    epochs_without_improvement += 1

                if early_stopping and epochs_without_improvement >= paciencia:
                    print("PARADA TEMPRANA")
                    break
Editor is loading...