Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/02/16 09:08:09 (8 years ago)
Author:
gkronber
Message:

#1795: merged r13646,13653,13655,13699,13703,13707,13889,13898,13917 from trunk to stable

Location:
stable
Files:
3 edited
1 copied

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GBM/GradientBoostingRegressionAlgorithm.cs

    r13646 r13977  
    3737using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    3838using HeuristicLab.Random;
     39using HeuristicLab.Selection;
    3940
    4041namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
     
    6364    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    6465    private const string CreateSolutionParameterName = "CreateSolution";
     66    private const string StoreRunsParameterName = "StoreRuns";
    6567    private const string RegressionAlgorithmSolutionResultParameterName = "RegressionAlgorithmResult";
    6668
     
    105107      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
    106108    }
     109    public IFixedValueParameter<BoolValue> StoreRunsParameter {
     110      get { return (IFixedValueParameter<BoolValue>)Parameters[StoreRunsParameterName]; }
     111    }
    107112
    108113    #endregion
     
    145150    }
    146151
     152    public bool StoreRuns {
     153      get { return StoreRunsParameter.Value.Value; }
     154      set { StoreRunsParameter.Value.Value = value; }
     155    }
     156
    147157    public IAlgorithm RegressionAlgorithm {
    148158      get { return RegressionAlgorithmParameter.Value; }
     
    172182      Problem = new RegressionProblem(); // default problem
    173183      var mctsSymbReg = new MctsSymbolicRegressionAlgorithm();
    174       // var sgp = CreateSGP();
     184      mctsSymbReg.Iterations = 10000;
     185      mctsSymbReg.StoreAlgorithmInEachRun = false;
     186      var sgp = CreateOSGP();
    175187      var regressionAlgs = new ItemSet<IAlgorithm>(new IAlgorithm[] {
    176         new LinearRegression(), new RandomForestRegression(), new NearestNeighbourRegression(),
    177         // sgp,
     188        new RandomForestRegression(),
     189        sgp,
    178190        mctsSymbReg
    179191      });
     
    203215        "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    204216      Parameters[CreateSolutionParameterName].Hidden = true;
     217      Parameters.Add(new FixedValueParameter<BoolValue>(StoreRunsParameterName,
     218        "Flag that indicates if the results of the individual runs should be stored for detailed analysis", new BoolValue(false)));
     219      Parameters[StoreRunsParameterName].Hidden = true;
    205220    }
    206221
     
    215230
    216231      var table = new DataTable("Qualities");
    217       table.Rows.Add(new DataRow("Loss (train)"));
    218       table.Rows.Add(new DataRow("Loss (test)"));
     232      table.Rows.Add(new DataRow(" (train)"));
     233      table.Rows.Add(new DataRow(" (test)"));
    219234      Results.Add(new Result("Qualities", table));
    220235      var curLoss = new DoubleValue();
    221236      var curTestLoss = new DoubleValue();
    222       Results.Add(new Result("Loss (train)", curLoss));
    223       Results.Add(new Result("Loss (test)", curTestLoss));
     237      Results.Add(new Result(" (train)", curLoss));
     238      Results.Add(new Result(" (test)", curTestLoss));
    224239      var runCollection = new RunCollection();
    225       Results.Add(new Result("Runs", runCollection));
     240      if (StoreRuns)
     241        Results.Add(new Result("Runs", runCollection));
    226242
    227243      // init
    228244      var problemData = Problem.ProblemData;
    229       var targetVarName = Problem.ProblemData.TargetVariable;
     245      var targetVarName = problemData.TargetVariable;
     246      var activeVariables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
    230247      var modifiableDataset = new ModifiableDataset(
    231         problemData.Dataset.VariableNames,
    232         problemData.Dataset.VariableNames.Select(v => problemData.Dataset.GetDoubleValues(v).ToList()));
     248        activeVariables,
     249        activeVariables.Select(v => problemData.Dataset.GetDoubleValues(v).ToList()));
    233250
    234251      var trainingRows = problemData.TrainingIndices;
     
    269286          IRegressionModel model;
    270287          IRun run;
     288
    271289          // try to find a model. The algorithm might fail to produce a model. In this case we just retry until the iterations are exhausted
    272           if (TryExecute(alg, RegressionAlgorithmResult, out model, out run)) {
     290          if (TryExecute(alg, rand.Next(), RegressionAlgorithmResult, out model, out run)) {
    273291            int row = 0;
    274292            // update predictions for training and test
     
    299317          }
    300318
    301           runCollection.Add(run);
    302           table.Rows["Loss (train)"].Values.Add(curLoss.Value);
    303           table.Rows["Loss (test)"].Values.Add(curTestLoss.Value);
     319          if (StoreRuns)
     320            runCollection.Add(run);
     321          table.Rows["R² (train)"].Values.Add(curLoss.Value);
     322          table.Rows["R² (test)"].Values.Add(curTestLoss.Value);
    304323          iterations.Value = i + 1;
    305324        }
     
    312331          }
    313332          // just produce an ensemble solution for now (TODO: correct scaling or linear regression for ensemble model weights)
    314           Results.Add(new Result("EnsembleSolution", new RegressionEnsembleSolution(models, (IRegressionProblemData)problemData.Clone())));
     333
     334          var ensembleSolution = CreateEnsembleSolution(models, (IRegressionProblemData)problemData.Clone());
     335          Results.Add(new Result("EnsembleSolution", ensembleSolution));
    315336        }
    316       } finally {
     337      }
     338      finally {
    317339        // reset everything
    318340        alg.Prepare(true);
     
    320342    }
    321343
    322     // this is probably slow as hell
     344    private static IRegressionEnsembleSolution CreateEnsembleSolution(List<IRegressionModel> models,
     345      IRegressionProblemData problemData) {
     346      var rows = problemData.TrainingPartition.Size;
     347      var features = models.Count;
     348      double[,] inputMatrix = new double[rows, features + 1];
     349      //add model estimates
     350      for (int m = 0; m < models.Count; m++) {
     351        var model = models[m];
     352        var estimates = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
     353        int estimatesCounter = 0;
     354        foreach (var estimate in estimates) {
     355          inputMatrix[estimatesCounter, m] = estimate;
     356          estimatesCounter++;
     357        }
     358      }
     359
     360      //add target
     361      var targets = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     362      int targetCounter = 0;
     363      foreach (var target in targets) {
     364        inputMatrix[targetCounter, models.Count] = target;
     365        targetCounter++;
     366      }
     367
     368      alglib.linearmodel lm = new alglib.linearmodel();
     369      alglib.lrreport ar = new alglib.lrreport();
     370      double[] coefficients;
     371      int retVal = 1;
     372      alglib.lrbuildz(inputMatrix, rows, features, out retVal, out lm, out ar);
     373      if (retVal != 1) throw new ArgumentException("Error in calculation of linear regression solution");
     374
     375      alglib.lrunpack(lm, out coefficients, out features);
     376
     377      var ensembleModel = new RegressionEnsembleModel(models, coefficients.Take(models.Count)) { AverageModelEstimates = false };
     378      var ensembleSolution = ensembleModel.CreateRegressionSolution(problemData);
     379      return ensembleSolution;
     380    }
     381
     382
     383    private IAlgorithm CreateOSGP() {
     384      // configure strict osgp
     385      var alg = new OffspringSelectionGeneticAlgorithm.OffspringSelectionGeneticAlgorithm();
     386      var prob = new SymbolicRegressionSingleObjectiveProblem();
     387      prob.MaximumSymbolicExpressionTreeDepth.Value = 7;
     388      prob.MaximumSymbolicExpressionTreeLength.Value = 15;
     389      alg.Problem = prob;
     390      alg.SuccessRatio.Value = 1.0;
     391      alg.ComparisonFactorLowerBound.Value = 1.0;
     392      alg.ComparisonFactorUpperBound.Value = 1.0;
     393      alg.MutationProbability.Value = 0.15;
     394      alg.PopulationSize.Value = 200;
     395      alg.MaximumSelectionPressure.Value = 100;
     396      alg.MaximumEvaluatedSolutions.Value = 20000;
     397      alg.SelectorParameter.Value = alg.SelectorParameter.ValidValues.OfType<GenderSpecificSelector>().First();
     398      alg.MutatorParameter.Value = alg.MutatorParameter.ValidValues.OfType<MultiSymbolicExpressionTreeManipulator>().First();
     399      alg.StoreAlgorithmInEachRun = false;
     400      return alg;
     401    }
     402
    323403    private void SampleTrainingData(MersenneTwister rand, ModifiableDataset ds, int rRows,
    324404      IDataset sourceDs, double[] curTarget, string targetVarName, IEnumerable<int> trainingIndices) {
     
    374454        prob.ProblemDataParameter.Value = problemData;
    375455        return true;
    376       } else if (alg.Problem != null) {
    377         // a problem is set and it is not compatible
    378         return false;
    379       } else {
    380         try {
    381           // we try to use a symbolic regression problem (works for simple regression algs and GP)
    382           alg.Problem = new SymbolicRegressionSingleObjectiveProblem();
    383         } catch (Exception) {
    384           return false;
    385         }
    386         return true;
    387       }
    388     }
    389 
    390     private static bool TryExecute(IAlgorithm alg, string regressionAlgorithmResultName, out IRegressionModel model, out IRun run) {
     456      } else return false;
     457    }
     458
     459    private static bool TryExecute(IAlgorithm alg, int seed, string regressionAlgorithmResultName, out IRegressionModel model, out IRun run) {
    391460      model = null;
     461      SetSeed(alg, seed);
    392462      using (var wh = new AutoResetEvent(false)) {
    393         EventHandler<EventArgs<Exception>> handler = (sender, args) => wh.Set();
     463        Exception ex = null;
     464        EventHandler<EventArgs<Exception>> handler = (sender, args) => {
     465          ex = args.Value;
     466          wh.Set();
     467        };
    394468        EventHandler handler2 = (sender, args) => wh.Set();
    395469        alg.ExceptionOccurred += handler;
     
    400474          wh.WaitOne();
    401475
     476          if (ex != null) throw new AggregateException(ex);
    402477          run = alg.Runs.Last();
     478          alg.Runs.Clear();
    403479          var sols = alg.Results.Select(r => r.Value).OfType<IRegressionSolution>();
    404480          if (!sols.Any()) return false;
     
    419495            model = sol.Model;
    420496          }
    421         } finally {
     497        }
     498        finally {
    422499          alg.ExceptionOccurred -= handler;
    423500          alg.Stopped -= handler2;
     
    426503      return model != null;
    427504    }
     505
     506    private static void SetSeed(IAlgorithm alg, int seed) {
     507      // no common interface for algs that use a PRNG -> use naming convention to set seed
     508      var paramItem = alg as IParameterizedItem;
     509
     510      if (paramItem.Parameters.ContainsKey("SetSeedRandomly")) {
     511        ((BoolValue)paramItem.Parameters["SetSeedRandomly"].ActualValue).Value = false;
     512        ((IntValue)paramItem.Parameters["Seed"].ActualValue).Value = seed;
     513      } else {
     514        throw new ArgumentException("Base learner does not have a seed parameter (algorithm {0})", alg.Name);
     515      }
     516
     517    }
    428518  }
    429519}
Note: See TracChangeset for help on using the changeset viewer.