Changeset 17888


Ignore:
Timestamp:
03/12/21 14:35:03 (5 months ago)
Author:
mkommend
Message:

#2898: Corrected calculation of MSE and RMSE in GAMs by implementing methods for their calculation.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/GAM/GeneralizedAdditiveModelAlgorithm.cs

    r17815 r17888  
    3939  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 600)]
    4040  public sealed class GeneralizedAdditiveModelAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
     41
    4142    #region ParameterNames
    4243
     
    4647    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    4748    private const string CreateSolutionParameterName = "CreateSolution";
     49
    4850    #endregion
    4951
     
    141143      var problemData = Problem.ProblemData;
    142144      var ds = problemData.Dataset;
    143       var trainRows = problemData.TrainingIndices;
    144       var testRows = problemData.TestIndices;
     145      var trainRows = problemData.TrainingIndices.ToArray();
     146      var testRows = problemData.TestIndices.ToArray();
    145147      var avgY = problemData.TargetVariableTrainingValues.Average();
    146148      var inputVars = problemData.AllowedInputVariables.ToArray();
     
    178180      double[] resTest = problemData.TargetVariableTestValues.Select(yi => yi - avgY).ToArray();
    179181
    180       curRMSE.Value = res.StandardDeviation();
    181       curRMSETest.Value = resTest.StandardDeviation();
    182       rmseRow.Values.Add(res.StandardDeviation());
    183       rmseRowTest.Values.Add(resTest.StandardDeviation());
     182      curRMSE.Value = RMSE(res);
     183      curRMSETest.Value = RMSE(resTest);
     184      rmseRow.Values.Add(curRMSE.Value);
     185      rmseRowTest.Values.Add(curRMSETest.Value);
    184186
    185187
     
    197199          AddInPlace(resTest, f[inputIdx].GetEstimatedValues(ds, testRows));
    198200
    199           rssTable[inputIdx, 0] = res.Variance();
     201          rssTable[inputIdx, 0] = MSE(res);
    200202          f[inputIdx] = RegressSpline(problemData, inputVar, res, lambda);
    201203
     
    204206        }
    205207
    206         curRMSE.Value = res.StandardDeviation();
    207         curRMSETest.Value = resTest.StandardDeviation();
     208        curRMSE.Value = RMSE(res);
     209        curRMSETest.Value = RMSE(resTest);
    208210        rmseRow.Values.Add(curRMSE.Value);
    209211        rmseRowTest.Values.Add(curRMSETest.Value);
     
    215217        var model = new RegressionEnsembleModel(f.Concat(new[] { new ConstantModel(avgY, problemData.TargetVariable) }));
    216218        model.AverageModelEstimates = false;
    217         var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
     219        var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());       
    218220        Results.Add(new Result("Ensemble solution", solution));
    219221      }
     222    }
     223
     224    public static double MSE(IEnumerable<double> residuals) {
     225      var mse  = residuals.Select(r => r * r).Average();
     226      return mse;
     227    }
     228
     229    public static double RMSE(IEnumerable<double> residuals) {
     230      var mse = MSE(residuals);
     231      var rmse = Math.Sqrt(mse);
     232      return rmse;
    220233    }
    221234
Note: See TracChangeset for help on using the changeset viewer.