Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks/3.3/SymbolicDataAnalysisVariableImpactsAnalyzer.cs @ 13728

Last change on this file since 13728 was 13728, checked in by bburlacu, 8 years ago

#2288: Added VariableInteractionNetwork graph class. Small improvements to the impacts analyzer. Add license header and improve formatting in CreateTargetVariationExperiment.cs

File size: 14.8 KB
RevLine 
[13665]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
[12460]23using System.Linq;
24using HeuristicLab.Analysis;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
[13665]28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[12460]29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Problems.DataAnalysis;
33using HeuristicLab.Problems.DataAnalysis.Symbolic;
34using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
35
[13665]36namespace HeuristicLab.VariableInteractionNetworks {
37  [Item("SymbolicRegressionVariableImpactsAnalyzer", "An analyzer which calculates variable impacts based on the average node impacts from the tree")]
38  [StorableClass]
39  public class SymbolicRegressionVariableImpactsAnalyzer : SymbolicDataAnalysisAnalyzer {
[13728]40    #region parameter names
[13665]41    private const string UpdateCounterParameterName = "UpdateCounter";
42    private const string UpdateIntervalParameterName = "UpdateInterval";
43    public const string QualityParameterName = "Quality";
44    private const string SymbolicDataAnalysisTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
45    private const string ProblemDataParameterName = "ProblemData";
46    private const string ApplyLinearScalingParameterName = "ApplyLinearScaling";
47    private const string MaxCOIterationsParameterName = "MaxCOIterations";
48    private const string EstimationLimitsParameterName = "EstimationLimits";
49    private const string EvaluatorParameterName = "Evaluator";
50    private const string PercentageBestParameterName = "PercentageBest";
51    private const string LastGenerationsParameterName = "LastGenerations";
52    private const string MaximumGenerationsParameterName = "MaximumGenerations";
53    private const string OptimizeConstantsParameterName = "OptimizeConstants";
54    private const string PruneTreesParameterName = "PruneTrees";
[13728]55    private const string AverageVariableImpactsResultName = "Average variable impacts";
56    private const string AverageVariableImpactsHistoryResultName = "Average variable impacts history";
57    #endregion
[12460]58
[13665]59    private SymbolicDataAnalysisExpressionTreeSimplifier simplifier;
60    private SymbolicRegressionSolutionImpactValuesCalculator impactsCalculator;
[12460]61
[13665]62    #region parameters
63    public ValueParameter<IntValue> UpdateCounterParameter {
64      get { return (ValueParameter<IntValue>)Parameters[UpdateCounterParameterName]; }
65    }
66    public ValueParameter<IntValue> UpdateIntervalParameter {
67      get { return (ValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }
68    }
69    public IScopeTreeLookupParameter<DoubleValue> QualityParameter {
70      get { return (IScopeTreeLookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
71    }
72    public ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter> SymbolicDataAnalysisTreeInterpreterParameter {
73      get { return (ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>)Parameters[SymbolicDataAnalysisTreeInterpreterParameterName]; }
74    }
75    public ILookupParameter<IRegressionProblemData> ProblemDataParameter {
76      get { return (ILookupParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
77    }
78    public ILookupParameter<BoolValue> ApplyLinearScalingParameter {
79      get { return (ILookupParameter<BoolValue>)Parameters[ApplyLinearScalingParameterName]; }
80    }
81    public IFixedValueParameter<IntValue> MaxCOIterationsParameter {
82      get { return (IFixedValueParameter<IntValue>)Parameters[MaxCOIterationsParameterName]; }
83    }
84    public ILookupParameter<DoubleLimit> EstimationLimitsParameter {
85      get { return (ILookupParameter<DoubleLimit>)Parameters[EstimationLimitsParameterName]; }
86    }
87    public ILookupParameter<ISymbolicRegressionSingleObjectiveEvaluator> EvaluatorParameter {
88      get { return (ILookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>)Parameters[EvaluatorParameterName]; }
89    }
90    public IFixedValueParameter<PercentValue> PercentageBestParameter {
91      get { return (IFixedValueParameter<PercentValue>)Parameters[PercentageBestParameterName]; }
92    }
93    public IFixedValueParameter<IntValue> LastGenerationsParameter {
94      get { return (IFixedValueParameter<IntValue>)Parameters[LastGenerationsParameterName]; }
95    }
96    public IFixedValueParameter<BoolValue> OptimizeConstantsParameter {
97      get { return (IFixedValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
98    }
99    public IFixedValueParameter<BoolValue> PruneTreesParameter {
100      get { return (IFixedValueParameter<BoolValue>)Parameters[PruneTreesParameterName]; }
101    }
102    private ILookupParameter<IntValue> MaximumGenerationsParameter {
103      get { return (ILookupParameter<IntValue>)Parameters[MaximumGenerationsParameterName]; }
104    }
105    #endregion
[12568]106
[13665]107    #region parameter properties
108    public int UpdateCounter {
109      get { return UpdateCounterParameter.Value.Value; }
110      set { UpdateCounterParameter.Value.Value = value; }
111    }
112    public int UpdateInterval {
113      get { return UpdateIntervalParameter.Value.Value; }
114      set { UpdateIntervalParameter.Value.Value = value; }
115    }
116    #endregion
[12460]117
[13665]118    public SymbolicRegressionVariableImpactsAnalyzer() {
119      #region add parameters
120      Parameters.Add(new ValueParameter<IntValue>(UpdateCounterParameterName, new IntValue(0)));
121      Parameters.Add(new ValueParameter<IntValue>(UpdateIntervalParameterName, new IntValue(1)));
122      Parameters.Add(new LookupParameter<IRegressionProblemData>(ProblemDataParameterName));
123      Parameters.Add(new LookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>(SymbolicDataAnalysisTreeInterpreterParameterName));
124      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>(QualityParameterName, "The individual qualities."));
125      Parameters.Add(new LookupParameter<BoolValue>(ApplyLinearScalingParameterName));
126      Parameters.Add(new LookupParameter<DoubleLimit>(EstimationLimitsParameterName));
127      Parameters.Add(new FixedValueParameter<IntValue>(MaxCOIterationsParameterName, new IntValue(3)));
[13728]128      Parameters.Add(new FixedValueParameter<PercentValue>(PercentageBestParameterName, new PercentValue(1)));
[13665]129      Parameters.Add(new FixedValueParameter<IntValue>(LastGenerationsParameterName, new IntValue(10)));
130      Parameters.Add(new FixedValueParameter<BoolValue>(OptimizeConstantsParameterName, new BoolValue(false)));
131      Parameters.Add(new FixedValueParameter<BoolValue>(PruneTreesParameterName, new BoolValue(false)));
132      Parameters.Add(new LookupParameter<IntValue>(MaximumGenerationsParameterName, "The maximum number of generations which should be processed."));
133      Parameters.Add(new LookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>(EvaluatorParameterName));
134      #endregion
[12460]135
[13665]136      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
137      simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
138    }
[12460]139
[13665]140    [StorableConstructor]
141    protected SymbolicRegressionVariableImpactsAnalyzer(bool deserializing) : base(deserializing) { }
[12460]142
[13665]143    [StorableHook(HookType.AfterDeserialization)]
144    private void AfterDeserialization() {
145      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
146      simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
[12460]147
[13665]148      if (!Parameters.ContainsKey(EvaluatorParameterName))
149        Parameters.Add(new LookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>(EvaluatorParameterName));
150    }
[12568]151
[13665]152    protected SymbolicRegressionVariableImpactsAnalyzer(SymbolicRegressionVariableImpactsAnalyzer original, Cloner cloner)
153        : base(original, cloner) {
154      impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
155      simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
156    }
[12460]157
[13665]158    public override IDeepCloneable Clone(Cloner cloner) {
159      return new SymbolicRegressionVariableImpactsAnalyzer(this, cloner);
160    }
[12460]161
[13665]162    public override IOperation Apply() {
163      #region Update counter & update interval
164      UpdateCounter++;
165      if (UpdateCounter != UpdateInterval) {
166        return base.Apply();
167      }
168      UpdateCounter = 0;
169      #endregion
170      var results = ResultCollectionParameter.ActualValue;
171      int maxGen = MaximumGenerationsParameter.ActualValue.Value;
[13728]172      int gen = results.ContainsKey("Generations") ? ((IntValue)results["Generations"].Value).Value : 0;
[13665]173      int lastGen = LastGenerationsParameter.Value.Value;
[12460]174
[13665]175      if (lastGen > 0 && gen < maxGen - lastGen)
176        return base.Apply();
[12460]177
[13665]178      var trees = SymbolicExpressionTree.ToArray();
179      var qualities = QualityParameter.ActualValue.ToArray();
[12460]180
[13665]181      Array.Sort(qualities, trees);
182      Array.Reverse(qualities);
183      Array.Reverse(trees);
[12460]184
[13665]185      var interpreter = SymbolicDataAnalysisTreeInterpreterParameter.ActualValue;
186      var problemData = ProblemDataParameter.ActualValue;
187      var applyLinearScaling = ApplyLinearScalingParameter.ActualValue.Value;
188      var constantOptimizationIterations = MaxCOIterationsParameter.Value.Value; // fixed value parameter => Value
189      var estimationLimits = EstimationLimitsParameter.ActualValue; // lookup parameter => ActualValue
190      var percentageBest = PercentageBestParameter.Value.Value;
191      var optimizeConstants = OptimizeConstantsParameter.Value.Value;
192      var pruneTrees = PruneTreesParameter.Value.Value;
[12460]193
[13665]194      var allowedInputVariables = problemData.AllowedInputVariables.ToList();
195      DataTable dataTable;
[13728]196      if (!results.ContainsKey(AverageVariableImpactsHistoryResultName)) {
[13665]197        dataTable = new DataTable("Variable impacts", "Average impact of variables over the population");
198        dataTable.VisualProperties.XAxisTitle = "Generation";
199        dataTable.VisualProperties.YAxisTitle = "Average variable impact";
[13728]200        results.Add(new Result(AverageVariableImpactsHistoryResultName, dataTable));
[12568]201
[13665]202        foreach (var v in allowedInputVariables) {
203          dataTable.Rows.Add(new DataRow(v) { VisualProperties = { StartIndexZero = true } });
204        }
205      }
[13728]206      dataTable = (DataTable)results[AverageVariableImpactsHistoryResultName].Value;
207
[13665]208      int nTrees = (int)Math.Round(trees.Length * percentageBest);
209      var bestTrees = trees.Take(nTrees).Select(x => (ISymbolicExpressionTree)x.Clone()).ToList();
210      // simplify trees before doing anything else
211      var simplifiedTrees = bestTrees.Select(x => simplifier.Simplify(x)).ToList();
[12460]212
[13665]213      if (optimizeConstants) {
214        for (int i = 0; i < simplifiedTrees.Count; ++i) {
[13728]215          qualities[i].Value = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, simplifiedTrees[i], problemData, problemData.TrainingIndices, applyLinearScaling, constantOptimizationIterations, true, estimationLimits.Upper, estimationLimits.Lower);
[13665]216        }
217      }
[12460]218
[13665]219      if (pruneTrees) {
220        for (int i = 0; i < simplifiedTrees.Count; ++i) {
221          simplifiedTrees[i] = SymbolicRegressionPruningOperator.Prune(simplifiedTrees[i], impactsCalculator, interpreter, problemData, estimationLimits, problemData.TrainingIndices);
222        }
223      }
224      // map each variable to a list of indices of the trees that contain it
225      var variablesToTreeIndices = allowedInputVariables.ToDictionary(x => x, x => Enumerable.Range(0, simplifiedTrees.Count).Where(i => ContainsVariable(simplifiedTrees[i], x)).ToList());
[12568]226
[13665]227      foreach (var mapping in variablesToTreeIndices) {
228        var variableName = mapping.Key;
229        var median = problemData.Dataset.GetDoubleValues(variableName, problemData.TrainingIndices).Median();
230        var ds = new ModifiableDataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(x => problemData.Dataset.GetReadOnlyDoubleValues(x).ToList()));
231        foreach (var i in problemData.TrainingIndices) {
232          ds.SetVariableValue(median, variableName, i);
233        }
234        var pd = new RegressionProblemData(ds, allowedInputVariables, problemData.TargetVariable);
235        pd.TrainingPartition.Start = problemData.TrainingPartition.Start;
236        pd.TrainingPartition.End = problemData.TrainingPartition.End;
237        pd.TestPartition.Start = problemData.TestPartition.Start;
238        pd.TestPartition.End = problemData.TestPartition.End;
[12460]239
[13665]240        var indices = mapping.Value;
241        var averageImpact = 0d;
242        for (int i = 0; i < indices.Count; ++i) {
[13728]243          var tree = simplifiedTrees[i];
[13665]244          var originalQuality = qualities[i].Value;
245          double newQuality;
246          if (optimizeConstants) {
[13728]247            newQuality = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, pd, problemData.TrainingIndices, applyLinearScaling, constantOptimizationIterations, true, estimationLimits.Upper, estimationLimits.Lower);
[13665]248          } else {
249            var evaluator = EvaluatorParameter.ActualValue;
[13728]250            newQuality = evaluator.Evaluate(this.ExecutionContext, tree, pd, pd.TrainingIndices);
[13665]251          }
252          averageImpact += originalQuality - newQuality; // impact calculated this way may be negative
253        }
254        averageImpact /= indices.Count;
255        dataTable.Rows[variableName].Values.Add(averageImpact);
256      }
[12460]257
[13728]258      var averageVariableImpacts = new DoubleMatrix(dataTable.Rows.Count, 1);
259      var rowNames = dataTable.Rows.Select(x => x.Name).ToList();
260      averageVariableImpacts.RowNames = rowNames;
261      for (int i = 0; i < rowNames.Count; ++i) {
262        averageVariableImpacts[i, 0] = dataTable.Rows[rowNames[i]].Values.Last();
263      }
264      if (!results.ContainsKey(AverageVariableImpactsResultName)) {
265        results.Add(new Result(AverageVariableImpactsResultName, averageVariableImpacts));
266      } else {
267        results[AverageVariableImpactsResultName].Value = averageVariableImpacts;
268      }
[13665]269      return base.Apply();
270    }
[12460]271
[13665]272    private static bool ContainsVariable(ISymbolicExpressionTree tree, string variableName) {
273      return tree.IterateNodesPrefix().OfType<VariableTreeNode>().Any(x => x.VariableName == variableName);
[12460]274    }
[13665]275  }
276}
Note: See TracBrowser for help on using the repository browser.