Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2288_HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks/3.3/SymbolicDataAnalysisVariableImpactsAnalyzer.cs @ 17021

Last change on this file since 17021 was 16864, checked in by gkronber, 6 years ago

#2288: updated to .NET 4.6.1 and new persistence backend for compatibility with current trunk

File size: 15.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Analysis;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
36using HEAL.Attic;
37
38namespace HeuristicLab.VariableInteractionNetworks {
39  [Item("SymbolicRegressionVariableImpactsAnalyzer", "An analyzer which calculates variable impacts based on the average node impacts from the tree")]
40  [StorableType("9FF5C215-6A7D-4947-A449-A554D2AA6496")]
41  public class SymbolicRegressionVariableImpactsAnalyzer : SymbolicDataAnalysisAnalyzer {
42    #region parameter names
43    private const string UpdateCounterParameterName = "UpdateCounter";
44    private const string UpdateIntervalParameterName = "UpdateInterval";
45    public const string QualityParameterName = "Quality";
46    private const string SymbolicDataAnalysisTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
47    private const string ProblemDataParameterName = "ProblemData";
48    private const string ApplyLinearScalingParameterName = "ApplyLinearScaling";
49    private const string MaxCOIterationsParameterName = "MaxCOIterations";
50    private const string EstimationLimitsParameterName = "EstimationLimits";
51    private const string EvaluatorParameterName = "Evaluator";
52    private const string PercentageBestParameterName = "PercentageBest";
53    private const string LastGenerationsParameterName = "LastGenerations";
54    private const string MaximumGenerationsParameterName = "MaximumGenerations";
55    private const string OptimizeConstantsParameterName = "OptimizeConstants";
56    private const string PruneTreesParameterName = "PruneTrees";
57    private const string AverageVariableImpactsResultName = "Average variable impacts";
58    private const string AverageVariableImpactsHistoryResultName = "Average variable impacts history";
59    #endregion
60
61    private SymbolicRegressionSolutionImpactValuesCalculator impactsCalculator;
62
63    #region parameters
64    public ValueParameter<IntValue> UpdateCounterParameter {
65      get { return (ValueParameter<IntValue>)Parameters[UpdateCounterParameterName]; }
66    }
67    public ValueParameter<IntValue> UpdateIntervalParameter {
68      get { return (ValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }
69    }
70    public IScopeTreeLookupParameter<DoubleValue> QualityParameter {
71      get { return (IScopeTreeLookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
72    }
73    public ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter> SymbolicDataAnalysisTreeInterpreterParameter {
74      get { return (ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>)Parameters[SymbolicDataAnalysisTreeInterpreterParameterName]; }
75    }
76    public ILookupParameter<IRegressionProblemData> ProblemDataParameter {
77      get { return (ILookupParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
78    }
79    public ILookupParameter<BoolValue> ApplyLinearScalingParameter {
80      get { return (ILookupParameter<BoolValue>)Parameters[ApplyLinearScalingParameterName]; }
81    }
82    public IFixedValueParameter<IntValue> MaxCOIterationsParameter {
83      get { return (IFixedValueParameter<IntValue>)Parameters[MaxCOIterationsParameterName]; }
84    }
85    public ILookupParameter<DoubleLimit> EstimationLimitsParameter {
86      get { return (ILookupParameter<DoubleLimit>)Parameters[EstimationLimitsParameterName]; }
87    }
88    public ILookupParameter<ISymbolicRegressionSingleObjectiveEvaluator> EvaluatorParameter {
89      get { return (ILookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>)Parameters[EvaluatorParameterName]; }
90    }
91    public IFixedValueParameter<PercentValue> PercentageBestParameter {
92      get { return (IFixedValueParameter<PercentValue>)Parameters[PercentageBestParameterName]; }
93    }
94    public IFixedValueParameter<IntValue> LastGenerationsParameter {
95      get { return (IFixedValueParameter<IntValue>)Parameters[LastGenerationsParameterName]; }
96    }
97    public IFixedValueParameter<BoolValue> OptimizeConstantsParameter {
98      get { return (IFixedValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
99    }
100    public IFixedValueParameter<BoolValue> PruneTreesParameter {
101      get { return (IFixedValueParameter<BoolValue>)Parameters[PruneTreesParameterName]; }
102    }
103    private ILookupParameter<IntValue> MaximumGenerationsParameter {
104      get { return (ILookupParameter<IntValue>)Parameters[MaximumGenerationsParameterName]; }
105    }
106    #endregion
107
108    #region parameter properties
109    public int UpdateCounter {
110      get { return UpdateCounterParameter.Value.Value; }
111      set { UpdateCounterParameter.Value.Value = value; }
112    }
113    public int UpdateInterval {
114      get { return UpdateIntervalParameter.Value.Value; }
115      set { UpdateIntervalParameter.Value.Value = value; }
116    }
117    #endregion
118
119    public SymbolicRegressionVariableImpactsAnalyzer() {
120      #region add parameters
121      Parameters.Add(new ValueParameter<IntValue>(UpdateCounterParameterName, new IntValue(0)));
122      Parameters.Add(new ValueParameter<IntValue>(UpdateIntervalParameterName, new IntValue(1)));
123      Parameters.Add(new LookupParameter<IRegressionProblemData>(ProblemDataParameterName));
124      Parameters.Add(new LookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>(SymbolicDataAnalysisTreeInterpreterParameterName));
125      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>(QualityParameterName, "The individual qualities."));
126      Parameters.Add(new LookupParameter<BoolValue>(ApplyLinearScalingParameterName));
127      Parameters.Add(new LookupParameter<DoubleLimit>(EstimationLimitsParameterName));
128      Parameters.Add(new FixedValueParameter<IntValue>(MaxCOIterationsParameterName, new IntValue(3)));
129      Parameters.Add(new FixedValueParameter<PercentValue>(PercentageBestParameterName, new PercentValue(1)));
130      Parameters.Add(new FixedValueParameter<IntValue>(LastGenerationsParameterName, new IntValue(10)));
131      Parameters.Add(new FixedValueParameter<BoolValue>(OptimizeConstantsParameterName, new BoolValue(false)));
132      Parameters.Add(new FixedValueParameter<BoolValue>(PruneTreesParameterName, new BoolValue(false)));
133      Parameters.Add(new LookupParameter<IntValue>(MaximumGenerationsParameterName, "The maximum number of generations which should be processed."));
134      Parameters.Add(new LookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>(EvaluatorParameterName));
135      #endregion
136
137      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
138    }
139
140    [StorableConstructor]
141    protected SymbolicRegressionVariableImpactsAnalyzer(StorableConstructorFlag _) : base(_) { }
142
143    [StorableHook(HookType.AfterDeserialization)]
144    private void AfterDeserialization() {
145      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
146
147      if (!Parameters.ContainsKey(EvaluatorParameterName))
148        Parameters.Add(new LookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>(EvaluatorParameterName));
149    }
150
151    protected SymbolicRegressionVariableImpactsAnalyzer(SymbolicRegressionVariableImpactsAnalyzer original, Cloner cloner)
152        : base(original, cloner) {
153      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
154    }
155
156    public override IDeepCloneable Clone(Cloner cloner) {
157      return new SymbolicRegressionVariableImpactsAnalyzer(this, cloner);
158    }
159
160    public override IOperation Apply() {
161      #region Update counter & update interval
162      UpdateCounter++;
163      if (UpdateCounter != UpdateInterval) {
164        return base.Apply();
165      }
166      UpdateCounter = 0;
167      #endregion
168      var results = ResultCollectionParameter.ActualValue;
169      int maxGen = MaximumGenerationsParameter.ActualValue.Value;
170      int gen = results.ContainsKey("Generations") ? ((IntValue)results["Generations"].Value).Value : 0;
171      int lastGen = LastGenerationsParameter.Value.Value;
172
173      if (lastGen > 0 && gen < maxGen - lastGen)
174        return base.Apply();
175
176      var trees = SymbolicExpressionTree.ToArray();
177      var qualities = QualityParameter.ActualValue.ToArray();
178
179      Array.Sort(qualities, trees);
180      Array.Reverse(qualities);
181      Array.Reverse(trees);
182
183      var interpreter = SymbolicDataAnalysisTreeInterpreterParameter.ActualValue;
184      var problemData = ProblemDataParameter.ActualValue;
185      var applyLinearScaling = ApplyLinearScalingParameter.ActualValue.Value;
186      var constantOptimizationIterations = MaxCOIterationsParameter.Value.Value; // fixed value parameter => Value
187      var estimationLimits = EstimationLimitsParameter.ActualValue; // lookup parameter => ActualValue
188      var percentageBest = PercentageBestParameter.Value.Value;
189      var optimizeConstants = OptimizeConstantsParameter.Value.Value;
190      var pruneTrees = PruneTreesParameter.Value.Value;
191
192      var allowedInputVariables = problemData.AllowedInputVariables.ToList();
193      DataTable dataTable;
194      if (!results.ContainsKey(AverageVariableImpactsHistoryResultName)) {
195        dataTable = new DataTable("Variable impacts", "Average impact of variables over the population");
196        dataTable.VisualProperties.XAxisTitle = "Generation";
197        dataTable.VisualProperties.YAxisTitle = "Average variable impact";
198        results.Add(new Result(AverageVariableImpactsHistoryResultName, dataTable));
199
200        foreach (var v in allowedInputVariables) {
201          dataTable.Rows.Add(new DataRow(v) { VisualProperties = { StartIndexZero = true } });
202        }
203      }
204      dataTable = (DataTable)results[AverageVariableImpactsHistoryResultName].Value;
205
206      int nTrees = (int)Math.Round(trees.Length * percentageBest);
207      var bestTrees = trees.Take(nTrees).Select(x => (ISymbolicExpressionTree)x.Clone()).ToList();
208      // simplify trees before doing anything else
209      var simplifiedTrees = bestTrees.Select(x => TreeSimplifier.Simplify(x)).ToList();
210
211      if (optimizeConstants) {
212        for (int i = 0; i < simplifiedTrees.Count; ++i) {
213          qualities[i].Value = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, simplifiedTrees[i], problemData, problemData.TrainingIndices, applyLinearScaling, constantOptimizationIterations, true, estimationLimits.Lower, estimationLimits.Upper);
214        }
215      }
216
217      if (pruneTrees) {
218        for (int i = 0; i < simplifiedTrees.Count; ++i) {
219          simplifiedTrees[i] = SymbolicRegressionPruningOperator.Prune(simplifiedTrees[i], impactsCalculator, interpreter, problemData, estimationLimits, problemData.TrainingIndices);
220        }
221      }
222      // map each variable to a list of indices of the trees that contain it
223      var variablesToTreeIndices = allowedInputVariables.ToDictionary(x => x, x => Enumerable.Range(0, simplifiedTrees.Count).Where(i => ContainsVariable(simplifiedTrees[i], x)).ToList());
224
225      // variable values used for restoring original values in the dataset
226      var variableValues = allowedInputVariables.Select(x => problemData.Dataset.GetReadOnlyDoubleValues(x).ToList()).ToList();
227      // the ds gets new variable values (not the above).
228      var variableNames = allowedInputVariables.Concat(new[] { problemData.TargetVariable }).ToList();
229      var ds = new ModifiableDataset(variableNames, variableNames.Select(x => problemData.Dataset.GetReadOnlyDoubleValues(x).ToList()));
230      var pd = new RegressionProblemData(ds, allowedInputVariables, problemData.TargetVariable);
231      pd.TrainingPartition.Start = problemData.TrainingPartition.Start;
232      pd.TrainingPartition.End = problemData.TrainingPartition.End;
233      pd.TestPartition.Start = problemData.TestPartition.Start;
234      pd.TestPartition.End = problemData.TestPartition.End;
235
236      for (int i = 0; i < allowedInputVariables.Count; ++i) {
237        var v = allowedInputVariables[i];
238        var median = problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).Median();
239        var values = new List<double>(Enumerable.Repeat(median, problemData.Dataset.Rows));
240        // replace values with median
241        ds.ReplaceVariable(v, values);
242
243        var indices = variablesToTreeIndices[v];
244        if (!indices.Any()) {
245          dataTable.Rows[v].Values.Add(0);
246          continue;
247        }
248
249        var averageImpact = 0d;
250        for (int j = 0; j < indices.Count; ++j) {
251          var tree = simplifiedTrees[j];
252          var originalQuality = qualities[j].Value;
253          double newQuality;
254          if (optimizeConstants) {
255            newQuality = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, pd, problemData.TrainingIndices, applyLinearScaling, constantOptimizationIterations, true, estimationLimits.Lower, estimationLimits.Upper);
256          } else {
257            var evaluator = EvaluatorParameter.ActualValue;
258            newQuality = evaluator.Evaluate(this.ExecutionContext, tree, pd, pd.TrainingIndices);
259          }
260          averageImpact += originalQuality - newQuality; // impact calculated this way may be negative
261        }
262        averageImpact /= indices.Count;
263        dataTable.Rows[v].Values.Add(averageImpact);
264        // restore original values
265        ds.ReplaceVariable(v, variableValues[i]);
266      }
267
268      var averageVariableImpacts = new DoubleMatrix(dataTable.Rows.Count, 1);
269      var rowNames = dataTable.Rows.Select(x => x.Name).ToList();
270      averageVariableImpacts.RowNames = rowNames;
271      for (int i = 0; i < rowNames.Count; ++i) {
272        averageVariableImpacts[i, 0] = dataTable.Rows[rowNames[i]].Values.Last();
273      }
274      if (!results.ContainsKey(AverageVariableImpactsResultName)) {
275        results.Add(new Result(AverageVariableImpactsResultName, averageVariableImpacts));
276      } else {
277        results[AverageVariableImpactsResultName].Value = averageVariableImpacts;
278      }
279      return base.Apply();
280    }
281
282    private static bool ContainsVariable(ISymbolicExpressionTree tree, string variableName) {
283      return tree.IterateNodesPrefix().OfType<VariableTreeNode>().Any(x => x.VariableName == variableName);
284    }
285  }
286}
287
Note: See TracBrowser for help on using the repository browser.