Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/22/17 15:01:46 (7 years ago)
Author:
bwerth
Message:

#2745 fixed bug concerning new Start and StartAsync methods; passed CancellationToken to sub algorithms

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/EfficientGlobalOptimization/HeuristicLab.Algorithms.EGO/Operators/ModelBuilder.cs

    r15064 r15338  
    2222using System;
    2323using System.Linq;
     24using System.Threading;
    2425using HeuristicLab.Algorithms.DataAnalysis;
    2526using HeuristicLab.Common;
     
    3132using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3233using HeuristicLab.Problems.DataAnalysis;
     34using HeuristicLab.Problems.SurrogateProblem;
    3335
    3436namespace HeuristicLab.Algorithms.EGO {
     
    3840  [Item("ModelBuilder", "Builds a model from a dataset and a given RegressionAlgorithm")]
    3941  [StorableClass]
    40   public class ModelBuilder : InstrumentedOperator, IStochasticOperator {
     42  public class ModelBuilder : InstrumentedOperator, IStochasticOperator, ICancellableOperator {
    4143    public override bool CanChangeName => true;
     44    public CancellationToken Cancellation { get; set; }
    4245
     46    #region Parameter properties
    4347    public ILookupParameter<IDataAnalysisAlgorithm<IRegressionProblem>> RegressionAlgorithmParameter => (ILookupParameter<IDataAnalysisAlgorithm<IRegressionProblem>>)Parameters["RegressionAlgorithm"];
    4448    public ILookupParameter<IRegressionSolution> ModelParameter => (ILookupParameter<IRegressionSolution>)Parameters["Model"];
     
    4751    public ILookupParameter<IntValue> MaxModelSizeParameter => (ILookupParameter<IntValue>)Parameters["Maximal Model Size"];
    4852    public ILookupParameter<DoubleMatrix> InfillBoundsParameter => (ILookupParameter<DoubleMatrix>)Parameters["InfillBounds"];
     53    #endregion
    4954
    5055    [StorableConstructor]
     
    9398    }
    9499
    95     private static IRegressionSolution BuildModel(IRandom random, IDataAnalysisAlgorithm<IRegressionProblem> regressionAlgorithm, IDataset dataset, IRegressionSolution oldSolution) {
     100    private IRegressionSolution BuildModel(IRandom random, IDataAnalysisAlgorithm<IRegressionProblem> regressionAlgorithm, IDataset dataset, IRegressionSolution oldSolution) {
    96101      //var dataset = EgoUtilities.GetDataSet(dataSamples, RemoveDuplicates);
    97102      var problemdata = new RegressionProblemData(dataset, dataset.VariableNames.Where(x => !x.Equals("output")), "output");
     
    108113
    109114      while (solution == null && i++ < 100) {
    110         var results = EgoUtilities.SyncRunSubAlgorithm(regressionAlgorithm, random.Next(int.MaxValue));
     115        var results = EgoUtilities.SyncRunSubAlgorithm(regressionAlgorithm, random.Next(int.MaxValue), Cancellation);
    111116        solution = results.Select(x => x.Value).OfType<IRegressionSolution>().SingleOrDefault();
    112117      }
    113118
    114       //try creating a model with old hyperparameters and new dataset;
    115       var gp = regressionAlgorithm as GaussianProcessRegression;
    116       var oldmodel = oldSolution as GaussianProcessRegressionSolution;
    117       if (gp != null && oldmodel != null) {
    118         var mean = (IMeanFunction)oldmodel.Model.MeanFunction.Clone();
    119         var cov = (ICovarianceFunction)oldmodel.Model.CovarianceFunction.Clone();
    120         try {
    121           var model = new GaussianProcessModel(problemdata.Dataset, problemdata.TargetVariable,
    122             problemdata.AllowedInputVariables, problemdata.TrainingIndices, new[] { 0.0 }, mean, cov);
    123           model.FixParameters();
    124           var sol = new GaussianProcessRegressionSolution(model, problemdata);
    125           if (solution == null || solution.TrainingMeanSquaredError > sol.TrainingMeanSquaredError) {
    126             solution = sol;
    127           }
    128         }
    129         catch (ArgumentException) { }
    130       }
     119      if (regressionAlgorithm is GaussianProcessRegression && oldSolution != null)
     120        solution = SanitizeGaussianProcess(oldSolution as GaussianProcessRegressionSolution, solution as GaussianProcessRegressionSolution, Cancellation);
     121
     122      if (regressionAlgorithm is M5RegressionTree && oldSolution != null)
     123        solution = SanitizeM5Regression(oldSolution.Model as M5Model, solution, random, Cancellation);
     124
    131125
    132126      regressionAlgorithm.Runs.Clear();
     
    134128
    135129    }
     130
     131    private static IRegressionSolution SanitizeM5Regression(M5Model oldmodel, IRegressionSolution newSolution, IRandom random, CancellationToken cancellation) {
     132      var problemdata = newSolution.ProblemData;
     133      oldmodel.UpdateLeafModels(problemdata, problemdata.AllIndices, random, cancellation);
     134      var oldSolution = oldmodel.CreateRegressionSolution(problemdata);
     135      var magicDecision = newSolution.TrainingRSquared < oldSolution.TrainingRSquared - 0.05;
     136      return magicDecision ? newSolution : oldmodel.CreateRegressionSolution(problemdata);
     137    }
     138
     139    //try creating a model with old hyperparameters and new dataset;
     140    private static IRegressionSolution SanitizeGaussianProcess(GaussianProcessRegressionSolution oldmodel, GaussianProcessRegressionSolution newSolution, CancellationToken cancellation) {
     141      var problemdata = newSolution.ProblemData;
     142      var mean = (IMeanFunction)oldmodel.Model.MeanFunction.Clone();
     143      var cov = (ICovarianceFunction)oldmodel.Model.CovarianceFunction.Clone();
     144      try {
     145        var model = new GaussianProcessModel(problemdata.Dataset, problemdata.TargetVariable, problemdata.AllowedInputVariables, problemdata.TrainingIndices, new[] { 0.0 }, mean, cov);
     146        cancellation.ThrowIfCancellationRequested();
     147        model.FixParameters();
     148        var sol = new GaussianProcessRegressionSolution(model, problemdata);
     149        if (newSolution.TrainingMeanSquaredError > sol.TrainingMeanSquaredError) {
     150          newSolution = sol;
     151        }
     152      } catch (ArgumentException) { }
     153      return newSolution;
     154    }
     155
    136156  }
    137157}
Note: See TracChangeset for help on using the changeset viewer.