Free cookie consent management tool by TermsFeed Policy Generator

Changeset 15775


Ignore:
Timestamp:
02/14/18 00:19:49 (7 years ago)
Author:
gkronber
Message:

#2898 added simple implementation of GAM based on univariate penalized regression splines with the same penalization factor for each term

Location:
branches/2898_GeneralizedAdditiveModels/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
2 added
1 edited
1 copied

Legend:

Unmodified
Added
Removed
  • branches/2898_GeneralizedAdditiveModels/HeuristicLab.Algorithms.DataAnalysis/3.4/GAM/GeneralizedAdditiveModelAlgorithm.cs

    r15774 r15775  
    2929using HeuristicLab.Core;
    3030using HeuristicLab.Data;
    31 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    3231using HeuristicLab.Optimization;
    3332using HeuristicLab.Parameters;
    3433using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3534using HeuristicLab.Problems.DataAnalysis;
    36 using HeuristicLab.Problems.DataAnalysis.Symbolic;
    37 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    3835using HeuristicLab.Random;
    39 using HeuristicLab.Selection;
    40 
    41 namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
    42   [Item("Gradient Boosting Machine Regression (GBM)",
    43     "Gradient boosting for any regression base learner (e.g. MCTS symbolic regression)")]
     36
     37namespace HeuristicLab.Algorithms.DataAnalysis {
     38  [Item("Generalized Additive Model (GAM)",
     39    "Generalized Additive Model Algorithm")]
    4440  [StorableClass]
    45   [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 350)]
    46   public class GradientBoostingRegressionAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
    47 
     41  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 600)]
     42  public sealed class GeneralizedAdditiveModelAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
    4843    #region ParameterNames
    4944
    5045    private const string IterationsParameterName = "Iterations";
    51     private const string NuParameterName = "Nu";
    52     private const string MParameterName = "M";
    53     private const string RParameterName = "R";
    54     private const string RegressionAlgorithmParameterName = "RegressionAlgorithm";
     46    private const string LambdaParameterName = "Lambda";
    5547    private const string SeedParameterName = "Seed";
    5648    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    5749    private const string CreateSolutionParameterName = "CreateSolution";
    58     private const string StoreRunsParameterName = "StoreRuns";
    59     private const string RegressionAlgorithmSolutionResultParameterName = "RegressionAlgorithmResult";
    60 
    6150    #endregion
    6251
     
    6756    }
    6857
    69     public IFixedValueParameter<DoubleValue> NuParameter {
    70       get { return (IFixedValueParameter<DoubleValue>)Parameters[NuParameterName]; }
    71     }
    72 
    73     public IFixedValueParameter<DoubleValue> RParameter {
    74       get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
    75     }
    76 
    77     public IFixedValueParameter<DoubleValue> MParameter {
    78       get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
    79     }
    80 
    81     // regression algorithms are currently: DataAnalysisAlgorithms, BasicAlgorithms and engine algorithms with no common interface
    82     public IConstrainedValueParameter<IAlgorithm> RegressionAlgorithmParameter {
    83       get { return (IConstrainedValueParameter<IAlgorithm>)Parameters[RegressionAlgorithmParameterName]; }
    84     }
    85 
    86     public IFixedValueParameter<StringValue> RegressionAlgorithmSolutionResultParameter {
    87       get { return (IFixedValueParameter<StringValue>)Parameters[RegressionAlgorithmSolutionResultParameterName]; }
     58    public IFixedValueParameter<DoubleValue> LambdaParameter {
     59      get { return (IFixedValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
    8860    }
    8961
     
    9870    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
    9971      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
    100     }
    101     public IFixedValueParameter<BoolValue> StoreRunsParameter {
    102       get { return (IFixedValueParameter<BoolValue>)Parameters[StoreRunsParameterName]; }
    10372    }
    10473
     
    11281    }
    11382
     83    public double Lambda {
     84      get { return LambdaParameter.Value.Value; }
     85      set { LambdaParameter.Value.Value = value; }
     86    }
     87
    11488    public int Seed {
    11589      get { return SeedParameter.Value.Value; }
     
    12296    }
    12397
    124     public double Nu {
    125       get { return NuParameter.Value.Value; }
    126       set { NuParameter.Value.Value = value; }
    127     }
    128 
    129     public double R {
    130       get { return RParameter.Value.Value; }
    131       set { RParameter.Value.Value = value; }
    132     }
    133 
    134     public double M {
    135       get { return MParameter.Value.Value; }
    136       set { MParameter.Value.Value = value; }
    137     }
    138 
    13998    public bool CreateSolution {
    14099      get { return CreateSolutionParameter.Value.Value; }
     
    142101    }
    143102
    144     public bool StoreRuns {
    145       get { return StoreRunsParameter.Value.Value; }
    146       set { StoreRunsParameter.Value.Value = value; }
    147     }
    148 
    149     public IAlgorithm RegressionAlgorithm {
    150       get { return RegressionAlgorithmParameter.Value; }
    151     }
    152 
    153     public string RegressionAlgorithmResult {
    154       get { return RegressionAlgorithmSolutionResultParameter.Value.Value; }
    155       set { RegressionAlgorithmSolutionResultParameter.Value.Value = value; }
    156     }
    157 
    158103    #endregion
    159104
    160105    [StorableConstructor]
    161     protected GradientBoostingRegressionAlgorithm(bool deserializing)
     106    private GeneralizedAdditiveModelAlgorithm(bool deserializing)
    162107      : base(deserializing) {
    163108    }
    164109
    165     protected GradientBoostingRegressionAlgorithm(GradientBoostingRegressionAlgorithm original, Cloner cloner)
     110    private GeneralizedAdditiveModelAlgorithm(GeneralizedAdditiveModelAlgorithm original, Cloner cloner)
    166111      : base(original, cloner) {
    167112    }
    168113
    169114    public override IDeepCloneable Clone(Cloner cloner) {
    170       return new GradientBoostingRegressionAlgorithm(this, cloner);
    171     }
    172 
    173     public GradientBoostingRegressionAlgorithm() {
     115      return new GeneralizedAdditiveModelAlgorithm(this, cloner);
     116    }
     117
     118    public GeneralizedAdditiveModelAlgorithm() {
    174119      Problem = new RegressionProblem(); // default problem
    175       var osgp = CreateOSGP();
    176       var regressionAlgs = new ItemSet<IAlgorithm>(new IAlgorithm[] {
    177         new RandomForestRegression(),
    178         osgp,
    179       });
    180       foreach (var alg in regressionAlgs) alg.Prepare();
    181 
    182120
    183121      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName,
    184         "Number of iterations", new IntValue(100)));
     122        "Number of iterations. Try a large value and check convergence of the error over iterations. Usually, only a few iterations (e.g. 10) are needed for convergence.", new IntValue(10)));
     123      Parameters.Add(new FixedValueParameter<DoubleValue>(LambdaParameterName,
     124        "The penalty parameter for the penalized regression splines. Set to a value between -8 (weak smoothing) and 8 (strong smooting). Usually, a value between -4 and 4 should be fine", new DoubleValue(3)));
    185125      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName,
    186126        "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    187127      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName,
    188128        "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    189       Parameters.Add(new FixedValueParameter<DoubleValue>(NuParameterName,
    190         "The learning rate nu when updating predictions in GBM (0 < nu <= 1)", new DoubleValue(0.5)));
    191       Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName,
    192         "The fraction of rows that are sampled randomly for the base learner in each iteration (0 < r <= 1)",
    193         new DoubleValue(1)));
    194       Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName,
    195         "The fraction of variables that are sampled randomly for the base learner in each iteration (0 < m <= 1)",
    196         new DoubleValue(0.5)));
    197       Parameters.Add(new ConstrainedValueParameter<IAlgorithm>(RegressionAlgorithmParameterName,
    198         "The regression algorithm to use as a base learner", regressionAlgs, osgp));
    199       Parameters.Add(new FixedValueParameter<StringValue>(RegressionAlgorithmSolutionResultParameterName,
    200         "The name of the solution produced by the regression algorithm", new StringValue("Solution")));
    201       Parameters[RegressionAlgorithmSolutionResultParameterName].Hidden = true;
    202129      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName,
    203130        "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    204131      Parameters[CreateSolutionParameterName].Hidden = true;
    205       Parameters.Add(new FixedValueParameter<BoolValue>(StoreRunsParameterName,
    206         "Flag that indicates if the results of the individual runs should be stored for detailed analysis", new BoolValue(false)));
    207       Parameters[StoreRunsParameterName].Hidden = true;
    208132    }
    209133
     
    213137      var rand = new MersenneTwister((uint)Seed);
    214138
     139      // calculates a GAM model using univariate non-linear functions
     140      // using backfitting algorithm (see The Elements of Statistical Learning page 298)
     141
     142      // init
     143      var problemData = Problem.ProblemData;
     144      var ds = problemData.Dataset;
     145      var trainRows = problemData.TrainingIndices;
     146      var testRows = problemData.TestIndices;
     147      var avgY = problemData.TargetVariableTrainingValues.Average();
     148      var inputVars = problemData.AllowedInputVariables.ToArray();
     149
     150      int nTerms = inputVars.Length;
     151
     152      #region init results
    215153      // Set up the results display
    216154      var iterations = new IntValue(0);
     
    218156
    219157      var table = new DataTable("Qualities");
    220       table.Rows.Add(new DataRow("R² (train)"));
    221       table.Rows.Add(new DataRow("R² (test)"));
     158      var rmseRow = new DataRow("RMSE (train)");
     159      var rmseRowTest = new DataRow("RMSE (test)");
     160      table.Rows.Add(rmseRow);
     161      table.Rows.Add(rmseRowTest);
    222162      Results.Add(new Result("Qualities", table));
    223       var curLoss = new DoubleValue();
    224       var curTestLoss = new DoubleValue();
    225       Results.Add(new Result("R² (train)", curLoss));
    226       Results.Add(new Result("R² (test)", curTestLoss));
    227       var runCollection = new RunCollection();
    228       if (StoreRuns)
    229         Results.Add(new Result("Runs", runCollection));
    230 
    231       // init
    232       var problemData = Problem.ProblemData;
    233       var targetVarName = problemData.TargetVariable;
    234       var activeVariables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
    235       var modifiableDataset = new ModifiableDataset(
    236         activeVariables,
    237         activeVariables.Select(v => problemData.Dataset.GetDoubleValues(v).ToList()));
    238 
    239       var trainingRows = problemData.TrainingIndices;
    240       var testRows = problemData.TestIndices;
    241       var yPred = new double[trainingRows.Count()];
    242       var yPredTest = new double[testRows.Count()];
    243       var y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
    244       var curY = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
    245 
    246       var yTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToArray();
    247       var curYTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToArray();
    248       var nu = Nu;
    249       var mVars = (int)Math.Ceiling(M * problemData.AllowedInputVariables.Count());
    250       var rRows = (int)Math.Ceiling(R * problemData.TrainingIndices.Count());
    251       var alg = RegressionAlgorithm;
    252       List<IRegressionModel> models = new List<IRegressionModel>();
    253       try {
    254 
    255         // Loop until iteration limit reached or canceled.
    256         for (int i = 0; i < Iterations; i++) {
    257           cancellationToken.ThrowIfCancellationRequested();
    258 
    259           modifiableDataset.RemoveVariable(targetVarName);
    260           modifiableDataset.AddVariable(targetVarName, curY.Concat(curYTest).ToList());
    261 
    262           SampleTrainingData(rand, modifiableDataset, rRows, problemData.Dataset, curY, problemData.TargetVariable, problemData.TrainingIndices); // all training indices from the original problem data are allowed
    263           var modifiableProblemData = new RegressionProblemData(modifiableDataset,
    264             problemData.AllowedInputVariables.SampleRandomWithoutRepetition(rand, mVars),
    265             problemData.TargetVariable);
    266           modifiableProblemData.TrainingPartition.Start = 0;
    267           modifiableProblemData.TrainingPartition.End = rRows;
    268           modifiableProblemData.TestPartition.Start = problemData.TestPartition.Start;
    269           modifiableProblemData.TestPartition.End = problemData.TestPartition.End;
    270 
    271           if (!TrySetProblemData(alg, modifiableProblemData))
    272             throw new NotSupportedException("The algorithm cannot be used with GBM.");
    273 
    274           IRegressionModel model;
    275           IRun run;
    276 
    277           // try to find a model. The algorithm might fail to produce a model. In this case we just retry until the iterations are exhausted
    278           if (TryExecute(alg, rand.Next(), RegressionAlgorithmResult, out model, out run)) {
    279             int row = 0;
    280             // update predictions for training and test
    281             // update new targets (in the case of squared error loss we simply use negative residuals)
    282             foreach (var pred in model.GetEstimatedValues(problemData.Dataset, trainingRows)) {
    283               yPred[row] = yPred[row] + nu * pred;
    284               curY[row] = y[row] - yPred[row];
    285               row++;
    286             }
    287             row = 0;
    288             foreach (var pred in model.GetEstimatedValues(problemData.Dataset, testRows)) {
    289               yPredTest[row] = yPredTest[row] + nu * pred;
    290               curYTest[row] = yTest[row] - yPredTest[row];
    291               row++;
    292             }
    293             // determine quality
    294             OnlineCalculatorError error;
    295             var trainR = OnlinePearsonsRCalculator.Calculate(yPred, y, out error);
    296             var testR = OnlinePearsonsRCalculator.Calculate(yPredTest, yTest, out error);
    297 
    298             // iteration results
    299             curLoss.Value = error == OnlineCalculatorError.None ? trainR * trainR : 0.0;
    300             curTestLoss.Value = error == OnlineCalculatorError.None ? testR * testR : 0.0;
    301 
    302             models.Add(model);
    303 
    304 
    305           }
    306 
    307           if (StoreRuns)
    308             runCollection.Add(run);
    309           table.Rows["R² (train)"].Values.Add(curLoss.Value);
    310           table.Rows["R² (test)"].Values.Add(curTestLoss.Value);
    311           iterations.Value = i + 1;
     163      var curRMSE = new DoubleValue();
     164      var curRMSETest = new DoubleValue();
     165      Results.Add(new Result("RMSE (train)", curRMSE));
     166      Results.Add(new Result("RMSE (test)", curRMSETest));
     167
     168      // calculate table with residual contributions of each term
     169      var rssTable = new DoubleMatrix(nTerms, 1, new string[] { "RSS" }, inputVars);
     170      Results.Add(new Result("RSS Values", rssTable));
     171      #endregion
     172
     173      // start with a set of constant models = 0
     174      IRegressionModel[] f = new IRegressionModel[nTerms];
     175      for (int i = 0; i < f.Length; i++) {
     176        f[i] = new ConstantModel(0.0, problemData.TargetVariable);
     177      }
     178      // init res which contains the current residual vector
     179      double[] res = problemData.TargetVariableTrainingValues.Select(yi => yi - avgY).ToArray();
     180      double[] resTest = problemData.TargetVariableTestValues.Select(yi => yi - avgY).ToArray();
     181
     182      curRMSE.Value = res.StandardDeviation();
     183      curRMSETest.Value = resTest.StandardDeviation();
     184      rmseRow.Values.Add(res.StandardDeviation());
     185      rmseRowTest.Values.Add(resTest.StandardDeviation());
     186
     187
     188      double lambda = Lambda;
     189      var idx = Enumerable.Range(0, nTerms).ToArray();
     190
     191      // Loop until iteration limit reached or canceled.
     192      for (int i = 0; i < Iterations && !cancellationToken.IsCancellationRequested; i++) {
     193        // shuffle order of terms in each iteration to remove bias on earlier terms
     194        idx.ShuffleInPlace(rand);
     195        foreach (var inputIdx in idx) {
     196          var inputVar = inputVars[inputIdx];
     197          // first remove the effect of the previous model for the inputIdx (by adding the output of the current model to the residual)
     198          AddInPlace(res, f[inputIdx].GetEstimatedValues(ds, trainRows));
     199          AddInPlace(resTest, f[inputIdx].GetEstimatedValues(ds, testRows));
     200
     201          rssTable[inputIdx, 0] = res.Variance();
     202          f[inputIdx] = RegressSpline(problemData, inputVar, res, lambda);
     203
     204          SubtractInPlace(res, f[inputIdx].GetEstimatedValues(ds, trainRows));
     205          SubtractInPlace(resTest, f[inputIdx].GetEstimatedValues(ds, testRows));
    312206        }
    313207
    314         // produce solution
    315         if (CreateSolution) {
    316           // when all our models are symbolic models we can easily combine them to a single model
    317           if (models.All(m => m is ISymbolicRegressionModel)) {
    318             Results.Add(new Result("Solution", CreateSymbolicSolution(models, Nu, (IRegressionProblemData)problemData.Clone())));
    319           }
    320           // just produce an ensemble solution for now (TODO: correct scaling or linear regression for ensemble model weights)
    321 
    322           var ensembleSolution = CreateEnsembleSolution(models, (IRegressionProblemData)problemData.Clone());
    323           Results.Add(new Result("EnsembleSolution", ensembleSolution));
    324         }
    325       }
    326       finally {
    327         // reset everything
    328         alg.Prepare(true);
    329       }
    330     }
    331 
    332     private static IRegressionEnsembleSolution CreateEnsembleSolution(List<IRegressionModel> models,
    333       IRegressionProblemData problemData) {
    334       var rows = problemData.TrainingPartition.Size;
    335       var features = models.Count;
    336       double[,] inputMatrix = new double[rows, features + 1];
    337       //add model estimates
    338       for (int m = 0; m < models.Count; m++) {
    339         var model = models[m];
    340         var estimates = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
    341         int estimatesCounter = 0;
    342         foreach (var estimate in estimates) {
    343           inputMatrix[estimatesCounter, m] = estimate;
    344           estimatesCounter++;
    345         }
    346       }
    347 
    348       //add target
    349       var targets = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
    350       int targetCounter = 0;
    351       foreach (var target in targets) {
    352         inputMatrix[targetCounter, models.Count] = target;
    353         targetCounter++;
    354       }
    355 
    356       alglib.linearmodel lm = new alglib.linearmodel();
    357       alglib.lrreport ar = new alglib.lrreport();
    358       double[] coefficients;
    359       int retVal = 1;
    360       alglib.lrbuildz(inputMatrix, rows, features, out retVal, out lm, out ar);
    361       if (retVal != 1) throw new ArgumentException("Error in calculation of linear regression solution");
    362 
    363       alglib.lrunpack(lm, out coefficients, out features);
    364 
    365       var ensembleModel = new RegressionEnsembleModel(models, coefficients.Take(models.Count)) { AverageModelEstimates = false };
    366       var ensembleSolution = (IRegressionEnsembleSolution)ensembleModel.CreateRegressionSolution(problemData);
    367       return ensembleSolution;
    368     }
    369 
    370 
    371     private IAlgorithm CreateOSGP() {
    372       // configure strict osgp
    373       var alg = new OffspringSelectionGeneticAlgorithm.OffspringSelectionGeneticAlgorithm();
    374       var prob = new SymbolicRegressionSingleObjectiveProblem();
    375       prob.MaximumSymbolicExpressionTreeDepth.Value = 7;
    376       prob.MaximumSymbolicExpressionTreeLength.Value = 15;
    377       alg.Problem = prob;
    378       alg.SuccessRatio.Value = 1.0;
    379       alg.ComparisonFactorLowerBound.Value = 1.0;
    380       alg.ComparisonFactorUpperBound.Value = 1.0;
    381       alg.MutationProbability.Value = 0.15;
    382       alg.PopulationSize.Value = 200;
    383       alg.MaximumSelectionPressure.Value = 100;
    384       alg.MaximumEvaluatedSolutions.Value = 20000;
    385       alg.SelectorParameter.Value = alg.SelectorParameter.ValidValues.OfType<GenderSpecificSelector>().First();
    386       alg.MutatorParameter.Value = alg.MutatorParameter.ValidValues.OfType<MultiSymbolicExpressionTreeManipulator>().First();
    387       alg.StoreAlgorithmInEachRun = false;
    388       return alg;
    389     }
    390 
    391     private void SampleTrainingData(MersenneTwister rand, ModifiableDataset ds, int rRows,
    392       IDataset sourceDs, double[] curTarget, string targetVarName, IEnumerable<int> trainingIndices) {
    393       var selectedRows = trainingIndices.SampleRandomWithoutRepetition(rand, rRows).ToArray();
    394       int t = 0;
    395       object[] srcRow = new object[ds.Columns];
    396       var varNames = ds.DoubleVariables.ToArray();
    397       foreach (var r in selectedRows) {
    398         // take all values from the original dataset
    399         for (int c = 0; c < srcRow.Length; c++) {
    400           var col = sourceDs.GetReadOnlyDoubleValues(varNames[c]);
    401           srcRow[c] = col[r];
    402         }
    403         ds.ReplaceRow(t, srcRow);
    404         // but use the updated target values
    405         ds.SetVariableValue(curTarget[r], targetVarName, t);
    406         t++;
    407       }
    408     }
    409 
    410     private static ISymbolicRegressionSolution CreateSymbolicSolution(List<IRegressionModel> models, double nu, IRegressionProblemData problemData) {
    411       var symbModels = models.OfType<ISymbolicRegressionModel>();
    412       var lowerLimit = symbModels.Min(m => m.LowerEstimationLimit);
    413       var upperLimit = symbModels.Max(m => m.UpperEstimationLimit);
    414       var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
    415       var progRootNode = new ProgramRootSymbol().CreateTreeNode();
    416       var startNode = new StartSymbol().CreateTreeNode();
    417 
    418       var addNode = new Addition().CreateTreeNode();
    419       var mulNode = new Multiplication().CreateTreeNode();
    420       var scaleNode = (ConstantTreeNode)new Constant().CreateTreeNode(); // all models are scaled using the same nu
    421       scaleNode.Value = nu;
    422 
    423       foreach (var m in symbModels) {
    424         var relevantPart = m.SymbolicExpressionTree.Root.GetSubtree(0).GetSubtree(0); // skip root and start
    425         addNode.AddSubtree((ISymbolicExpressionTreeNode)relevantPart.Clone());
    426       }
    427 
    428       mulNode.AddSubtree(addNode);
    429       mulNode.AddSubtree(scaleNode);
    430       startNode.AddSubtree(mulNode);
    431       progRootNode.AddSubtree(startNode);
    432       var t = new SymbolicExpressionTree(progRootNode);
    433       var combinedModel = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerLimit, upperLimit);
    434       var sol = new SymbolicRegressionSolution(combinedModel, problemData);
    435       return sol;
    436     }
    437 
    438     private static bool TrySetProblemData(IAlgorithm alg, IRegressionProblemData problemData) {
    439       var prob = alg.Problem as IRegressionProblem;
    440       // there is already a problem and it is compatible -> just set problem data
    441       if (prob != null) {
    442         prob.ProblemDataParameter.Value = problemData;
    443         return true;
    444       } else return false;
    445     }
    446 
    447     private static bool TryExecute(IAlgorithm alg, int seed, string regressionAlgorithmResultName, out IRegressionModel model, out IRun run) {
    448       model = null;
    449       SetSeed(alg, seed);
    450       using (var wh = new AutoResetEvent(false)) {
    451         Exception ex = null;
    452         EventHandler<EventArgs<Exception>> handler = (sender, args) => {
    453           ex = args.Value;
    454           wh.Set();
    455         };
    456         EventHandler handler2 = (sender, args) => wh.Set();
    457         alg.ExceptionOccurred += handler;
    458         alg.Stopped += handler2;
    459         try {
    460           alg.Prepare();
    461           alg.Start();
    462           wh.WaitOne();
    463 
    464           if (ex != null) throw new AggregateException(ex);
    465           run = alg.Runs.Last();
    466           alg.Runs.Clear();
    467           var sols = alg.Results.Select(r => r.Value).OfType<IRegressionSolution>();
    468           if (!sols.Any()) return false;
    469           var sol = sols.First();
    470           if (sols.Skip(1).Any()) {
    471             // more than one solution => use regressionAlgorithmResult
    472             if (alg.Results.ContainsKey(regressionAlgorithmResultName)) {
    473               sol = (IRegressionSolution)alg.Results[regressionAlgorithmResultName].Value;
    474             }
    475           }
    476           var symbRegSol = sol as SymbolicRegressionSolution;
    477           // only accept symb reg solutions that do not hit the estimation limits
    478           // NaN evaluations would not be critical but are problematic if we want to combine all symbolic models into a single symbolic model
    479           if (symbRegSol == null ||
    480             (symbRegSol.TrainingLowerEstimationLimitHits == 0 && symbRegSol.TrainingUpperEstimationLimitHits == 0 &&
    481              symbRegSol.TestLowerEstimationLimitHits == 0 && symbRegSol.TestUpperEstimationLimitHits == 0) &&
    482             symbRegSol.TrainingNaNEvaluations == 0 && symbRegSol.TestNaNEvaluations == 0) {
    483             model = sol.Model;
    484           }
    485         }
    486         finally {
    487           alg.ExceptionOccurred -= handler;
    488           alg.Stopped -= handler2;
    489         }
    490       }
    491       return model != null;
    492     }
    493 
    494     private static void SetSeed(IAlgorithm alg, int seed) {
    495       // no common interface for algs that use a PRNG -> use naming convention to set seed
    496       var paramItem = alg as IParameterizedItem;
    497 
    498       if (paramItem.Parameters.ContainsKey("SetSeedRandomly")) {
    499         ((BoolValue)paramItem.Parameters["SetSeedRandomly"].ActualValue).Value = false;
    500         ((IntValue)paramItem.Parameters["Seed"].ActualValue).Value = seed;
    501       } else {
    502         throw new ArgumentException("Base learner does not have a seed parameter (algorithm {0})", alg.Name);
    503       }
    504 
     208        curRMSE.Value = res.StandardDeviation();
     209        curRMSETest.Value = resTest.StandardDeviation();
     210        rmseRow.Values.Add(curRMSE.Value);
     211        rmseRowTest.Values.Add(curRMSETest.Value);
     212        iterations.Value = i;
     213      }
     214
     215      // produce solution
     216      if (CreateSolution) {
     217        var model = new RegressionEnsembleModel(f.Concat(new[] { new ConstantModel(avgY, problemData.TargetVariable) }));
     218        model.AverageModelEstimates = false;
     219        var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
     220        Results.Add(new Result("Ensemble solution", solution));
     221      }
     222    }
     223
     224    private IRegressionModel RegressSpline(IRegressionProblemData problemData, string inputVar, double[] target, double lambda) {
     225      var x = problemData.Dataset.GetDoubleValues(inputVar, problemData.TrainingIndices).ToArray();
     226      var y = (double[])target.Clone();
     227      int info;
     228      alglib.spline1dinterpolant s;
     229      alglib.spline1dfitreport rep;
     230      int numKnots = (int)Math.Min(50, 3 * Math.Sqrt(x.Length)); // heuristic for number of knots  (forgot the source, but it is probably the R documentation or Elements of Statistical Learning)
     231
     232      alglib.spline1dfitpenalized(x, y, numKnots, lambda, out info, out s, out rep);
     233
     234      return new Spline1dModel(s.innerobj, problemData.TargetVariable, inputVar);
     235    }
     236
     237
     238    private static void AddInPlace(double[] a, IEnumerable<double> enumerable) {
     239      int i = 0;
     240      foreach (var elem in enumerable) {
     241        a[i] += elem;
     242        i++;
     243      }
     244    }
     245
     246    private static void SubtractInPlace(double[] a, IEnumerable<double> enumerable) {
     247      int i = 0;
     248      foreach (var elem in enumerable) {
     249        a[i] -= elem;
     250        i++;
     251      }
    505252    }
    506253  }
  • branches/2898_GeneralizedAdditiveModels/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r15532 r15775  
    132132    </Compile>
    133133    <Compile Include="FixedDataAnalysisAlgorithm.cs" />
     134    <Compile Include="GAM\Spline1dModel.cs" />
     135    <Compile Include="GAM\GeneralizedAdditiveModelAlgorithm.cs" />
    134136    <Compile Include="GaussianProcess\CovarianceFunctions\CovarianceSpectralMixture.cs" />
    135137    <Compile Include="GaussianProcess\CovarianceFunctions\CovariancePiecewisePolynomial.cs" />
     
    320322    <Compile Include="TSNE\Distances\IndexedItemDistance.cs" />
    321323    <Compile Include="TSNE\Distances\ManhattanDistance.cs" />
    322   <Compile Include="TSNE\Distances\WeightedEuclideanDistance.cs" />
     324    <Compile Include="TSNE\Distances\WeightedEuclideanDistance.cs" />
    323325    <Compile Include="TSNE\Distances\IDistance.cs" />
    324326    <Compile Include="TSNE\PriorityQueue.cs" />
Note: See TracChangeset for help on using the changeset viewer.