Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3022-FastFunctionExtraction/FFX/FFXModel.cs @ 17779

Last change on this file since 17779 was 17779, checked in by gkronber, 3 years ago

#3022: made a few changes while reviewing the code.

File size: 4.5 KB
Line 
1using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
2using HeuristicLab.Problems.DataAnalysis;
3using HeuristicLab.Problems.DataAnalysis.Symbolic;
4using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
5using System;
6using System.Collections.Generic;
7using System.Globalization;
8using System.Linq;
9
10namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
11    internal class FFXModel {
12        public FFXModel(double intercept, IEnumerable<(double coeff, IBasisFunction function)> basisFunctions) {
13            Intercept = intercept;
14            BasisFunctions = basisFunctions ?? throw new ArgumentNullException(nameof(basisFunctions));
15        }
16
17        public double Intercept { get; set; }
18        public IEnumerable<(double coeff, IBasisFunction function)> BasisFunctions { get; set; }
19        public IEnumerable<(double coeff, IBasisFunction function)> NominatorFunctions =>
20            BasisFunctions.Where(bf => bf.function.IsDenominator);
21        public IEnumerable<(double coeff, IBasisFunction function)> DenominatorFunctions =>
22            BasisFunctions.Where(bf => !bf.function.IsDenominator);
23
24        public int NumNumeratorFunctions => NominatorFunctions != null ? NominatorFunctions.Count() : 0;
25        public int NumDenominatorFunctions => DenominatorFunctions != null ? DenominatorFunctions.Count() : 0;
26        public int NumBases => NumNumeratorFunctions + NumDenominatorFunctions;
27        public IEnumerable<IBasisFunction> Bases =>
28            NominatorFunctions.Select(tuple => tuple.function)
29            .Concat(DenominatorFunctions.Select(tuple => tuple.function));
30        public int Complexity {
31            get {
32                // We have a leading constant, then for each base we have a coefficient,
33                // a multiply, and a plus, plus the complexity of the base itself.
34                int numerator_complexity = 1 + NominatorFunctions.Select(bf => 3 + bf.function.Complexity).Sum();
35                if (NumDenominatorFunctions == 0) {
36                    return numerator_complexity;
37                } else {
38                    int denominator_complexity = 1 + DenominatorFunctions.Select(bf => 3 + bf.function.Complexity).Sum();
39                    return 1 + numerator_complexity + denominator_complexity;
40                }
41            }
42        }
43
44        public override string ToString() {
45            var culture = new CultureInfo("en-US");
46            var result = Intercept.ToString(culture);
47            foreach (var (coeff, function) in NominatorFunctions) {
48                var str = $" + ({Math.Round(coeff, 4).ToString(culture)}) * {function}";
49                result += str;
50            }
51            if (NumDenominatorFunctions == 0) return result;
52
53            result = "(" + result + ") / (1.0";
54            foreach (var (coeff, function) in DenominatorFunctions) {
55                var str = $" + ({coeff.ToString(culture)}) * {function}";
56                result += str;
57            }
58            result += ")";
59            return result;
60        }
61
62        public double[] Simulate(IRegressionProblemData data, IEnumerable<int> rows) {
63            var symbolicRegressionModel = ToSymbolicRegressionModel(data.TargetVariable);
64            return symbolicRegressionModel.GetEstimatedValues(data.Dataset, rows).ToArray();
65        }
66
67        public ISymbolicRegressionModel ToSymbolicRegressionModel(string targetVariable) {
68            var str = this.ToString();
69            var tree = new InfixExpressionParser().Parse(str);
70            return new SymbolicRegressionModel(targetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
71        }
72
73        public void OptimizeCoefficients(IRegressionProblemData data) {
74            var elnetData = BFUtils.PrepareData(data, this.BasisFunctions.Select(tuple => tuple.function));
75            var oldCoeff = this.BasisFunctions.Select(tuple => tuple.coeff).ToArray();
76            var oldIntercept = this.Intercept;
77            var newCoeff = ElasticNetLinearRegression.CalculateModelCoefficients(elnetData, 0, 0, out var _, out var _);
78            var newIntercept = newCoeff.Last();
79            newCoeff = newCoeff.Take(newCoeff.Length - 1).ToArray();
80
81            var bfArray = this.BasisFunctions.ToArray();
82            for (int i = 0; i < bfArray.Length; i++) {
83                bfArray[i] = (newCoeff[i], bfArray[i].function);
84            }
85
86            this.BasisFunctions = bfArray;
87            this.Intercept = newIntercept;
88        }
89    }
90}
Note: See TracBrowser for help on using the repository browser.