[17737] | 1 | using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
|
---|
| 2 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 3 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
| 4 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
| 5 | using System;
|
---|
| 6 | using System.Collections.Generic;
|
---|
| 7 | using System.Globalization;
|
---|
| 8 | using System.Linq;
|
---|
| 9 |
|
---|
| 10 | namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
|
---|
| 11 | internal class FFXModel {
|
---|
| 12 | public FFXModel(double intercept, IEnumerable<(double coeff, IBasisFunction function)> basisFunctions) {
|
---|
| 13 | Intercept = intercept;
|
---|
| 14 | BasisFunctions = basisFunctions ?? throw new ArgumentNullException(nameof(basisFunctions));
|
---|
| 15 | }
|
---|
| 16 |
|
---|
| 17 | public double Intercept { get; set; }
|
---|
| 18 | public IEnumerable<(double coeff, IBasisFunction function)> BasisFunctions { get; set; }
|
---|
| 19 | public IEnumerable<(double coeff, IBasisFunction function)> NominatorFunctions =>
|
---|
[17779] | 20 | BasisFunctions.Where(bf => bf.function.IsDenominator);
|
---|
[17737] | 21 | public IEnumerable<(double coeff, IBasisFunction function)> DenominatorFunctions =>
|
---|
[17779] | 22 | BasisFunctions.Where(bf => !bf.function.IsDenominator);
|
---|
[17737] | 23 |
|
---|
| 24 | public int NumNumeratorFunctions => NominatorFunctions != null ? NominatorFunctions.Count() : 0;
|
---|
| 25 | public int NumDenominatorFunctions => DenominatorFunctions != null ? DenominatorFunctions.Count() : 0;
|
---|
| 26 | public int NumBases => NumNumeratorFunctions + NumDenominatorFunctions;
|
---|
| 27 | public IEnumerable<IBasisFunction> Bases =>
|
---|
| 28 | NominatorFunctions.Select(tuple => tuple.function)
|
---|
| 29 | .Concat(DenominatorFunctions.Select(tuple => tuple.function));
|
---|
| 30 | public int Complexity {
|
---|
| 31 | get {
|
---|
| 32 | // We have a leading constant, then for each base we have a coefficient,
|
---|
| 33 | // a multiply, and a plus, plus the complexity of the base itself.
|
---|
| 34 | int numerator_complexity = 1 + NominatorFunctions.Select(bf => 3 + bf.function.Complexity).Sum();
|
---|
| 35 | if (NumDenominatorFunctions == 0) {
|
---|
| 36 | return numerator_complexity;
|
---|
| 37 | } else {
|
---|
| 38 | int denominator_complexity = 1 + DenominatorFunctions.Select(bf => 3 + bf.function.Complexity).Sum();
|
---|
| 39 | return 1 + numerator_complexity + denominator_complexity;
|
---|
| 40 | }
|
---|
| 41 | }
|
---|
| 42 | }
|
---|
| 43 |
|
---|
| 44 | public override string ToString() {
|
---|
| 45 | var culture = new CultureInfo("en-US");
|
---|
| 46 | var result = Intercept.ToString(culture);
|
---|
| 47 | foreach (var (coeff, function) in NominatorFunctions) {
|
---|
| 48 | var str = $" + ({Math.Round(coeff, 4).ToString(culture)}) * {function}";
|
---|
| 49 | result += str;
|
---|
| 50 | }
|
---|
| 51 | if (NumDenominatorFunctions == 0) return result;
|
---|
| 52 |
|
---|
| 53 | result = "(" + result + ") / (1.0";
|
---|
| 54 | foreach (var (coeff, function) in DenominatorFunctions) {
|
---|
| 55 | var str = $" + ({coeff.ToString(culture)}) * {function}";
|
---|
| 56 | result += str;
|
---|
| 57 | }
|
---|
| 58 | result += ")";
|
---|
| 59 | return result;
|
---|
| 60 | }
|
---|
| 61 |
|
---|
| 62 | public double[] Simulate(IRegressionProblemData data, IEnumerable<int> rows) {
|
---|
| 63 | var symbolicRegressionModel = ToSymbolicRegressionModel(data.TargetVariable);
|
---|
| 64 | return symbolicRegressionModel.GetEstimatedValues(data.Dataset, rows).ToArray();
|
---|
| 65 | }
|
---|
| 66 |
|
---|
| 67 | public ISymbolicRegressionModel ToSymbolicRegressionModel(string targetVariable) {
|
---|
| 68 | var str = this.ToString();
|
---|
| 69 | var tree = new InfixExpressionParser().Parse(str);
|
---|
| 70 | return new SymbolicRegressionModel(targetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
|
---|
| 71 | }
|
---|
| 72 |
|
---|
| 73 | public void OptimizeCoefficients(IRegressionProblemData data) {
|
---|
| 74 | var elnetData = BFUtils.PrepareData(data, this.BasisFunctions.Select(tuple => tuple.function));
|
---|
| 75 | var oldCoeff = this.BasisFunctions.Select(tuple => tuple.coeff).ToArray();
|
---|
| 76 | var oldIntercept = this.Intercept;
|
---|
| 77 | var newCoeff = ElasticNetLinearRegression.CalculateModelCoefficients(elnetData, 0, 0, out var _, out var _);
|
---|
| 78 | var newIntercept = newCoeff.Last();
|
---|
| 79 | newCoeff = newCoeff.Take(newCoeff.Length - 1).ToArray();
|
---|
| 80 |
|
---|
| 81 | var bfArray = this.BasisFunctions.ToArray();
|
---|
| 82 | for (int i = 0; i < bfArray.Length; i++) {
|
---|
| 83 | bfArray[i] = (newCoeff[i], bfArray[i].function);
|
---|
| 84 | }
|
---|
| 85 |
|
---|
| 86 | this.BasisFunctions = bfArray;
|
---|
| 87 | this.Intercept = newIntercept;
|
---|
| 88 | }
|
---|
| 89 | }
|
---|
| 90 | }
|
---|