Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/29/20 09:23:06 (4 years ago)
Author:
lleko
Message:

#3022 implement ffx

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/3022-FastFunctionExtraction/FFX/FastFunctionExtraction.cs

    r17227 r17737  
    22using System.Threading;
    33using System.Linq;
    4 using HeuristicLab.Common; // required for parameters collection
    5 using HeuristicLab.Core; // required for parameters collection
    6 using HeuristicLab.Data; // IntValue, ...
    7 using HeuristicLab.Encodings.BinaryVectorEncoding;
    8 using HeuristicLab.Optimization; // BasicAlgorithm
     4using HeuristicLab.Common;
     5using HeuristicLab.Core;
     6using HeuristicLab.Data;
     7using HeuristicLab.Optimization;
    98using HeuristicLab.Parameters;
    10 using HeuristicLab.Problems.Binary;
    11 using HeuristicLab.Random; // MersenneTwister
    129using HEAL.Attic;
    1310using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
    1411using HeuristicLab.Problems.DataAnalysis;
    1512using System.Collections.Generic;
    16 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    17 using System.Collections;
    18 using System.Diagnostics;
    19 using HeuristicLab.Problems.DataAnalysis.Symbolic;
    2013using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    21 using HeuristicLab.Analysis;
    22 using HeuristicLab.Collections;
    2314
    2415namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
    2516
    26   [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
    27   [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
    28   [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
    29   public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem> {
    30 
    31     private static readonly double[] exponents = { 0.5, 1, 2 };
    32     private static readonly OpCode[] nonlinFuncs = { OpCode.Absolute, OpCode.Log, OpCode.Sin, OpCode.Cos };
    33 
    34     private static readonly BidirectionalDictionary<OpCode, string> OpCodeToString = new BidirectionalDictionary<OpCode, string> {
    35         { OpCode.Log, "LOG" },
    36         { OpCode.Absolute, "ABS"},
    37         { OpCode.Sin, "SIN"},
    38         { OpCode.Cos, "COS"},
    39         { OpCode.Square, "SQR"},
    40         { OpCode.SquareRoot, "SQRT"},
    41         { OpCode.Cube, "CUBE"},
    42         { OpCode.CubeRoot, "CUBEROOT"}
    43     };
    44 
    45     private const string ConsiderInteractionsParameterName = "Consider Interactions";
    46     private const string ConsiderDenominationParameterName = "Consider Denomination";
    47     private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
    48     private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear functions";
    49     private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
    50     private const string PenaltyParameterName = "Penalty";
    51     private const string LambdaParameterName = "Lambda";
    52     private const string NonlinearFuncsParameterName = "Nonlinear Functions";
    53 
    54     #region parameters
    55     public IValueParameter<BoolValue> ConsiderInteractionsParameter
    56     {
    57       get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
     17    [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
     18    [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
     19    [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
     20    public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem> {
     21
     22        #region constants
     23        private static readonly HashSet<double> exponents = new HashSet<double> { -1.0, -0.5, +0.5, +1.0 };
     24        private static readonly HashSet<NonlinearOperator> nonlinFuncs = new HashSet<NonlinearOperator> { NonlinearOperator.Abs, NonlinearOperator.Log, NonlinearOperator.None };
     25        private static readonly double minHingeThr = 0.2;
     26        private static readonly double maxHingeThr = 0.8;
     27        private static readonly int numHingeThrs = 5;
     28
     29        private const string ConsiderInteractionsParameterName = "Consider Interactions";
     30        private const string ConsiderDenominationParameterName = "Consider Denomination";
     31        private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
     32        private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
     33        private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear Functions";
     34        private const string LambdaParameterName = "Elastic Net Lambda";
     35        private const string PenaltyParameterName = "Elastic Net Penalty";
     36        private const string MaxNumBasisFuncsParameterName = "Maximum Number of Basis Functions";
     37
     38        #endregion
     39
     40        #region parameters
     41        public IValueParameter<BoolValue> ConsiderInteractionsParameter {
     42            get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
     43        }
     44        public IValueParameter<BoolValue> ConsiderDenominationsParameter {
     45            get { return (IValueParameter<BoolValue>)Parameters[ConsiderDenominationParameterName]; }
     46        }
     47        public IValueParameter<BoolValue> ConsiderExponentiationsParameter {
     48            get { return (IValueParameter<BoolValue>)Parameters[ConsiderExponentiationParameterName]; }
     49        }
     50        public IValueParameter<BoolValue> ConsiderNonlinearFuncsParameter {
     51            get { return (IValueParameter<BoolValue>)Parameters[ConsiderNonlinearFuncsParameterName]; }
     52        }
     53        public IValueParameter<BoolValue> ConsiderHingeFuncsParameter {
     54            get { return (IValueParameter<BoolValue>)Parameters[ConsiderHingeFuncsParameterName]; }
     55        }
     56        public IValueParameter<DoubleValue> PenaltyParameter {
     57            get { return (IValueParameter<DoubleValue>)Parameters[PenaltyParameterName]; }
     58        }
     59        public IValueParameter<DoubleValue> LambdaParameter {
     60            get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
     61        }
     62        public IValueParameter<IntValue> MaxNumBasisFuncsParameter {
     63            get { return (IValueParameter<IntValue>)Parameters[MaxNumBasisFuncsParameterName]; }
     64        }
     65        #endregion
     66
     67        #region properties
     68        public bool ConsiderInteractions {
     69            get { return ConsiderInteractionsParameter.Value.Value; }
     70            set { ConsiderInteractionsParameter.Value.Value = value; }
     71        }
     72        public bool ConsiderDenominations {
     73            get { return ConsiderDenominationsParameter.Value.Value; }
     74            set { ConsiderDenominationsParameter.Value.Value = value; }
     75        }
     76        public bool ConsiderExponentiations {
     77            get { return ConsiderExponentiationsParameter.Value.Value; }
     78            set { ConsiderExponentiationsParameter.Value.Value = value; }
     79        }
     80        public bool ConsiderNonlinearFunctions {
     81            get { return ConsiderNonlinearFuncsParameter.Value.Value; }
     82            set { ConsiderNonlinearFuncsParameter.Value.Value = value; }
     83        }
     84        public bool ConsiderHingeFunctions {
     85            get { return ConsiderHingeFuncsParameter.Value.Value; }
     86            set { ConsiderHingeFuncsParameter.Value.Value = value; }
     87        }
     88        public double Penalty {
     89            get { return PenaltyParameter.Value.Value; }
     90            set { PenaltyParameter.Value.Value = value; }
     91        }
     92        public DoubleValue Lambda {
     93            get { return LambdaParameter.Value; }
     94            set { LambdaParameter.Value = value; }
     95        }
     96        public int MaxNumBasisFuncs {
     97            get { return MaxNumBasisFuncsParameter.Value.Value; }
     98            set { MaxNumBasisFuncsParameter.Value.Value = value; }
     99        }
     100        #endregion
     101
     102        #region ctor
     103
     104        [StorableConstructor]
     105        private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
     106        public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) {
     107        }
     108        public FastFunctionExtraction() : base() {
     109            base.Problem = new RegressionProblem();
     110            Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true)));
     111            Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true)));
     112            Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true)));
     113            Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
     114            Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true)));
     115            Parameters.Add(new ValueParameter<IntValue>(MaxNumBasisFuncsParameterName, "Set how many basis functions the models can have at most. if Max Num Basis Funcs is negative => no restriction on size", new IntValue(20)));
     116            Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
     117            Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.05)));
     118        }
     119
     120        [StorableHook(HookType.AfterDeserialization)]
     121        private void AfterDeserialization() { }
     122
     123        public override IDeepCloneable Clone(Cloner cloner) {
     124            return new FastFunctionExtraction(this, cloner);
     125        }
     126        #endregion
     127
     128        public override Type ProblemType { get { return typeof(RegressionProblem); } }
     129        public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
     130
     131        protected override void Run(CancellationToken cancellationToken) {
     132            var models = Fit(Problem.ProblemData, Penalty, out var numBases, ConsiderExponentiations,
     133                ConsiderNonlinearFunctions, ConsiderInteractions,
     134                ConsiderDenominations, ConsiderHingeFunctions, MaxNumBasisFuncs
     135            );
     136
     137            int i = 0;
     138            var numBasesArr = numBases.ToArray();
     139            var solutionsArr = new List<ISymbolicRegressionSolution>();
     140
     141            foreach (var model in models) {
     142                Results.Add(new Result(
     143                    "Num Bases: " + numBasesArr[i++],
     144                    model
     145                ));
     146                solutionsArr.Add(new SymbolicRegressionSolution(model, Problem.ProblemData));
     147            }
     148            Results.Add(new Result("Model Accuracies", new ItemCollection<ISymbolicRegressionSolution>(solutionsArr)));
     149        }
     150
     151        public static IEnumerable<ISymbolicRegressionModel> Fit(IRegressionProblemData data, double elnetPenalty, out IEnumerable<int> numBases, bool exp = true, bool nonlinFuncs = true, bool interactions = true, bool denoms = false, bool hingeFuncs = true, int maxNumBases = -1) {
     152            var approaches = CreateApproaches(interactions, denoms, exp, nonlinFuncs, hingeFuncs, maxNumBases, elnetPenalty);
     153
     154            var allFFXModels = approaches
     155                .SelectMany(approach => CreateFFXModels(data, approach)).ToList();
     156
     157            // Final pareto filter over all generated models from all different approaches
     158            var nondominatedFFXModels = NondominatedModels(data, allFFXModels);
     159
     160            numBases = nondominatedFFXModels
     161                .Select(ffxModel => ffxModel.NumBases).ToArray();
     162            return nondominatedFFXModels.Select(ffxModel => ffxModel.ToSymbolicRegressionModel(data.TargetVariable));
     163        }
     164
     165        private static IEnumerable<FFXModel> NondominatedModels(IRegressionProblemData data, IEnumerable<FFXModel> ffxModels) {
     166            var numBases = ffxModels.Select(ffxModel => (double)ffxModel.NumBases).ToArray();
     167            var errors = ffxModels.Select(ffxModel => {
     168                var originalValues = data.TargetVariableTestValues.ToArray();
     169                var estimatedValues = ffxModel.Simulate(data, data.TestIndices);
     170                // do not create a regressionSolution here for better performance:
     171                // RegressionSolutions calculate all kinds of errors when calling the ctor, but we only need testMSE
     172                var testMSE = OnlineMeanSquaredErrorCalculator.Calculate(originalValues, estimatedValues, out var state);
     173                if (state != OnlineCalculatorError.None) throw new ArrayTypeMismatchException("could not calculate TestMSE");
     174                return testMSE;
     175            }).ToArray();
     176
     177            int n = numBases.Length;
     178            double[][] qualities = new double[n][];
     179            for (int i = 0; i < n; i++) {
     180                qualities[i] = new double[2];
     181                qualities[i][0] = numBases[i];
     182                qualities[i][1] = errors[i];
     183            }
     184
     185            return DominationCalculator<FFXModel>.CalculateBestParetoFront(ffxModels.ToArray(), qualities, new bool[] { false, false })
     186                .Select(tuple => tuple.Item1).OrderBy(ffxModel => ffxModel.NumBases);
     187        }
     188
     189        // Build FFX models
     190        private static IEnumerable<FFXModel> CreateFFXModels(IRegressionProblemData data, Approach approach) {
     191            // FFX Step 1
     192            var basisFunctions = BFUtils.CreateBasisFunctions(data, approach).ToArray();
     193
     194            var funcsArr = basisFunctions.ToArray();
     195            var elnetData = BFUtils.PrepareData(data, funcsArr);
     196
     197            // FFX Step 2
     198            ElasticNetLinearRegression.RunElasticNetLinearRegression(elnetData, approach.ElnetPenalty, out var _, out var _, out var _, out var candidateCoeffs, out var intercept, maxVars: approach.MaxNumBases);
     199
     200            // create models out of the learned coefficients
     201            var ffxModels = GetUniqueModelsFromCoeffs(candidateCoeffs, intercept, funcsArr, approach);
     202             
     203            // one last LS-optimization step on the training data
     204            foreach (var ffxModel in ffxModels) {
     205                //if (ffxModel.NumBases > 0) ffxModel.OptimizeCoefficients(data);
     206            }
     207            return ffxModels;
     208        }
     209
     210        // finds all models with unique combinations of basis functions
     211        private static IEnumerable<FFXModel> GetUniqueModelsFromCoeffs(double[,] candidateCoeffs, double[] intercept, IBasisFunction[] funcsArr, Approach approach) {
     212            List<FFXModel> ffxModels = new List<FFXModel>();
     213            List<int[]> bfCombinations = new List<int[]>();
     214
     215            for (int i = 0; i < intercept.Length; i++) {
     216                var row = candidateCoeffs.GetRow(i);
     217                var nonzeroIndices = row.FindAllIndices(val => val != 0).ToArray();
     218                if (nonzeroIndices.Count() > approach.MaxNumBases) continue;
     219                // ignore duplicate models (models with same combination of basis functions)
     220                if (bfCombinations.Any(arr => Enumerable.SequenceEqual(arr, nonzeroIndices))) continue;
     221                var ffxModel = new FFXModel(intercept[i], nonzeroIndices.Select(idx => (row[idx], funcsArr[idx])));
     222                bfCombinations.Add(nonzeroIndices);
     223                ffxModels.Add(ffxModel);
     224            }
     225            return ffxModels;
     226        }
     227
     228        private static IEnumerable<Approach> CreateApproaches(bool interactions, bool denominator, bool exponentiations, bool nonlinearFuncs, bool hingeFunctions, int maxNumBases, double penalty) {
     229            var approaches = new List<Approach>();
     230            var valids = new bool[5] { interactions, denominator, exponentiations, nonlinearFuncs, hingeFunctions };
     231
     232            // return true if ALL indices of true values of arr1 also have true values in arr2
     233            bool follows(bool[] arr1, bool[] arr2) {
     234                if (arr1.Length != arr2.Length) throw new ArgumentException("invalid lengths");
     235                for (int i = 0; i < arr1.Length; i++) {
     236                    if (arr1[i] && !arr2[i]) return false;
     237                }
     238                return true;
     239            }
     240
     241            for (int i = 0; i < 32; i++) {
     242                // map i to a bool array of length 5
     243                var arr = i.ToBoolArray(5);
     244                if (!follows(arr, valids)) continue;
     245                int sum = arr.Where(b => b).Count(); // how many features are enabled?
     246                if (sum >= 4) continue; // not too many features at once
     247                if (arr[0] && arr[2]) continue; // never need both exponent and inter
     248                approaches.Add(new Approach(arr[0], arr[1], arr[2], arr[3], arr[4], exponents, nonlinFuncs, maxNumBases, penalty, minHingeThr, maxHingeThr, numHingeThrs));
     249            }
     250            return approaches;
     251        }
    58252    }
    59     public IValueParameter<BoolValue> ConsiderDenominationsParameter
    60     {
    61       get { return (IValueParameter<BoolValue>)Parameters[ConsiderDenominationParameterName]; }
    62     }
    63     public IValueParameter<BoolValue> ConsiderExponentiationsParameter
    64     {
    65       get { return (IValueParameter<BoolValue>)Parameters[ConsiderExponentiationParameterName]; }
    66     }
    67     public IValueParameter<BoolValue> ConsiderNonlinearFuncsParameter
    68     {
    69       get { return (IValueParameter<BoolValue>)Parameters[ConsiderNonlinearFuncsParameterName]; }
    70     }
    71     public IValueParameter<BoolValue> ConsiderHingeFuncsParameter
    72     {
    73       get { return (IValueParameter<BoolValue>)Parameters[ConsiderHingeFuncsParameterName]; }
    74     }
    75     public IValueParameter<DoubleValue> PenaltyParameter
    76     {
    77       get { return (IValueParameter<DoubleValue>)Parameters[PenaltyParameterName]; }
    78     }
    79     public IValueParameter<DoubleValue> LambdaParameter
    80     {
    81       get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
    82     }
    83     public IValueParameter<CheckedItemCollection<EnumValue<OpCode>>> NonlinearFuncsParameter
    84     {
    85       get { return (IValueParameter<CheckedItemCollection<EnumValue<OpCode>>>)Parameters[NonlinearFuncsParameterName]; }
    86     }
    87     #endregion
    88 
    89     #region properties
    90     public bool ConsiderInteractions
    91     {
    92       get { return ConsiderInteractionsParameter.Value.Value; }
    93       set { ConsiderInteractionsParameter.Value.Value = value; }
    94     }
    95     public bool ConsiderDenominations
    96     {
    97       get { return ConsiderDenominationsParameter.Value.Value; }
    98       set { ConsiderDenominationsParameter.Value.Value = value; }
    99     }
    100     public bool ConsiderExponentiations
    101     {
    102       get { return ConsiderExponentiationsParameter.Value.Value; }
    103       set { ConsiderExponentiationsParameter.Value.Value = value; }
    104     }
    105     public bool ConsiderNonlinearFuncs
    106     {
    107       get { return ConsiderNonlinearFuncsParameter.Value.Value; }
    108       set { ConsiderNonlinearFuncsParameter.Value.Value = value; }
    109     }
    110     public bool ConsiderHingeFuncs
    111     {
    112       get { return ConsiderHingeFuncsParameter.Value.Value; }
    113       set { ConsiderHingeFuncsParameter.Value.Value = value; }
    114     }
    115     public double Penalty
    116     {
    117       get { return PenaltyParameter.Value.Value; }
    118       set { PenaltyParameter.Value.Value = value; }
    119     }
    120     public DoubleValue Lambda
    121     {
    122       get { return LambdaParameter.Value; }
    123       set { LambdaParameter.Value = value; }
    124     }
    125     public CheckedItemCollection<EnumValue<OpCode>> NonlinearFuncs
    126     {
    127       get { return NonlinearFuncsParameter.Value; }
    128       set { NonlinearFuncsParameter.Value = value; }
    129     }
    130     #endregion
    131 
    132 
    133     [StorableConstructor]
    134     private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
    135     public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) {
    136     }
    137     public FastFunctionExtraction() : base() {
    138       var items = new CheckedItemCollection<EnumValue<OpCode>>();
    139       foreach (var op in nonlinFuncs) {
    140         items.Add(new EnumValue<OpCode>(op));
    141       }
    142       base.Problem = new RegressionProblem();
    143       Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true)));
    144       Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true)));
    145       Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true)));
    146       Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
    147       Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true)));
    148       Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.9)));
    149       Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
    150       Parameters.Add(new ValueParameter<CheckedItemCollection<EnumValue<OpCode>>>(NonlinearFuncsParameterName, "What nonlinear functions the models should be able to include.", items));
    151     }
    152 
    153     [StorableHook(HookType.AfterDeserialization)]
    154     private void AfterDeserialization() { }
    155 
    156     public override IDeepCloneable Clone(Cloner cloner) {
    157       return new FastFunctionExtraction(this, cloner);
    158     }
    159 
    160     public override Type ProblemType { get { return typeof(RegressionProblem); } }
    161     public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
    162 
    163     public override bool SupportsPause { get { return true; } }
    164 
    165     protected override void Run(CancellationToken cancellationToken) {
    166       var basisFunctions = createBasisFunctions(Problem.ProblemData);
    167       Results.Add(new Result("Basis Functions", "A Dataset consisting of the generated Basis Functions from FFX Alg Step 1.", createProblemData(Problem.ProblemData, basisFunctions)));
    168 
    169       // add denominator bases to the already existing basis functions
    170       if (ConsiderDenominations) basisFunctions = basisFunctions.Concat(createDenominatorBases(Problem.ProblemData, basisFunctions)).ToList();
    171 
    172       // create either path of solutions, or one solution for given lambda
    173       LearnModels(Problem.ProblemData, basisFunctions);
    174     }
    175 
    176     private List<BasisFunction> createBasisFunctions(IRegressionProblemData problemData) {
    177       var basisFunctions = createUnivariateBases(problemData);
    178       basisFunctions = basisFunctions.Concat(createMultivariateBases(basisFunctions)).ToList();
    179       return basisFunctions;
    180     }
    181 
    182     private List<BasisFunction> createUnivariateBases(IRegressionProblemData problemData) {
    183       var B1 = new List<BasisFunction>();
    184       var inputVariables = problemData.AllowedInputVariables;
    185       var validExponents = ConsiderExponentiations ? exponents : new double[] { 1 };
    186       var validFuncs = NonlinearFuncs.CheckedItems.Select(val => val.Value);
    187       // TODO: add Hinge functions
    188 
    189       foreach (var variableName in inputVariables) {
    190         foreach (var exp in validExponents) {
    191           var data = problemData.Dataset.GetDoubleValues(variableName).Select(x => Math.Pow(x, exp)).ToArray();
    192           if (!ok(data)) continue;
    193           var name = expToString(exp, variableName);
    194           B1.Add(new BasisFunction(name, data, false));
    195           foreach (OpCode _op in validFuncs) {
    196             var inner_data = data.Select(x => eval(_op, x)).ToArray();
    197             if (!ok(inner_data)) continue;
    198             var inner_name = OpCodeToString.GetByFirst(_op) + "(" + name + ")";
    199             B1.Add(new BasisFunction(inner_name, inner_data, true));
    200           }
    201         }
    202       }
    203       return B1;
    204     }
    205 
    206     private List<BasisFunction> createMultivariateBases(List<BasisFunction> B1) {
    207       if (!ConsiderInteractions) return B1;
    208       var B2 = new List<BasisFunction>();
    209       for (int i = 0; i < B1.Count(); i++) {
    210         var b_i = B1.ElementAt(i);
    211         for (int j = 0; j < i; j++) {
    212           var b_j = B1.ElementAt(j);
    213           if (b_j.IsOperator) continue; // disallow op() * op()
    214           var b_inter = b_i * b_j;
    215           B2.Add(b_inter);
    216         }
    217       }
    218 
    219       return B2;
    220       // return union of B1 and B2
    221     }
    222 
    223     // creates 1 denominator basis function for each corresponding basis function from basisFunctions
    224     private IEnumerable<BasisFunction> createDenominatorBases(IRegressionProblemData problemData, IEnumerable<BasisFunction> basisFunctions) {
    225       var y = new BasisFunction(problemData.TargetVariable, problemData.TargetVariableValues.ToArray(), false);
    226       var denomBasisFuncs = new List<BasisFunction>();
    227       foreach (var func in basisFunctions) {
    228         var denomFunc = y * func;
    229         denomBasisFuncs.Add(denomFunc);
    230       }
    231       return denomBasisFuncs;
    232     }
    233 
    234     private static string expToString(double exponent, string varname) {
    235       if (exponent.IsAlmost(1)) return varname;
    236       if (exponent.IsAlmost(1 / 2)) return OpCodeToString.GetByFirst(OpCode.SquareRoot) + "(" + varname + ")";
    237       if (exponent.IsAlmost(1 / 3)) return OpCodeToString.GetByFirst(OpCode.CubeRoot) + "(" + varname + ")";
    238       if (exponent.IsAlmost(2)) return OpCodeToString.GetByFirst(OpCode.Square) + "(" + varname + ")";
    239       if (exponent.IsAlmost(3)) return OpCodeToString.GetByFirst(OpCode.Cube) + "(" + varname + ")";
    240       else return varname + " ^ " + exponent;
    241     }
    242 
    243     public static double eval(OpCode op, double x) {
    244       switch (op) {
    245         case OpCode.Absolute:
    246           return Math.Abs(x);
    247         case OpCode.Log:
    248           return Math.Log10(x);
    249         case OpCode.Sin:
    250           return Math.Sin(x);
    251         case OpCode.Cos:
    252           return Math.Cos(x);
    253         default:
    254           throw new Exception("Unimplemented operator: " + op.ToString());
    255       }
    256     }
    257 
    258     private void PathwiseLearning(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {
    259       ElasticNetLinearRegression reg = new ElasticNetLinearRegression();
    260       reg.Lambda = Lambda;
    261       reg.Penality = Penalty;
    262       reg.Problem.ProblemData = createProblemData(problemData, basisFunctions);
    263       reg.Start();
    264       Results.AddRange(reg.Results);
    265     }
    266 
    267     private void LearnModels(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {
    268       double[] lambda;
    269       double[] trainNMSE;
    270       double[] testNMSE;
    271       double[,] coeff;
    272       double[] intercept;
    273       int numNominatorBases = ConsiderDenominations ? basisFunctions.Count / 2 : basisFunctions.Count;
    274 
    275       // wraps the list of basis functions in a dataset, so that it can be passed on to the ElNet function
    276       var X_b = createProblemData(problemData, basisFunctions);
    277 
    278       ElasticNetLinearRegression.RunElasticNetLinearRegression(X_b, Penalty, out lambda, out trainNMSE, out testNMSE, out coeff, out intercept);
    279 
    280       var errorTable = NMSEGraph(coeff, lambda, trainNMSE, testNMSE);
    281       Results.Add(new Result(errorTable.Name, errorTable.Description, errorTable));
    282       var coeffTable = CoefficientGraph(coeff, lambda, X_b.AllowedInputVariables, X_b.Dataset);
    283       Results.Add(new Result(coeffTable.Name, coeffTable.Description, coeffTable));
    284 
    285       ItemCollection<IResult> models = new ItemCollection<IResult>();
    286       for (int modelIdx = 0; modelIdx < coeff.GetUpperBound(0); modelIdx++) {
    287         var tree = Tree(basisFunctions, GetRow(coeff, modelIdx), intercept[modelIdx]);
    288         ISymbolicRegressionModel m = new SymbolicRegressionModel(Problem.ProblemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
    289         ISymbolicRegressionSolution s = new SymbolicRegressionSolution(m, Problem.ProblemData);
    290         models.Add(new Result("Solution " + modelIdx, s));
    291       }
    292 
    293       Results.Add(new Result("Models", "The model path returned by the Elastic Net Regression (not only the pareto-optimal subset). ", models));
    294     }
    295 
    296     private static IndexedDataTable<double> CoefficientGraph(double[,] coeff, double[] lambda, IEnumerable<string> allowedVars, IDataset ds) {
    297       var coeffTable = new IndexedDataTable<double>("Coefficients", "The paths of standarized coefficient values over different lambda values");
    298       coeffTable.VisualProperties.YAxisMaximumAuto = false;
    299       coeffTable.VisualProperties.YAxisMinimumAuto = false;
    300       coeffTable.VisualProperties.XAxisMaximumAuto = false;
    301       coeffTable.VisualProperties.XAxisMinimumAuto = false;
    302 
    303       coeffTable.VisualProperties.XAxisLogScale = true;
    304       coeffTable.VisualProperties.XAxisTitle = "Lambda";
    305       coeffTable.VisualProperties.YAxisTitle = "Coefficients";
    306       coeffTable.VisualProperties.SecondYAxisTitle = "Number of variables";
    307 
    308       var nLambdas = lambda.Length;
    309       var nCoeff = coeff.GetLength(1);
    310       var dataRows = new IndexedDataRow<double>[nCoeff];
    311       var numNonZeroCoeffs = new int[nLambdas];
    312 
    313       var doubleVariables = allowedVars.Where(ds.VariableHasType<double>);
    314       var factorVariableNames = allowedVars.Where(ds.VariableHasType<string>);
    315       var factorVariablesAndValues = ds.GetFactorVariableValues(factorVariableNames, Enumerable.Range(0, ds.Rows)); // must consider all factor values (in train and test set)
    316       {
    317         int i = 0;
    318         foreach (var factorVariableAndValues in factorVariablesAndValues) {
    319           foreach (var factorValue in factorVariableAndValues.Value) {
    320             double sigma = ds.GetStringValues(factorVariableAndValues.Key)
    321               .Select(s => s == factorValue ? 1.0 : 0.0)
    322               .StandardDeviation(); // calc std dev of binary indicator
    323             var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();
    324             dataRows[i] = new IndexedDataRow<double>(factorVariableAndValues.Key + "=" + factorValue, factorVariableAndValues.Key + "=" + factorValue, path);
    325             i++;
    326           }
    327         }
    328 
    329         foreach (var doubleVariable in doubleVariables) {
    330           double sigma = ds.GetDoubleValues(doubleVariable).StandardDeviation();
    331           var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();
    332           dataRows[i] = new IndexedDataRow<double>(doubleVariable, doubleVariable, path);
    333           i++;
    334         }
    335         // add to coeffTable by total weight (larger area under the curve => more important);
    336         foreach (var r in dataRows.OrderByDescending(r => r.Values.Select(t => t.Item2).Sum(x => Math.Abs(x)))) {
    337           coeffTable.Rows.Add(r);
    338         }
    339       }
    340 
    341       for (int i = 0; i < coeff.GetLength(0); i++) {
    342         for (int j = 0; j < coeff.GetLength(1); j++) {
    343           if (!coeff[i, j].IsAlmost(0.0)) {
    344             numNonZeroCoeffs[i]++;
    345           }
    346         }
    347       }
    348       if (lambda.Length > 2) {
    349         coeffTable.VisualProperties.XAxisMinimumFixedValue = Math.Pow(10, Math.Floor(Math.Log10(lambda.Last())));
    350         coeffTable.VisualProperties.XAxisMaximumFixedValue = Math.Pow(10, Math.Ceiling(Math.Log10(lambda.Skip(1).First())));
    351       }
    352       coeffTable.Rows.Add(new IndexedDataRow<double>("Number of variables", "The number of non-zero coefficients for each step in the path", lambda.Zip(numNonZeroCoeffs, (l, v) => Tuple.Create(l, (double)v))));
    353       coeffTable.Rows["Number of variables"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
    354       coeffTable.Rows["Number of variables"].VisualProperties.SecondYAxis = true;
    355 
    356       return coeffTable;
    357     }
    358 
    359     private static IndexedDataTable<double> NMSEGraph(double[,] coeff, double[] lambda, double[] trainNMSE, double[] testNMSE) {
    360       var errorTable = new IndexedDataTable<double>("NMSE", "Path of NMSE values over different lambda values");
    361       var numNonZeroCoeffs = new int[lambda.Length];
    362       errorTable.VisualProperties.YAxisMaximumAuto = false;
    363       errorTable.VisualProperties.YAxisMinimumAuto = false;
    364       errorTable.VisualProperties.XAxisMaximumAuto = false;
    365       errorTable.VisualProperties.XAxisMinimumAuto = false;
    366 
    367       for (int i = 0; i < coeff.GetLength(0); i++) {
    368         for (int j = 0; j < coeff.GetLength(1); j++) {
    369           if (!coeff[i, j].IsAlmost(0.0)) {
    370             numNonZeroCoeffs[i]++;
    371           }
    372         }
    373       }
    374 
    375       errorTable.VisualProperties.YAxisMinimumFixedValue = 0;
    376       errorTable.VisualProperties.YAxisMaximumFixedValue = 1.0;
    377       errorTable.VisualProperties.XAxisLogScale = true;
    378       errorTable.VisualProperties.XAxisTitle = "Lambda";
    379       errorTable.VisualProperties.YAxisTitle = "Normalized mean of squared errors (NMSE)";
    380       errorTable.VisualProperties.SecondYAxisTitle = "Number of variables";
    381       errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (train)", "Path of NMSE values over different lambda values", lambda.Zip(trainNMSE, (l, v) => Tuple.Create(l, v))));
    382       errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (test)", "Path of NMSE values over different lambda values", lambda.Zip(testNMSE, (l, v) => Tuple.Create(l, v))));
    383       errorTable.Rows.Add(new IndexedDataRow<double>("Number of variables", "The number of non-zero coefficients for each step in the path", lambda.Zip(numNonZeroCoeffs, (l, v) => Tuple.Create(l, (double)v))));
    384       if (lambda.Length > 2) {
    385         errorTable.VisualProperties.XAxisMinimumFixedValue = Math.Pow(10, Math.Floor(Math.Log10(lambda.Last())));
    386         errorTable.VisualProperties.XAxisMaximumFixedValue = Math.Pow(10, Math.Ceiling(Math.Log10(lambda.Skip(1).First())));
    387       }
    388       errorTable.Rows["NMSE (train)"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
    389       errorTable.Rows["NMSE (test)"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
    390       errorTable.Rows["Number of variables"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
    391       errorTable.Rows["Number of variables"].VisualProperties.SecondYAxis = true;
    392 
    393       return errorTable;
    394     }
    395 
    396     private ISymbolicExpressionTree Tree(List<BasisFunction> basisFunctions, double[] coeffs, double offset) {
    397       Debug.Assert(basisFunctions.Count() == coeffs.Length);
    398       //SymbolicExpressionTree
    399       var numNumeratorFuncs = ConsiderDenominations ? basisFunctions.Count() / 2 : basisFunctions.Count();
    400       var numeratorBasisFuncs = basisFunctions.Take(numNumeratorFuncs);
    401 
    402       // returns true if there exists at least 1 coefficient value in the model that is part of the denominator
    403       // (i.e. if there exists at least 1 non-zero value in the second half of the array)
    404       bool withDenom(double[] coeffarr) => coeffarr.Take(coeffarr.Length / 2).ToArray().Any(val => !val.IsAlmost(0.0));
    405       string model = "(" + offset.ToString();
    406       for (int i = 0; i < numNumeratorFuncs; i++) {
    407         var func = basisFunctions.ElementAt(i);
    408         // only generate nodes for relevant basis functions (those with non-zero coeffs)
    409         if (!coeffs[i].IsAlmost(0.0))
    410           model += " + (" + coeffs[i] + ") * " + func.Var;
    411       }
    412       if (ConsiderDenominations && withDenom(coeffs)) {
    413         model += ") / (1";
    414         for (int i = numNumeratorFuncs; i < basisFunctions.Count(); i++) {
    415           var func = basisFunctions.ElementAt(i);
    416           // only generate nodes for relevant basis functions (those with non-zero coeffs)
    417           if (!coeffs[i].IsAlmost(0.0))
    418             model += " + (" + coeffs[i] + ") * " + func.Var.Substring(4);
    419         }
    420       }
    421       model += ")";
    422       InfixExpressionParser p = new InfixExpressionParser();
    423       return p.Parse(model);
    424     }
    425 
    426     // wraps the list of basis functions into an IRegressionProblemData object
    427     private static IRegressionProblemData createProblemData(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {
    428       List<string> variableNames = new List<string>();
    429       List<IList> variableVals = new List<IList>();
    430       foreach (var basisFunc in basisFunctions) {
    431         variableNames.Add(basisFunc.Var);
    432         // basisFunctions already contains the calculated values of the corresponding basis function, so you can just take that value
    433         variableVals.Add(new List<double>(basisFunc.Val));
    434       }
    435       var matrix = new ModifiableDataset(variableNames, variableVals);
    436 
    437       // add the unmodified target variable to the matrix
    438       matrix.AddVariable(problemData.TargetVariable, problemData.TargetVariableValues.ToList());
    439       var allowedInputVars = matrix.VariableNames.Where(x => !x.Equals(problemData.TargetVariable));
    440       IRegressionProblemData rpd = new RegressionProblemData(matrix, allowedInputVars, problemData.TargetVariable);
    441       rpd.TrainingPartition.Start = problemData.TrainingPartition.Start;
    442       rpd.TrainingPartition.End = problemData.TrainingPartition.End;
    443       rpd.TestPartition.Start = problemData.TestPartition.Start;
    444       rpd.TestPartition.End = problemData.TestPartition.End;
    445       return rpd;
    446     }
    447 
    448     private static bool ok(double[] data) => data.All(x => !double.IsNaN(x) && !double.IsInfinity(x));
    449 
    450     // helper function which returns a row of a 2D array
    451     private static T[] GetRow<T>(T[,] matrix, int row) {
    452       var columns = matrix.GetLength(1);
    453       var array = new T[columns];
    454       for (int i = 0; i < columns; ++i)
    455         array[i] = matrix[row, i];
    456       return array;
    457     }
    458 
    459     // returns all models with pareto-optimal tradeoff between error and complexity
    460     private static List<IRegressionSolution> nondominatedFilter(double[][] coefficientVectorSet, BasisFunction[] basisFunctions) {
    461       return null;
    462     }
    463   }
    464253}
Note: See TracChangeset for help on using the changeset viewer.