source: branches/3136_Structural_GP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/StructuredSymbolicRegressionSingleObjectiveProblem.cs @ 18072

Last change on this file since 18072 was 18072, checked in by chaider, 8 months ago

#3136

  • Added info text in StructureTemplateView
  • Fixed cloning constructors
  • Added check if linear scaling nodes are set
File size: 9.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading.Tasks;
6using HeuristicLab.Core;
7using HeuristicLab.Optimization;
8using HEAL.Attic;
9using HeuristicLab.Common;
10using HeuristicLab.Problems.Instances;
11using HeuristicLab.Parameters;
12using HeuristicLab.Data;
13using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
14
15namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
16  [StorableType("7464E84B-65CC-440A-91F0-9FA920D730F9")]
17  [Item(Name = "Structured Symbolic Regression Single Objective Problem (single-objective)", Description = "A problem with a structural definition and unfixed subfunctions.")]
18  [Creatable(CreatableAttribute.Categories.GeneticProgrammingProblems, Priority = 150)]
19  public class StructuredSymbolicRegressionSingleObjectiveProblem : SingleObjectiveBasicProblem<MultiEncoding>, IRegressionProblem, IProblemInstanceConsumer<RegressionProblemData> {
20
21    #region Constants
22    private const string ProblemDataParameterName = "ProblemData";
23    private const string StructureDefinitionParameterName = "Structure Definition";
24    private const string StructureTemplateParameterName = "Structure Template";
25
26    private const string StructureTemplateDescriptionText =
27      "Enter your expression as string in infix format into the empty input field.\n" +
28      "By checking the \"Apply Linear Scaling\" checkbox you can add the relevant scaling terms to your expression.\n" +
29      "After entering the expression click parse to build the tree.\n" +
30      "To edit the defined sub-functions, click on the coressponding colored node in the tree view.";
31    #endregion
32
33    #region Parameters
34    public IValueParameter<IRegressionProblemData> ProblemDataParameter => (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName];
35    public IFixedValueParameter<StringValue> StructureDefinitionParameter => (IFixedValueParameter<StringValue>)Parameters[StructureDefinitionParameterName];
36    public IFixedValueParameter<StructureTemplate> StructureTemplateParameter => (IFixedValueParameter<StructureTemplate>)Parameters[StructureTemplateParameterName];
37    #endregion
38
39    #region Properties
40    public IRegressionProblemData ProblemData {
41      get => ProblemDataParameter.Value;
42      set {
43        ProblemDataParameter.Value = value;
44        ProblemDataChanged?.Invoke(this, EventArgs.Empty);
45      }
46    }
47
48    public string StructureDefinition {
49      get => StructureDefinitionParameter.Value.Value;
50      set => StructureDefinitionParameter.Value.Value = value;
51    }
52
53    public StructureTemplate StructureTemplate {
54      get => StructureTemplateParameter.Value;
55    }
56
57    public ISymbolicDataAnalysisExpressionTreeInterpreter Interpreter { get; } = new SymbolicDataAnalysisExpressionTreeInterpreter();
58
59    IParameter IDataAnalysisProblem.ProblemDataParameter => ProblemDataParameter;
60    IDataAnalysisProblemData IDataAnalysisProblem.ProblemData => ProblemData;
61
62    public override bool Maximization => true;
63    #endregion
64
65    #region EventHandlers
66    public event EventHandler ProblemDataChanged;
67    #endregion
68
69    #region Constructors & Cloning
70    public StructuredSymbolicRegressionSingleObjectiveProblem() {
71      var problemData = new ShapeConstrainedRegressionProblemData();
72
73      var structureTemplate = new StructureTemplate();
74      structureTemplate.Changed += OnTemplateChanged;
75
76      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, problemData));
77      Parameters.Add(new FixedValueParameter<StructureTemplate>(StructureTemplateParameterName,
78        StructureTemplateDescriptionText, structureTemplate));
79
80
81    }
82
83    public StructuredSymbolicRegressionSingleObjectiveProblem(StructuredSymbolicRegressionSingleObjectiveProblem original,
84      Cloner cloner) : base(original, cloner){ }
85
86    [StorableConstructor]
87    protected StructuredSymbolicRegressionSingleObjectiveProblem(StorableConstructorFlag _) : base(_) { }
88    #endregion
89
90    #region Cloning
91    public override IDeepCloneable Clone(Cloner cloner) =>
92      new StructuredSymbolicRegressionSingleObjectiveProblem(this, cloner);
93    #endregion
94
95    private void OnTemplateChanged(object sender, EventArgs args) {
96      SetupStructureTemplate();
97    }
98
99    private void SetupStructureTemplate() {
100      foreach (var e in Encoding.Encodings.ToArray())
101        Encoding.Remove(e);
102
103      foreach (var f in StructureTemplate.SubFunctions.Values) {
104        SetupVariables(f);
105        if(!Encoding.Encodings.Any(x => x.Name == f.Name)) // to prevent the same encoding twice
106          Encoding.Add(new SymbolicExpressionTreeEncoding(f.Name, f.Grammar, f.MaximumSymbolicExpressionTreeLength, f.MaximumSymbolicExpressionTreeDepth));
107      }
108    }
109
110    public override void Analyze(Individual[] individuals, double[] qualities, ResultCollection results, IRandom random) {
111      base.Analyze(individuals, qualities, results, random);
112
113      int bestIdx = 0;
114      double bestQuality = Maximization ? double.MinValue : double.MaxValue;
115      for(int idx = 0; idx < qualities.Length; ++idx) {
116        if((Maximization && qualities[idx] > bestQuality) ||
117          (!Maximization && qualities[idx] < bestQuality)) {
118          bestQuality = qualities[idx];
119          bestIdx = idx;
120        }
121      }
122
123      if (results.TryGetValue("Best Tree", out IResult result)) {
124        var tree = BuildTree(individuals[bestIdx]);
125        if (StructureTemplate.ApplyLinearScaling)
126          AdjustLinearScalingParams(tree, Interpreter);
127        result.Value = tree;
128      }
129      else {
130        var tree = BuildTree(individuals[bestIdx]);
131        if (StructureTemplate.ApplyLinearScaling)
132          AdjustLinearScalingParams(tree, Interpreter);
133        results.Add(new Result("Best Tree", tree));
134      }
135       
136    }
137
138    public override double Evaluate(Individual individual, IRandom random) {
139      var tree = BuildTree(individual);
140
141      if (StructureTemplate.ApplyLinearScaling)
142        AdjustLinearScalingParams(tree, Interpreter);
143      var estimationInterval = ProblemData.VariableRanges.GetInterval(ProblemData.TargetVariable);
144      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(
145        Interpreter, tree,
146        estimationInterval.LowerBound, estimationInterval.UpperBound,
147        ProblemData, ProblemData.TrainingIndices, false);
148     
149      return quality;
150    }
151
152    private void AdjustLinearScalingParams(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter) {
153      var offsetNode = tree.Root.GetSubtree(0).GetSubtree(0);
154      var scalingNode = offsetNode.Subtrees.Where(x => !(x is ConstantTreeNode)).First();
155
156      var offsetConstantNode = (ConstantTreeNode)offsetNode.Subtrees.Where(x => x is ConstantTreeNode).First();
157      var scalingConstantNode = (ConstantTreeNode)scalingNode.Subtrees.Where(x => x is ConstantTreeNode).First();
158
159      var estimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, ProblemData.Dataset, ProblemData.TrainingIndices);
160      var targetValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
161
162      OnlineLinearScalingParameterCalculator.Calculate(estimatedValues, targetValues, out double a, out double b, out OnlineCalculatorError error);
163      if(error == OnlineCalculatorError.None) {
164        offsetConstantNode.Value = a;
165        scalingConstantNode.Value = b;
166      }
167    }
168
169    private ISymbolicExpressionTree BuildTree(Individual individual) {
170      var templateTree = (ISymbolicExpressionTree)StructureTemplate.Tree.Clone();
171
172      // build main tree
173      foreach (var n in templateTree.IterateNodesPrefix()) {
174        if (n.Symbol is SubFunctionSymbol) {
175          var subFunctionTreeNode = n as SubFunctionTreeNode;
176          var subFunctionTree = individual.SymbolicExpressionTree(subFunctionTreeNode.Name);
177          //var parent = n.Parent;
178
179          // remove SubFunctionTreeNode
180          //parent.RemoveSubtree(parent.IndexOfSubtree(subFunctionTreeNode));
181
182          // add new tree
183          var subTree = subFunctionTree.Root.GetSubtree(0)  // Start
184                                            .GetSubtree(0); // Offset
185          //parent.AddSubtree(subTree);
186          subFunctionTreeNode.AddSubtree(subTree);
187        }
188      }
189      return templateTree;
190    }
191
192    private void SetupVariables(SubFunction subFunction) {
193      var varSym = (Variable)subFunction.Grammar.GetSymbol("Variable");
194      if (varSym == null) {
195        varSym = new Variable();
196        subFunction.Grammar.AddSymbol(varSym);
197      }
198
199      var allVariables = ProblemData.InputVariables.Select(x => x.Value);
200      var allInputs = allVariables.Where(x => x != ProblemData.TargetVariable);
201
202      // set all variables
203      varSym.AllVariableNames = allVariables;
204
205      // set all allowed variables
206      if (subFunction.Arguments.Contains("_")) {
207        varSym.VariableNames = allInputs;
208      } else {
209        var vars = new List<string>();
210        var exceptions = new List<Exception>();
211        foreach (var arg in subFunction.Arguments) {
212          if (allInputs.Contains(arg))
213            vars.Add(arg);
214          else
215            exceptions.Add(new ArgumentException($"The argument '{arg}' for sub-function '{subFunction.Name}' is not a valid variable."));
216        }
217        if (exceptions.Any())
218          throw new AggregateException(exceptions);
219        varSym.VariableNames = vars;
220      }
221
222      varSym.Enabled = true;
223    }
224
225    public void Load(RegressionProblemData data) {
226      ProblemData = data;
227      SetupStructureTemplate();
228    }
229  }
230}
Note: See TracBrowser for help on using the repository browser.