Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2870_AutoDiff-nuget/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Views/3.4/SymbolicRegressionSolutionResponseFunctionView.cs @ 17298

Last change on this file since 17298 was 15583, checked in by swagner, 7 years ago

#2640: Updated year of copyrights in license headers

File size: 10.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Windows.Forms;
27using HeuristicLab.Common;
28using HeuristicLab.Core.Views;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.MainForm;
31using HeuristicLab.MainForm.WindowsForms;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Views {
34  [View("Response Function View")]
35  [Content(typeof(ISymbolicRegressionSolution), false)]
36  public partial class SymbolicRegressionSolutionResponseFunctionView : ItemView {
37    private Dictionary<string, List<ISymbolicExpressionTreeNode>> variableNodes;
38    private ISymbolicExpressionTree clonedTree;
39    private Dictionary<string, double> medianValues;
40    public SymbolicRegressionSolutionResponseFunctionView() {
41      InitializeComponent();
42      variableNodes = new Dictionary<string, List<ISymbolicExpressionTreeNode>>();
43      medianValues = new Dictionary<string, double>();
44      Caption = "Response Function View";
45    }
46
47    public new ISymbolicRegressionSolution Content {
48      get { return (ISymbolicRegressionSolution)base.Content; }
49      set { base.Content = value; }
50    }
51
52    protected override void RegisterContentEvents() {
53      base.RegisterContentEvents();
54      Content.ModelChanged += new EventHandler(Content_ModelChanged);
55      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
56    }
57    protected override void DeregisterContentEvents() {
58      base.DeregisterContentEvents();
59      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
60      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
61    }
62
63    private void Content_ModelChanged(object sender, EventArgs e) {
64      OnModelChanged();
65    }
66    private void Content_ProblemDataChanged(object sender, EventArgs e) {
67      OnProblemDataChanged();
68    }
69
70    protected virtual void OnModelChanged() {
71      this.UpdateView();
72    }
73
74    protected virtual void OnProblemDataChanged() {
75      this.UpdateView();
76    }
77
78    protected override void OnContentChanged() {
79      base.OnContentChanged();
80      this.UpdateView();
81    }
82
83    private void UpdateView() {
84      if (Content != null && Content.Model != null && Content.ProblemData != null) {
85        var referencedVariables =
86       (from varNode in Content.Model.SymbolicExpressionTree.IterateNodesPrefix().OfType<VariableTreeNode>()
87        select varNode.VariableName)
88         .Distinct()
89         .OrderBy(x => x, new NaturalStringComparer())
90         .ToList();
91
92        medianValues.Clear();
93        foreach (var variableName in referencedVariables) {
94          medianValues.Add(variableName, Content.ProblemData.Dataset.GetDoubleValues(variableName).Median());
95        }
96
97        comboBox.Items.Clear();
98        comboBox.Items.AddRange(referencedVariables.ToArray());
99        comboBox.SelectedIndex = 0;
100      }
101    }
102
103    private void CreateSliders(IEnumerable<string> variableNames) {
104      flowLayoutPanel.Controls.Clear();
105
106      foreach (var variableName in variableNames) {
107        var variableTrackbar = new VariableTrackbar(variableName,
108                                                    Content.ProblemData.Dataset.GetDoubleValues(variableName));
109        variableTrackbar.Size = new Size(variableTrackbar.Size.Width, flowLayoutPanel.Size.Height - 23);
110        variableTrackbar.ValueChanged += TrackBarValueChanged;
111        flowLayoutPanel.Controls.Add(variableTrackbar);
112      }
113    }
114
115    private void TrackBarValueChanged(object sender, EventArgs e) {
116      var trackBar = (VariableTrackbar)sender;
117      string variableName = trackBar.VariableName;
118      ChangeVariableValue(variableName, trackBar.Value);
119    }
120
121    private void ChangeVariableValue(string variableName, double value) {
122      foreach (var constNode in variableNodes[variableName].Cast<ConstantTreeNode>())
123        constNode.Value = value;
124
125      UpdateResponseSeries();
126    }
127
128    private void UpdateScatterPlot() {
129      string freeVariable = (string)comboBox.SelectedItem;
130      IEnumerable<string> fixedVariables = comboBox.Items.OfType<string>()
131        .Except(new string[] { freeVariable });
132
133      // scatter plots for subset of samples that have values near the median values for all variables
134      Func<int, bool> NearMedianValue = (r) => {
135        foreach (var fixedVar in fixedVariables) {
136          double med = medianValues[fixedVar];
137          if (!(Content.ProblemData.Dataset.GetDoubleValue(fixedVar, r) < med + 0.1 * Math.Abs(med) &&
138            Content.ProblemData.Dataset.GetDoubleValue(fixedVar, r) > med - 0.1 * Math.Abs(med)))
139            return false;
140        }
141        return true;
142      };
143
144      var mainTrainingIndices = (from row in Content.ProblemData.TrainingIndices
145                                 where NearMedianValue(row)
146                                 select row)
147        .ToArray();
148      var mainTestIndices = (from row in Content.ProblemData.TestIndices
149                             where NearMedianValue(row)
150                             select row)
151        .ToArray();
152
153      var freeVariableValues = Content.ProblemData.Dataset.GetDoubleValues(freeVariable, mainTrainingIndices).ToArray();
154      var trainingValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable,
155                                                                     mainTrainingIndices).ToArray();
156      Array.Sort(freeVariableValues, trainingValues);
157      responseChart.Series["Training Data"].Points.DataBindXY(freeVariableValues, trainingValues);
158
159      freeVariableValues = Content.ProblemData.Dataset.GetDoubleValues(freeVariable, mainTestIndices).ToArray();
160      var testValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable,
161                                                                     mainTestIndices).ToArray();
162      Array.Sort(freeVariableValues, testValues);
163      responseChart.Series["Test Data"].Points.DataBindXY(freeVariableValues, testValues);
164
165      // draw scatter plots of remaining values
166      freeVariableValues = Content.ProblemData.Dataset.GetDoubleValues(freeVariable, Content.ProblemData.TrainingIndices).ToArray();
167      trainingValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable,
168                                                                     Content.ProblemData.TrainingIndices).ToArray();
169      Array.Sort(freeVariableValues, trainingValues);
170      responseChart.Series["Training Data (edge)"].Points.DataBindXY(freeVariableValues, trainingValues);
171
172      freeVariableValues = Content.ProblemData.Dataset.GetDoubleValues(freeVariable, Content.ProblemData.TestIndices).ToArray();
173      testValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable,
174                                                                     Content.ProblemData.TestIndices).ToArray();
175      Array.Sort(freeVariableValues, testValues);
176      responseChart.Series["Test Data (edge)"].Points.DataBindXY(freeVariableValues, testValues);
177
178
179
180      responseChart.ChartAreas[0].AxisX.Maximum = Math.Ceiling(freeVariableValues.Max());
181      responseChart.ChartAreas[0].AxisX.Minimum = Math.Floor(freeVariableValues.Min());
182      responseChart.ChartAreas[0].AxisY.Maximum = Math.Ceiling(Math.Max(testValues.Max(), trainingValues.Max()));
183      responseChart.ChartAreas[0].AxisY.Minimum = Math.Floor(Math.Min(testValues.Min(), trainingValues.Min()));
184    }
185
186    private void UpdateResponseSeries() {
187      string freeVariable = (string)comboBox.SelectedItem;
188
189      var freeVariableValues = Content.ProblemData.Dataset.GetDoubleValues(freeVariable, Content.ProblemData.TrainingIndices).ToArray();
190      var responseValues = Content.Model.Interpreter.GetSymbolicExpressionTreeValues(clonedTree,
191                                                                              Content.ProblemData.Dataset,
192                                                                              Content.ProblemData.TrainingIndices)
193                                                                              .ToArray();
194      Array.Sort(freeVariableValues, responseValues);
195      responseChart.Series["Model Response"].Points.DataBindXY(freeVariableValues, responseValues);
196    }
197
198    private void ComboBoxSelectedIndexChanged(object sender, EventArgs e) {
199      string freeVariable = (string)comboBox.SelectedItem;
200      IEnumerable<string> fixedVariables = comboBox.Items.OfType<string>()
201        .Except(new string[] { freeVariable });
202
203      variableNodes.Clear();
204      clonedTree = (ISymbolicExpressionTree)Content.Model.SymbolicExpressionTree.Clone();
205
206      foreach (var varNode in clonedTree.IterateNodesPrefix().OfType<VariableTreeNode>()) {
207        if (fixedVariables.Contains(varNode.VariableName)) {
208          if (!variableNodes.ContainsKey(varNode.VariableName))
209            variableNodes.Add(varNode.VariableName, new List<ISymbolicExpressionTreeNode>());
210
211          int childIndex = varNode.Parent.IndexOfSubtree(varNode);
212          var replacementNode = MakeConstantTreeNode(medianValues[varNode.VariableName]);
213          var parent = varNode.Parent;
214          parent.RemoveSubtree(childIndex);
215          parent.InsertSubtree(childIndex, MakeProduct(replacementNode, varNode.Weight));
216          variableNodes[varNode.VariableName].Add(replacementNode);
217        }
218      }
219
220      CreateSliders(fixedVariables);
221      UpdateScatterPlot();
222      UpdateResponseSeries();
223    }
224
225    private ISymbolicExpressionTreeNode MakeProduct(ConstantTreeNode c, double weight) {
226      var mul = new Multiplication();
227      var prod = mul.CreateTreeNode();
228      prod.AddSubtree(MakeConstantTreeNode(weight));
229      prod.AddSubtree(c);
230      return prod;
231    }
232
233    private ConstantTreeNode MakeConstantTreeNode(double value) {
234      Constant constant = new Constant();
235      constant.MinValue = value - 1;
236      constant.MaxValue = value + 1;
237      ConstantTreeNode constantTreeNode = (ConstantTreeNode)constant.CreateTreeNode();
238      constantTreeNode.Value = value;
239      return constantTreeNode;
240    }
241  }
242}
Note: See TracBrowser for help on using the repository browser.