using HeuristicLab.Algorithms.DataAnalysis.Glmnet; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; using System; using System.Collections.Generic; using System.Globalization; using System.Linq; namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction { internal class FFXModel { public FFXModel(double intercept, IEnumerable<(double coeff, IBasisFunction function)> basisFunctions) { Intercept = intercept; BasisFunctions = basisFunctions ?? throw new ArgumentNullException(nameof(basisFunctions)); } public double Intercept { get; set; } public IEnumerable<(double coeff, IBasisFunction function)> BasisFunctions { get; set; } public IEnumerable<(double coeff, IBasisFunction function)> NominatorFunctions => BasisFunctions.Where(bf => bf.function.IsDenominator); public IEnumerable<(double coeff, IBasisFunction function)> DenominatorFunctions => BasisFunctions.Where(bf => !bf.function.IsDenominator); public int NumNumeratorFunctions => NominatorFunctions != null ? NominatorFunctions.Count() : 0; public int NumDenominatorFunctions => DenominatorFunctions != null ? DenominatorFunctions.Count() : 0; public int NumBases => NumNumeratorFunctions + NumDenominatorFunctions; public IEnumerable Bases => NominatorFunctions.Select(tuple => tuple.function) .Concat(DenominatorFunctions.Select(tuple => tuple.function)); public int Complexity { get { // We have a leading constant, then for each base we have a coefficient, // a multiply, and a plus, plus the complexity of the base itself. int numerator_complexity = 1 + NominatorFunctions.Select(bf => 3 + bf.function.Complexity).Sum(); if (NumDenominatorFunctions == 0) { return numerator_complexity; } else { int denominator_complexity = 1 + DenominatorFunctions.Select(bf => 3 + bf.function.Complexity).Sum(); return 1 + numerator_complexity + denominator_complexity; } } } public override string ToString() { var culture = new CultureInfo("en-US"); var result = Intercept.ToString(culture); foreach (var (coeff, function) in NominatorFunctions) { var str = $" + ({Math.Round(coeff, 4).ToString(culture)}) * {function}"; result += str; } if (NumDenominatorFunctions == 0) return result; result = "(" + result + ") / (1.0"; foreach (var (coeff, function) in DenominatorFunctions) { var str = $" + ({coeff.ToString(culture)}) * {function}"; result += str; } result += ")"; return result; } public double[] Simulate(IRegressionProblemData data, IEnumerable rows) { var symbolicRegressionModel = ToSymbolicRegressionModel(data.TargetVariable); return symbolicRegressionModel.GetEstimatedValues(data.Dataset, rows).ToArray(); } public ISymbolicRegressionModel ToSymbolicRegressionModel(string targetVariable) { var str = this.ToString(); var tree = new InfixExpressionParser().Parse(str); return new SymbolicRegressionModel(targetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter()); } public void OptimizeCoefficients(IRegressionProblemData data) { var elnetData = BFUtils.PrepareData(data, this.BasisFunctions.Select(tuple => tuple.function)); var oldCoeff = this.BasisFunctions.Select(tuple => tuple.coeff).ToArray(); var oldIntercept = this.Intercept; var newCoeff = ElasticNetLinearRegression.CalculateModelCoefficients(elnetData, 0, 0, out var _, out var _); var newIntercept = newCoeff.Last(); newCoeff = newCoeff.Take(newCoeff.Length - 1).ToArray(); var bfArray = this.BasisFunctions.ToArray(); for (int i = 0; i < bfArray.Length; i++) { bfArray[i] = (newCoeff[i], bfArray[i].function); } this.BasisFunctions = bfArray; this.Intercept = newIntercept; } } }