1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Linq;
|
---|
4 | using System.Security;
|
---|
5 | using System.Security.AccessControl;
|
---|
6 | using System.Text;
|
---|
7 | using HeuristicLab.Common;
|
---|
8 | using HeuristicLab.Problems.DataAnalysis;
|
---|
9 | using HeuristicLab.Problems.Instances;
|
---|
10 | using HeuristicLab.Problems.Instances.DataAnalysis;
|
---|
11 |
|
---|
12 | namespace HeuristicLab.Problems.GrammaticalOptimization.SymbReg {
|
---|
13 | // provides bridge to HL regression problem instances
|
---|
14 | public class SymbolicRegressionProblem : IProblem {
|
---|
15 |
|
---|
16 | private const string grammarString = @"
|
---|
17 | G(E):
|
---|
18 | E -> V | V+E | V-E | V*E | V/E | (E) | C | C+E | C-E | C*E | C/E
|
---|
19 | C -> 0..9
|
---|
20 | V -> <variables>
|
---|
21 | ";
|
---|
22 | // C represents Koza-style ERC (the symbol is the index of the ERC), the values are initialized below
|
---|
23 |
|
---|
24 | private readonly IGrammar grammar;
|
---|
25 |
|
---|
26 | private readonly int N;
|
---|
27 | private readonly double[][] x;
|
---|
28 | private readonly double[] y;
|
---|
29 | private readonly int d;
|
---|
30 | private readonly double[] erc;
|
---|
31 |
|
---|
32 |
|
---|
33 | public SymbolicRegressionProblem(Random random, string partOfName) {
|
---|
34 | var instanceProvider = new RegressionRealWorldInstanceProvider();
|
---|
35 | var dds = instanceProvider.GetDataDescriptors().OfType<RegressionDataDescriptor>();
|
---|
36 |
|
---|
37 | var problemData = instanceProvider.LoadData(dds.Single(ds => ds.Name.Contains(partOfName)));
|
---|
38 |
|
---|
39 | this.N = problemData.TrainingIndices.Count();
|
---|
40 | this.d = problemData.AllowedInputVariables.Count();
|
---|
41 | if (d > 26) throw new NotSupportedException(); // we only allow single-character terminal symbols so far
|
---|
42 | this.x = new double[N][];
|
---|
43 | this.y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
|
---|
44 |
|
---|
45 | int i = 0;
|
---|
46 | foreach (var r in problemData.TrainingIndices) {
|
---|
47 | x[i] = new double[d];
|
---|
48 | int j = 0;
|
---|
49 | foreach (var inputVariable in problemData.AllowedInputVariables) {
|
---|
50 | x[i][j++] = problemData.Dataset.GetDoubleValue(inputVariable, r);
|
---|
51 | }
|
---|
52 | i++;
|
---|
53 | }
|
---|
54 | // initialize ERC values
|
---|
55 | erc = Enumerable.Range(0, 10).Select(_ => Rand.RandNormal(random) * 10).ToArray();
|
---|
56 |
|
---|
57 | char firstVar = 'a';
|
---|
58 | char lastVar = Convert.ToChar(Convert.ToByte('a') + d - 1);
|
---|
59 | this.grammar = new Grammar(grammarString.Replace("<variables>", firstVar + " .. " + lastVar));
|
---|
60 | }
|
---|
61 |
|
---|
62 |
|
---|
63 | public double BestKnownQuality(int maxLen) {
|
---|
64 | // for now only an upper bound is returned, ideally we have an R² of 1.0
|
---|
65 | return 1.0;
|
---|
66 | }
|
---|
67 |
|
---|
68 | public IGrammar Grammar {
|
---|
69 | get { return grammar; }
|
---|
70 | }
|
---|
71 |
|
---|
72 | public double Evaluate(string sentence) {
|
---|
73 | var interpreter = new ExpressionInterpreter();
|
---|
74 | return HeuristicLab.Common.Extensions.RSq(y, Enumerable.Range(0, N).Select(i => interpreter.Interpret(sentence, x[i], erc)));
|
---|
75 | }
|
---|
76 |
|
---|
77 |
|
---|
78 | public string CanonicalRepresentation(string phrase) {
|
---|
79 | return phrase;
|
---|
80 | //var terms = phrase.Split('+');
|
---|
81 | //return string.Join("+", terms.Select(term => string.Join("", term.Replace("*", "").OrderBy(ch => ch)))
|
---|
82 | // .OrderBy(term => term));
|
---|
83 | }
|
---|
84 |
|
---|
85 | public IEnumerable<Feature> GetFeatures(string phrase)
|
---|
86 | {
|
---|
87 | throw new NotImplementedException();
|
---|
88 | }
|
---|
89 |
|
---|
90 |
|
---|
91 | /*
|
---|
92 | public static double OptimizeConstants(string sentence) {
|
---|
93 |
|
---|
94 | List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
|
---|
95 | List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
|
---|
96 | List<string> variableNames = new List<string>();
|
---|
97 |
|
---|
98 | AutoDiff.Term func;
|
---|
99 | if (!TryTransformToAutoDiff(sentence, 0, variables, parameters, out func))
|
---|
100 | throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
|
---|
101 | if (variableNames.Count == 0) return 0.0;
|
---|
102 |
|
---|
103 | AutoDiff.IParametricCompiledTerm compiledFunc = AutoDiff.TermUtils.Compile(func, variables.ToArray(), parameters.ToArray());
|
---|
104 |
|
---|
105 | List<SymbolicExpressionTreeTerminalNode> terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
|
---|
106 | double[] c = new double[variables.Count];
|
---|
107 |
|
---|
108 | {
|
---|
109 | c[0] = 0.0;
|
---|
110 | c[1] = 1.0;
|
---|
111 | //extract inital constants
|
---|
112 | int i = 2;
|
---|
113 | foreach (var node in terminalNodes) {
|
---|
114 | ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
|
---|
115 | VariableTreeNode variableTreeNode = node as VariableTreeNode;
|
---|
116 | if (constantTreeNode != null)
|
---|
117 | c[i++] = constantTreeNode.Value;
|
---|
118 | else if (variableTreeNode != null)
|
---|
119 | c[i++] = variableTreeNode.Weight;
|
---|
120 | }
|
---|
121 | }
|
---|
122 | double[] originalConstants = (double[])c.Clone();
|
---|
123 | double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
|
---|
124 |
|
---|
125 | alglib.lsfitstate state;
|
---|
126 | alglib.lsfitreport rep;
|
---|
127 | int info;
|
---|
128 |
|
---|
129 | Dataset ds = problemData.Dataset;
|
---|
130 | double[,] x = new double[rows.Count(), variableNames.Count];
|
---|
131 | int row = 0;
|
---|
132 | foreach (var r in rows) {
|
---|
133 | for (int col = 0; col < variableNames.Count; col++) {
|
---|
134 | x[row, col] = ds.GetDoubleValue(variableNames[col], r);
|
---|
135 | }
|
---|
136 | row++;
|
---|
137 | }
|
---|
138 | double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
|
---|
139 | int n = x.GetLength(0);
|
---|
140 | int m = x.GetLength(1);
|
---|
141 | int k = c.Length;
|
---|
142 |
|
---|
143 | alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
|
---|
144 | alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
|
---|
145 |
|
---|
146 | try {
|
---|
147 | alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
|
---|
148 | alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
|
---|
149 | //alglib.lsfitsetgradientcheck(state, 0.001);
|
---|
150 | alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
|
---|
151 | alglib.lsfitresults(state, out info, out c, out rep);
|
---|
152 | } catch (ArithmeticException) {
|
---|
153 | return originalQuality;
|
---|
154 | } catch (alglib.alglibexception) {
|
---|
155 | return originalQuality;
|
---|
156 | }
|
---|
157 |
|
---|
158 | //info == -7 => constant optimization failed due to wrong gradient
|
---|
159 | if (info != -7) throw new ArgumentException();
|
---|
160 | }
|
---|
161 |
|
---|
162 |
|
---|
163 | private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
|
---|
164 | return (double[] c, double[] x, ref double func, object o) => {
|
---|
165 | func = compiledFunc.Evaluate(c, x);
|
---|
166 | };
|
---|
167 | }
|
---|
168 |
|
---|
169 | private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
|
---|
170 | return (double[] c, double[] x, ref double func, double[] grad, object o) => {
|
---|
171 | var tupel = compiledFunc.Differentiate(c, x);
|
---|
172 | func = tupel.Item2;
|
---|
173 | Array.Copy(tupel.Item1, grad, grad.Length);
|
---|
174 | };
|
---|
175 | }
|
---|
176 |
|
---|
177 | private static bool TryTransformToAutoDiff(string phrase, int symbolPos, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, out AutoDiff.Term term)
|
---|
178 | {
|
---|
179 | var curSy = phrase[0];
|
---|
180 | if () {
|
---|
181 | var var = new AutoDiff.Variable();
|
---|
182 | variables.Add(var);
|
---|
183 | term = var;
|
---|
184 | return true;
|
---|
185 | }
|
---|
186 | if (node.Symbol is Variable) {
|
---|
187 | var varNode = node as VariableTreeNode;
|
---|
188 | var par = new AutoDiff.Variable();
|
---|
189 | parameters.Add(par);
|
---|
190 | variableNames.Add(varNode.VariableName);
|
---|
191 | var w = new AutoDiff.Variable();
|
---|
192 | variables.Add(w);
|
---|
193 | term = AutoDiff.TermBuilder.Product(w, par);
|
---|
194 | return true;
|
---|
195 | }
|
---|
196 | if (node.Symbol is Addition) {
|
---|
197 | List<AutoDiff.Term> terms = new List<Term>();
|
---|
198 | foreach (var subTree in node.Subtrees) {
|
---|
199 | AutoDiff.Term t;
|
---|
200 | if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out t)) {
|
---|
201 | term = null;
|
---|
202 | return false;
|
---|
203 | }
|
---|
204 | terms.Add(t);
|
---|
205 | }
|
---|
206 | term = AutoDiff.TermBuilder.Sum(terms);
|
---|
207 | return true;
|
---|
208 | }
|
---|
209 | if (node.Symbol is Subtraction) {
|
---|
210 | List<AutoDiff.Term> terms = new List<Term>();
|
---|
211 | for (int i = 0; i < node.SubtreeCount; i++) {
|
---|
212 | AutoDiff.Term t;
|
---|
213 | if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, out t)) {
|
---|
214 | term = null;
|
---|
215 | return false;
|
---|
216 | }
|
---|
217 | if (i > 0) t = -t;
|
---|
218 | terms.Add(t);
|
---|
219 | }
|
---|
220 | term = AutoDiff.TermBuilder.Sum(terms);
|
---|
221 | return true;
|
---|
222 | }
|
---|
223 | if (node.Symbol is Multiplication) {
|
---|
224 | AutoDiff.Term a, b;
|
---|
225 | if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
|
---|
226 | !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
|
---|
227 | term = null;
|
---|
228 | return false;
|
---|
229 | } else {
|
---|
230 | List<AutoDiff.Term> factors = new List<Term>();
|
---|
231 | foreach (var subTree in node.Subtrees.Skip(2)) {
|
---|
232 | AutoDiff.Term f;
|
---|
233 | if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
|
---|
234 | term = null;
|
---|
235 | return false;
|
---|
236 | }
|
---|
237 | factors.Add(f);
|
---|
238 | }
|
---|
239 | term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray());
|
---|
240 | return true;
|
---|
241 | }
|
---|
242 | }
|
---|
243 | if (node.Symbol is Division) {
|
---|
244 | // only works for at least two subtrees
|
---|
245 | AutoDiff.Term a, b;
|
---|
246 | if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
|
---|
247 | !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
|
---|
248 | term = null;
|
---|
249 | return false;
|
---|
250 | } else {
|
---|
251 | List<AutoDiff.Term> factors = new List<Term>();
|
---|
252 | foreach (var subTree in node.Subtrees.Skip(2)) {
|
---|
253 | AutoDiff.Term f;
|
---|
254 | if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
|
---|
255 | term = null;
|
---|
256 | return false;
|
---|
257 | }
|
---|
258 | factors.Add(1.0 / f);
|
---|
259 | }
|
---|
260 | term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray());
|
---|
261 | return true;
|
---|
262 | }
|
---|
263 | }
|
---|
264 |
|
---|
265 | if (node.Symbol is StartSymbol) {
|
---|
266 | var alpha = new AutoDiff.Variable();
|
---|
267 | var beta = new AutoDiff.Variable();
|
---|
268 | variables.Add(beta);
|
---|
269 | variables.Add(alpha);
|
---|
270 | AutoDiff.Term branchTerm;
|
---|
271 | if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out branchTerm)) {
|
---|
272 | term = branchTerm * alpha + beta;
|
---|
273 | return true;
|
---|
274 | } else {
|
---|
275 | term = null;
|
---|
276 | return false;
|
---|
277 | }
|
---|
278 | }
|
---|
279 | term = null;
|
---|
280 | return false;
|
---|
281 | }
|
---|
282 | */
|
---|
283 | }
|
---|
284 | }
|
---|