Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2745_EfficientGlobalOptimization/HeuristicLab.Algorithms.EGO/Operators/ModelBuilder.cs @ 17456

Last change on this file since 17456 was 17332, checked in by bwerth, 5 years ago

#2745 updated persistence to HEAL.Attic

File size: 8.1 KB
RevLine 
[15064]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
[15338]24using System.Threading;
[17332]25using HEAL.Attic;
[15064]26using HeuristicLab.Algorithms.DataAnalysis;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Operators;
31using HeuristicLab.Optimization;
32using HeuristicLab.Parameters;
33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
34using HeuristicLab.Problems.DataAnalysis;
35
36namespace HeuristicLab.Algorithms.EGO {
37  [Item("ModelBuilder", "Builds a model from a dataset and a given RegressionAlgorithm")]
[17332]38    [StorableType("8b80026f-b6a5-4892-9826-86ffba1e4e10")]
39    public class ModelBuilder : InstrumentedOperator, IStochasticOperator, ICancellableOperator {
[15064]40    public override bool CanChangeName => true;
[15338]41    public CancellationToken Cancellation { get; set; }
[15064]42
[15338]43    #region Parameter properties
[15064]44    public ILookupParameter<IDataAnalysisAlgorithm<IRegressionProblem>> RegressionAlgorithmParameter => (ILookupParameter<IDataAnalysisAlgorithm<IRegressionProblem>>)Parameters["RegressionAlgorithm"];
45    public ILookupParameter<IRegressionSolution> ModelParameter => (ILookupParameter<IRegressionSolution>)Parameters["Model"];
46    public ILookupParameter<ModifiableDataset> DatasetParameter => (ILookupParameter<ModifiableDataset>)Parameters["Dataset"];
47    public ILookupParameter<IRandom> RandomParameter => (ILookupParameter<IRandom>)Parameters["Random"];
48    public ILookupParameter<IntValue> MaxModelSizeParameter => (ILookupParameter<IntValue>)Parameters["Maximal Model Size"];
49    public ILookupParameter<DoubleMatrix> InfillBoundsParameter => (ILookupParameter<DoubleMatrix>)Parameters["InfillBounds"];
[15338]50    #endregion
[15064]51
52    [StorableConstructor]
[17332]53    protected ModelBuilder(StorableConstructorFlag deserializing) : base(deserializing) { }
[15064]54    protected ModelBuilder(ModelBuilder original, Cloner cloner) : base(original, cloner) { }
55    public ModelBuilder() {
56      Parameters.Add(new LookupParameter<IDataAnalysisAlgorithm<IRegressionProblem>>("RegressionAlgorithm", "The algorithm used to build a model") { Hidden = true });
57      Parameters.Add(new LookupParameter<IRegressionSolution>("Model", "The resulting model") { Hidden = true });
58      Parameters.Add(new LookupParameter<ModifiableDataset>("Dataset", "The Dataset from which the model is created") { Hidden = true });
59      Parameters.Add(new LookupParameter<IRandom>("Random", "A random number generator") { Hidden = true });
60      Parameters.Add(new LookupParameter<IntValue>("Maximal Model Size", "The maximum number of sample points used to build the model (Set -1 for infinite size") { Hidden = true });
61      Parameters.Add(new LookupParameter<DoubleMatrix>("InfillBounds", "The bounds applied for infill solving") { Hidden = true });
62    }
63
64    public override IDeepCloneable Clone(Cloner cloner) {
65      return new ModelBuilder(this, cloner);
66    }
67
68    public override IOperation InstrumentedApply() {
69      var regressionAlg = RegressionAlgorithmParameter.ActualValue;
70      IDataset data = DatasetParameter.ActualValue;
71      var random = RandomParameter.ActualValue;
72      var oldModel = ModelParameter.ActualValue;
73      var max = MaxModelSizeParameter.ActualValue.Value;
74      if (data.Rows > max && max > 0) {
75        data = SelectBestSamples(data, max);
76        InfillBoundsParameter.ActualValue = GetBounds(data);
77      }
78      ModelParameter.ActualValue = BuildModel(random, regressionAlg, data, oldModel);
79      return base.InstrumentedApply();
80    }
81
82    private DoubleMatrix GetBounds(IDataset data) {
83      var res = new DoubleMatrix(data.Columns - 1, 2);
84      var names = data.DoubleVariables.ToArray();
85      for (var i = 0; i < names.Length - 1; i++) {
86        res[i, 0] = data.GetDoubleValues(names[i]).Min();
87        res[i, 1] = data.GetDoubleValues(names[i]).Max();
88      }
89      return res;
90    }
91
92    private static Dataset SelectBestSamples(IDataset data, int max) {
93      var bestSampleIndices = data.GetDoubleValues("output").Select((d, i) => Tuple.Create(d, i)).OrderBy(x => x.Item1).Take(max).Select(x => x.Item2).ToArray();
94      return new Dataset(data.VariableNames, data.VariableNames.Select(v => data.GetDoubleValues(v, bestSampleIndices).ToList()));
95    }
96
[15338]97    private IRegressionSolution BuildModel(IRandom random, IDataAnalysisAlgorithm<IRegressionProblem> regressionAlgorithm, IDataset dataset, IRegressionSolution oldSolution) {
[15064]98      //var dataset = EgoUtilities.GetDataSet(dataSamples, RemoveDuplicates);
99      var problemdata = new RegressionProblemData(dataset, dataset.VariableNames.Where(x => !x.Equals("output")), "output");
100      problemdata.TrainingPartition.Start = 0;
101      problemdata.TrainingPartition.End = dataset.Rows;
102      problemdata.TestPartition.Start = dataset.Rows;
103      problemdata.TestPartition.End = dataset.Rows;
104
105      //train
106      var problem = (RegressionProblem)regressionAlgorithm.Problem;
107      problem.ProblemDataParameter.Value = problemdata;
108      var i = 0;
109      IRegressionSolution solution = null;
110
111      while (solution == null && i++ < 100) {
[15338]112        var results = EgoUtilities.SyncRunSubAlgorithm(regressionAlgorithm, random.Next(int.MaxValue), Cancellation);
[15064]113        solution = results.Select(x => x.Value).OfType<IRegressionSolution>().SingleOrDefault();
114      }
115
[15338]116      if (regressionAlgorithm is GaussianProcessRegression && oldSolution != null)
117        solution = SanitizeGaussianProcess(oldSolution as GaussianProcessRegressionSolution, solution as GaussianProcessRegressionSolution, Cancellation);
[15064]118
[15343]119      //if (regressionAlgorithm is M5RegressionTree && oldSolution != null) solution = SanitizeM5Regression(oldSolution.Model as M5Model, solution, random, Cancellation);
[15338]120
121
[15064]122      regressionAlgorithm.Runs.Clear();
123      return solution;
124
125    }
[15338]126
[15343]127    //private static IRegressionSolution SanitizeM5Regression(M5Model oldmodel, IRegressionSolution newSolution, IRandom random, CancellationToken cancellation) {
128    //  var problemdata = newSolution.ProblemData;
129    //  oldmodel.UpdateLeafModels(problemdata, problemdata.AllIndices, random, cancellation);
130    //  var oldSolution = oldmodel.CreateRegressionSolution(problemdata);
131    //  var magicDecision = newSolution.TrainingRSquared < oldSolution.TrainingRSquared - 0.05;
132    //  return magicDecision ? newSolution : oldmodel.CreateRegressionSolution(problemdata);
133    //}
[15338]134
135    //try creating a model with old hyperparameters and new dataset;
136    private static IRegressionSolution SanitizeGaussianProcess(GaussianProcessRegressionSolution oldmodel, GaussianProcessRegressionSolution newSolution, CancellationToken cancellation) {
137      var problemdata = newSolution.ProblemData;
138      var mean = (IMeanFunction)oldmodel.Model.MeanFunction.Clone();
139      var cov = (ICovarianceFunction)oldmodel.Model.CovarianceFunction.Clone();
140      try {
141        var model = new GaussianProcessModel(problemdata.Dataset, problemdata.TargetVariable, problemdata.AllowedInputVariables, problemdata.TrainingIndices, new[] { 0.0 }, mean, cov);
142        cancellation.ThrowIfCancellationRequested();
143        model.FixParameters();
144        var sol = new GaussianProcessRegressionSolution(model, problemdata);
145        if (newSolution.TrainingMeanSquaredError > sol.TrainingMeanSquaredError) {
146          newSolution = sol;
147        }
[15343]148      }
149      catch (ArgumentException) { }
[15338]150      return newSolution;
151    }
152
[15064]153  }
154}
Note: See TracBrowser for help on using the repository browser.