Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2288_HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks/3.3/SymbolicDataAnalysisVariableImpactsAnalyzer.cs @ 15929

Last change on this file since 15929 was 15421, checked in by bburlacu, 7 years ago

#2288: Sync with trunk + Minor refactor.

File size: 15.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Analysis;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
36
37namespace HeuristicLab.VariableInteractionNetworks {
38  [Item("SymbolicRegressionVariableImpactsAnalyzer", "An analyzer which calculates variable impacts based on the average node impacts from the tree")]
39  [StorableClass]
40  public class SymbolicRegressionVariableImpactsAnalyzer : SymbolicDataAnalysisAnalyzer {
41    #region parameter names
42    private const string UpdateCounterParameterName = "UpdateCounter";
43    private const string UpdateIntervalParameterName = "UpdateInterval";
44    public const string QualityParameterName = "Quality";
45    private const string SymbolicDataAnalysisTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
46    private const string ProblemDataParameterName = "ProblemData";
47    private const string ApplyLinearScalingParameterName = "ApplyLinearScaling";
48    private const string MaxCOIterationsParameterName = "MaxCOIterations";
49    private const string EstimationLimitsParameterName = "EstimationLimits";
50    private const string EvaluatorParameterName = "Evaluator";
51    private const string PercentageBestParameterName = "PercentageBest";
52    private const string LastGenerationsParameterName = "LastGenerations";
53    private const string MaximumGenerationsParameterName = "MaximumGenerations";
54    private const string OptimizeConstantsParameterName = "OptimizeConstants";
55    private const string PruneTreesParameterName = "PruneTrees";
56    private const string AverageVariableImpactsResultName = "Average variable impacts";
57    private const string AverageVariableImpactsHistoryResultName = "Average variable impacts history";
58    #endregion
59
60    private SymbolicRegressionSolutionImpactValuesCalculator impactsCalculator;
61
62    #region parameters
63    public ValueParameter<IntValue> UpdateCounterParameter {
64      get { return (ValueParameter<IntValue>)Parameters[UpdateCounterParameterName]; }
65    }
66    public ValueParameter<IntValue> UpdateIntervalParameter {
67      get { return (ValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }
68    }
69    public IScopeTreeLookupParameter<DoubleValue> QualityParameter {
70      get { return (IScopeTreeLookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
71    }
72    public ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter> SymbolicDataAnalysisTreeInterpreterParameter {
73      get { return (ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>)Parameters[SymbolicDataAnalysisTreeInterpreterParameterName]; }
74    }
75    public ILookupParameter<IRegressionProblemData> ProblemDataParameter {
76      get { return (ILookupParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
77    }
78    public ILookupParameter<BoolValue> ApplyLinearScalingParameter {
79      get { return (ILookupParameter<BoolValue>)Parameters[ApplyLinearScalingParameterName]; }
80    }
81    public IFixedValueParameter<IntValue> MaxCOIterationsParameter {
82      get { return (IFixedValueParameter<IntValue>)Parameters[MaxCOIterationsParameterName]; }
83    }
84    public ILookupParameter<DoubleLimit> EstimationLimitsParameter {
85      get { return (ILookupParameter<DoubleLimit>)Parameters[EstimationLimitsParameterName]; }
86    }
87    public ILookupParameter<ISymbolicRegressionSingleObjectiveEvaluator> EvaluatorParameter {
88      get { return (ILookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>)Parameters[EvaluatorParameterName]; }
89    }
90    public IFixedValueParameter<PercentValue> PercentageBestParameter {
91      get { return (IFixedValueParameter<PercentValue>)Parameters[PercentageBestParameterName]; }
92    }
93    public IFixedValueParameter<IntValue> LastGenerationsParameter {
94      get { return (IFixedValueParameter<IntValue>)Parameters[LastGenerationsParameterName]; }
95    }
96    public IFixedValueParameter<BoolValue> OptimizeConstantsParameter {
97      get { return (IFixedValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
98    }
99    public IFixedValueParameter<BoolValue> PruneTreesParameter {
100      get { return (IFixedValueParameter<BoolValue>)Parameters[PruneTreesParameterName]; }
101    }
102    private ILookupParameter<IntValue> MaximumGenerationsParameter {
103      get { return (ILookupParameter<IntValue>)Parameters[MaximumGenerationsParameterName]; }
104    }
105    #endregion
106
107    #region parameter properties
108    public int UpdateCounter {
109      get { return UpdateCounterParameter.Value.Value; }
110      set { UpdateCounterParameter.Value.Value = value; }
111    }
112    public int UpdateInterval {
113      get { return UpdateIntervalParameter.Value.Value; }
114      set { UpdateIntervalParameter.Value.Value = value; }
115    }
116    #endregion
117
118    public SymbolicRegressionVariableImpactsAnalyzer() {
119      #region add parameters
120      Parameters.Add(new ValueParameter<IntValue>(UpdateCounterParameterName, new IntValue(0)));
121      Parameters.Add(new ValueParameter<IntValue>(UpdateIntervalParameterName, new IntValue(1)));
122      Parameters.Add(new LookupParameter<IRegressionProblemData>(ProblemDataParameterName));
123      Parameters.Add(new LookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>(SymbolicDataAnalysisTreeInterpreterParameterName));
124      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>(QualityParameterName, "The individual qualities."));
125      Parameters.Add(new LookupParameter<BoolValue>(ApplyLinearScalingParameterName));
126      Parameters.Add(new LookupParameter<DoubleLimit>(EstimationLimitsParameterName));
127      Parameters.Add(new FixedValueParameter<IntValue>(MaxCOIterationsParameterName, new IntValue(3)));
128      Parameters.Add(new FixedValueParameter<PercentValue>(PercentageBestParameterName, new PercentValue(1)));
129      Parameters.Add(new FixedValueParameter<IntValue>(LastGenerationsParameterName, new IntValue(10)));
130      Parameters.Add(new FixedValueParameter<BoolValue>(OptimizeConstantsParameterName, new BoolValue(false)));
131      Parameters.Add(new FixedValueParameter<BoolValue>(PruneTreesParameterName, new BoolValue(false)));
132      Parameters.Add(new LookupParameter<IntValue>(MaximumGenerationsParameterName, "The maximum number of generations which should be processed."));
133      Parameters.Add(new LookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>(EvaluatorParameterName));
134      #endregion
135
136      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
137    }
138
139    [StorableConstructor]
140    protected SymbolicRegressionVariableImpactsAnalyzer(bool deserializing) : base(deserializing) { }
141
142    [StorableHook(HookType.AfterDeserialization)]
143    private void AfterDeserialization() {
144      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
145
146      if (!Parameters.ContainsKey(EvaluatorParameterName))
147        Parameters.Add(new LookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>(EvaluatorParameterName));
148    }
149
150    protected SymbolicRegressionVariableImpactsAnalyzer(SymbolicRegressionVariableImpactsAnalyzer original, Cloner cloner)
151        : base(original, cloner) {
152      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
153    }
154
155    public override IDeepCloneable Clone(Cloner cloner) {
156      return new SymbolicRegressionVariableImpactsAnalyzer(this, cloner);
157    }
158
159    public override IOperation Apply() {
160      #region Update counter & update interval
161      UpdateCounter++;
162      if (UpdateCounter != UpdateInterval) {
163        return base.Apply();
164      }
165      UpdateCounter = 0;
166      #endregion
167      var results = ResultCollectionParameter.ActualValue;
168      int maxGen = MaximumGenerationsParameter.ActualValue.Value;
169      int gen = results.ContainsKey("Generations") ? ((IntValue)results["Generations"].Value).Value : 0;
170      int lastGen = LastGenerationsParameter.Value.Value;
171
172      if (lastGen > 0 && gen < maxGen - lastGen)
173        return base.Apply();
174
175      var trees = SymbolicExpressionTree.ToArray();
176      var qualities = QualityParameter.ActualValue.ToArray();
177
178      Array.Sort(qualities, trees);
179      Array.Reverse(qualities);
180      Array.Reverse(trees);
181
182      var interpreter = SymbolicDataAnalysisTreeInterpreterParameter.ActualValue;
183      var problemData = ProblemDataParameter.ActualValue;
184      var applyLinearScaling = ApplyLinearScalingParameter.ActualValue.Value;
185      var constantOptimizationIterations = MaxCOIterationsParameter.Value.Value; // fixed value parameter => Value
186      var estimationLimits = EstimationLimitsParameter.ActualValue; // lookup parameter => ActualValue
187      var percentageBest = PercentageBestParameter.Value.Value;
188      var optimizeConstants = OptimizeConstantsParameter.Value.Value;
189      var pruneTrees = PruneTreesParameter.Value.Value;
190
191      var allowedInputVariables = problemData.AllowedInputVariables.ToList();
192      DataTable dataTable;
193      if (!results.ContainsKey(AverageVariableImpactsHistoryResultName)) {
194        dataTable = new DataTable("Variable impacts", "Average impact of variables over the population");
195        dataTable.VisualProperties.XAxisTitle = "Generation";
196        dataTable.VisualProperties.YAxisTitle = "Average variable impact";
197        results.Add(new Result(AverageVariableImpactsHistoryResultName, dataTable));
198
199        foreach (var v in allowedInputVariables) {
200          dataTable.Rows.Add(new DataRow(v) { VisualProperties = { StartIndexZero = true } });
201        }
202      }
203      dataTable = (DataTable)results[AverageVariableImpactsHistoryResultName].Value;
204
205      int nTrees = (int)Math.Round(trees.Length * percentageBest);
206      var bestTrees = trees.Take(nTrees).Select(x => (ISymbolicExpressionTree)x.Clone()).ToList();
207      // simplify trees before doing anything else
208      var simplifiedTrees = bestTrees.Select(x => TreeSimplifier.Simplify(x)).ToList();
209
210      if (optimizeConstants) {
211        for (int i = 0; i < simplifiedTrees.Count; ++i) {
212          qualities[i].Value = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, simplifiedTrees[i], problemData, problemData.TrainingIndices, applyLinearScaling, constantOptimizationIterations, true, estimationLimits.Lower, estimationLimits.Upper);
213        }
214      }
215
216      if (pruneTrees) {
217        for (int i = 0; i < simplifiedTrees.Count; ++i) {
218          simplifiedTrees[i] = SymbolicRegressionPruningOperator.Prune(simplifiedTrees[i], impactsCalculator, interpreter, problemData, estimationLimits, problemData.TrainingIndices);
219        }
220      }
221      // map each variable to a list of indices of the trees that contain it
222      var variablesToTreeIndices = allowedInputVariables.ToDictionary(x => x, x => Enumerable.Range(0, simplifiedTrees.Count).Where(i => ContainsVariable(simplifiedTrees[i], x)).ToList());
223
224      // variable values used for restoring original values in the dataset
225      var variableValues = allowedInputVariables.Select(x => problemData.Dataset.GetReadOnlyDoubleValues(x).ToList()).ToList();
226      // the ds gets new variable values (not the above).
227      var variableNames = allowedInputVariables.Concat(new[] { problemData.TargetVariable }).ToList();
228      var ds = new ModifiableDataset(variableNames, variableNames.Select(x => problemData.Dataset.GetReadOnlyDoubleValues(x).ToList()));
229      var pd = new RegressionProblemData(ds, allowedInputVariables, problemData.TargetVariable);
230      pd.TrainingPartition.Start = problemData.TrainingPartition.Start;
231      pd.TrainingPartition.End = problemData.TrainingPartition.End;
232      pd.TestPartition.Start = problemData.TestPartition.Start;
233      pd.TestPartition.End = problemData.TestPartition.End;
234
235      for (int i = 0; i < allowedInputVariables.Count; ++i) {
236        var v = allowedInputVariables[i];
237        var median = problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).Median();
238        var values = new List<double>(Enumerable.Repeat(median, problemData.Dataset.Rows));
239        // replace values with median
240        ds.ReplaceVariable(v, values);
241
242        var indices = variablesToTreeIndices[v];
243        if (!indices.Any()) {
244          dataTable.Rows[v].Values.Add(0);
245          continue;
246        }
247
248        var averageImpact = 0d;
249        for (int j = 0; j < indices.Count; ++j) {
250          var tree = simplifiedTrees[j];
251          var originalQuality = qualities[j].Value;
252          double newQuality;
253          if (optimizeConstants) {
254            newQuality = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, pd, problemData.TrainingIndices, applyLinearScaling, constantOptimizationIterations, true, estimationLimits.Lower, estimationLimits.Upper);
255          } else {
256            var evaluator = EvaluatorParameter.ActualValue;
257            newQuality = evaluator.Evaluate(this.ExecutionContext, tree, pd, pd.TrainingIndices);
258          }
259          averageImpact += originalQuality - newQuality; // impact calculated this way may be negative
260        }
261        averageImpact /= indices.Count;
262        dataTable.Rows[v].Values.Add(averageImpact);
263        // restore original values
264        ds.ReplaceVariable(v, variableValues[i]);
265      }
266
267      var averageVariableImpacts = new DoubleMatrix(dataTable.Rows.Count, 1);
268      var rowNames = dataTable.Rows.Select(x => x.Name).ToList();
269      averageVariableImpacts.RowNames = rowNames;
270      for (int i = 0; i < rowNames.Count; ++i) {
271        averageVariableImpacts[i, 0] = dataTable.Rows[rowNames[i]].Values.Last();
272      }
273      if (!results.ContainsKey(AverageVariableImpactsResultName)) {
274        results.Add(new Result(AverageVariableImpactsResultName, averageVariableImpacts));
275      } else {
276        results[AverageVariableImpactsResultName].Value = averageVariableImpacts;
277      }
278      return base.Apply();
279    }
280
281    private static bool ContainsVariable(ISymbolicExpressionTree tree, string variableName) {
282      return tree.IterateNodesPrefix().OfType<VariableTreeNode>().Any(x => x.VariableName == variableName);
283    }
284  }
285}
286
Note: See TracBrowser for help on using the repository browser.