Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/19/21 16:07:45 (3 years ago)
Author:
mkommend
Message:

#2521: Merged trunk changes into branch.

Location:
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/GAM
Files:
1 edited
1 copied

Legend:

Unmodified
Added
Removed
  • branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/GAM/GeneralizedAdditiveModelAlgorithm.cs

    r17812 r18086  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    5  *
     3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     4 *
    65 * This file is part of HeuristicLab.
    76 *
     
    4039  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 600)]
    4140  public sealed class GeneralizedAdditiveModelAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
     41
    4242    #region ParameterNames
    4343
     
    4747    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    4848    private const string CreateSolutionParameterName = "CreateSolution";
     49
    4950    #endregion
    5051
     
    142143      var problemData = Problem.ProblemData;
    143144      var ds = problemData.Dataset;
    144       var trainRows = problemData.TrainingIndices;
    145       var testRows = problemData.TestIndices;
     145      var trainRows = problemData.TrainingIndices.ToArray();
     146      var testRows = problemData.TestIndices.ToArray();
    146147      var avgY = problemData.TargetVariableTrainingValues.Average();
    147148      var inputVars = problemData.AllowedInputVariables.ToArray();
     
    179180      double[] resTest = problemData.TargetVariableTestValues.Select(yi => yi - avgY).ToArray();
    180181
    181       curRMSE.Value = res.StandardDeviation();
    182       curRMSETest.Value = resTest.StandardDeviation();
    183       rmseRow.Values.Add(res.StandardDeviation());
    184       rmseRowTest.Values.Add(resTest.StandardDeviation());
     182      curRMSE.Value = RMSE(res);
     183      curRMSETest.Value = RMSE(resTest);
     184      rmseRow.Values.Add(curRMSE.Value);
     185      rmseRowTest.Values.Add(curRMSETest.Value);
    185186
    186187
     
    198199          AddInPlace(resTest, f[inputIdx].GetEstimatedValues(ds, testRows));
    199200
    200           rssTable[inputIdx, 0] = res.Variance();
     201          rssTable[inputIdx, 0] = MSE(res);
    201202          f[inputIdx] = RegressSpline(problemData, inputVar, res, lambda);
    202203
     
    205206        }
    206207
    207         curRMSE.Value = res.StandardDeviation();
    208         curRMSETest.Value = resTest.StandardDeviation();
     208        curRMSE.Value = RMSE(res);
     209        curRMSETest.Value = RMSE(resTest);
    209210        rmseRow.Values.Add(curRMSE.Value);
    210211        rmseRowTest.Values.Add(curRMSETest.Value);
     
    216217        var model = new RegressionEnsembleModel(f.Concat(new[] { new ConstantModel(avgY, problemData.TargetVariable) }));
    217218        model.AverageModelEstimates = false;
    218         var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
     219        var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());       
    219220        Results.Add(new Result("Ensemble solution", solution));
    220221      }
     222    }
     223
     224    public static double MSE(IEnumerable<double> residuals) {
     225      var mse  = residuals.Select(r => r * r).Average();
     226      return mse;
     227    }
     228
     229    public static double RMSE(IEnumerable<double> residuals) {
     230      var mse = MSE(residuals);
     231      var rmse = Math.Sqrt(mse);
     232      return rmse;
    221233    }
    222234
Note: See TracChangeset for help on using the changeset viewer.