Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3022-FastFunctionExtraction/FFX/FastFunctionExtraction.cs @ 17223

Last change on this file since 17223 was 17219, checked in by lleko, 5 years ago

#3022 add generateUnivariateBases(), add BasisFunction class

File size: 8.1 KB
Line 
1using System;
2using System.Threading;
3using System.Linq;
4using HeuristicLab.Common; // required for parameters collection
5using HeuristicLab.Core; // required for parameters collection
6using HeuristicLab.Data; // IntValue, ...
7using HeuristicLab.Encodings.BinaryVectorEncoding;
8using HeuristicLab.Optimization; // BasicAlgorithm
9using HeuristicLab.Parameters;
10using HeuristicLab.Problems.Binary;
11using HeuristicLab.Random; // MersenneTwister
12using HEAL.Attic;
13using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
14using HeuristicLab.Problems.DataAnalysis;
15using System.Collections.Generic;
16using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
17
18namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction
19{
20
21    [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
22    [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
23    [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
24    public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem>
25    {
26        private enum Operator { Abs, Log };
27        private static readonly double[] exponents = { 0.5, 1, 2 };
28
29        private const string PenaltyParameterName = "Penalty";
30        private const string ConsiderInteractionsParameterName = "Consider Interactions";
31        private const string ConsiderDenominationParameterName = "Consider Denomination";
32        private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
33        private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear functions";
34        private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
35
36        #region parameters
37        public IValueParameter<BoolValue> ConsiderInteractionsParameter
38        {
39            get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
40        }
41        #endregion
42
43        #region properties
44        public bool ConsiderInteractions
45        {
46            get { return ConsiderInteractionsParameter.Value.Value; }
47            set { ConsiderInteractionsParameter.Value.Value = value; }
48        }
49        #endregion
50
51        [StorableConstructor]
52        private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
53        public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner)
54        {
55            // Don't forget to call the cloning ctor of the base class 
56            // This class does not have fields, therefore we don't need to actually clone anything
57        }
58        public FastFunctionExtraction() : base()
59        {
60            // algorithm parameters are shown in the GUI
61            Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.5)));
62            Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want to consider interactions, otherwise false.", new BoolValue(true)));
63            Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want to consider denominations, otherwise false.", new BoolValue(true)));
64            Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want to consider exponentiation, otherwise false.", new BoolValue(true)));
65            Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want to consider nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
66            Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want to consider Hinge Functions, otherwise false.", new BoolValue(true)));
67        }
68
69        [StorableHook(HookType.AfterDeserialization)]
70        private void AfterDeserialization() { }
71
72        public override IDeepCloneable Clone(Cloner cloner)
73        {
74            return new FastFunctionExtraction(this, cloner);
75        }
76
77        public override Type ProblemType { get { return typeof(RegressionProblem); } }
78        public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
79
80
81        protected override void Run(CancellationToken cancellationToken)
82        {
83            var basisFunctions = generateBasisFunctions(Problem.ProblemData);
84            var x = Problem.ProblemData.AllowedInputsTrainingValues;
85            List<SymbolicExpressionTree> trees = new List<SymbolicExpressionTree>();
86
87
88            foreach (var basisFunc in basisFunctions)
89            {
90                // add tree representation of basisFunc to trees
91                trees.Add(generateSymbolicExpressionTree(basisFunc));
92            }
93
94            foreach (var tree in trees)
95            {
96                // create new data through the help of the Interpreter
97                //IEnumerable<double> responses =
98            }
99
100            var coefficientVectorSet = findCoefficientValues(basisFunctions);
101            var paretoFront = nondominatedFilter(coefficientVectorSet);
102        }
103
104        private SymbolicExpressionTree generateSymbolicExpressionTree(KeyValuePair<string, double[]> basisFunc)
105        {
106            throw new NotImplementedException();
107        }
108
109        // generate all possible models
110        private static Dictionary<string, double[]> generateBasisFunctions(IRegressionProblemData problemData)
111        {
112            var basisFunctions = generateUnivariateBases(problemData);
113            return basisFunctions;
114        }
115
116        private static Dictionary<string, double[]> generateUnivariateBases(IRegressionProblemData problemData)
117        {
118
119            var dataset = problemData.Dataset;
120            var rows = problemData.TrainingIndices;
121            var B1 = new Dictionary<string, double[]>();
122
123            foreach (var variableName in dataset.VariableNames)
124            {
125                foreach (var exp in new[] { 0.5, 1, 2 })
126                {
127                    var name = variableName + " ** " + exp;
128                    var data = dataset.GetDoubleValues(variableName, rows).Select(x => Math.Pow(x, exp)).ToArray();
129                    B1.Add(name, data);
130                    foreach (Operator op in Enum.GetValues(typeof(Operator)))
131                    {
132                        var inner_name = op.ToString() + "(" + name + ")";
133                        var inner_data = data.Select(x => executeOperator(x, op)).ToArray();
134                        B1.Add(inner_name, inner_data);
135                    }
136                }
137            }
138           
139            return B1;
140        }
141
142        private static double executeOperator(double x, Operator op)
143        {
144            switch (op)
145            {
146                case Operator.Abs:
147                    return x > 0 ? x : -x;
148                case Operator.Log:
149                    return Math.Log10(x);
150                default:
151                    throw new NotImplementedException();
152            }
153        }
154
155        private static Dictionary<string, double[]> generateMultiVariateBases(Dictionary<string, double[]> B1)
156        {
157            var B2 = new Dictionary<string, double[]>();
158            for(int i = 1; i <= B1.Count(); i++ )
159            {
160                var b_i = B1.ElementAt(i);
161                for (int j = 1; j < i; i++)
162                {
163                    var b_j = B1.ElementAt(j);
164                }
165            }
166
167            // return union of B1 and B2
168            return B2.Concat(B1).ToDictionary(kvp => kvp.Key, kvp => kvp.Value);
169        }
170
171        private static object findCoefficientValues(IEnumerable<KeyValuePair<string, double[]>> basisFunctions)
172        {
173            return new object();
174        }
175
176        private static object nondominatedFilter(object coefficientVectorSet)
177        {
178            return new object();
179        }
180
181        public override bool SupportsPause
182        {
183            get { return false; }
184        }
185    }
186}
Note: See TracBrowser for help on using the repository browser.