Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3107_LearningALPS/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/Creators/NodeImpactsReinitializationStrategyController.cs @ 17917

Last change on this file since 17917 was 17917, checked in by bburlacu, 3 years ago

#3107: Add node impacts strategy controller.

File size: 5.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4
5using HEAL.Attic;
6using HeuristicLab.Common;
7using HeuristicLab.Core;
8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
9using HeuristicLab.Operators;
10using HeuristicLab.Optimization;
11using HeuristicLab.Parameters;
12
13namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
14  [Item("NodeImpactsReinitializationStrategyController", "")]
15  [StorableType("C015C8C3-283D-4F51-A582-367587596709")]
16  public class NodeImpactsReinitializationStrategyController : InstrumentedOperator, IReinitializationStrategyController {
17    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
18    private const string SymbolicExpressionTreeGrammarParameterName = "SymbolicExpressionTreeGrammar";
19    private const string SymbolicDataAnalysisTreeInterpreterParameterName = "SymbolicDataAnalysisTreeInterpreter";
20    private const string ProblemDataParameterName = "ProblemData";
21    private const string EstimationLimitsParameterName = "EstimationLimits";
22
23    #region Parameter Properties
24    public IScopeTreeLookupParameter<ISymbolicExpressionTree> SymbolicExpressionTreeParameter {
25      get { return (IScopeTreeLookupParameter<ISymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
26    }
27
28    public IValueLookupParameter<ISymbolicExpressionGrammar> SymbolicExpressionTreeGrammarParameter {
29      get { return (IValueLookupParameter<ISymbolicExpressionGrammar>)Parameters[SymbolicExpressionTreeGrammarParameterName]; }
30    }
31
32    public ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter> SymbolicDataAnalysisTreeInterpreterParameter {
33      get { return (ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>)Parameters[SymbolicDataAnalysisTreeInterpreterParameterName]; }
34    }
35
36    public ILookupParameter<IRegressionProblemData> ProblemDataParameter {
37      get { return (ILookupParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
38    }
39
40    public IValueLookupParameter<DoubleLimit> EstimationLimitsParameter {
41      get { return (IValueLookupParameter<DoubleLimit>)Parameters[EstimationLimitsParameterName]; }
42    }
43    #endregion
44
45    #region Constructors
46    public NodeImpactsReinitializationStrategyController() {
47      Parameters.Add(new ScopeTreeLookupParameter<ISymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression tree whose length should be calculated."));
48      Parameters.Add(new ValueLookupParameter<ISymbolicExpressionGrammar>(SymbolicExpressionTreeGrammarParameterName, "The tree grammar that defines the correct syntax of symbolic expression trees that should be created."));
49      Parameters.Add(new LookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>(SymbolicDataAnalysisTreeInterpreterParameterName, "The symbolic data analysis tree interpreter for the symbolic expression tree."));
50      Parameters.Add(new LookupParameter<IRegressionProblemData>(ProblemDataParameterName, "The problem data for the symbolic regression solution."));
51      Parameters.Add(new ValueLookupParameter<DoubleLimit>(EstimationLimitsParameterName, "The lower and upper limit for the estimated values produced by the symbolic regression model."));
52    }
53
54    private NodeImpactsReinitializationStrategyController(NodeImpactsReinitializationStrategyController original, Cloner cloner)
55      : base(original, cloner) { }
56    public override IDeepCloneable Clone(Cloner cloner) {
57      return new NodeImpactsReinitializationStrategyController(this, cloner);
58    }
59
60    [StorableConstructor]
61    private NodeImpactsReinitializationStrategyController(StorableConstructorFlag _) : base(_) { }
62    [StorableHook(HookType.AfterDeserialization)]
63    private void AfterDeserialization() {
64    }
65    #endregion
66
67
68    public override IOperation InstrumentedApply() {
69      var trees = SymbolicExpressionTreeParameter.ActualValue;
70      var grammar = SymbolicExpressionTreeGrammarParameter.ActualValue;
71      var pd = ProblemDataParameter.ActualValue;
72      var interpreter = SymbolicDataAnalysisTreeInterpreterParameter.ActualValue;
73
74      if (interpreter == null) {
75        interpreter = new SymbolicDataAnalysisExpressionTreeBatchInterpreter();
76      }
77
78      var estimationLimits = EstimationLimitsParameter.ActualValue;
79      var ds = ((Dataset)pd.Dataset).ToModifiable();
80
81      var symbolImpacts = new Dictionary<string, double>();
82      var symbolCounts = new Dictionary<string, int>();
83
84      var impactValuesCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
85
86      foreach (var tree in trees) {
87        var model = new SymbolicRegressionModel(pd.TargetVariable, tree, interpreter, estimationLimits.Lower, estimationLimits.Upper);
88        var root = tree.Root;
89        while (root.Symbol is ProgramRootSymbol || root.Symbol is StartSymbol) {
90          root = root.GetSubtree(0);
91        }
92
93        foreach(var node in root.IterateNodesPrefix().Where(x => x.SubtreeCount > 0)) {
94          impactValuesCalculator.CalculateImpactAndReplacementValues(model, node, pd, pd.TrainingIndices, out double impactValue, out double replacementValue, out double newQuality);
95
96          var name = node.Symbol.Name;
97          if (symbolCounts.TryGetValue(name, out int count)) {
98            symbolCounts[name] = count + 1;
99            symbolImpacts[name] += impactValue;
100          } else {
101            symbolCounts[name] = 1;
102            symbolImpacts[name] = impactValue;
103          }
104        }
105      }
106
107      foreach(var symbol in grammar.AllowedSymbols) {
108        if (symbolImpacts.TryGetValue(symbol.Name, out double impact)) {
109          var f = Math.Max(0, impact / symbolCounts[symbol.Name]); // do something clever here
110          symbol.InitialFrequency = f;
111        }
112      }
113
114      return base.InstrumentedApply();
115    }
116  }
117}
Note: See TracBrowser for help on using the repository browser.