Buggy Application but fingers crossed

mail@pastecode.io avatar
unknown
java
23 days ago
6.5 kB
1
Indexable
Never
package com.example;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.impl.ListDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;

import javax.swing.*;
import java.awt.*;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class FRBDetectorApp {

    private static final int NUM_INPUTS = 100; // Number of features (e.g., signal intensity over time)
    private static final int NUM_OUTPUTS = 2;  // Binary classification (FRB or not)
    private static final int BATCH_SIZE = 32;
    private static final int NUM_EPOCHS = 50;

    public static void main(String[] args) {
        // Generate synthetic data for demonstration purposes
        DataSetIterator trainData = generateSyntheticData(1000);
        DataSetIterator testData = generateSyntheticData(200);

        // Normalize the data
        NormalizerStandardize normalizer = new NormalizerStandardize();
        normalizer.fit(trainData);              // Collect statistics (mean/stdev) from the training data
        trainData.setPreProcessor(normalizer);
        testData.setPreProcessor(normalizer);

        // Define the neural network configuration
        MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
                .seed(123)
                .updater(new Adam(0.001))
                .list()
                .layer(new ConvolutionLayer.Builder(5, 1)
                        .nIn(1) // 1 input channel
                        .nOut(16) // 16 filters
                        .stride(1, 1)
                        .activation(Activation.RELU)
                        .build())
                .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 1)
                        .stride(2, 1)
                        .build())
                .layer(new ConvolutionLayer.Builder(3, 1)
                        .nOut(32)
                        .stride(1, 1)
                        .activation(Activation.RELU)
                        .build())
                .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 1)
                        .stride(2, 1)
                        .build())
                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(NUM_OUTPUTS)
                        .activation(Activation.SOFTMAX)
                        .build())
                .setInputType(InputType.convolutional(NUM_INPUTS, 1, 1)) // Input: 1D convolution
                .build();

        // Create and train the neural network
        MultiLayerNetwork model = new MultiLayerNetwork(config);
        model.init();
        model.setListeners(new ScoreIterationListener(10));
        for (int i = 0; i < NUM_EPOCHS; i++) {
            model.fit(trainData);
            System.out.println("Completed epoch " + (i + 1) + "/" + NUM_EPOCHS);
        }

        // Evaluate the model on the test data
        Evaluation eval = model.evaluate(testData);
        System.out.println(eval.stats());

        // Save the trained model
        try {
            model.save(new File("frb-detector-model.zip"), true);
        } catch (IOException e) {
            e.printStackTrace();
        }

        // Load the trained model and make predictions
        try {
            MultiLayerNetwork loadedModel = MultiLayerNetwork.load(new File("frb-detector-model.zip"), true);
            INDArray input = Nd4j.create(new double[]{/* your test data here */}, 1, NUM_INPUTS);
            INDArray output = loadedModel.output(input);
            System.out.println("Predicted: " + output);
        } catch (IOException e) {
            e.printStackTrace();
        }

        // Plot training loss and accuracy
        plotTrainingResults(model.getIterationCount(), eval.accuracy());
    }

    private static DataSetIterator generateSyntheticData(int numSamples) {
        Random rand = new Random(123);
        List<DataSet> dataSets = new ArrayList<>();

        for (int i = 0; i < numSamples; i++) {
            double[] features = new double[NUM_INPUTS];
            for (int j = 0; j < NUM_INPUTS; j++) {
                features[j] = rand.nextDouble();
            }
            double[] labels = new double[NUM_OUTPUTS];
            labels[rand.nextInt(NUM_OUTPUTS)] = 1.0; // Randomly assign a label (FRB or not)

            INDArray featureNDArray = Nd4j.create(features);
            INDArray labelNDArray = Nd4j.create(labels);

            DataSet dataSet = new DataSet(featureNDArray, labelNDArray);
            dataSets.add(dataSet);
        }

        return new ListDataSetIterator<>(dataSets, BATCH_SIZE);
    }

    private static void plotTrainingResults(int numIterations, double accuracy) {
        XYSeries accuracySeries = new XYSeries("Accuracy");

        for (int i = 0; i < numIterations; i++) {
            accuracySeries.add(i, accuracy);
        }

        XYSeriesCollection dataset = new XYSeriesCollection();
        dataset.addSeries(accuracySeries);

        JFreeChart chart = ChartFactory.createXYLineChart(
                "Training Accuracy",
                "Iteration",
                "Accuracy",
                dataset,
                PlotOrientation.VERTICAL,
                true, true, false);

        Chart
Leave a Comment