Untitled

mail@pastecode.io avatar
unknown
plain_text
a month ago
2.5 kB
1
Indexable
Never
import java.util.Random;

public class DeltaRule {

    private static final double LEARNING_RATE = 0.1;
    private static final int EPOCHS = 1000;

    private double[] weights;
    private Random random;

    public DeltaRule(int inputSize) {
        weights = new double[inputSize + 1]; // +1 for bias
        random = new Random();
        initializeWeights();
    }

    private void initializeWeights() {
        for (int i = 0; i < weights.length; i++) {
            weights[i] = random.nextDouble() * 2 - 1; // random weights between -1 and 1
        }
    }

    public double predict(double[] inputs) {
        double sum = weights[0]; // bias

        for (int i = 0; i < inputs.length; i++) {
            sum += inputs[i] * weights[i + 1]; // skip bias weight
        }

        return activate(sum);
    }

    private double activate(double sum) {
        return sum >= 0 ? 1 : 0; // simple step function activation
    }

    public void train(double[][] inputs, double[] targets) {
        for (int epoch = 0; epoch < EPOCHS; epoch++) {
            for (int i = 0; i < inputs.length; i++) {
                double[] input = inputs[i];
                double target = targets[i];
                double prediction = predict(input);
                double error = target - prediction;

                // Update weights
                weights[0] += LEARNING_RATE * error; // update bias weight

                for (int j = 0; j < input.length; j++) {
                    weights[j + 1] += LEARNING_RATE * error * input[j]; // update other weights
                }
            }
        }
    }

    public static void main(String[] args) {
        double[][] inputs = {{0, 0}, {0, 1}, {1, 0}, {1, 1}};
        double[] targets = {0, 1, 1, 1}; // OR truth table

        DeltaRule deltaRule = new DeltaRule(2); // 2 input neurons (for OR gate)

        // Test before training
        System.out.println("Before training:");
        for (int i = 0; i < inputs.length; i++) {
            double[] input = inputs[i];
            System.out.println("Input: " + input[0] + ", " + input[1] + " => Output: " + deltaRule.predict(input));
        }

        // Train
        deltaRule.train(inputs, targets);

        // Test after training
        System.out.println("\nAfter training:");
        for (int i = 0; i < inputs.length; i++) {
            double[] input = inputs[i];
            System.out.println("Input: " + input[0] + ", " + input[1] + " => Output: " + deltaRule.predict(input));
        }
    }
}
Leave a Comment