Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
2.4 kB
10
Indexable
Never
def relu(x):
    return np.maximum(0, x)# ВАШ КОД ЗДЕСЬ (подсказка, воспользуйтесь np.maximum())
def relu_derivative(x):
    return x>0
class Perceptron_ReLU:
    def __init__(self, w=None, b=0):
        self.w = w
        self.b = b
        
    def activate(self, x):
        return relu(x)
        
    def forward(self, X):
        n = X.shape[0]
        y_pred = self.activate(X @ self.w.reshape(X.shape[1], 1) + self.b)
        return y_pred.reshape(-1, 1)
    
    def backward(self, X, y, y_pred, learning_rate=0.005):
        n = len(y)
        y = np.array(y).reshape(-1, 1)

        dw = X.T @ ((y_pred - y) * relu_derivative(X @ self.w.reshape(X.shape[1], 1) + self.b))
        db = np.sum((y_pred - y) * relu_derivative(X @ self.w.reshape(X.shape[1], 1) + self.b))

        self.w -= learning_rate * dw
        self.b -= learning_rate * db
    
    def fit(self, X, y, num_epochs=5000):
        self.w = np.zeros((X.shape[1], 1))
        self.b = 0
        loss_values = []
        
        for i in range(num_epochs):
            y_pred = self.forward(X)
            loss_values.append(mse_loss(y_pred, y))
            self.backward(X, y, y_pred)
        
        return np.array(loss_values)
perceptron = Perceptron_ReLU()
losses = perceptron.fit(X,y)

plt.figure(figsize=(10, 8))
plt.plot (losses.reshape(-1,))
plt.title('График функция потерь', fontsize=15)
plt.xlabel('Номер итерации', fontsize=14)
plt.ylabel('$Loss(\hat{y}, y)$', fontsize=14)
plt.show()

#веса изначально инициализированы нулями Поэтому обучение и не происходит.
#На графике прямая линия

#Заинициализируем их случайными числами (не забудьте закомментировать в классе инициализацию нулями)

perceptron = Perceptron_ReLU(w=np.random.normal(0, 1, (X.shape[1], 1)), b=0)
losses = perceptron.fit(X, y, num_epochs=5000)
plt.figure(figsize=(10, 8))
plt.plot(losses.reshape(-1,))
plt.title('График функция потерь', fontsize=15)
plt.xlabel('Номер итерации', fontsize=14)
plt.ylabel('$Loss(\hat{y}, y)$', fontsize=14)
plt.show()
#На графике все равно прямая линия(