Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks/3.3/SymbolicDataAnalysisVariableImpactsAnalyzer.cs @ 13665

Last change on this file since 13665 was 13665, checked in by bburlacu, 8 years ago

#2288: Improve calculation of variable impacts in the analyzer

File size: 14.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Analysis;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Problems.DataAnalysis;
33using HeuristicLab.Problems.DataAnalysis.Symbolic;
34using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
35
36namespace HeuristicLab.VariableInteractionNetworks {
37  [Item("SymbolicRegressionVariableImpactsAnalyzer", "An analyzer which calculates variable impacts based on the average node impacts from the tree")]
38  [StorableClass]
39  public class SymbolicRegressionVariableImpactsAnalyzer : SymbolicDataAnalysisAnalyzer {
40    private const string UpdateCounterParameterName = "UpdateCounter";
41    private const string UpdateIntervalParameterName = "UpdateInterval";
42    public const string QualityParameterName = "Quality";
43    private const string SymbolicDataAnalysisTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
44    private const string ProblemDataParameterName = "ProblemData";
45    private const string ApplyLinearScalingParameterName = "ApplyLinearScaling";
46    private const string MaxCOIterationsParameterName = "MaxCOIterations";
47    private const string EstimationLimitsParameterName = "EstimationLimits";
48    private const string EvaluatorParameterName = "Evaluator";
49    private const string VariableImpactsParameterName = "AverageVariableImpacts";
50    private const string PercentageBestParameterName = "PercentageBest";
51    private const string LastGenerationsParameterName = "LastGenerations";
52    private const string MaximumGenerationsParameterName = "MaximumGenerations";
53    private const string OptimizeConstantsParameterName = "OptimizeConstants";
54    private const string PruneTreesParameterName = "PruneTrees";
55
56    private SymbolicDataAnalysisExpressionTreeSimplifier simplifier;
57    private SymbolicRegressionSolutionImpactValuesCalculator impactsCalculator;
58
59    #region parameters
60    public ValueParameter<IntValue> UpdateCounterParameter {
61      get { return (ValueParameter<IntValue>)Parameters[UpdateCounterParameterName]; }
62    }
63    public ValueParameter<IntValue> UpdateIntervalParameter {
64      get { return (ValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }
65    }
66    public IScopeTreeLookupParameter<DoubleValue> QualityParameter {
67      get { return (IScopeTreeLookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
68    }
69    public ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter> SymbolicDataAnalysisTreeInterpreterParameter {
70      get { return (ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>)Parameters[SymbolicDataAnalysisTreeInterpreterParameterName]; }
71    }
72    public ILookupParameter<IRegressionProblemData> ProblemDataParameter {
73      get { return (ILookupParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
74    }
75    public ILookupParameter<BoolValue> ApplyLinearScalingParameter {
76      get { return (ILookupParameter<BoolValue>)Parameters[ApplyLinearScalingParameterName]; }
77    }
78    public IFixedValueParameter<IntValue> MaxCOIterationsParameter {
79      get { return (IFixedValueParameter<IntValue>)Parameters[MaxCOIterationsParameterName]; }
80    }
81    public ILookupParameter<DoubleLimit> EstimationLimitsParameter {
82      get { return (ILookupParameter<DoubleLimit>)Parameters[EstimationLimitsParameterName]; }
83    }
84    public ILookupParameter<ISymbolicRegressionSingleObjectiveEvaluator> EvaluatorParameter {
85      get { return (ILookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>)Parameters[EvaluatorParameterName]; }
86    }
87    public ILookupParameter<DataTable> VariableImpactsParameter {
88      get { return (ILookupParameter<DataTable>)Parameters[VariableImpactsParameterName]; }
89    }
90    public IFixedValueParameter<PercentValue> PercentageBestParameter {
91      get { return (IFixedValueParameter<PercentValue>)Parameters[PercentageBestParameterName]; }
92    }
93    public IFixedValueParameter<IntValue> LastGenerationsParameter {
94      get { return (IFixedValueParameter<IntValue>)Parameters[LastGenerationsParameterName]; }
95    }
96    public IFixedValueParameter<BoolValue> OptimizeConstantsParameter {
97      get { return (IFixedValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
98    }
99    public IFixedValueParameter<BoolValue> PruneTreesParameter {
100      get { return (IFixedValueParameter<BoolValue>)Parameters[PruneTreesParameterName]; }
101    }
102    private ILookupParameter<IntValue> MaximumGenerationsParameter {
103      get { return (ILookupParameter<IntValue>)Parameters[MaximumGenerationsParameterName]; }
104    }
105    #endregion
106
107    #region parameter properties
108    public int UpdateCounter {
109      get { return UpdateCounterParameter.Value.Value; }
110      set { UpdateCounterParameter.Value.Value = value; }
111    }
112    public int UpdateInterval {
113      get { return UpdateIntervalParameter.Value.Value; }
114      set { UpdateIntervalParameter.Value.Value = value; }
115    }
116    #endregion
117
118    public SymbolicRegressionVariableImpactsAnalyzer() {
119      #region add parameters
120      Parameters.Add(new ValueParameter<IntValue>(UpdateCounterParameterName, new IntValue(0)));
121      Parameters.Add(new ValueParameter<IntValue>(UpdateIntervalParameterName, new IntValue(1)));
122      Parameters.Add(new LookupParameter<IRegressionProblemData>(ProblemDataParameterName));
123      Parameters.Add(new LookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>(SymbolicDataAnalysisTreeInterpreterParameterName));
124      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>(QualityParameterName, "The individual qualities."));
125      Parameters.Add(new LookupParameter<BoolValue>(ApplyLinearScalingParameterName));
126      Parameters.Add(new LookupParameter<DoubleLimit>(EstimationLimitsParameterName));
127      Parameters.Add(new FixedValueParameter<IntValue>(MaxCOIterationsParameterName, new IntValue(3)));
128      Parameters.Add(new LookupParameter<DataTable>(VariableImpactsParameterName, "The relative variable relevance calculated as the average relative variable frequency over the whole run."));
129      Parameters.Add(new FixedValueParameter<PercentValue>(PercentageBestParameterName, new PercentValue(100)));
130      Parameters.Add(new FixedValueParameter<IntValue>(LastGenerationsParameterName, new IntValue(10)));
131      Parameters.Add(new FixedValueParameter<BoolValue>(OptimizeConstantsParameterName, new BoolValue(false)));
132      Parameters.Add(new FixedValueParameter<BoolValue>(PruneTreesParameterName, new BoolValue(false)));
133      Parameters.Add(new LookupParameter<IntValue>(MaximumGenerationsParameterName, "The maximum number of generations which should be processed."));
134      Parameters.Add(new LookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>(EvaluatorParameterName));
135      #endregion
136
137      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
138      simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
139    }
140
141    [StorableConstructor]
142    protected SymbolicRegressionVariableImpactsAnalyzer(bool deserializing) : base(deserializing) { }
143
144    [StorableHook(HookType.AfterDeserialization)]
145    private void AfterDeserialization() {
146      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
147      simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
148
149      if (!Parameters.ContainsKey(EvaluatorParameterName))
150        Parameters.Add(new LookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>(EvaluatorParameterName));
151    }
152
153    protected SymbolicRegressionVariableImpactsAnalyzer(SymbolicRegressionVariableImpactsAnalyzer original, Cloner cloner)
154        : base(original, cloner) {
155      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
156      simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
157    }
158
159    public override IDeepCloneable Clone(Cloner cloner) {
160      return new SymbolicRegressionVariableImpactsAnalyzer(this, cloner);
161    }
162
163    public override IOperation Apply() {
164      #region Update counter & update interval
165      UpdateCounter++;
166      if (UpdateCounter != UpdateInterval) {
167        return base.Apply();
168      }
169      UpdateCounter = 0;
170      #endregion
171      var results = ResultCollectionParameter.ActualValue;
172      int maxGen = MaximumGenerationsParameter.ActualValue.Value;
173      int gen = ((IntValue)results["Generations"].Value).Value;
174      int lastGen = LastGenerationsParameter.Value.Value;
175
176      if (lastGen > 0 && gen < maxGen - lastGen)
177        return base.Apply();
178
179      var trees = SymbolicExpressionTree.ToArray();
180      var qualities = QualityParameter.ActualValue.ToArray();
181
182      Array.Sort(qualities, trees);
183      Array.Reverse(qualities);
184      Array.Reverse(trees);
185
186      var interpreter = SymbolicDataAnalysisTreeInterpreterParameter.ActualValue;
187      var problemData = ProblemDataParameter.ActualValue;
188      var applyLinearScaling = ApplyLinearScalingParameter.ActualValue.Value;
189      var constantOptimizationIterations = MaxCOIterationsParameter.Value.Value; // fixed value parameter => Value
190      var estimationLimits = EstimationLimitsParameter.ActualValue; // lookup parameter => ActualValue
191      var percentageBest = PercentageBestParameter.Value.Value;
192      var optimizeConstants = OptimizeConstantsParameter.Value.Value;
193      var pruneTrees = PruneTreesParameter.Value.Value;
194
195      var allowedInputVariables = problemData.AllowedInputVariables.ToList();
196      DataTable dataTable;
197      if (VariableImpactsParameter.ActualValue == null) {
198        dataTable = new DataTable("Variable impacts", "Average impact of variables over the population");
199        dataTable.VisualProperties.XAxisTitle = "Generation";
200        dataTable.VisualProperties.YAxisTitle = "Average variable impact";
201        VariableImpactsParameter.ActualValue = dataTable;
202        results.Add(new Result("Average variable impacts", "The relative variable relevance calculated as the average relative variable frequency over the whole run.", new DataTable()));
203
204        foreach (var v in allowedInputVariables) {
205          dataTable.Rows.Add(new DataRow(v) { VisualProperties = { StartIndexZero = true } });
206        }
207        VariableImpactsParameter.ActualValue = dataTable;
208      }
209      dataTable = VariableImpactsParameter.ActualValue;
210      int nTrees = (int)Math.Round(trees.Length * percentageBest);
211      var bestTrees = trees.Take(nTrees).Select(x => (ISymbolicExpressionTree)x.Clone()).ToList();
212      // simplify trees before doing anything else
213      var simplifiedTrees = bestTrees.Select(x => simplifier.Simplify(x)).ToList();
214
215      if (optimizeConstants) {
216        for (int i = 0; i < simplifiedTrees.Count; ++i) {
217          qualities[i].Value = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, simplifiedTrees[i], problemData, problemData.TrainingIndices, applyLinearScaling, constantOptimizationIterations, estimationLimits.Upper, estimationLimits.Lower);
218        }
219      }
220
221      if (pruneTrees) {
222        for (int i = 0; i < simplifiedTrees.Count; ++i) {
223          simplifiedTrees[i] = SymbolicRegressionPruningOperator.Prune(simplifiedTrees[i], impactsCalculator, interpreter, problemData, estimationLimits, problemData.TrainingIndices);
224        }
225      }
226      // map each variable to a list of indices of the trees that contain it
227      var variablesToTreeIndices = allowedInputVariables.ToDictionary(x => x, x => Enumerable.Range(0, simplifiedTrees.Count).Where(i => ContainsVariable(simplifiedTrees[i], x)).ToList());
228
229      foreach (var mapping in variablesToTreeIndices) {
230        var variableName = mapping.Key;
231        var median = problemData.Dataset.GetDoubleValues(variableName, problemData.TrainingIndices).Median();
232        var ds = new ModifiableDataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(x => problemData.Dataset.GetReadOnlyDoubleValues(x).ToList()));
233        foreach (var i in problemData.TrainingIndices) {
234          ds.SetVariableValue(median, variableName, i);
235        }
236        var pd = new RegressionProblemData(ds, allowedInputVariables, problemData.TargetVariable);
237        pd.TrainingPartition.Start = problemData.TrainingPartition.Start;
238        pd.TrainingPartition.End = problemData.TrainingPartition.End;
239        pd.TestPartition.Start = problemData.TestPartition.Start;
240        pd.TestPartition.End = problemData.TestPartition.End;
241
242        var indices = mapping.Value;
243        var averageImpact = 0d;
244        for (int i = 0; i < indices.Count; ++i) {
245          var originalQuality = qualities[i].Value;
246          double newQuality;
247          if (optimizeConstants) {
248            newQuality = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, simplifiedTrees[i], pd, problemData.TrainingIndices, applyLinearScaling, constantOptimizationIterations, estimationLimits.Upper, estimationLimits.Lower);
249          } else {
250            var evaluator = EvaluatorParameter.ActualValue;
251            newQuality = evaluator.Evaluate(this.ExecutionContext, simplifiedTrees[i], pd, pd.TrainingIndices);
252          }
253          averageImpact += originalQuality - newQuality; // impact calculated this way may be negative
254        }
255        averageImpact /= indices.Count;
256        dataTable.Rows[variableName].Values.Add(averageImpact);
257      }
258
259      results["Average variable impacts"].Value = dataTable;
260      return base.Apply();
261    }
262
263    private static bool ContainsVariable(ISymbolicExpressionTree tree, string variableName) {
264      return tree.IterateNodesPrefix().OfType<VariableTreeNode>().Any(x => x.VariableName == variableName);
265    }
266  }
267}
Note: See TracBrowser for help on using the repository browser.