Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3022-FastFunctionExtraction/FFX/BFUtils.cs @ 17737

Last change on this file since 17737 was 17737, checked in by lleko, 4 years ago

#3022 implement ffx

File size: 10.6 KB
Line 
1using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
2using HeuristicLab.Data;
3using HeuristicLab.Problems.DataAnalysis;
4using System;
5using System.Collections.Generic;
6using System.Linq;
7using System.Runtime.CompilerServices;
8
9[assembly: InternalsVisibleTo("UnitTests")]
10namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
11    // utility functions for creating Basis Functions
12    internal static class BFUtils {
13        public static IEnumerable<IBasisFunction> CreateBasisFunctions(IRegressionProblemData data, Approach approach) {
14            var exponents = approach.AllowExp ? approach.Exponents : new HashSet<double> { 1 };
15            var funcs = approach.AllowNonlinFuncs ? approach.NonlinFuncs : new HashSet<NonlinearOperator> { NonlinearOperator.None };
16            var simpleBasisFuncs = CreateSimpleBases(data, exponents, funcs);
17
18            if (approach.AllowHinge) {
19                // only allow hinge functions for features with exponent = 1 (deemed too complex otherwise)
20                var linearSimpleBasisFuncs = simpleBasisFuncs.Where(simpleBf => simpleBf.Exponent == 1);
21                simpleBasisFuncs = simpleBasisFuncs.Concat(CreateHingeBases(data, linearSimpleBasisFuncs, approach.MinHingeThr, approach.MaxHingeThr, approach.NumHingeThrs));
22            }
23
24            IEnumerable<IBasisFunction> functions = simpleBasisFuncs;
25
26            if (approach.AllowInter) {
27                var multivariateBases = CreateMultivariateBases(data, simpleBasisFuncs.ToArray());
28                functions = functions.Concat(multivariateBases);
29            }
30
31            if (approach.AllowDenom) {
32                var denominatorBases = CreateDenominatorBases(functions);
33                functions = functions.Concat(denominatorBases);
34            }
35            return functions;
36        }
37
38        public static IEnumerable<ISimpleBasisFunction> CreateSimpleBases(IRegressionProblemData problemData, HashSet<double> exponents, HashSet<NonlinearOperator> nonlinFuncs) {
39            var simpleBasisFunctions = new List<ISimpleBasisFunction>();
40            foreach (var variableName in problemData.AllowedInputVariables) {
41                var vals = problemData.Dataset.GetDoubleValues(variableName).ToArray();
42                var min = vals.Min();
43                foreach (var exp in exponents) {
44                    var simpleBase = new SimpleBasisFunction(variableName, exp, NonlinearOperator.None);
45                    // if the basis function is not valid without any operator, then it won't be valid in combination with any nonlinear operator -> skip
46                    if (!Ok(simpleBase.Simulate(problemData))) continue;
47
48                    foreach (NonlinearOperator op in nonlinFuncs) {
49                        // ignore cases where op has no effect
50                        if (op.Equals(NonlinearOperator.Abs) && new[] { -2.0, 2.0 }.Contains(exp) && nonlinFuncs.Contains(NonlinearOperator.None)) continue;
51                        if (op.Equals(NonlinearOperator.Abs) && min >= 0) continue;
52                        var nonsimpleBase = (SimpleBasisFunction)simpleBase.DeepCopy();
53                        nonsimpleBase.Operator = op;
54                        if (!Ok(nonsimpleBase.Simulate(problemData))) continue;
55                        simpleBasisFunctions.Add(nonsimpleBase);
56                    }
57                }
58            }
59            return simpleBasisFunctions;
60        }
61
62        public static IEnumerable<IBasisFunction> CreateMultivariateBases(IRegressionProblemData data, IList<ISimpleBasisFunction> univariateBases) {
63            var orderedFuncs = OrderBasisFuncsByImportance(data, univariateBases).ToArray();
64            var multivariateBases = new List<IBasisFunction>();
65            int maxSize = 2 * orderedFuncs.Length;
66            foreach (var bf in orderedFuncs) {
67                // disallow bases with exponents
68                if (bf.Exponent != 1) continue;
69                multivariateBases.Add(new ProductBaseFunction(bf, bf, true));
70            }
71
72            for (int i = 0; i < orderedFuncs.Count(); i++) {
73                var b_i = orderedFuncs.ElementAt(i);
74                for (int j = 0; j < i; j++) {
75                    var b_j = orderedFuncs.ElementAt(j);
76                    if (b_j.Operator != NonlinearOperator.None) continue; // disallow op() * op(); deemed to complex
77                    var b_inter = new ProductBaseFunction(b_i, b_j, true);
78                    if (!Ok(b_inter.Simulate(data))) continue;
79                    multivariateBases.Add(b_inter);
80                    if (multivariateBases.Count() >= maxSize)
81                        return multivariateBases;
82                }
83            }
84            return multivariateBases;
85        }
86
87        // order basis functions by importance (decr)
88        // the importance of a basis function is measured by the absolute value of its coefficient when optimized on the data
89        public static IEnumerable<ISimpleBasisFunction> OrderBasisFuncsByImportance(IRegressionProblemData data, IList<ISimpleBasisFunction> candidateFunctions) {
90            var elnetData = PrepareData(Normalize(data), candidateFunctions);
91            var coeff = ElasticNetLinearRegression.CalculateModelCoefficients(elnetData, 0, 0, out var trainNMSE, out var testNMSE); // LS-fit
92            var intercept = coeff.Last();
93            coeff = coeff.Take(coeff.Length - 1).ToArray();
94            var order = Utils.Argsort(coeff);
95            Array.Reverse(order);
96            return order.Select(idx => candidateFunctions[idx]);
97        }
98
99        public static IList<ISimpleBasisFunction> CreateHingeBases(IRegressionProblemData data, IEnumerable<ISimpleBasisFunction> simple_bfs, double relative_start_thr = 0.0, double relative_end_thr = 1.0, int num_thrs = 3, IntRange trainingPartition = null) {
100            var hingeBases = new List<ISimpleBasisFunction>();
101
102            foreach (var simple_bf in simple_bfs) {
103                hingeBases.AddRange(CreateHingeBases(data, simple_bf, relative_start_thr, relative_end_thr, num_thrs, trainingPartition));
104            }
105            return hingeBases;
106        }
107
108        private static IEnumerable<ISimpleBasisFunction> CreateHingeBases(IRegressionProblemData data, ISimpleBasisFunction simple_bf, double relative_start_thr, double relative_end_thr, int num_thrs, IntRange trainingPartition) {
109            if (relative_start_thr >= relative_end_thr) throw new ArgumentException($"{nameof(relative_start_thr)} must be smaller than {nameof(relative_end_thr)}.");
110            var ans = new List<ISimpleBasisFunction>();
111
112            var vals = simple_bf.Simulate(data);
113            var temp = trainingPartition ?? data.TrainingPartition;
114            double min = Double.MaxValue;
115            double max = Double.MinValue;
116            for (int i = temp.Start; i < temp.End; i++) {
117                min = Math.Min(min, vals[i]);
118                max = Math.Max(max, vals[i]);
119            }
120            if (max - min == 0) return ans;
121            var full_range = max - min;
122            var start_thr = min + relative_start_thr * full_range;
123            var end_thr = min + relative_end_thr * full_range;
124            var thresholds = Utils.Linspace(start_thr, end_thr, num_thrs);
125
126            foreach (var thr in thresholds) {
127                ans.Add(new SimpleBasisFunction(simple_bf.Feature, 1, NonlinearOperator.Gth, true, thr));
128                ans.Add(new SimpleBasisFunction(simple_bf.Feature, 1, NonlinearOperator.Lth, true, thr));
129            }
130            return ans;
131        }
132
133        public static IEnumerable<IBasisFunction> CreateDenominatorBases(IEnumerable<IBasisFunction> basisFunctions) {
134            List<IBasisFunction> ans = new List<IBasisFunction>();
135            foreach (var bf in basisFunctions) {
136                if (!bf.IsNominator) continue;
137                var denomFunc = bf.DeepCopy();
138                denomFunc.IsNominator = false;
139                ans.Add(denomFunc);
140            }
141            return ans;
142        }
143
144        public static IRegressionProblemData PrepareData(IRegressionProblemData problemData, IEnumerable<IBasisFunction> basisFunctions) {
145            int numRows = problemData.Dataset.Rows;
146            int numCols = basisFunctions.Count();
147            HashSet<string> allowedInputVars = new HashSet<string>();
148            double[,] variableValues = new double[numRows, numCols + 1]; // +1 for target var
149
150            int col = 0;
151            foreach (var basisFunc in basisFunctions) {
152                allowedInputVars.Add(basisFunc.ToString() + (!basisFunc.IsNominator ? " * " + problemData.TargetVariable : ""));
153                var vals = basisFunc.Simulate(problemData);
154                for (int i = 0; i < numRows; i++) {
155                    variableValues[i, col] = vals[i];
156                }
157                col++;
158            }
159
160            // add the unmodified target variable to the dataset
161            var allVariables = new HashSet<string>(allowedInputVars);
162            allVariables.Add(problemData.TargetVariable);
163
164            var targetVals = problemData.TargetVariableValues.ToArray();
165            for (int i = 0; i < numRows; i++) {
166                variableValues[i, col] = targetVals[i];
167            }
168
169            var temp = new Dataset(allVariables, variableValues);
170
171            IRegressionProblemData rpd = new RegressionProblemData(temp, allowedInputVars, problemData.TargetVariable);
172            rpd.TrainingPartition.Start = problemData.TrainingPartition.Start;
173            rpd.TrainingPartition.End = problemData.TrainingPartition.End;
174            rpd.TestPartition.Start = problemData.TestPartition.Start;
175            rpd.TestPartition.End = problemData.TestPartition.End;
176            return rpd;
177        }
178
179        private static IRegressionProblemData Normalize(IRegressionProblemData data)
180            => new RegressionProblemData(Normalize(data.Dataset), data.AllowedInputVariables, data.TargetVariable);
181
182        // return a normalized version of IDataset ds
183        private static IDataset Normalize(IDataset ds) {
184            var doubleNames = ds.DoubleVariables.ToArray();
185            if (ds.VariableNames.Count() != doubleNames.Length) throw new ArgumentException(nameof(ds));
186            var variableVals = new List<List<double>>();
187            foreach (var name in doubleNames) {
188                var vals = Utils.Normalize(ds.GetDoubleValues(name).ToArray());
189                variableVals.Add(vals.ToList());
190            }
191            return new Dataset(doubleNames, variableVals);
192        }
193
194        private static bool Ok(IEnumerable<double> data) => data.All(x => !double.IsNaN(x) && !double.IsInfinity(x));
195    }
196}
Note: See TracBrowser for help on using the repository browser.