using System; using System.CodeDom.Compiler; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Security; using System.Security.AccessControl; using System.Security.Authentication.ExtendedProtection.Configuration; using System.Text; using AutoDiff; using HeuristicLab.Common; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.Instances; using HeuristicLab.Problems.Instances.DataAnalysis; namespace HeuristicLab.Problems.GrammaticalOptimization.SymbReg { // provides bridge to HL regression problem instances public class SymbolicRegressionProblem : ISymbolicExpressionTreeProblem { private const string grammarString = @" G(E): E -> V | C | V+E | V-E | V*E | V%E | (E) | C+E | C-E | C*E | C%E C -> 0..9 V -> "; // C represents Koza-style ERC (the symbol is the index of the ERC), the values are initialized below // S .. Sum (+), N .. Neg. sum (-), P .. Product (*), D .. Division (%) private const string treeGrammarString = @" G(E): E -> V | C | S | N | P | D S -> EE | EEE N -> EE | EEE P -> EE | EEE D -> EE C -> 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 V -> "; // when we use constants optimization we can completely ignore all constants by a simple strategy: // introduce a constant factor for each complete term // introduce a constant offset for each complete expression (including expressions in brackets) // e.g. 1*(2*a + b - 3 + 4) is the same as c0*a + c1*b + c2 private readonly IGrammar grammar; private readonly int N; private readonly double[,] x; private readonly double[] y; private readonly int d; private readonly double[] erc; public SymbolicRegressionProblem(Random random, string partOfName) { var instanceProvider = new RegressionRealWorldInstanceProvider(); var dds = instanceProvider.GetDataDescriptors().OfType(); var problemData = instanceProvider.LoadData(dds.Single(ds => ds.Name.Contains(partOfName))); this.N = problemData.TrainingIndices.Count(); this.d = problemData.AllowedInputVariables.Count(); if (d > 26) throw new NotSupportedException(); // we only allow single-character terminal symbols so far this.x = new double[N, d]; this.y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray(); int i = 0; foreach (var r in problemData.TrainingIndices) { int j = 0; foreach (var inputVariable in problemData.AllowedInputVariables) { x[i, j++] = problemData.Dataset.GetDoubleValue(inputVariable, r); } i++; } // initialize ERC values erc = Enumerable.Range(0, 10).Select(_ => Rand.RandNormal(random) * 10).ToArray(); char firstVar = 'a'; char lastVar = Convert.ToChar(Convert.ToByte('a') + d - 1); this.grammar = new Grammar(grammarString.Replace("", firstVar + " .. " + lastVar)); this.TreeBasedGPGrammar = new Grammar(treeGrammarString.Replace("", firstVar + " .. " + lastVar)); } public double BestKnownQuality(int maxLen) { // for now only an upper bound is returned, ideally we have an R² of 1.0 return 1.0; } public IGrammar Grammar { get { return grammar; } } public double Evaluate(string sentence) { return OptimizeConstantsAndEvaluate(sentence); } public double SimpleEvaluate(string sentence) { var interpreter = new ExpressionInterpreter(); var rowData = new double[d]; return HeuristicLab.Common.Extensions.RSq(y, Enumerable.Range(0, N).Select(i => { for (int j = 0; j < d; j++) rowData[j] = x[i, j]; return interpreter.Interpret(sentence, rowData, erc); })); } public string CanonicalRepresentation(string phrase) { return phrase; //var terms = phrase.Split('+'); //return string.Join("+", terms.Select(term => string.Join("", term.Replace("*", "").OrderBy(ch => ch))) // .OrderBy(term => term)); } public IEnumerable GetFeatures(string phrase) { throw new NotImplementedException(); } public double OptimizeConstantsAndEvaluate(string sentence) { AutoDiff.Term func; int pos = 0; var compiler = new ExpressionCompiler(); Variable[] variables; Variable[] constants; compiler.Compile(sentence, out func, out variables, out constants); // constants are manipulated if (!constants.Any()) return SimpleEvaluate(sentence); AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(constants, variables); // variate constants leave variables fixed to data double[] c = constants.Select(_ => 1.0).ToArray(); // start with ones alglib.lsfitstate state; alglib.lsfitreport rep; int info; int n = x.GetLength(0); int m = x.GetLength(1); int k = c.Length; alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc); alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc); const int maxIterations = 10; try { alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state); alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations); //alglib.lsfitsetgradientcheck(state, 0.001); alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null); alglib.lsfitresults(state, out info, out c, out rep); } catch (ArithmeticException) { return 0.0; } catch (alglib.alglibexception) { return 0.0; } //info == -7 => constant optimization failed due to wrong gradient if (info == -7) throw new ArgumentException(); { var rowData = new double[d]; return HeuristicLab.Common.Extensions.RSq(y, Enumerable.Range(0, N).Select(i => { for (int j = 0; j < d; j++) rowData[j] = x[i, j]; return compiledFunc.Evaluate(c, rowData); })); } } private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) { return (double[] c, double[] x, ref double func, object o) => { func = compiledFunc.Evaluate(c, x); }; } private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) { return (double[] c, double[] x, ref double func, double[] grad, object o) => { var tupel = compiledFunc.Differentiate(c, x); func = tupel.Item2; Array.Copy(tupel.Item1, grad, grad.Length); }; } public IGrammar TreeBasedGPGrammar { get; private set; } public string ConvertTreeToSentence(ISymbolicExpressionTree tree) { var sb = new StringBuilder(); TreeToSentence(tree.Root.GetSubtree(0).GetSubtree(0), sb); return sb.ToString(); } private void TreeToSentence(ISymbolicExpressionTreeNode treeNode, StringBuilder sb) { if (treeNode.SubtreeCount == 0) { // terminal sb.Append(treeNode.Symbol.Name); } else { string op = string.Empty; switch (treeNode.Symbol.Name) { case "S": op = "+"; break; case "N": op = "-"; break; case "P": op = "*"; break; case "D": op = "%"; break; default: { Debug.Assert(treeNode.SubtreeCount == 1); break; } } // nonterminal if (op == "+" || op == "-") sb.Append("("); TreeToSentence(treeNode.Subtrees.First(), sb); foreach (var subTree in treeNode.Subtrees.Skip(1)) { sb.Append(op); TreeToSentence(subTree, sb); } if (op == "+" || op == "-") sb.Append(")"); } } } }