Buggy Application but fingers crossed
unknown
java
a year ago
6.5 kB
11
Indexable
package com.example;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.impl.ListDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import javax.swing.*;
import java.awt.*;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class FRBDetectorApp {
private static final int NUM_INPUTS = 100; // Number of features (e.g., signal intensity over time)
private static final int NUM_OUTPUTS = 2; // Binary classification (FRB or not)
private static final int BATCH_SIZE = 32;
private static final int NUM_EPOCHS = 50;
public static void main(String[] args) {
// Generate synthetic data for demonstration purposes
DataSetIterator trainData = generateSyntheticData(1000);
DataSetIterator testData = generateSyntheticData(200);
// Normalize the data
NormalizerStandardize normalizer = new NormalizerStandardize();
normalizer.fit(trainData); // Collect statistics (mean/stdev) from the training data
trainData.setPreProcessor(normalizer);
testData.setPreProcessor(normalizer);
// Define the neural network configuration
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(new Adam(0.001))
.list()
.layer(new ConvolutionLayer.Builder(5, 1)
.nIn(1) // 1 input channel
.nOut(16) // 16 filters
.stride(1, 1)
.activation(Activation.RELU)
.build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 1)
.stride(2, 1)
.build())
.layer(new ConvolutionLayer.Builder(3, 1)
.nOut(32)
.stride(1, 1)
.activation(Activation.RELU)
.build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 1)
.stride(2, 1)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(NUM_OUTPUTS)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutional(NUM_INPUTS, 1, 1)) // Input: 1D convolution
.build();
// Create and train the neural network
MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();
model.setListeners(new ScoreIterationListener(10));
for (int i = 0; i < NUM_EPOCHS; i++) {
model.fit(trainData);
System.out.println("Completed epoch " + (i + 1) + "/" + NUM_EPOCHS);
}
// Evaluate the model on the test data
Evaluation eval = model.evaluate(testData);
System.out.println(eval.stats());
// Save the trained model
try {
model.save(new File("frb-detector-model.zip"), true);
} catch (IOException e) {
e.printStackTrace();
}
// Load the trained model and make predictions
try {
MultiLayerNetwork loadedModel = MultiLayerNetwork.load(new File("frb-detector-model.zip"), true);
INDArray input = Nd4j.create(new double[]{/* your test data here */}, 1, NUM_INPUTS);
INDArray output = loadedModel.output(input);
System.out.println("Predicted: " + output);
} catch (IOException e) {
e.printStackTrace();
}
// Plot training loss and accuracy
plotTrainingResults(model.getIterationCount(), eval.accuracy());
}
private static DataSetIterator generateSyntheticData(int numSamples) {
Random rand = new Random(123);
List<DataSet> dataSets = new ArrayList<>();
for (int i = 0; i < numSamples; i++) {
double[] features = new double[NUM_INPUTS];
for (int j = 0; j < NUM_INPUTS; j++) {
features[j] = rand.nextDouble();
}
double[] labels = new double[NUM_OUTPUTS];
labels[rand.nextInt(NUM_OUTPUTS)] = 1.0; // Randomly assign a label (FRB or not)
INDArray featureNDArray = Nd4j.create(features);
INDArray labelNDArray = Nd4j.create(labels);
DataSet dataSet = new DataSet(featureNDArray, labelNDArray);
dataSets.add(dataSet);
}
return new ListDataSetIterator<>(dataSets, BATCH_SIZE);
}
private static void plotTrainingResults(int numIterations, double accuracy) {
XYSeries accuracySeries = new XYSeries("Accuracy");
for (int i = 0; i < numIterations; i++) {
accuracySeries.add(i, accuracy);
}
XYSeriesCollection dataset = new XYSeriesCollection();
dataset.addSeries(accuracySeries);
JFreeChart chart = ChartFactory.createXYLineChart(
"Training Accuracy",
"Iteration",
"Accuracy",
dataset,
PlotOrientation.VERTICAL,
true, true, false);
ChartEditor is loading...
Leave a Comment