Free cookie consent management tool by TermsFeed Policy Generator

Changeset 15436


Ignore:
Timestamp:
10/27/17 16:19:11 (7 years ago)
Author:
gkronber
Message:

#2789: added parameter properties to GAM

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/MathNetNumerics-Exploration-2789/HeuristicLab.Algorithms.DataAnalysis.Experimental/GAM.cs

    r15433 r15436  
    4444  [StorableClass]
    4545  public sealed class GAM : FixedDataAnalysisAlgorithm<IRegressionProblem> {
     46
     47    private const string LambdaParameterName = "Lambda";
     48    private const string MaxIterationsParameterName = "Max iterations";
     49    private const string MaxInteractionsParameterName = "Max interactions";
     50
     51    public IFixedValueParameter<DoubleValue> LambdaParameter {
     52      get { return (IFixedValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
     53    }
     54    public IFixedValueParameter<IntValue> MaxIterationsParameter {
     55      get { return (IFixedValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }
     56    }
     57    public IFixedValueParameter<IntValue> MaxInteractionsParameter {
     58      get { return (IFixedValueParameter<IntValue>)Parameters[MaxInteractionsParameterName]; }
     59    }
     60
     61    public double Lambda {
     62      get { return LambdaParameter.Value.Value; }
     63      set { LambdaParameter.Value.Value = value; }
     64    }
     65    public int MaxIterations {
     66      get { return MaxIterationsParameter.Value.Value; }
     67      set { MaxIterationsParameter.Value.Value = value; }
     68    }
     69    public int MaxInteractions {
     70      get { return MaxInteractionsParameter.Value.Value; }
     71      set { MaxInteractionsParameter.Value.Value = value; }
     72    }
     73
    4674    [StorableConstructor]
    4775    private GAM(bool deserializing) : base(deserializing) { }
    4876    [StorableHook(HookType.AfterDeserialization)]
    49     private void AfterDeserialization() {     
     77    private void AfterDeserialization() {
    5078    }
    5179
     
    6088      : base() {
    6189      Problem = new RegressionProblem();
    62       Parameters.Add(new ValueParameter<DoubleValue>("Lambda", "Regularization for smoothing splines", new DoubleValue(1.0)));
    63       Parameters.Add(new ValueParameter<IntValue>("Max iterations", "", new IntValue(100)));
    64       Parameters.Add(new ValueParameter<IntValue>("Max interactions", "", new IntValue(1)));
    65     }   
     90      Parameters.Add(new FixedValueParameter<DoubleValue>(LambdaParameterName, "Regularization for smoothing splines", new DoubleValue(1.0)));
     91      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "", new IntValue(100)));
     92      Parameters.Add(new FixedValueParameter<IntValue>(MaxInteractionsParameterName, "", new IntValue(1)));
     93    }
     94
    6695
    6796    protected override void Run(CancellationToken cancellationToken) {
    68 
    69       double lambda = ((IValueParameter<DoubleValue>)Parameters["Lambda"]).Value.Value;
    70       int maxIters = ((IValueParameter<IntValue>)Parameters["Max iterations"]).Value.Value;
    71       int maxInteractions = ((IValueParameter<IntValue>)Parameters["Max interactions"]).Value.Value;
     97      double lambda = Lambda;
     98      int maxIters = MaxIterations ;
     99      int maxInteractions = MaxInteractions;
    72100      if (maxInteractions < 1 || maxInteractions > 5) throw new ArgumentException("Max interactions is outside the valid range [1 .. 5]");
    73101
     
    80108      var inputVars = Problem.ProblemData.AllowedInputVariables.ToArray();
    81109      var nTerms = inputVars.Length; // LR
    82       for(int i=1;i<=maxInteractions;i++) {
     110      for (int i = 1; i <= maxInteractions; i++) {
    83111        nTerms += inputVars.Combinations(i).Count();
    84112      }
    85       IRegressionModel[] f = new IRegressionModel[nTerms]; 
    86       for(int i=0;i<f.Length;i++) {
     113      IRegressionModel[] f = new IRegressionModel[nTerms];
     114      for (int i = 0; i < f.Length; i++) {
    87115        f[i] = new ConstantModel(0.0, problemData.TargetVariable);
    88116      }
     
    116144        }
    117145
    118         for(int interaction = 1; interaction <= maxInteractions;interaction++) {
     146        for (int interaction = 1; interaction <= maxInteractions; interaction++) {
    119147          var selectedVars = HeuristicLab.Common.EnumerableExtensions.Combinations(inputVars, interaction);
    120148
     
    141169      var model = new RegressionEnsembleModel(f.Concat(new[] { new ConstantModel(avgY, problemData.TargetVariable) }));
    142170      model.AverageModelEstimates = false;
    143       Results.Add(new Result("Ensemble solution", model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone())));
    144 
     171      var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
     172      Results.Add(new Result("Ensemble solution", solution));
    145173    }
    146174
     
    187215      if (inputVars.All(problemData.Dataset.VariableHasType<double>)) {
    188216        var product = problemData.Dataset.GetDoubleValues(inputVars.First(), problemData.TrainingIndices).ToArray();
    189         for(int i = 1;i<inputVars.Length;i++) {
     217        for (int i = 1; i < inputVars.Length; i++) {
    190218          product = product.Zip(problemData.Dataset.GetDoubleValues(inputVars[i], problemData.TrainingIndices), (pi, vi) => pi * vi).ToArray();
    191219        }
     
    210238        pd.TestPartition.End = problemData.TestPartition.End;
    211239        double rmsError, oobRmsError;
    212         double avgRelError, oobAvgRelError; 
     240        double avgRelError, oobAvgRelError;
    213241        return RandomForestRegression.CreateRandomForestRegressionModel(pd, 100, 0.5, 0.5, 1234, out rmsError, out avgRelError, out oobRmsError, out oobAvgRelError);
    214242      } else return new ConstantModel(target.Average(), problemData.TargetVariable);
Note: See TracChangeset for help on using the changeset viewer.