Free cookie consent management tool by TermsFeed Policy Generator

source: branches/symbreg-factors-2650/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs @ 14266

Last change on this file since 14266 was 14266, checked in by gkronber, 8 years ago

#2650: improved handling of factors in ConstantOptimizationEvaluator (create binary indicators only once)

File size: 27.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics.Contracts;
25using System.Linq;
26using AutoDiff;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33
34namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
35  [Item("Constant Optimization Evaluator", "Calculates Pearson R² of a symbolic regression solution and optimizes the constant used.")]
36  [StorableClass]
37  public class SymbolicRegressionConstantOptimizationEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
38    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
39    private const string ConstantOptimizationImprovementParameterName = "ConstantOptimizationImprovement";
40    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
41    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
42    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
43    private const string UpdateVariableWeightsParameterName = "Update Variable Weights";
44
45    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
46      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
47    }
48    public IFixedValueParameter<DoubleValue> ConstantOptimizationImprovementParameter {
49      get { return (IFixedValueParameter<DoubleValue>)Parameters[ConstantOptimizationImprovementParameterName]; }
50    }
51    public IFixedValueParameter<PercentValue> ConstantOptimizationProbabilityParameter {
52      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationProbabilityParameterName]; }
53    }
54    public IFixedValueParameter<PercentValue> ConstantOptimizationRowsPercentageParameter {
55      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
56    }
57    public IFixedValueParameter<BoolValue> UpdateConstantsInTreeParameter {
58      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
59    }
60    public IFixedValueParameter<BoolValue> UpdateVariableWeightsParameter {
61      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateVariableWeightsParameterName]; }
62    }
63
64
65    public IntValue ConstantOptimizationIterations {
66      get { return ConstantOptimizationIterationsParameter.Value; }
67    }
68    public DoubleValue ConstantOptimizationImprovement {
69      get { return ConstantOptimizationImprovementParameter.Value; }
70    }
71    public PercentValue ConstantOptimizationProbability {
72      get { return ConstantOptimizationProbabilityParameter.Value; }
73    }
74    public PercentValue ConstantOptimizationRowsPercentage {
75      get { return ConstantOptimizationRowsPercentageParameter.Value; }
76    }
77    public bool UpdateConstantsInTree {
78      get { return UpdateConstantsInTreeParameter.Value.Value; }
79      set { UpdateConstantsInTreeParameter.Value.Value = value; }
80    }
81
82    public bool UpdateVariableWeights {
83      get { return UpdateVariableWeightsParameter.Value.Value; }
84      set { UpdateVariableWeightsParameter.Value.Value = value; }
85    }
86
87    public override bool Maximization {
88      get { return true; }
89    }
90
91    [StorableConstructor]
92    protected SymbolicRegressionConstantOptimizationEvaluator(bool deserializing) : base(deserializing) { }
93    protected SymbolicRegressionConstantOptimizationEvaluator(SymbolicRegressionConstantOptimizationEvaluator original, Cloner cloner)
94      : base(original, cloner) {
95    }
96    public SymbolicRegressionConstantOptimizationEvaluator()
97      : base() {
98      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true));
99      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true) { Hidden = true });
100      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
101      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
102      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)) { Hidden = true });
103      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)) { Hidden = true });
104    }
105
106    public override IDeepCloneable Clone(Cloner cloner) {
107      return new SymbolicRegressionConstantOptimizationEvaluator(this, cloner);
108    }
109
110    [StorableHook(HookType.AfterDeserialization)]
111    private void AfterDeserialization() {
112      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
113        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
114      if (!Parameters.ContainsKey(UpdateVariableWeightsParameterName))
115        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)));
116    }
117
118    public override IOperation InstrumentedApply() {
119      var solution = SymbolicExpressionTreeParameter.ActualValue;
120      double quality;
121      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
122        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
123        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
124           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value, updateVariableWeights: UpdateVariableWeights, lowerEstimationLimit: EstimationLimitsParameter.ActualValue.Lower, upperEstimationLimit: EstimationLimitsParameter.ActualValue.Upper, updateConstantsInTree: UpdateConstantsInTree);
125
126        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
127          var evaluationRows = GenerateRowsToEvaluate();
128          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
129        }
130      } else {
131        var evaluationRows = GenerateRowsToEvaluate();
132        quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
133      }
134      QualityParameter.ActualValue = new DoubleValue(quality);
135
136      return base.InstrumentedApply();
137    }
138
139    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows) {
140      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
141      EstimationLimitsParameter.ExecutionContext = context;
142      ApplyLinearScalingParameter.ExecutionContext = context;
143
144      // Pearson R² evaluator is used on purpose instead of the const-opt evaluator,
145      // because Evaluate() is used to get the quality of evolved models on
146      // different partitions of the dataset (e.g., best validation model)
147      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value);
148
149      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
150      EstimationLimitsParameter.ExecutionContext = null;
151      ApplyLinearScalingParameter.ExecutionContext = null;
152
153      return r2;
154    }
155
156    #region derivations of functions
157    // create function factory for arctangent
158    private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
159      eval: Math.Atan,
160      diff: x => 1 / (1 + x * x));
161    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
162      eval: Math.Sin,
163      diff: Math.Cos);
164    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
165       eval: Math.Cos,
166       diff: x => -Math.Sin(x));
167    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
168      eval: Math.Tan,
169      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
170    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
171      eval: alglib.errorfunction,
172      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
173    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
174      eval: alglib.normaldistribution,
175      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
176    #endregion
177
178
179    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, bool updateVariableWeights = true, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, bool updateConstantsInTree = true) {
180
181      List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
182      List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
183      List<string> variableNames = new List<string>();
184      List<string> categoricalVariableValues = new List<string>();
185
186      AutoDiff.Term func;
187      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out func))
188        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
189      if (variableNames.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0
190
191      AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameters.ToArray());
192
193      List<SymbolicExpressionTreeTerminalNode> terminalNodes = null; // gkronber only used for extraction of initial constants
194      if (updateVariableWeights)
195        terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
196      else
197        terminalNodes = new List<SymbolicExpressionTreeTerminalNode>
198          (tree.Root.IterateNodesPrefix()
199          .OfType<SymbolicExpressionTreeTerminalNode>()
200          .Where(node => node is ConstantTreeNode || node is FactorVariableTreeNode));
201
202      //extract inital constants
203      double[] c = new double[variables.Count];
204      {
205        c[0] = 0.0;
206        c[1] = 1.0;
207        int i = 2;
208        foreach (var node in terminalNodes) {
209          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
210          VariableTreeNode variableTreeNode = node as VariableTreeNode;
211          BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
212          FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
213          if (constantTreeNode != null)
214            c[i++] = constantTreeNode.Value;
215          else if (updateVariableWeights && variableTreeNode != null)
216            c[i++] = variableTreeNode.Weight;
217          else if (updateVariableWeights && binFactorVarTreeNode != null)
218            c[i++] = binFactorVarTreeNode.Weight;
219          else if (factorVarTreeNode != null) {
220            // gkronber: a factorVariableTreeNode holds a category-specific constant therefore we can consider factors to be the same as constants
221            foreach (var w in factorVarTreeNode.Weights) c[i++] = w;
222          }
223        }
224      }
225      double[] originalConstants = (double[])c.Clone();
226      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
227
228      alglib.lsfitstate state;
229      alglib.lsfitreport rep;
230      int info;
231
232      IDataset ds = problemData.Dataset;
233      double[,] x = new double[rows.Count(), variableNames.Count];
234      int row = 0;
235      foreach (var r in rows) {
236        for (int col = 0; col < variableNames.Count; col++) {
237          if (ds.VariableHasType<double>(variableNames[col])) {
238            x[row, col] = ds.GetDoubleValue(variableNames[col], r);
239          } else if (ds.VariableHasType<string>(variableNames[col])) {
240            x[row, col] = ds.GetStringValue(variableNames[col], r) == categoricalVariableValues[col] ? 1 : 0;
241          } else throw new InvalidProgramException("found a variable of unknown type");
242        }
243        row++;
244      }
245      double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
246      int n = x.GetLength(0);
247      int m = x.GetLength(1);
248      int k = c.Length;
249
250      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
251      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
252
253      try {
254        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
255        alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
256        //alglib.lsfitsetgradientcheck(state, 0.001);
257        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
258        alglib.lsfitresults(state, out info, out c, out rep);
259      }
260      catch (ArithmeticException) {
261        return originalQuality;
262      }
263      catch (alglib.alglibexception) {
264        return originalQuality;
265      }
266
267      //info == -7  => constant optimization failed due to wrong gradient
268      if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray(), updateVariableWeights);
269      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
270
271      if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
272      if (originalQuality - quality > 0.001 || double.IsNaN(quality)) {
273        UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
274        return originalQuality;
275      }
276      return quality;
277    }
278
279    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants, bool updateVariableWeights) {
280      int i = 0;
281      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
282        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
283        VariableTreeNode variableTreeNode = node as VariableTreeNode;
284        BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
285        FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
286        if (constantTreeNode != null)
287          constantTreeNode.Value = constants[i++];
288        else if (updateVariableWeights && variableTreeNode != null)
289          variableTreeNode.Weight = constants[i++];
290        else if (updateVariableWeights && binFactorVarTreeNode != null)
291          binFactorVarTreeNode.Weight = constants[i++];
292        else if (factorVarTreeNode != null) {
293          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
294            factorVarTreeNode.Weights[j] = constants[i++];
295        }
296      }
297    }
298
299    private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
300      return (double[] c, double[] x, ref double func, object o) => {
301        func = compiledFunc.Evaluate(c, x);
302      };
303    }
304
305    private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
306      return (double[] c, double[] x, ref double func, double[] grad, object o) => {
307        var tupel = compiledFunc.Differentiate(c, x);
308        func = tupel.Item2;
309        Array.Copy(tupel.Item1, grad, grad.Length);
310      };
311    }
312
313    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters,
314      List<string> variableNames, List<string> categoricalVariableValues, bool updateVariableWeights, out AutoDiff.Term term) {
315      if (node.Symbol is Constant) {
316        var var = new AutoDiff.Variable();
317        variables.Add(var);
318        term = var;
319        return true;
320      }
321      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
322        var varNode = node as VariableTreeNodeBase;
323        var factorVarNode = node as BinaryFactorVariableTreeNode;
324        // factor variable values are only 0 or 1 and set in x accordingly
325        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
326        var par = FindOrCreateParameter(varNode.VariableName, varValue, parameters, variableNames, categoricalVariableValues);
327
328        if (updateVariableWeights) {
329          var w = new AutoDiff.Variable();
330          variables.Add(w);
331          term = AutoDiff.TermBuilder.Product(w, par);
332        } else {
333          term = par;
334        }
335        return true;
336      }
337      if (node.Symbol is FactorVariable) {
338        var factorVarNode = node as FactorVariableTreeNode;
339        var products = new List<Term>();
340        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
341          var par = FindOrCreateParameter(factorVarNode.VariableName, variableValue, parameters, variableNames, categoricalVariableValues);
342
343          var wVar = new AutoDiff.Variable();
344          variables.Add(wVar);
345
346          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
347        }
348        term = AutoDiff.TermBuilder.Sum(products);
349        return true;
350      }
351      if (node.Symbol is Addition) {
352        List<AutoDiff.Term> terms = new List<Term>();
353        foreach (var subTree in node.Subtrees) {
354          AutoDiff.Term t;
355          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
356            term = null;
357            return false;
358          }
359          terms.Add(t);
360        }
361        term = AutoDiff.TermBuilder.Sum(terms);
362        return true;
363      }
364      if (node.Symbol is Subtraction) {
365        List<AutoDiff.Term> terms = new List<Term>();
366        for (int i = 0; i < node.SubtreeCount; i++) {
367          AutoDiff.Term t;
368          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
369            term = null;
370            return false;
371          }
372          if (i > 0) t = -t;
373          terms.Add(t);
374        }
375        if (terms.Count == 1) term = -terms[0];
376        else term = AutoDiff.TermBuilder.Sum(terms);
377        return true;
378      }
379      if (node.Symbol is Multiplication) {
380        List<AutoDiff.Term> terms = new List<Term>();
381        foreach (var subTree in node.Subtrees) {
382          AutoDiff.Term t;
383          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
384            term = null;
385            return false;
386          }
387          terms.Add(t);
388        }
389        if (terms.Count == 1) term = terms[0];
390        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
391        return true;
392
393      }
394      if (node.Symbol is Division) {
395        List<AutoDiff.Term> terms = new List<Term>();
396        foreach (var subTree in node.Subtrees) {
397          AutoDiff.Term t;
398          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
399            term = null;
400            return false;
401          }
402          terms.Add(t);
403        }
404        if (terms.Count == 1) term = 1.0 / terms[0];
405        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
406        return true;
407      }
408      if (node.Symbol is Logarithm) {
409        AutoDiff.Term t;
410        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
411          term = null;
412          return false;
413        } else {
414          term = AutoDiff.TermBuilder.Log(t);
415          return true;
416        }
417      }
418      if (node.Symbol is Exponential) {
419        AutoDiff.Term t;
420        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
421          term = null;
422          return false;
423        } else {
424          term = AutoDiff.TermBuilder.Exp(t);
425          return true;
426        }
427      }
428      if (node.Symbol is Square) {
429        AutoDiff.Term t;
430        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
431          term = null;
432          return false;
433        } else {
434          term = AutoDiff.TermBuilder.Power(t, 2.0);
435          return true;
436        }
437      }
438      if (node.Symbol is SquareRoot) {
439        AutoDiff.Term t;
440        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
441          term = null;
442          return false;
443        } else {
444          term = AutoDiff.TermBuilder.Power(t, 0.5);
445          return true;
446        }
447      }
448      if (node.Symbol is Sine) {
449        AutoDiff.Term t;
450        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
451          term = null;
452          return false;
453        } else {
454          term = sin(t);
455          return true;
456        }
457      }
458      if (node.Symbol is Cosine) {
459        AutoDiff.Term t;
460        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
461          term = null;
462          return false;
463        } else {
464          term = cos(t);
465          return true;
466        }
467      }
468      if (node.Symbol is Tangent) {
469        AutoDiff.Term t;
470        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
471          term = null;
472          return false;
473        } else {
474          term = tan(t);
475          return true;
476        }
477      }
478      if (node.Symbol is Erf) {
479        AutoDiff.Term t;
480        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
481          term = null;
482          return false;
483        } else {
484          term = erf(t);
485          return true;
486        }
487      }
488      if (node.Symbol is Norm) {
489        AutoDiff.Term t;
490        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
491          term = null;
492          return false;
493        } else {
494          term = norm(t);
495          return true;
496        }
497      }
498      if (node.Symbol is StartSymbol) {
499        var alpha = new AutoDiff.Variable();
500        var beta = new AutoDiff.Variable();
501        variables.Add(beta);
502        variables.Add(alpha);
503        AutoDiff.Term branchTerm;
504        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out branchTerm)) {
505          term = branchTerm * alpha + beta;
506          return true;
507        } else {
508          term = null;
509          return false;
510        }
511      }
512      term = null;
513      return false;
514    }
515
516    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
517    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
518    private static Term FindOrCreateParameter(string varName, string varValue,
519      List<AutoDiff.Variable> parameters, List<string> variableNames, List<string> variableValues) {
520      Contract.Assert(variableNames.Count == variableValues.Count);
521      int idx = -1;
522      for (int i = 0; i < variableNames.Count; i++) {
523        if (variableNames[i] == varName && variableValues[i] == varValue) {
524          idx = i;
525          break;
526        }
527      }
528
529      AutoDiff.Variable par = null;
530      if (idx == -1) {
531        // not found -> create new parameter and entries in names and values lists
532        par = new AutoDiff.Variable();
533        parameters.Add(par);
534        variableNames.Add(varName);
535        variableValues.Add(varValue);
536      } else {
537        par = parameters[idx];
538      }
539      return par;
540    }
541
542    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
543      var containsUnknownSymbol = (
544        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
545        where
546         !(n.Symbol is Variable) &&
547         !(n.Symbol is BinaryFactorVariable) &&
548         !(n.Symbol is FactorVariable) &&
549         !(n.Symbol is Constant) &&
550         !(n.Symbol is Addition) &&
551         !(n.Symbol is Subtraction) &&
552         !(n.Symbol is Multiplication) &&
553         !(n.Symbol is Division) &&
554         !(n.Symbol is Logarithm) &&
555         !(n.Symbol is Exponential) &&
556         !(n.Symbol is SquareRoot) &&
557         !(n.Symbol is Square) &&
558         !(n.Symbol is Sine) &&
559         !(n.Symbol is Cosine) &&
560         !(n.Symbol is Tangent) &&
561         !(n.Symbol is Erf) &&
562         !(n.Symbol is Norm) &&
563         !(n.Symbol is StartSymbol)
564        select n).
565      Any();
566      return !containsUnknownSymbol;
567    }
568  }
569}
Note: See TracBrowser for help on using the repository browser.