Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/29/19 15:57:35 (5 years ago)
Author:
mkommend
Message:

#2521: Merged trunk changes into problem refactoring branch.

Location:
branches/2521_ProblemRefactoring
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • branches/2521_ProblemRefactoring

  • branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis

  • branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4

  • branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r16723 r17226  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    2020#endregion
    2121
     22using System.Collections.Generic;
     23using System.Linq;
    2224using System.Threading;
     25using HEAL.Attic;
     26using HeuristicLab.Algorithms.DataAnalysis.RandomForest;
    2327using HeuristicLab.Common;
    2428using HeuristicLab.Core;
     
    2630using HeuristicLab.Optimization;
    2731using HeuristicLab.Parameters;
    28 using HEAL.Attic;
    2932using HeuristicLab.Problems.DataAnalysis;
    3033
     
    4346    private const string SeedParameterName = "Seed";
    4447    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    45     private const string CreateSolutionParameterName = "CreateSolution";
     48    private const string ModelCreationParameterName = "ModelCreation";
    4649
    4750    #region parameter properties
     
    6164      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    6265    }
    63     public IFixedValueParameter<BoolValue> CreateSolutionParameter {
    64       get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
     66    private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter {
     67      get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; }
    6568    }
    6669    #endregion
     
    8689      set { SetSeedRandomlyParameter.Value.Value = value; }
    8790    }
    88     public bool CreateSolution {
    89       get { return CreateSolutionParameter.Value.Value; }
    90       set { CreateSolutionParameter.Value.Value = value; }
     91    public ModelCreation ModelCreation {
     92      get { return ModelCreationParameter.Value.Value; }
     93      set { ModelCreationParameter.Value.Value = value; }
    9194    }
    9295    #endregion
     
    104107      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    105108      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    106       Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    107       Parameters[CreateSolutionParameterName].Hidden = true;
     109      Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
     110      Parameters[ModelCreationParameterName].Hidden = true;
    108111
    109112      Problem = new RegressionProblem();
     
    120123      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
    121124        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    122       if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
    123         Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    124         Parameters[CreateSolutionParameterName].Hidden = true;
     125
     126      // parameter type has been changed
     127      if (Parameters.ContainsKey("CreateSolution")) {
     128        var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>;
     129        Parameters.Remove(createSolutionParam);
     130
     131        ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly;
     132        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
     133        Parameters[ModelCreationParameterName].Hidden = true;
     134      } else if (!Parameters.ContainsKey(ModelCreationParameterName)) {
     135        // very old version contains neither ModelCreationParameter nor CreateSolutionParameter
     136        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
     137        Parameters[ModelCreationParameterName].Hidden = true;
    125138      }
    126139      #endregion
     
    143156      Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError)));
    144157
    145       if (CreateSolution) {
    146         var solution = new RandomForestRegressionSolution(model, (IRegressionProblemData)Problem.ProblemData.Clone());
     158      IRegressionSolution solution = null;
     159      if (ModelCreation == ModelCreation.Model) {
     160        solution = model.CreateRegressionSolution(Problem.ProblemData);
     161      } else if (ModelCreation == ModelCreation.SurrogateModel) {
     162        var problemData = Problem.ProblemData;
     163        var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M);
     164        solution = surrogateModel.CreateRegressionSolution(problemData);
     165      }
     166
     167      if (solution != null) {
    147168        Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
    148169      }
    149170    }
     171
    150172
    151173    // keep for compatibility with old API
     
    157179    }
    158180
    159     public static RandomForestModel CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
    160       double r, double m, int seed,
    161       out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
    162       return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed,
    163         rmsError: out rmsError, avgRelError: out avgRelError, outOfBagRmsError: out outOfBagRmsError, outOfBagAvgRelError: out outOfBagAvgRelError);
     181    public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
     182     double r, double m, int seed,
     183     out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     184      var model = CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     185      return model;
     186    }
     187
     188    public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
     189    out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     190
     191      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     192      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
     193
     194      alglib.dfreport rep;
     195      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
     196
     197      rmsError = rep.rmserror;
     198      outOfBagRmsError = rep.oobrmserror;
     199      avgRelError = rep.avgrelerror;
     200      outOfBagAvgRelError = rep.oobavgrelerror;
     201
     202      return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables);
    164203    }
    165204
Note: See TracChangeset for help on using the changeset viewer.