import com.imsl.datamining.neural.*; import java.io.*; import java.util.logging.*; import com.imsl.math.PrintMatrix; import com.imsl.math.PrintMatrixFormat; import java.util.Random; //***************************************************************************** // Two Layer Feed-Forward Network with 11 inputs: 4 nominal with 2,2,3,4 categories, // encoded using binary encoding, and 1 output target (class). // // new classification training_ex1.c //***************************************************************************** public class BinaryClassificationEx1 implements Serializable { // Network Settings private static int nObs = 48; // number of training patterns private static int nInputs = 11; // four nominal with 2,2,3,4 categories private static int nCategorical = 11; // three categorical attributes private static int nOutputs = 1; // one continuous output (nClasses=2) private static int nPerceptrons1 = 3; // perceptrons in 1st hidden layer private static int nPerceptrons2 = 2; // perceptrons in 2nd hidden layer private static boolean trace = true; // Turns on/off training log private static Activation hiddenLayerActivation = Activation.LINEAR; private static Activation outputLayerActivation = Activation.LOGISTIC; /* 2 classifications */ private static int[] x1 = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; /* 2 classifications */ private static int[] x2 = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; /* 3 classifications */ private static int[] x3 = { 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3 }; /* 4 classifications */ private static int[] x4 = { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }; // ********************************************************************** // MAIN // ********************************************************************** public static void main(String[] args) throws Exception { double x[]; // temporary x space for generating forecasts double xData[][]; // Input Attributes for Trainer int yData[]; // Output Attributes for Trainer int i, j; // array indicies int nWeights = 0; // Number of weights obtained from network String trainLogName = "BinaryClassificationExample.log"; // ****************************************************************** // Binary encode 4 categorical variables. // Var x1 contains 2 classes // Var x2 contains 2 classes // Var x3 contains 3 classes // Var x4 contains 4 classes // ******************************************************************* int[][] z1; int[][] z2; int[][] z3; int[][] z4; UnsupervisedNominalFilter filter = new UnsupervisedNominalFilter(2); z1 = filter.encode(x1); z2 = filter.encode(x2); filter = new UnsupervisedNominalFilter(3); z3 = filter.encode(x3); filter = new UnsupervisedNominalFilter(4); z4 = filter.encode(x4); /* Concatenate binary encoded z's */ xData = new double[nObs][nInputs]; yData = new int[nObs]; for (i=0; i<(nObs); i++) { for (j=0; j 1 && j < 4) xData[i][j] = (double) z2[i][j-2]; if (j > 3 && j < 7) xData[i][j] = (double) z3[i][j-4]; if (j > 6) xData[i][j] = (double)z4[i][j-7]; } yData[i] = ((x1[i] +x2[i] == 2) ? 1 : 0); } // ********************************************************************** // CREATE FEEDFORWARD NETWORK // ********************************************************************** long t0 = System.currentTimeMillis(); FeedForwardNetwork network = new FeedForwardNetwork(); network.getInputLayer().createInputs(nInputs); network.createHiddenLayer().createPerceptrons(nPerceptrons1); network.createHiddenLayer().createPerceptrons(nPerceptrons2); network.getOutputLayer().createPerceptrons(nOutputs); BinaryClassification classification = new BinaryClassification(network); network.linkAll(); Random r = new Random(123457L); network.setRandomWeights(xData, r); Perceptron perceptrons[] = network.getPerceptrons(); for (i=0; i < perceptrons.length-1; i++) { perceptrons[i].setActivation(hiddenLayerActivation); } perceptrons[perceptrons.length-1].setActivation(outputLayerActivation); // ********************************************************************** // TRAIN NETWORK USING QUASI-NEWTON TRAINER // ********************************************************************** QuasiNewtonTrainer trainer = new QuasiNewtonTrainer(); trainer.setError(classification.getError()); trainer.setMaximumTrainingIterations(1000); trainer.setMaximumStepsize(3.0); trainer.setGradientTolerance(1.0e-20); trainer.setFalseConvergenceTolerance(1.0e-20); trainer.setStepTolerance(1.0e-20); trainer.setRelativeTolerance(1.0e-20); if (trace) { try { Handler handler = new FileHandler(trainLogName); Logger logger = Logger.getLogger("com.imsl.datamining.neural"); logger.setLevel(Level.FINEST); logger.addHandler(handler); handler.setFormatter(QuasiNewtonTrainer.getFormatter()); System.out.println("--> Training Log Created in "+ trainLogName); } catch (Exception e) { System.out.println("--> Cannot Create Training Log."); } } classification.train(trainer, xData, yData); // ********************************************************************** // DISPLAY TRAINING STATISTICS // ********************************************************************** double stats[] = classification.computeStatistics(xData, yData); System.out.println("***********************************************"); System.out.println("--> Cross-entropy error: "+(float)stats[0]); System.out.println("--> Classification error rate: "+(float)stats[1]); System.out.println("***********************************************"); System.out.println(""); // ********************************************************************** // OBTAIN AND DISPLAY NETWORK WEIGHTS AND GRADIENTS // ********************************************************************** double weight[] = network.getWeights(); double gradient[] = trainer.getErrorGradient(); double wg[][] = new double[weight.length][2]; for(i = 0; i < weight.length; i++) { wg[i][0] = weight[i]; wg[i][1] = gradient[i]; } PrintMatrixFormat pmf = new PrintMatrixFormat(); pmf.setNumberFormat(new java.text.DecimalFormat("0.000000")); pmf.setColumnLabels(new String[]{"Weights", "Gradients"}); new PrintMatrix().print(pmf,wg); // **************************** // forecast the network // **************************** double report[][] = new double[nObs][6]; for ( i = 0; i < nObs; i++) { report[i][0] = x1[i]; report[i][1] = x2[i]; report[i][2] = x3[i]; report[i][3] = x4[i]; report[i][4] = yData[i]; report[i][5] = classification.predictedClass(xData[i]); } pmf = new PrintMatrixFormat(); pmf.setColumnLabels(new String[]{ "X1", "X2", "X3", "X4", "Expected", "Predicted"}); new PrintMatrix("Forecast").print(pmf, report); // ********************************************************************** // DISPLAY CLASSIFICATION STATISTICS // ********************************************************************** double statsClass[] = classification.computeStatistics(xData, yData); // Display Network Errors System.out.println("***********************************************"); System.out.println("--> Cross-Entropy Error: "+(float)statsClass[0]); System.out.println("--> Classification Error: "+(float)statsClass[1]); System.out.println("***********************************************"); System.out.println(""); long t1 = System.currentTimeMillis(); double small = 1.e-7; double time = t1-t0; time = time/1000; System.out.println("****************Time: "+time); System.out.println("trainer.getErrorValue = "+trainer.getErrorValue()); } }