Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/19/21 16:07:45 (2 years ago)
Author:
mkommend
Message:

#2521: Merged trunk changes into branch.

Location:
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/GAM
Files:
2 edited
1 copied

Legend:

Unmodified
Added
Removed
  • branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/GAM/GeneralizedAdditiveModelAlgorithm.cs

    r17812 r18086  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    5  *
     3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     4 *
    65 * This file is part of HeuristicLab.
    76 *
     
    4039  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 600)]
    4140  public sealed class GeneralizedAdditiveModelAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
     41
    4242    #region ParameterNames
    4343
     
    4747    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    4848    private const string CreateSolutionParameterName = "CreateSolution";
     49
    4950    #endregion
    5051
     
    142143      var problemData = Problem.ProblemData;
    143144      var ds = problemData.Dataset;
    144       var trainRows = problemData.TrainingIndices;
    145       var testRows = problemData.TestIndices;
     145      var trainRows = problemData.TrainingIndices.ToArray();
     146      var testRows = problemData.TestIndices.ToArray();
    146147      var avgY = problemData.TargetVariableTrainingValues.Average();
    147148      var inputVars = problemData.AllowedInputVariables.ToArray();
     
    179180      double[] resTest = problemData.TargetVariableTestValues.Select(yi => yi - avgY).ToArray();
    180181
    181       curRMSE.Value = res.StandardDeviation();
    182       curRMSETest.Value = resTest.StandardDeviation();
    183       rmseRow.Values.Add(res.StandardDeviation());
    184       rmseRowTest.Values.Add(resTest.StandardDeviation());
     182      curRMSE.Value = RMSE(res);
     183      curRMSETest.Value = RMSE(resTest);
     184      rmseRow.Values.Add(curRMSE.Value);
     185      rmseRowTest.Values.Add(curRMSETest.Value);
    185186
    186187
     
    198199          AddInPlace(resTest, f[inputIdx].GetEstimatedValues(ds, testRows));
    199200
    200           rssTable[inputIdx, 0] = res.Variance();
     201          rssTable[inputIdx, 0] = MSE(res);
    201202          f[inputIdx] = RegressSpline(problemData, inputVar, res, lambda);
    202203
     
    205206        }
    206207
    207         curRMSE.Value = res.StandardDeviation();
    208         curRMSETest.Value = resTest.StandardDeviation();
     208        curRMSE.Value = RMSE(res);
     209        curRMSETest.Value = RMSE(resTest);
    209210        rmseRow.Values.Add(curRMSE.Value);
    210211        rmseRowTest.Values.Add(curRMSETest.Value);
     
    216217        var model = new RegressionEnsembleModel(f.Concat(new[] { new ConstantModel(avgY, problemData.TargetVariable) }));
    217218        model.AverageModelEstimates = false;
    218         var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
     219        var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());       
    219220        Results.Add(new Result("Ensemble solution", solution));
    220221      }
     222    }
     223
     224    public static double MSE(IEnumerable<double> residuals) {
     225      var mse  = residuals.Select(r => r * r).Average();
     226      return mse;
     227    }
     228
     229    public static double RMSE(IEnumerable<double> residuals) {
     230      var mse = MSE(residuals);
     231      var rmse = Math.Sqrt(mse);
     232      return rmse;
    221233    }
    222234
  • branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/GAM/Spline1dModel.cs

    r17812 r18086  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    5  *
     3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     4 *
    65 * This file is part of HeuristicLab.
    76 *
     
    2120#endregion
    2221
    23 using System;
    2422using HEAL.Attic;
    2523using System.Collections.Generic;
     
    2725using HeuristicLab.Common;
    2826using HeuristicLab.Core;
    29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3027using HeuristicLab.Problems.DataAnalysis;
     28using System;
    3129
    3230namespace HeuristicLab.Algorithms.DataAnalysis {
     
    3836    private alglib.spline1d.spline1dinterpolant interpolant;
    3937
    40     [Storable]
    41     private readonly string[] variablesUsedForPrediction;
    42     public override IEnumerable<string> VariablesUsedForPrediction {
    43       get {
    44         return variablesUsedForPrediction;
     38    [Storable(OldName = "variablesUsedForPrediction")]
     39    private string[] StorableVariablesUsedForPrediction {
     40      set {
     41        if (value.Length > 1) throw new ArgumentException("A one-dimensional spline model supports only one input variable.");
     42        inputVariable = value[0];
    4543      }
    4644    }
     45
     46    [Storable]
     47    private string inputVariable;
     48    public override IEnumerable<string> VariablesUsedForPrediction => new[] { inputVariable };
    4749
    4850    [StorableConstructor]
     
    5254
    5355    private Spline1dModel(Spline1dModel orig, Cloner cloner) : base(orig, cloner) {
    54       this.variablesUsedForPrediction = orig.VariablesUsedForPrediction.ToArray();
    55       this.interpolant = (alglib.spline1d.spline1dinterpolant)orig.interpolant.make_copy();
     56      this.inputVariable = orig.inputVariable;
     57      if(orig.interpolant != null) this.interpolant = (alglib.spline1d.spline1dinterpolant)orig.interpolant.make_copy();
    5658    }
    5759    public Spline1dModel(alglib.spline1d.spline1dinterpolant interpolant, string targetVar, string inputVar)
    58       : base("Spline model (1d)", "Spline model (1d)") {
     60      : base(targetVar, $"Spline model ({inputVar})") {
    5961      this.interpolant = (alglib.spline1d.spline1dinterpolant)interpolant.make_copy();
    60       this.TargetVariable = targetVar;
    61       this.variablesUsedForPrediction = new string[] { inputVar };
     62      this.inputVariable = inputVar;
    6263    }
    6364
    6465
    65     public override IDeepCloneable Clone(Cloner cloner) {
    66       return new Spline1dModel(this, cloner);
     66    public override IDeepCloneable Clone(Cloner cloner) => new Spline1dModel(this, cloner);
     67
     68    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     69      var solution =  new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
     70      solution.Name = $"Regression Spline ({inputVariable})";
     71
     72      return solution;
    6773    }
    6874
    69     public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    70       return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
    71     }
    72 
    73     public double GetEstimatedValue(double x) {
    74       return alglib.spline1d.spline1dcalc(interpolant, x);
    75     }
     75    public double GetEstimatedValue(double x) => alglib.spline1d.spline1dcalc(interpolant, x, null);
    7676
    7777    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    78       var x = dataset.GetDoubleValues(VariablesUsedForPrediction.First(), rows).ToArray();
    79       foreach (var xi in x) {
    80         yield return GetEstimatedValue(xi);
    81       }
     78      return dataset.GetDoubleValues(inputVariable, rows).Select(GetEstimatedValue);
    8279    }
    8380
Note: See TracChangeset for help on using the changeset viewer.