Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GaussianProcessEvolution/HeuristicLab.Problems.GaussianProcessTuning/Provider.cs @ 8753

Last change on this file since 8753 was 8753, checked in by gkronber, 12 years ago

#1967 initial import of Gaussian process evolution plugin

File size: 8.6 KB
Line 
1using System;
2using System.Linq;
3using System.Threading;
4using HeuristicLab.Algorithms.OffspringSelectionGeneticAlgorithm;
5using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
6using HeuristicLab.Optimization;
7using HeuristicLab.Problems.DataAnalysis;
8using HeuristicLab.Random;
9using HeuristicLab.Selection;
10
11namespace HeuristicLab.Problems.GaussianProcessTuning {
12  public class Provider {
13    public static void Main(string[] args) {
14      // use random data for testing (3000 rows, 500 variables)
15
16      int rows = 13;
17      double[,] inputData = GenerateRandomData(rows, 5);
18      double[] targetData = GenerateRandomData(rows);
19
20      var prov = new Provider();
21      var p = new Parameters();
22      p.Generations = 1;
23      p.PopulationSize = 10;
24      var sol = prov.Train(targetData, inputData, null, p);
25      // Console.WriteLine(sol.TestRSquared);
26
27      //double[,] inputDataTest = GenerateRandomData(rows, 10);
28      //double[] predictionTest = prov.Predict(inputDataTest, sol);
29      //double[] targetDataTest = GenerateRandomData(rows);
30      //OnlineCalculatorError error;
31      //double r2 = OnlinePearsonsRSquaredCalculator.Calculate(predictionTest, targetDataTest, out error);
32      //if (error != OnlineCalculatorError.None) Console.WriteLine(error);
33      //else Console.WriteLine(r2);
34    }
35
36    // train a symb reg model
37    public IRegressionSolution Train(double[] target, double[,] input, Solution oldSolution, Parameters parameters) {
38      var targetVariable = "y";
39
40      if (target.Length != input.GetLength(0)) throw new ArgumentException("length of input vectors does not match.");
41      int rows = target.Length;
42      int columns = input.GetLength(1);
43
44      var variableNames =
45        new string[] { targetVariable }.Concat(Enumerable.Range(1, input.GetLength(1)).Select(i => string.Format("x{0}", i)));
46
47      double[,] combinedMatrix = new double[rows, columns + 1];
48      for (int r = 0; r < rows; r++) {
49        combinedMatrix[r, 0] = target[r];
50        for (int c = 0; c < columns; c++) {
51          combinedMatrix[r, c + 1] = input[r, c];
52        }
53      }
54
55      var allowedInputVariables = variableNames.Skip(1);
56      var ds = new Dataset(variableNames, combinedMatrix);
57      var probData = new RegressionProblemData(ds, allowedInputVariables, targetVariable);
58      probData.TrainingPartition.Start = 0;
59      probData.TrainingPartition.End = rows;
60      probData.TestPartition.Start = rows;
61      probData.TestPartition.End = rows;
62      var prob = new Problem();
63      prob.ProblemDataParameter.Value = probData;
64      prob.DimensionParameter.Value.Value = allowedInputVariables.Count();
65
66
67      var configuredAlgorithm = CreateGeneticProgrammingAlgorithm(prob, parameters);
68      // save the configured algorithm
69      // XmlGenerator.Serialize(configuredAlgorithm, "osgp_" + targetVariable + ".hl");
70
71
72      var resetEvent = new AutoResetEvent(false);
73      configuredAlgorithm.ExceptionOccurred += (sender, args) => {
74        Console.WriteLine("Exception: " + args.Value);
75        configuredAlgorithm.Stop();
76      };
77
78      Solution solution = null;
79      configuredAlgorithm.Stopped += (sender, args) => {
80        try {
81          Console.WriteLine("Run finished (execution time: {0})", configuredAlgorithm.ExecutionTime);
82          solution = (Solution)configuredAlgorithm.Results["Best solution"].Value;
83        }
84        finally {
85          resetEvent.Set();
86        }
87      };
88
89
90      configuredAlgorithm.Prepare();
91      configuredAlgorithm.Start(); // start run asynchronously
92      resetEvent.WaitOne(); // wait until stopped
93
94      //save ensemble solution
95      //XmlGenerator.Serialize(ensemble, "osgp_ensemble" + targetVariable + ".hl");
96
97      return solution;
98    }
99
100    // predict a single row
101    public double Predict(double[] input, IRegressionSolution solution) {
102      var allowedInputVariables = Enumerable.Range(1, input.GetLength(0)).Select(i => string.Format("x{0}", i));
103      int columns = input.GetLength(0);
104      double[,] combinedMatrix = new double[1, columns + 1];
105      combinedMatrix[0, 0] = 0.0;
106      for (int c = 0; c < columns; c++) {
107        combinedMatrix[0, c + 1] = input[c];
108      }
109
110      var ds = new Dataset(allowedInputVariables.Concat(new string[] { "y" }), combinedMatrix);
111      var newProbData = new RegressionProblemData(ds, allowedInputVariables, "y");
112      var newSolution = solution.Model.CreateRegressionSolution(newProbData);
113      return newSolution.EstimatedValues.Single();
114    }
115
116
117    // predict multiple rows
118    public double[] Predict(double[,] input, IRegressionSolution solution) {
119      var allowedInputVariables = Enumerable.Range(1, input.GetLength(1)).Select(i => string.Format("x{0}", i));
120      int rows = input.GetLength(0);
121      int columns = input.GetLength(1);
122      double[,] combinedMatrix = new double[rows, columns + 1];
123      for (int r = 0; r < rows; r++) {
124        combinedMatrix[r, 0] = 0.0;
125        for (int c = 0; c < columns; c++) {
126          combinedMatrix[r, c + 1] = input[r, c];
127        }
128      }
129
130      var ds = new Dataset(allowedInputVariables.Concat(new string[] { "y" }), combinedMatrix);
131      var newProbData = new RegressionProblemData(ds, allowedInputVariables, "y");
132      var newSolution = solution.Model.CreateRegressionSolution(newProbData);
133      return newSolution.EstimatedValues.ToArray();
134    }
135
136    #region helper for algorithm and problem configuration
137    public HeuristicOptimizationEngineAlgorithm CreateGeneticProgrammingAlgorithm(Problem problem, Parameters parameters) {
138      // offspring selection GA proved to be very successful for GP
139      var osga = new OffspringSelectionGeneticAlgorithm();
140
141      problem.MaxGaussianProcessConfigurationLengthParameter.Value.Value = parameters.MaxTreeLength;
142      problem.MaxGaussianProcessConfigurationDepthParameter.Value.Value = parameters.MaxTreeDepth;
143
144      var grammar = new Grammar();
145
146      problem.GrammarParameter.Value = grammar;
147
148      // prepare grammar for time series prognosis
149      // combine algorithm + problem
150      osga.Problem = problem;
151
152      // set algorithm parameters
153      ConfigureOffspringSelectionGeneticAlgorithmParameters<GenderSpecificSelector, SubtreeCrossover, MultiSymbolicExpressionTreeManipulator>
154        (osga, parameters.PopulationSize, 0, parameters.Generations, parameters.MaxSelectionPressure, parameters.MutationRate);
155
156      return osga;
157    }
158
159    private void ConfigureOffspringSelectionGeneticAlgorithmParameters<S, C, M>(OffspringSelectionGeneticAlgorithm osga, int popSize, int elites, int maxGens, double maxSelectionPressure, double mutationRate = 0.15, double comparisonFactorLowerBound = 1.0, double comparisonFactorUpperBound = 1.0, double successRatio = 1.0)
160      where S : ISelector
161      where C : ICrossover
162      where M : IManipulator {
163      osga.Elites.Value = elites;
164      osga.MaximumGenerations.Value = maxGens;
165      osga.MutationProbability.Value = mutationRate;
166      osga.PopulationSize.Value = popSize;
167      osga.ComparisonFactorLowerBound.Value = comparisonFactorLowerBound;
168      osga.ComparisonFactorUpperBound.Value = comparisonFactorUpperBound;
169      osga.SuccessRatio.Value = successRatio;
170      osga.OffspringSelectionBeforeMutation.Value = false;
171      osga.MaximumSelectionPressure.Value = maxSelectionPressure;
172
173      osga.Seed.Value = 0;
174      osga.SetSeedRandomly.Value = true;
175      osga.Selector = osga.SelectorParameter.ValidValues
176        .OfType<S>()
177        .Single();
178
179      osga.Crossover = osga.CrossoverParameter.ValidValues
180        .OfType<C>()
181        .Single();
182
183      osga.Mutator = osga.MutatorParameter.ValidValues
184        .OfType<M>()
185        .Single();
186
187      osga.Engine = new SequentialEngine.SequentialEngine();
188    }
189    #endregion
190
191
192    #region helper for testing
193    private static double[,] GenerateRandomData(int rows, int columns) {
194      var data = new double[rows, columns];
195      // generate data using a PRNG with N(0, 1) distribution
196      var normalPRNG = new NormalDistributedRandom(new MersenneTwister(), 0.0, 1.0);
197      for (int r = 0; r < rows; r++)
198        for (int c = 0; c < columns; c++) {
199          data[r, c] = normalPRNG.NextDouble();
200        }
201      return data;
202    }
203    private static double[] GenerateRandomData(int rows) {
204      var data = new double[rows];
205      // generate data using a PRNG with N(0, 1) distribution
206      var normalPRNG = new NormalDistributedRandom(new MersenneTwister(), 0.0, 1.0);
207      for (int r = 0; r < rows; r++)
208
209        data[r] = normalPRNG.NextDouble();
210
211      return data;
212    }
213    #endregion
214  }
215}
Note: See TracBrowser for help on using the repository browser.