source: branches/HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks/3.3/SymbolicDataAnalysisVariableImpactsAnalyzer.cs @ 13835

Last change on this file since 13835 was 13835, checked in by bburlacu, 3 years ago

#2288: Performance improvements in the SymbolicRegressionVariableImpactsAnalyzer.

File size: 15.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Analysis;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
36
37namespace HeuristicLab.VariableInteractionNetworks {
38  [Item("SymbolicRegressionVariableImpactsAnalyzer", "An analyzer which calculates variable impacts based on the average node impacts from the tree")]
39  [StorableClass]
40  public class SymbolicRegressionVariableImpactsAnalyzer : SymbolicDataAnalysisAnalyzer {
41    #region parameter names
42    private const string UpdateCounterParameterName = "UpdateCounter";
43    private const string UpdateIntervalParameterName = "UpdateInterval";
44    public const string QualityParameterName = "Quality";
45    private const string SymbolicDataAnalysisTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
46    private const string ProblemDataParameterName = "ProblemData";
47    private const string ApplyLinearScalingParameterName = "ApplyLinearScaling";
48    private const string MaxCOIterationsParameterName = "MaxCOIterations";
49    private const string EstimationLimitsParameterName = "EstimationLimits";
50    private const string EvaluatorParameterName = "Evaluator";
51    private const string PercentageBestParameterName = "PercentageBest";
52    private const string LastGenerationsParameterName = "LastGenerations";
53    private const string MaximumGenerationsParameterName = "MaximumGenerations";
54    private const string OptimizeConstantsParameterName = "OptimizeConstants";
55    private const string PruneTreesParameterName = "PruneTrees";
56    private const string AverageVariableImpactsResultName = "Average variable impacts";
57    private const string AverageVariableImpactsHistoryResultName = "Average variable impacts history";
58    #endregion
59
60    private SymbolicDataAnalysisExpressionTreeSimplifier simplifier;
61    private SymbolicRegressionSolutionImpactValuesCalculator impactsCalculator;
62
63    #region parameters
64    public ValueParameter<IntValue> UpdateCounterParameter {
65      get { return (ValueParameter<IntValue>)Parameters[UpdateCounterParameterName]; }
66    }
67    public ValueParameter<IntValue> UpdateIntervalParameter {
68      get { return (ValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }
69    }
70    public IScopeTreeLookupParameter<DoubleValue> QualityParameter {
71      get { return (IScopeTreeLookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
72    }
73    public ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter> SymbolicDataAnalysisTreeInterpreterParameter {
74      get { return (ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>)Parameters[SymbolicDataAnalysisTreeInterpreterParameterName]; }
75    }
76    public ILookupParameter<IRegressionProblemData> ProblemDataParameter {
77      get { return (ILookupParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
78    }
79    public ILookupParameter<BoolValue> ApplyLinearScalingParameter {
80      get { return (ILookupParameter<BoolValue>)Parameters[ApplyLinearScalingParameterName]; }
81    }
82    public IFixedValueParameter<IntValue> MaxCOIterationsParameter {
83      get { return (IFixedValueParameter<IntValue>)Parameters[MaxCOIterationsParameterName]; }
84    }
85    public ILookupParameter<DoubleLimit> EstimationLimitsParameter {
86      get { return (ILookupParameter<DoubleLimit>)Parameters[EstimationLimitsParameterName]; }
87    }
88    public ILookupParameter<ISymbolicRegressionSingleObjectiveEvaluator> EvaluatorParameter {
89      get { return (ILookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>)Parameters[EvaluatorParameterName]; }
90    }
91    public IFixedValueParameter<PercentValue> PercentageBestParameter {
92      get { return (IFixedValueParameter<PercentValue>)Parameters[PercentageBestParameterName]; }
93    }
94    public IFixedValueParameter<IntValue> LastGenerationsParameter {
95      get { return (IFixedValueParameter<IntValue>)Parameters[LastGenerationsParameterName]; }
96    }
97    public IFixedValueParameter<BoolValue> OptimizeConstantsParameter {
98      get { return (IFixedValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
99    }
100    public IFixedValueParameter<BoolValue> PruneTreesParameter {
101      get { return (IFixedValueParameter<BoolValue>)Parameters[PruneTreesParameterName]; }
102    }
103    private ILookupParameter<IntValue> MaximumGenerationsParameter {
104      get { return (ILookupParameter<IntValue>)Parameters[MaximumGenerationsParameterName]; }
105    }
106    #endregion
107
108    #region parameter properties
109    public int UpdateCounter {
110      get { return UpdateCounterParameter.Value.Value; }
111      set { UpdateCounterParameter.Value.Value = value; }
112    }
113    public int UpdateInterval {
114      get { return UpdateIntervalParameter.Value.Value; }
115      set { UpdateIntervalParameter.Value.Value = value; }
116    }
117    #endregion
118
119    public SymbolicRegressionVariableImpactsAnalyzer() {
120      #region add parameters
121      Parameters.Add(new ValueParameter<IntValue>(UpdateCounterParameterName, new IntValue(0)));
122      Parameters.Add(new ValueParameter<IntValue>(UpdateIntervalParameterName, new IntValue(1)));
123      Parameters.Add(new LookupParameter<IRegressionProblemData>(ProblemDataParameterName));
124      Parameters.Add(new LookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>(SymbolicDataAnalysisTreeInterpreterParameterName));
125      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>(QualityParameterName, "The individual qualities."));
126      Parameters.Add(new LookupParameter<BoolValue>(ApplyLinearScalingParameterName));
127      Parameters.Add(new LookupParameter<DoubleLimit>(EstimationLimitsParameterName));
128      Parameters.Add(new FixedValueParameter<IntValue>(MaxCOIterationsParameterName, new IntValue(3)));
129      Parameters.Add(new FixedValueParameter<PercentValue>(PercentageBestParameterName, new PercentValue(1)));
130      Parameters.Add(new FixedValueParameter<IntValue>(LastGenerationsParameterName, new IntValue(10)));
131      Parameters.Add(new FixedValueParameter<BoolValue>(OptimizeConstantsParameterName, new BoolValue(false)));
132      Parameters.Add(new FixedValueParameter<BoolValue>(PruneTreesParameterName, new BoolValue(false)));
133      Parameters.Add(new LookupParameter<IntValue>(MaximumGenerationsParameterName, "The maximum number of generations which should be processed."));
134      Parameters.Add(new LookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>(EvaluatorParameterName));
135      #endregion
136
137      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
138      simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
139    }
140
141    [StorableConstructor]
142    protected SymbolicRegressionVariableImpactsAnalyzer(bool deserializing) : base(deserializing) { }
143
144    [StorableHook(HookType.AfterDeserialization)]
145    private void AfterDeserialization() {
146      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
147      simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
148
149      if (!Parameters.ContainsKey(EvaluatorParameterName))
150        Parameters.Add(new LookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>(EvaluatorParameterName));
151    }
152
153    protected SymbolicRegressionVariableImpactsAnalyzer(SymbolicRegressionVariableImpactsAnalyzer original, Cloner cloner)
154        : base(original, cloner) {
155      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
156      simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
157    }
158
159    public override IDeepCloneable Clone(Cloner cloner) {
160      return new SymbolicRegressionVariableImpactsAnalyzer(this, cloner);
161    }
162
163    public override IOperation Apply() {
164      #region Update counter & update interval
165      UpdateCounter++;
166      if (UpdateCounter != UpdateInterval) {
167        return base.Apply();
168      }
169      UpdateCounter = 0;
170      #endregion
171      var results = ResultCollectionParameter.ActualValue;
172      int maxGen = MaximumGenerationsParameter.ActualValue.Value;
173      int gen = results.ContainsKey("Generations") ? ((IntValue)results["Generations"].Value).Value : 0;
174      int lastGen = LastGenerationsParameter.Value.Value;
175
176      if (lastGen > 0 && gen < maxGen - lastGen)
177        return base.Apply();
178
179      var trees = SymbolicExpressionTree.ToArray();
180      var qualities = QualityParameter.ActualValue.ToArray();
181
182      Array.Sort(qualities, trees);
183      Array.Reverse(qualities);
184      Array.Reverse(trees);
185
186      var interpreter = SymbolicDataAnalysisTreeInterpreterParameter.ActualValue;
187      var problemData = ProblemDataParameter.ActualValue;
188      var applyLinearScaling = ApplyLinearScalingParameter.ActualValue.Value;
189      var constantOptimizationIterations = MaxCOIterationsParameter.Value.Value; // fixed value parameter => Value
190      var estimationLimits = EstimationLimitsParameter.ActualValue; // lookup parameter => ActualValue
191      var percentageBest = PercentageBestParameter.Value.Value;
192      var optimizeConstants = OptimizeConstantsParameter.Value.Value;
193      var pruneTrees = PruneTreesParameter.Value.Value;
194
195      var allowedInputVariables = problemData.AllowedInputVariables.ToList();
196      DataTable dataTable;
197      if (!results.ContainsKey(AverageVariableImpactsHistoryResultName)) {
198        dataTable = new DataTable("Variable impacts", "Average impact of variables over the population");
199        dataTable.VisualProperties.XAxisTitle = "Generation";
200        dataTable.VisualProperties.YAxisTitle = "Average variable impact";
201        results.Add(new Result(AverageVariableImpactsHistoryResultName, dataTable));
202
203        foreach (var v in allowedInputVariables) {
204          dataTable.Rows.Add(new DataRow(v) { VisualProperties = { StartIndexZero = true } });
205        }
206      }
207      dataTable = (DataTable)results[AverageVariableImpactsHistoryResultName].Value;
208
209      int nTrees = (int)Math.Round(trees.Length * percentageBest);
210      var bestTrees = trees.Take(nTrees).Select(x => (ISymbolicExpressionTree)x.Clone()).ToList();
211      // simplify trees before doing anything else
212      var simplifiedTrees = bestTrees.Select(x => simplifier.Simplify(x)).ToList();
213
214      if (optimizeConstants) {
215        for (int i = 0; i < simplifiedTrees.Count; ++i) {
216          qualities[i].Value = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, simplifiedTrees[i], problemData, problemData.TrainingIndices, applyLinearScaling, constantOptimizationIterations, true, estimationLimits.Upper, estimationLimits.Lower);
217        }
218      }
219
220      if (pruneTrees) {
221        for (int i = 0; i < simplifiedTrees.Count; ++i) {
222          simplifiedTrees[i] = SymbolicRegressionPruningOperator.Prune(simplifiedTrees[i], impactsCalculator, interpreter, problemData, estimationLimits, problemData.TrainingIndices);
223        }
224      }
225      // map each variable to a list of indices of the trees that contain it
226      var variablesToTreeIndices = allowedInputVariables.ToDictionary(x => x, x => Enumerable.Range(0, simplifiedTrees.Count).Where(i => ContainsVariable(simplifiedTrees[i], x)).ToList());
227
228      // variable values used for restoring original values in the dataset
229      var variableValues = allowedInputVariables.Select(x => problemData.Dataset.GetReadOnlyDoubleValues(x).ToList()).ToList();
230      // the ds gets new variable values (not the above).
231      var variableNames = allowedInputVariables.Concat(new[] { problemData.TargetVariable }).ToList();
232      var ds = new ModifiableDataset(variableNames, variableNames.Select(x => problemData.Dataset.GetReadOnlyDoubleValues(x).ToList()));
233      var pd = new RegressionProblemData(ds, allowedInputVariables, problemData.TargetVariable);
234      pd.TrainingPartition.Start = problemData.TrainingPartition.Start;
235      pd.TrainingPartition.End = problemData.TrainingPartition.End;
236      pd.TestPartition.Start = problemData.TestPartition.Start;
237      pd.TestPartition.End = problemData.TestPartition.End;
238
239      for (int i = 0; i < allowedInputVariables.Count; ++i) {
240        var v = allowedInputVariables[i];
241        var median = problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).Median();
242        var values = new List<double>(Enumerable.Repeat(median, problemData.Dataset.Rows));
243        // replace values with median
244        ds.ReplaceVariable(v, values);
245
246        var indices = variablesToTreeIndices[v];
247        if (!indices.Any()) {
248          dataTable.Rows[v].Values.Add(0);
249          continue;
250        }
251
252        var averageImpact = 0d;
253        for (int j = 0; j < indices.Count; ++j) {
254          var tree = simplifiedTrees[j];
255          var originalQuality = qualities[j].Value;
256          double newQuality;
257          if (optimizeConstants) {
258            newQuality = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, pd, problemData.TrainingIndices, applyLinearScaling, constantOptimizationIterations, true, estimationLimits.Upper, estimationLimits.Lower);
259          } else {
260            var evaluator = EvaluatorParameter.ActualValue;
261            newQuality = evaluator.Evaluate(this.ExecutionContext, tree, pd, pd.TrainingIndices);
262          }
263          averageImpact += originalQuality - newQuality; // impact calculated this way may be negative
264        }
265        averageImpact /= indices.Count;
266        dataTable.Rows[v].Values.Add(averageImpact);
267        // restore original values
268        ds.ReplaceVariable(v, variableValues[i]);
269      }
270
271      var averageVariableImpacts = new DoubleMatrix(dataTable.Rows.Count, 1);
272      var rowNames = dataTable.Rows.Select(x => x.Name).ToList();
273      averageVariableImpacts.RowNames = rowNames;
274      for (int i = 0; i < rowNames.Count; ++i) {
275        averageVariableImpacts[i, 0] = dataTable.Rows[rowNames[i]].Values.Last();
276      }
277      if (!results.ContainsKey(AverageVariableImpactsResultName)) {
278        results.Add(new Result(AverageVariableImpactsResultName, averageVariableImpacts));
279      } else {
280        results[AverageVariableImpactsResultName].Value = averageVariableImpacts;
281      }
282      return base.Apply();
283    }
284
285    private static bool ContainsVariable(ISymbolicExpressionTree tree, string variableName) {
286      return tree.IterateNodesPrefix().OfType<VariableTreeNode>().Any(x => x.VariableName == variableName);
287    }
288  }
289}
290
Note: See TracBrowser for help on using the repository browser.