1 | using System;
|
---|
2 | using System.Threading;
|
---|
3 | using System.Linq;
|
---|
4 | using HeuristicLab.Common; // required for parameters collection
|
---|
5 | using HeuristicLab.Core; // required for parameters collection
|
---|
6 | using HeuristicLab.Data; // IntValue, ...
|
---|
7 | using HeuristicLab.Encodings.BinaryVectorEncoding;
|
---|
8 | using HeuristicLab.Optimization; // BasicAlgorithm
|
---|
9 | using HeuristicLab.Parameters;
|
---|
10 | using HeuristicLab.Problems.Binary;
|
---|
11 | using HeuristicLab.Random; // MersenneTwister
|
---|
12 | using HEAL.Attic;
|
---|
13 | using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
|
---|
14 | using HeuristicLab.Problems.DataAnalysis;
|
---|
15 | using System.Collections.Generic;
|
---|
16 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
17 |
|
---|
18 | namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction
|
---|
19 | {
|
---|
20 |
|
---|
21 | [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
|
---|
22 | [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
|
---|
23 | [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
|
---|
24 | public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem>
|
---|
25 | {
|
---|
26 | private enum Operator { Abs, Log };
|
---|
27 | private static readonly double[] exponents = { 0.5, 1, 2 };
|
---|
28 |
|
---|
29 | private const string PenaltyParameterName = "Penalty";
|
---|
30 | private const string ConsiderInteractionsParameterName = "Consider Interactions";
|
---|
31 | private const string ConsiderDenominationParameterName = "Consider Denomination";
|
---|
32 | private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
|
---|
33 | private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear functions";
|
---|
34 | private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
|
---|
35 |
|
---|
36 | #region parameters
|
---|
37 | public IValueParameter<BoolValue> ConsiderInteractionsParameter
|
---|
38 | {
|
---|
39 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
|
---|
40 | }
|
---|
41 | #endregion
|
---|
42 |
|
---|
43 | #region properties
|
---|
44 | public bool ConsiderInteractions
|
---|
45 | {
|
---|
46 | get { return ConsiderInteractionsParameter.Value.Value; }
|
---|
47 | set { ConsiderInteractionsParameter.Value.Value = value; }
|
---|
48 | }
|
---|
49 | #endregion
|
---|
50 |
|
---|
51 | [StorableConstructor]
|
---|
52 | private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
|
---|
53 | public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner)
|
---|
54 | {
|
---|
55 | // Don't forget to call the cloning ctor of the base class
|
---|
56 | // This class does not have fields, therefore we don't need to actually clone anything
|
---|
57 | }
|
---|
58 | public FastFunctionExtraction() : base()
|
---|
59 | {
|
---|
60 | // algorithm parameters are shown in the GUI
|
---|
61 | Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.5)));
|
---|
62 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want to consider interactions, otherwise false.", new BoolValue(true)));
|
---|
63 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want to consider denominations, otherwise false.", new BoolValue(true)));
|
---|
64 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want to consider exponentiation, otherwise false.", new BoolValue(true)));
|
---|
65 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want to consider nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
|
---|
66 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want to consider Hinge Functions, otherwise false.", new BoolValue(true)));
|
---|
67 | }
|
---|
68 |
|
---|
69 | [StorableHook(HookType.AfterDeserialization)]
|
---|
70 | private void AfterDeserialization() { }
|
---|
71 |
|
---|
72 | public override IDeepCloneable Clone(Cloner cloner)
|
---|
73 | {
|
---|
74 | return new FastFunctionExtraction(this, cloner);
|
---|
75 | }
|
---|
76 |
|
---|
77 | public override Type ProblemType { get { return typeof(RegressionProblem); } }
|
---|
78 | public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
|
---|
79 |
|
---|
80 |
|
---|
81 | protected override void Run(CancellationToken cancellationToken)
|
---|
82 | {
|
---|
83 | var basisFunctions = generateBasisFunctions(Problem.ProblemData);
|
---|
84 | var x = Problem.ProblemData.AllowedInputsTrainingValues;
|
---|
85 | List<SymbolicExpressionTree> trees = new List<SymbolicExpressionTree>();
|
---|
86 |
|
---|
87 |
|
---|
88 | foreach (var basisFunc in basisFunctions)
|
---|
89 | {
|
---|
90 | // add tree representation of basisFunc to trees
|
---|
91 | trees.Add(generateSymbolicExpressionTree(basisFunc));
|
---|
92 | }
|
---|
93 |
|
---|
94 | foreach (var tree in trees)
|
---|
95 | {
|
---|
96 | // create new data through the help of the Interpreter
|
---|
97 | //IEnumerable<double> responses =
|
---|
98 | }
|
---|
99 |
|
---|
100 | var coefficientVectorSet = findCoefficientValues(basisFunctions);
|
---|
101 | var paretoFront = nondominatedFilter(coefficientVectorSet);
|
---|
102 | }
|
---|
103 |
|
---|
104 | private SymbolicExpressionTree generateSymbolicExpressionTree(KeyValuePair<string, double[]> basisFunc)
|
---|
105 | {
|
---|
106 | throw new NotImplementedException();
|
---|
107 | }
|
---|
108 |
|
---|
109 | // generate all possible models
|
---|
110 | private static Dictionary<string, double[]> generateBasisFunctions(IRegressionProblemData problemData)
|
---|
111 | {
|
---|
112 | var basisFunctions = generateUnivariateBases(problemData);
|
---|
113 | return basisFunctions;
|
---|
114 | }
|
---|
115 |
|
---|
116 | private static Dictionary<string, double[]> generateUnivariateBases(IRegressionProblemData problemData)
|
---|
117 | {
|
---|
118 |
|
---|
119 | var dataset = problemData.Dataset;
|
---|
120 | var rows = problemData.TrainingIndices;
|
---|
121 | var B1 = new Dictionary<string, double[]>();
|
---|
122 |
|
---|
123 | foreach (var variableName in dataset.VariableNames)
|
---|
124 | {
|
---|
125 | foreach (var exp in new[] { 0.5, 1, 2 })
|
---|
126 | {
|
---|
127 | var name = variableName + " ** " + exp;
|
---|
128 | var data = dataset.GetDoubleValues(variableName, rows).Select(x => Math.Pow(x, exp)).ToArray();
|
---|
129 | B1.Add(name, data);
|
---|
130 | foreach (Operator op in Enum.GetValues(typeof(Operator)))
|
---|
131 | {
|
---|
132 | var inner_name = op.ToString() + "(" + name + ")";
|
---|
133 | var inner_data = data.Select(x => executeOperator(x, op)).ToArray();
|
---|
134 | B1.Add(inner_name, inner_data);
|
---|
135 | }
|
---|
136 | }
|
---|
137 | }
|
---|
138 |
|
---|
139 | return B1;
|
---|
140 | }
|
---|
141 |
|
---|
142 | private static double executeOperator(double x, Operator op)
|
---|
143 | {
|
---|
144 | switch (op)
|
---|
145 | {
|
---|
146 | case Operator.Abs:
|
---|
147 | return x > 0 ? x : -x;
|
---|
148 | case Operator.Log:
|
---|
149 | return Math.Log10(x);
|
---|
150 | default:
|
---|
151 | throw new NotImplementedException();
|
---|
152 | }
|
---|
153 | }
|
---|
154 |
|
---|
155 | private static Dictionary<string, double[]> generateMultiVariateBases(Dictionary<string, double[]> B1)
|
---|
156 | {
|
---|
157 | var B2 = new Dictionary<string, double[]>();
|
---|
158 | for(int i = 1; i <= B1.Count(); i++ )
|
---|
159 | {
|
---|
160 | var b_i = B1.ElementAt(i);
|
---|
161 | for (int j = 1; j < i; i++)
|
---|
162 | {
|
---|
163 | var b_j = B1.ElementAt(j);
|
---|
164 | }
|
---|
165 | }
|
---|
166 |
|
---|
167 | // return union of B1 and B2
|
---|
168 | return B2.Concat(B1).ToDictionary(kvp => kvp.Key, kvp => kvp.Value);
|
---|
169 | }
|
---|
170 |
|
---|
171 | private static object findCoefficientValues(IEnumerable<KeyValuePair<string, double[]>> basisFunctions)
|
---|
172 | {
|
---|
173 | return new object();
|
---|
174 | }
|
---|
175 |
|
---|
176 | private static object nondominatedFilter(object coefficientVectorSet)
|
---|
177 | {
|
---|
178 | return new object();
|
---|
179 | }
|
---|
180 |
|
---|
181 | public override bool SupportsPause
|
---|
182 | {
|
---|
183 | get { return false; }
|
---|
184 | }
|
---|
185 | }
|
---|
186 | } |
---|