Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2789_MathNetNumerics-Exploration/HeuristicLab.Algorithms.DataAnalysis.Experimental/RBF.cs @ 17078

Last change on this file since 17078 was 15352, checked in by gkronber, 7 years ago

#2789 testing alglib RBF and splines

File size: 7.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Concurrent;
24using System.Collections.Generic;
25using System.Linq;
26using System.Threading;
27using System.Threading.Tasks;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using HeuristicLab.Optimization;
33using HeuristicLab.Parameters;
34using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
35using HeuristicLab.Problems.DataAnalysis;
36using HeuristicLab.Problems.DataAnalysis.Symbolic;
37using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
38
39namespace HeuristicLab.Algorithms.DataAnalysis.Experimental {
40  [Item("RBF (alglib)", "")]
41  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 102)]
42  [StorableClass]
43  public sealed class RBF : FixedDataAnalysisAlgorithm<IRegressionProblem> {
44    [StorableConstructor]
45    private RBF(bool deserializing) : base(deserializing) { }
46    [StorableHook(HookType.AfterDeserialization)]
47    private void AfterDeserialization() {
48    }
49
50    private RBF(RBF original, Cloner cloner)
51      : base(original, cloner) {
52    }
53    public override IDeepCloneable Clone(Cloner cloner) {
54      return new RBF(this, cloner);
55    }
56
57    public RBF()
58      : base() {
59      Problem = new RegressionProblem();
60      Parameters.Add(new ValueParameter<DoubleValue>("RBase", new DoubleValue(1.0)));
61      Parameters.Add(new ValueParameter<IntValue>("NLayers", new IntValue(3)));
62      Parameters.Add(new ValueParameter<DoubleValue>("LambdaNS", new DoubleValue(1.0)));
63    }
64
65
66    protected override void Run(CancellationToken cancellationToken) {
67      var scaling = CreateScaling(Problem.ProblemData.Dataset, Problem.ProblemData.TrainingIndices.ToArray(), Problem.ProblemData.AllowedInputVariables.ToArray());
68
69      double[,] inputMatrix = ExtractData(Problem.ProblemData.Dataset, Problem.ProblemData.TrainingIndices, Problem.ProblemData.AllowedInputVariables.ToArray(), scaling);
70
71      double[,] target = ExtractData(Problem.ProblemData.Dataset, Problem.ProblemData.TrainingIndices, new string[] { Problem.ProblemData.TargetVariable });
72      inputMatrix = inputMatrix.HorzCat(target);
73
74      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
75        throw new NotSupportedException("Splines does not support NaN or infinity values in the input dataset.");
76
77
78      var inputVars = Problem.ProblemData.AllowedInputVariables.ToArray();
79      if (inputVars.Length > 3) throw new NotSupportedException();
80
81      alglib.rbfmodel model;
82      alglib.rbfreport rep;
83
84      alglib.rbfcreate(inputVars.Length, 1, out model);
85
86      alglib.rbfsetzeroterm(model);
87      var rbase = ((DoubleValue)Parameters["RBase"].ActualValue).Value;
88      var nlayers = ((IntValue)Parameters["NLayers"].ActualValue).Value;
89      var lambdans = ((DoubleValue)Parameters["LambdaNS"].ActualValue).Value;
90      alglib.rbfsetalgohierarchical(model, rbase, nlayers, lambdans);
91      alglib.rbfsetpoints(model, inputMatrix);
92      alglib.rbfbuildmodel(model, out rep);
93
94      Results.Add(new Result("TerminationType", new DoubleValue(rep.terminationtype)));
95      Results.Add(new Result("RMSE", new DoubleValue(rep.rmserror)));
96
97      Results.Add(new Result("Solution", new RegressionSolution(new RBFModel(model, Problem.ProblemData.TargetVariable, inputVars, scaling),
98        (IRegressionProblemData)Problem.ProblemData.Clone())));
99    }
100
101
102    private static ITransformation<double>[] CreateScaling(IDataset dataset, int[] rows, IReadOnlyCollection<string> allowedInputVariables) {
103      var trans = new ITransformation<double>[allowedInputVariables.Count];
104      int i = 0;
105      foreach (var variable in allowedInputVariables) {
106        var lin = new LinearTransformation(allowedInputVariables);
107        var max = dataset.GetDoubleValues(variable, rows).Max();
108        var min = dataset.GetDoubleValues(variable, rows).Min();
109        lin.Multiplier = 1.0 / (max - min);
110        lin.Addend = -min / (max - min);
111        trans[i] = lin;
112        i++;
113      }
114      return trans;
115    }
116
117    private static double[,] ExtractData(IDataset dataset, IEnumerable<int> rows, IReadOnlyCollection<string> allowedInputVariables, ITransformation<double>[] scaling = null) {
118      double[][] variables;
119      if (scaling != null) {
120        variables =
121          allowedInputVariables.Select((var, i) => scaling[i].Apply(dataset.GetDoubleValues(var, rows)).ToArray())
122            .ToArray();
123      } else {
124        variables =
125        allowedInputVariables.Select(var => dataset.GetDoubleValues(var, rows).ToArray()).ToArray();
126      }
127      int n = variables.First().Length;
128      var res = new double[n, variables.Length];
129      for (int r = 0; r < n; r++)
130        for (int c = 0; c < variables.Length; c++) {
131          res[r, c] = variables[c][r];
132        }
133      return res;
134    }
135  }
136
137}
138
139
140// UNFINISHED
141public class RBFModel : NamedItem, IRegressionModel {
142  private alglib.rbfmodel model;
143
144  public string TargetVariable { get; set; }
145
146  public IEnumerable<string> VariablesUsedForPrediction { get; private set; }
147  private ITransformation<double>[] scaling;
148
149  public event EventHandler TargetVariableChanged;
150
151  public RBFModel(RBFModel orig, Cloner cloner) : base(orig, cloner) {
152    this.TargetVariable = orig.TargetVariable;
153    this.VariablesUsedForPrediction = orig.VariablesUsedForPrediction.ToArray();
154    this.model = (alglib.rbfmodel)orig.model.make_copy();
155    this.scaling = orig.scaling.Select(s => cloner.Clone(s)).ToArray();
156  }
157  public RBFModel(alglib.rbfmodel model, string targetVar, string[] inputs, IEnumerable<ITransformation<double>> scaling) : base("RBFModel", "RBFModel") {
158    this.model = model;
159    this.TargetVariable = targetVar;
160    this.VariablesUsedForPrediction = inputs;
161    this.scaling = scaling.ToArray();
162  }
163
164  public override IDeepCloneable Clone(Cloner cloner) {
165    return new RBFModel(this, cloner);
166  }
167
168  public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
169    return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
170  }
171
172  public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
173    double[] x = new double[VariablesUsedForPrediction.Count()];
174    double[] y;
175    foreach (var r in rows) {
176      int c = 0;
177      foreach (var v in VariablesUsedForPrediction) {
178        x[c] = scaling[c].Apply(dataset.GetDoubleValue(v, r).ToEnumerable()).First(); // OUCH!
179        c++;
180      }
181      alglib.rbfcalc(model, x, out y);
182      yield return y[0];
183    }
184  }
185}
Note: See TracBrowser for help on using the repository browser.