Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3022-FastFunctionExtraction/FFX/FastFunctionExtraction.cs @ 17510

Last change on this file since 17510 was 17227, checked in by lleko, 5 years ago

#3022: Add implementation for FFX.

File size: 23.7 KB
RevLine 
[17218]1using System;
2using System.Threading;
[17219]3using System.Linq;
[17218]4using HeuristicLab.Common; // required for parameters collection
5using HeuristicLab.Core; // required for parameters collection
6using HeuristicLab.Data; // IntValue, ...
7using HeuristicLab.Encodings.BinaryVectorEncoding;
8using HeuristicLab.Optimization; // BasicAlgorithm
9using HeuristicLab.Parameters;
10using HeuristicLab.Problems.Binary;
11using HeuristicLab.Random; // MersenneTwister
12using HEAL.Attic;
[17219]13using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
14using HeuristicLab.Problems.DataAnalysis;
15using System.Collections.Generic;
16using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[17227]17using System.Collections;
18using System.Diagnostics;
19using HeuristicLab.Problems.DataAnalysis.Symbolic;
20using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
21using HeuristicLab.Analysis;
22using HeuristicLab.Collections;
[17218]23
[17227]24namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
[17218]25
[17227]26  [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
27  [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
28  [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
29  public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem> {
30
31    private static readonly double[] exponents = { 0.5, 1, 2 };
32    private static readonly OpCode[] nonlinFuncs = { OpCode.Absolute, OpCode.Log, OpCode.Sin, OpCode.Cos };
33
34    private static readonly BidirectionalDictionary<OpCode, string> OpCodeToString = new BidirectionalDictionary<OpCode, string> {
35        { OpCode.Log, "LOG" },
36        { OpCode.Absolute, "ABS"},
37        { OpCode.Sin, "SIN"},
38        { OpCode.Cos, "COS"},
39        { OpCode.Square, "SQR"},
40        { OpCode.SquareRoot, "SQRT"},
41        { OpCode.Cube, "CUBE"},
42        { OpCode.CubeRoot, "CUBEROOT"}
43    };
44
45    private const string ConsiderInteractionsParameterName = "Consider Interactions";
46    private const string ConsiderDenominationParameterName = "Consider Denomination";
47    private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
48    private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear functions";
49    private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
50    private const string PenaltyParameterName = "Penalty";
51    private const string LambdaParameterName = "Lambda";
52    private const string NonlinearFuncsParameterName = "Nonlinear Functions";
53
54    #region parameters
55    public IValueParameter<BoolValue> ConsiderInteractionsParameter
[17219]56    {
[17227]57      get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
58    }
59    public IValueParameter<BoolValue> ConsiderDenominationsParameter
60    {
61      get { return (IValueParameter<BoolValue>)Parameters[ConsiderDenominationParameterName]; }
62    }
63    public IValueParameter<BoolValue> ConsiderExponentiationsParameter
64    {
65      get { return (IValueParameter<BoolValue>)Parameters[ConsiderExponentiationParameterName]; }
66    }
67    public IValueParameter<BoolValue> ConsiderNonlinearFuncsParameter
68    {
69      get { return (IValueParameter<BoolValue>)Parameters[ConsiderNonlinearFuncsParameterName]; }
70    }
71    public IValueParameter<BoolValue> ConsiderHingeFuncsParameter
72    {
73      get { return (IValueParameter<BoolValue>)Parameters[ConsiderHingeFuncsParameterName]; }
74    }
75    public IValueParameter<DoubleValue> PenaltyParameter
76    {
77      get { return (IValueParameter<DoubleValue>)Parameters[PenaltyParameterName]; }
78    }
79    public IValueParameter<DoubleValue> LambdaParameter
80    {
81      get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
82    }
83    public IValueParameter<CheckedItemCollection<EnumValue<OpCode>>> NonlinearFuncsParameter
84    {
85      get { return (IValueParameter<CheckedItemCollection<EnumValue<OpCode>>>)Parameters[NonlinearFuncsParameterName]; }
86    }
87    #endregion
[17218]88
[17227]89    #region properties
90    public bool ConsiderInteractions
91    {
92      get { return ConsiderInteractionsParameter.Value.Value; }
93      set { ConsiderInteractionsParameter.Value.Value = value; }
94    }
95    public bool ConsiderDenominations
96    {
97      get { return ConsiderDenominationsParameter.Value.Value; }
98      set { ConsiderDenominationsParameter.Value.Value = value; }
99    }
100    public bool ConsiderExponentiations
101    {
102      get { return ConsiderExponentiationsParameter.Value.Value; }
103      set { ConsiderExponentiationsParameter.Value.Value = value; }
104    }
105    public bool ConsiderNonlinearFuncs
106    {
107      get { return ConsiderNonlinearFuncsParameter.Value.Value; }
108      set { ConsiderNonlinearFuncsParameter.Value.Value = value; }
109    }
110    public bool ConsiderHingeFuncs
111    {
112      get { return ConsiderHingeFuncsParameter.Value.Value; }
113      set { ConsiderHingeFuncsParameter.Value.Value = value; }
114    }
115    public double Penalty
116    {
117      get { return PenaltyParameter.Value.Value; }
118      set { PenaltyParameter.Value.Value = value; }
119    }
120    public DoubleValue Lambda
121    {
122      get { return LambdaParameter.Value; }
123      set { LambdaParameter.Value = value; }
124    }
125    public CheckedItemCollection<EnumValue<OpCode>> NonlinearFuncs
126    {
127      get { return NonlinearFuncsParameter.Value; }
128      set { NonlinearFuncsParameter.Value = value; }
129    }
130    #endregion
[17218]131
132
[17227]133    [StorableConstructor]
134    private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
135    public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) {
136    }
137    public FastFunctionExtraction() : base() {
138      var items = new CheckedItemCollection<EnumValue<OpCode>>();
139      foreach (var op in nonlinFuncs) {
140        items.Add(new EnumValue<OpCode>(op));
141      }
142      base.Problem = new RegressionProblem();
143      Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true)));
144      Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true)));
145      Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true)));
146      Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
147      Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true)));
148      Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.9)));
149      Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
150      Parameters.Add(new ValueParameter<CheckedItemCollection<EnumValue<OpCode>>>(NonlinearFuncsParameterName, "What nonlinear functions the models should be able to include.", items));
151    }
152
153    [StorableHook(HookType.AfterDeserialization)]
154    private void AfterDeserialization() { }
155
156    public override IDeepCloneable Clone(Cloner cloner) {
157      return new FastFunctionExtraction(this, cloner);
158    }
159
160    public override Type ProblemType { get { return typeof(RegressionProblem); } }
161    public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
162
163    public override bool SupportsPause { get { return true; } }
164
165    protected override void Run(CancellationToken cancellationToken) {
166      var basisFunctions = createBasisFunctions(Problem.ProblemData);
167      Results.Add(new Result("Basis Functions", "A Dataset consisting of the generated Basis Functions from FFX Alg Step 1.", createProblemData(Problem.ProblemData, basisFunctions)));
168
169      // add denominator bases to the already existing basis functions
170      if (ConsiderDenominations) basisFunctions = basisFunctions.Concat(createDenominatorBases(Problem.ProblemData, basisFunctions)).ToList();
171
172      // create either path of solutions, or one solution for given lambda
173      LearnModels(Problem.ProblemData, basisFunctions);
174    }
175
176    private List<BasisFunction> createBasisFunctions(IRegressionProblemData problemData) {
177      var basisFunctions = createUnivariateBases(problemData);
178      basisFunctions = basisFunctions.Concat(createMultivariateBases(basisFunctions)).ToList();
179      return basisFunctions;
180    }
181
182    private List<BasisFunction> createUnivariateBases(IRegressionProblemData problemData) {
183      var B1 = new List<BasisFunction>();
184      var inputVariables = problemData.AllowedInputVariables;
185      var validExponents = ConsiderExponentiations ? exponents : new double[] { 1 };
186      var validFuncs = NonlinearFuncs.CheckedItems.Select(val => val.Value);
187      // TODO: add Hinge functions
188
189      foreach (var variableName in inputVariables) {
190        foreach (var exp in validExponents) {
191          var data = problemData.Dataset.GetDoubleValues(variableName).Select(x => Math.Pow(x, exp)).ToArray();
192          if (!ok(data)) continue;
193          var name = expToString(exp, variableName);
194          B1.Add(new BasisFunction(name, data, false));
195          foreach (OpCode _op in validFuncs) {
196            var inner_data = data.Select(x => eval(_op, x)).ToArray();
197            if (!ok(inner_data)) continue;
198            var inner_name = OpCodeToString.GetByFirst(_op) + "(" + name + ")";
199            B1.Add(new BasisFunction(inner_name, inner_data, true));
200          }
[17219]201        }
[17227]202      }
203      return B1;
204    }
[17218]205
[17227]206    private List<BasisFunction> createMultivariateBases(List<BasisFunction> B1) {
207      if (!ConsiderInteractions) return B1;
208      var B2 = new List<BasisFunction>();
209      for (int i = 0; i < B1.Count(); i++) {
210        var b_i = B1.ElementAt(i);
211        for (int j = 0; j < i; j++) {
212          var b_j = B1.ElementAt(j);
213          if (b_j.IsOperator) continue; // disallow op() * op()
214          var b_inter = b_i * b_j;
215          B2.Add(b_inter);
[17219]216        }
[17227]217      }
[17218]218
[17227]219      return B2;
220      // return union of B1 and B2
221    }
[17218]222
[17227]223    // creates 1 denominator basis function for each corresponding basis function from basisFunctions
224    private IEnumerable<BasisFunction> createDenominatorBases(IRegressionProblemData problemData, IEnumerable<BasisFunction> basisFunctions) {
225      var y = new BasisFunction(problemData.TargetVariable, problemData.TargetVariableValues.ToArray(), false);
226      var denomBasisFuncs = new List<BasisFunction>();
227      foreach (var func in basisFunctions) {
228        var denomFunc = y * func;
229        denomBasisFuncs.Add(denomFunc);
230      }
231      return denomBasisFuncs;
232    }
[17218]233
[17227]234    private static string expToString(double exponent, string varname) {
235      if (exponent.IsAlmost(1)) return varname;
236      if (exponent.IsAlmost(1 / 2)) return OpCodeToString.GetByFirst(OpCode.SquareRoot) + "(" + varname + ")";
237      if (exponent.IsAlmost(1 / 3)) return OpCodeToString.GetByFirst(OpCode.CubeRoot) + "(" + varname + ")";
238      if (exponent.IsAlmost(2)) return OpCodeToString.GetByFirst(OpCode.Square) + "(" + varname + ")";
239      if (exponent.IsAlmost(3)) return OpCodeToString.GetByFirst(OpCode.Cube) + "(" + varname + ")";
240      else return varname + " ^ " + exponent;
241    }
[17218]242
[17227]243    public static double eval(OpCode op, double x) {
244      switch (op) {
245        case OpCode.Absolute:
246          return Math.Abs(x);
247        case OpCode.Log:
248          return Math.Log10(x);
249        case OpCode.Sin:
250          return Math.Sin(x);
251        case OpCode.Cos:
252          return Math.Cos(x);
253        default:
254          throw new Exception("Unimplemented operator: " + op.ToString());
255      }
256    }
[17218]257
[17227]258    private void PathwiseLearning(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {
259      ElasticNetLinearRegression reg = new ElasticNetLinearRegression();
260      reg.Lambda = Lambda;
261      reg.Penality = Penalty;
262      reg.Problem.ProblemData = createProblemData(problemData, basisFunctions);
263      reg.Start();
264      Results.AddRange(reg.Results);
265    }
[17218]266
[17227]267    private void LearnModels(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {
268      double[] lambda;
269      double[] trainNMSE;
270      double[] testNMSE;
271      double[,] coeff;
272      double[] intercept;
273      int numNominatorBases = ConsiderDenominations ? basisFunctions.Count / 2 : basisFunctions.Count;
[17218]274
[17227]275      // wraps the list of basis functions in a dataset, so that it can be passed on to the ElNet function
276      var X_b = createProblemData(problemData, basisFunctions);
[17218]277
[17227]278      ElasticNetLinearRegression.RunElasticNetLinearRegression(X_b, Penalty, out lambda, out trainNMSE, out testNMSE, out coeff, out intercept);
[17218]279
[17227]280      var errorTable = NMSEGraph(coeff, lambda, trainNMSE, testNMSE);
281      Results.Add(new Result(errorTable.Name, errorTable.Description, errorTable));
282      var coeffTable = CoefficientGraph(coeff, lambda, X_b.AllowedInputVariables, X_b.Dataset);
283      Results.Add(new Result(coeffTable.Name, coeffTable.Description, coeffTable));
284
285      ItemCollection<IResult> models = new ItemCollection<IResult>();
286      for (int modelIdx = 0; modelIdx < coeff.GetUpperBound(0); modelIdx++) {
287        var tree = Tree(basisFunctions, GetRow(coeff, modelIdx), intercept[modelIdx]);
288        ISymbolicRegressionModel m = new SymbolicRegressionModel(Problem.ProblemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
289        ISymbolicRegressionSolution s = new SymbolicRegressionSolution(m, Problem.ProblemData);
290        models.Add(new Result("Solution " + modelIdx, s));
291      }
292
293      Results.Add(new Result("Models", "The model path returned by the Elastic Net Regression (not only the pareto-optimal subset). ", models));
294    }
295
296    private static IndexedDataTable<double> CoefficientGraph(double[,] coeff, double[] lambda, IEnumerable<string> allowedVars, IDataset ds) {
297      var coeffTable = new IndexedDataTable<double>("Coefficients", "The paths of standarized coefficient values over different lambda values");
298      coeffTable.VisualProperties.YAxisMaximumAuto = false;
299      coeffTable.VisualProperties.YAxisMinimumAuto = false;
300      coeffTable.VisualProperties.XAxisMaximumAuto = false;
301      coeffTable.VisualProperties.XAxisMinimumAuto = false;
302
303      coeffTable.VisualProperties.XAxisLogScale = true;
304      coeffTable.VisualProperties.XAxisTitle = "Lambda";
305      coeffTable.VisualProperties.YAxisTitle = "Coefficients";
306      coeffTable.VisualProperties.SecondYAxisTitle = "Number of variables";
307
308      var nLambdas = lambda.Length;
309      var nCoeff = coeff.GetLength(1);
310      var dataRows = new IndexedDataRow<double>[nCoeff];
311      var numNonZeroCoeffs = new int[nLambdas];
312
313      var doubleVariables = allowedVars.Where(ds.VariableHasType<double>);
314      var factorVariableNames = allowedVars.Where(ds.VariableHasType<string>);
315      var factorVariablesAndValues = ds.GetFactorVariableValues(factorVariableNames, Enumerable.Range(0, ds.Rows)); // must consider all factor values (in train and test set)
316      {
317        int i = 0;
318        foreach (var factorVariableAndValues in factorVariablesAndValues) {
319          foreach (var factorValue in factorVariableAndValues.Value) {
320            double sigma = ds.GetStringValues(factorVariableAndValues.Key)
321              .Select(s => s == factorValue ? 1.0 : 0.0)
322              .StandardDeviation(); // calc std dev of binary indicator
323            var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();
324            dataRows[i] = new IndexedDataRow<double>(factorVariableAndValues.Key + "=" + factorValue, factorVariableAndValues.Key + "=" + factorValue, path);
325            i++;
326          }
[17219]327        }
[17218]328
[17227]329        foreach (var doubleVariable in doubleVariables) {
330          double sigma = ds.GetDoubleValues(doubleVariable).StandardDeviation();
331          var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();
332          dataRows[i] = new IndexedDataRow<double>(doubleVariable, doubleVariable, path);
333          i++;
[17219]334        }
[17227]335        // add to coeffTable by total weight (larger area under the curve => more important);
336        foreach (var r in dataRows.OrderByDescending(r => r.Values.Select(t => t.Item2).Sum(x => Math.Abs(x)))) {
337          coeffTable.Rows.Add(r);
338        }
339      }
[17218]340
[17227]341      for (int i = 0; i < coeff.GetLength(0); i++) {
342        for (int j = 0; j < coeff.GetLength(1); j++) {
343          if (!coeff[i, j].IsAlmost(0.0)) {
344            numNonZeroCoeffs[i]++;
345          }
[17219]346        }
[17227]347      }
348      if (lambda.Length > 2) {
349        coeffTable.VisualProperties.XAxisMinimumFixedValue = Math.Pow(10, Math.Floor(Math.Log10(lambda.Last())));
350        coeffTable.VisualProperties.XAxisMaximumFixedValue = Math.Pow(10, Math.Ceiling(Math.Log10(lambda.Skip(1).First())));
351      }
352      coeffTable.Rows.Add(new IndexedDataRow<double>("Number of variables", "The number of non-zero coefficients for each step in the path", lambda.Zip(numNonZeroCoeffs, (l, v) => Tuple.Create(l, (double)v))));
353      coeffTable.Rows["Number of variables"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
354      coeffTable.Rows["Number of variables"].VisualProperties.SecondYAxis = true;
[17218]355
[17227]356      return coeffTable;
357    }
[17219]358
[17227]359    private static IndexedDataTable<double> NMSEGraph(double[,] coeff, double[] lambda, double[] trainNMSE, double[] testNMSE) {
360      var errorTable = new IndexedDataTable<double>("NMSE", "Path of NMSE values over different lambda values");
361      var numNonZeroCoeffs = new int[lambda.Length];
362      errorTable.VisualProperties.YAxisMaximumAuto = false;
363      errorTable.VisualProperties.YAxisMinimumAuto = false;
364      errorTable.VisualProperties.XAxisMaximumAuto = false;
365      errorTable.VisualProperties.XAxisMinimumAuto = false;
[17219]366
[17227]367      for (int i = 0; i < coeff.GetLength(0); i++) {
368        for (int j = 0; j < coeff.GetLength(1); j++) {
369          if (!coeff[i, j].IsAlmost(0.0)) {
370            numNonZeroCoeffs[i]++;
371          }
[17218]372        }
[17227]373      }
[17218]374
[17227]375      errorTable.VisualProperties.YAxisMinimumFixedValue = 0;
376      errorTable.VisualProperties.YAxisMaximumFixedValue = 1.0;
377      errorTable.VisualProperties.XAxisLogScale = true;
378      errorTable.VisualProperties.XAxisTitle = "Lambda";
379      errorTable.VisualProperties.YAxisTitle = "Normalized mean of squared errors (NMSE)";
380      errorTable.VisualProperties.SecondYAxisTitle = "Number of variables";
381      errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (train)", "Path of NMSE values over different lambda values", lambda.Zip(trainNMSE, (l, v) => Tuple.Create(l, v))));
382      errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (test)", "Path of NMSE values over different lambda values", lambda.Zip(testNMSE, (l, v) => Tuple.Create(l, v))));
383      errorTable.Rows.Add(new IndexedDataRow<double>("Number of variables", "The number of non-zero coefficients for each step in the path", lambda.Zip(numNonZeroCoeffs, (l, v) => Tuple.Create(l, (double)v))));
384      if (lambda.Length > 2) {
385        errorTable.VisualProperties.XAxisMinimumFixedValue = Math.Pow(10, Math.Floor(Math.Log10(lambda.Last())));
386        errorTable.VisualProperties.XAxisMaximumFixedValue = Math.Pow(10, Math.Ceiling(Math.Log10(lambda.Skip(1).First())));
387      }
388      errorTable.Rows["NMSE (train)"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
389      errorTable.Rows["NMSE (test)"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
390      errorTable.Rows["Number of variables"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
391      errorTable.Rows["Number of variables"].VisualProperties.SecondYAxis = true;
[17219]392
[17227]393      return errorTable;
394    }
[17219]395
[17227]396    private ISymbolicExpressionTree Tree(List<BasisFunction> basisFunctions, double[] coeffs, double offset) {
397      Debug.Assert(basisFunctions.Count() == coeffs.Length);
398      //SymbolicExpressionTree
399      var numNumeratorFuncs = ConsiderDenominations ? basisFunctions.Count() / 2 : basisFunctions.Count();
400      var numeratorBasisFuncs = basisFunctions.Take(numNumeratorFuncs);
[17219]401
[17227]402      // returns true if there exists at least 1 coefficient value in the model that is part of the denominator
403      // (i.e. if there exists at least 1 non-zero value in the second half of the array)
404      bool withDenom(double[] coeffarr) => coeffarr.Take(coeffarr.Length / 2).ToArray().Any(val => !val.IsAlmost(0.0));
405      string model = "(" + offset.ToString();
406      for (int i = 0; i < numNumeratorFuncs; i++) {
407        var func = basisFunctions.ElementAt(i);
408        // only generate nodes for relevant basis functions (those with non-zero coeffs)
409        if (!coeffs[i].IsAlmost(0.0))
410          model += " + (" + coeffs[i] + ") * " + func.Var;
411      }
412      if (ConsiderDenominations && withDenom(coeffs)) {
413        model += ") / (1";
414        for (int i = numNumeratorFuncs; i < basisFunctions.Count(); i++) {
415          var func = basisFunctions.ElementAt(i);
416          // only generate nodes for relevant basis functions (those with non-zero coeffs)
417          if (!coeffs[i].IsAlmost(0.0))
418            model += " + (" + coeffs[i] + ") * " + func.Var.Substring(4);
[17219]419        }
[17227]420      }
421      model += ")";
422      InfixExpressionParser p = new InfixExpressionParser();
423      return p.Parse(model);
424    }
[17219]425
[17227]426    // wraps the list of basis functions into an IRegressionProblemData object
427    private static IRegressionProblemData createProblemData(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {
428      List<string> variableNames = new List<string>();
429      List<IList> variableVals = new List<IList>();
430      foreach (var basisFunc in basisFunctions) {
431        variableNames.Add(basisFunc.Var);
432        // basisFunctions already contains the calculated values of the corresponding basis function, so you can just take that value
433        variableVals.Add(new List<double>(basisFunc.Val));
434      }
435      var matrix = new ModifiableDataset(variableNames, variableVals);
[17219]436
[17227]437      // add the unmodified target variable to the matrix
438      matrix.AddVariable(problemData.TargetVariable, problemData.TargetVariableValues.ToList());
439      var allowedInputVars = matrix.VariableNames.Where(x => !x.Equals(problemData.TargetVariable));
440      IRegressionProblemData rpd = new RegressionProblemData(matrix, allowedInputVars, problemData.TargetVariable);
441      rpd.TrainingPartition.Start = problemData.TrainingPartition.Start;
442      rpd.TrainingPartition.End = problemData.TrainingPartition.End;
443      rpd.TestPartition.Start = problemData.TestPartition.Start;
444      rpd.TestPartition.End = problemData.TestPartition.End;
445      return rpd;
[17218]446    }
[17227]447
448    private static bool ok(double[] data) => data.All(x => !double.IsNaN(x) && !double.IsInfinity(x));
449
450    // helper function which returns a row of a 2D array
451    private static T[] GetRow<T>(T[,] matrix, int row) {
452      var columns = matrix.GetLength(1);
453      var array = new T[columns];
454      for (int i = 0; i < columns; ++i)
455        array[i] = matrix[row, i];
456      return array;
457    }
458
459    // returns all models with pareto-optimal tradeoff between error and complexity
460    private static List<IRegressionSolution> nondominatedFilter(double[][] coefficientVectorSet, BasisFunction[] basisFunctions) {
461      return null;
462    }
463  }
[17219]464}
Note: See TracBrowser for help on using the repository browser.