Buggy Application but fingers crossed
unknown
java
10 months ago
6.5 kB
7
Indexable
package com.example; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.impl.ListDataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.jfree.chart.ChartFactory; import org.jfree.chart.ChartPanel; import org.jfree.chart.JFreeChart; import org.jfree.chart.plot.PlotOrientation; import org.jfree.data.xy.XYSeries; import org.jfree.data.xy.XYSeriesCollection; import javax.swing.*; import java.awt.*; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Random; public class FRBDetectorApp { private static final int NUM_INPUTS = 100; // Number of features (e.g., signal intensity over time) private static final int NUM_OUTPUTS = 2; // Binary classification (FRB or not) private static final int BATCH_SIZE = 32; private static final int NUM_EPOCHS = 50; public static void main(String[] args) { // Generate synthetic data for demonstration purposes DataSetIterator trainData = generateSyntheticData(1000); DataSetIterator testData = generateSyntheticData(200); // Normalize the data NormalizerStandardize normalizer = new NormalizerStandardize(); normalizer.fit(trainData); // Collect statistics (mean/stdev) from the training data trainData.setPreProcessor(normalizer); testData.setPreProcessor(normalizer); // Define the neural network configuration MultiLayerConfiguration config = new NeuralNetConfiguration.Builder() .seed(123) .updater(new Adam(0.001)) .list() .layer(new ConvolutionLayer.Builder(5, 1) .nIn(1) // 1 input channel .nOut(16) // 16 filters .stride(1, 1) .activation(Activation.RELU) .build()) .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2, 1) .stride(2, 1) .build()) .layer(new ConvolutionLayer.Builder(3, 1) .nOut(32) .stride(1, 1) .activation(Activation.RELU) .build()) .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2, 1) .stride(2, 1) .build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(NUM_OUTPUTS) .activation(Activation.SOFTMAX) .build()) .setInputType(InputType.convolutional(NUM_INPUTS, 1, 1)) // Input: 1D convolution .build(); // Create and train the neural network MultiLayerNetwork model = new MultiLayerNetwork(config); model.init(); model.setListeners(new ScoreIterationListener(10)); for (int i = 0; i < NUM_EPOCHS; i++) { model.fit(trainData); System.out.println("Completed epoch " + (i + 1) + "/" + NUM_EPOCHS); } // Evaluate the model on the test data Evaluation eval = model.evaluate(testData); System.out.println(eval.stats()); // Save the trained model try { model.save(new File("frb-detector-model.zip"), true); } catch (IOException e) { e.printStackTrace(); } // Load the trained model and make predictions try { MultiLayerNetwork loadedModel = MultiLayerNetwork.load(new File("frb-detector-model.zip"), true); INDArray input = Nd4j.create(new double[]{/* your test data here */}, 1, NUM_INPUTS); INDArray output = loadedModel.output(input); System.out.println("Predicted: " + output); } catch (IOException e) { e.printStackTrace(); } // Plot training loss and accuracy plotTrainingResults(model.getIterationCount(), eval.accuracy()); } private static DataSetIterator generateSyntheticData(int numSamples) { Random rand = new Random(123); List<DataSet> dataSets = new ArrayList<>(); for (int i = 0; i < numSamples; i++) { double[] features = new double[NUM_INPUTS]; for (int j = 0; j < NUM_INPUTS; j++) { features[j] = rand.nextDouble(); } double[] labels = new double[NUM_OUTPUTS]; labels[rand.nextInt(NUM_OUTPUTS)] = 1.0; // Randomly assign a label (FRB or not) INDArray featureNDArray = Nd4j.create(features); INDArray labelNDArray = Nd4j.create(labels); DataSet dataSet = new DataSet(featureNDArray, labelNDArray); dataSets.add(dataSet); } return new ListDataSetIterator<>(dataSets, BATCH_SIZE); } private static void plotTrainingResults(int numIterations, double accuracy) { XYSeries accuracySeries = new XYSeries("Accuracy"); for (int i = 0; i < numIterations; i++) { accuracySeries.add(i, accuracy); } XYSeriesCollection dataset = new XYSeriesCollection(); dataset.addSeries(accuracySeries); JFreeChart chart = ChartFactory.createXYLineChart( "Training Accuracy", "Iteration", "Accuracy", dataset, PlotOrientation.VERTICAL, true, true, false); Chart
Editor is loading...
Leave a Comment