Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/DiscriminantFunctionClassificationSolutionThresholdView.cs @ 15401

Last change on this file since 15401 was 14255, checked in by pfleck, 8 years ago

#2632

  • Added the name of the target variable in plots and charts (scatter, line, ...).
  • Renamed MathSymbolicDataAnalysisModelView and added two subclasses for regression and classification that shows the name of the target variable in the equation. (added and used a new Format method to the LatexFormatter that uses the actual target name when encountering the StartSymbol)
File size: 13.3 KB
RevLine 
[4417]1#region License Information
2/* HeuristicLab
[14185]3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[4417]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Windows.Forms;
27using System.Windows.Forms.DataVisualization.Charting;
28using HeuristicLab.Common;
29using HeuristicLab.MainForm;
30using HeuristicLab.MainForm.WindowsForms;
31
[5829]32namespace HeuristicLab.Problems.DataAnalysis.Views {
[5975]33  [View("Classification Threshold")]
[6729]34  [Content(typeof(IDiscriminantFunctionClassificationSolution), false)]
[6642]35  public sealed partial class DiscriminantFunctionClassificationSolutionThresholdView : DataAnalysisSolutionEvaluationView {
[4417]36    private const double TrainingAxisValue = 0.0;
37    private const double TestAxisValue = 10.0;
38    private const double TrainingTestBorder = (TestAxisValue - TrainingAxisValue) / 2;
39    private const string TrainingLabelText = "Training Samples";
40    private const string TestLabelText = "Test Samples";
41
[5664]42    public new IDiscriminantFunctionClassificationSolution Content {
43      get { return (IDiscriminantFunctionClassificationSolution)base.Content; }
[4417]44      set { base.Content = value; }
45    }
46
47    private Dictionary<double, Series> classValueSeriesMapping;
[13100]48    private System.Random random;
[8868]49    private bool updateInProgress;
[4417]50
[5829]51    public DiscriminantFunctionClassificationSolutionThresholdView()
[4417]52      : base() {
53      InitializeComponent();
54
55      classValueSeriesMapping = new Dictionary<double, Series>();
[13100]56      random = new System.Random();
[4417]57      updateInProgress = false;
58
[4651]59      this.chart.CustomizeAllChartAreas();
[4417]60      this.chart.ChartAreas[0].CursorX.IsUserSelectionEnabled = true;
61      this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true;
62      this.chart.ChartAreas[0].AxisX.Minimum = TrainingAxisValue - TrainingTestBorder;
63      this.chart.ChartAreas[0].AxisX.Maximum = TestAxisValue + TrainingTestBorder;
64      AddCustomLabelToAxis(this.chart.ChartAreas[0].AxisX);
65
66      this.chart.ChartAreas[0].AxisY.Title = "Estimated Values";
[8868]67      this.chart.ChartAreas[0].AxisY.IsStartedFromZero = false;
[4417]68      this.chart.ChartAreas[0].CursorY.IsUserSelectionEnabled = true;
69      this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true;
70    }
71
72    private void AddCustomLabelToAxis(Axis axis) {
73      CustomLabel trainingLabel = new CustomLabel();
74      trainingLabel.Text = TrainingLabelText;
75      trainingLabel.FromPosition = TrainingAxisValue - TrainingTestBorder;
76      trainingLabel.ToPosition = TrainingAxisValue + TrainingTestBorder;
77      axis.CustomLabels.Add(trainingLabel);
78
79      CustomLabel testLabel = new CustomLabel();
80      testLabel.Text = TestLabelText;
81      testLabel.FromPosition = TestAxisValue - TrainingTestBorder;
82      testLabel.ToPosition = TestAxisValue + TrainingTestBorder;
83      axis.CustomLabels.Add(testLabel);
84    }
85
86    protected override void RegisterContentEvents() {
87      base.RegisterContentEvents();
[5664]88      Content.ModelChanged += new EventHandler(Content_ModelChanged);
[4417]89      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
90    }
91    protected override void DeregisterContentEvents() {
92      base.DeregisterContentEvents();
[5664]93      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
[4417]94      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
95    }
96
97    private void Content_ProblemDataChanged(object sender, EventArgs e) {
98      UpdateChart();
99    }
[5664]100    private void Content_ModelChanged(object sender, EventArgs e) {
[5736]101      Content.Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
[4417]102      UpdateChart();
103    }
[5736]104    private void Model_ThresholdsChanged(object sender, EventArgs e) {
[4417]105      AddThresholds();
106    }
107    protected override void OnContentChanged() {
108      base.OnContentChanged();
109      UpdateChart();
110    }
111
112    private void UpdateChart() {
113      if (InvokeRequired) Invoke((Action)UpdateChart);
114      else if (!updateInProgress) {
115        updateInProgress = true;
116        chart.Series.Clear();
117        classValueSeriesMapping.Clear();
118        if (Content != null) {
119          IEnumerator<string> classNameEnumerator = Content.ProblemData.ClassNames.GetEnumerator();
[5664]120          IEnumerator<double> classValueEnumerator = Content.ProblemData.ClassValues.OrderBy(x => x).GetEnumerator();
[4417]121          while (classNameEnumerator.MoveNext() && classValueEnumerator.MoveNext()) {
[14255]122            Series series = new Series(Content.Model.TargetVariable + ": " + classNameEnumerator.Current);
[4417]123            series.ChartType = SeriesChartType.FastPoint;
124            series.Tag = classValueEnumerator.Current;
125            chart.Series.Add(series);
126            classValueSeriesMapping.Add(classValueEnumerator.Current, series);
127            FillSeriesWithDataPoints(series);
128          }
[8868]129          AddThresholds();
[4417]130        }
131        chart.ChartAreas[0].RecalculateAxesScale();
132        updateInProgress = false;
133      }
134    }
135
136    private void FillSeriesWithDataPoints(Series series) {
[4469]137      List<double> estimatedValues = Content.EstimatedValues.ToList();
[6740]138      var targetValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToList();
139
[8139]140      foreach (int row in Content.ProblemData.TrainingIndices) {
[4469]141        double estimatedValue = estimatedValues[row];
[6740]142        double targetValue = targetValues[row];
[4469]143        if (targetValue.IsAlmost((double)series.Tag)) {
[4417]144          double jitterValue = random.NextDouble() * 2.0 - 1.0;
145          DataPoint point = new DataPoint();
146          point.XValue = TrainingAxisValue + 0.01 * jitterValue * JitterTrackBar.Value * (TrainingTestBorder * 0.9);
147          point.YValues[0] = estimatedValue;
148          point.Tag = new KeyValuePair<double, double>(TrainingAxisValue, jitterValue);
149          series.Points.Add(point);
150        }
151      }
152
[8139]153      foreach (int row in Content.ProblemData.TestIndices) {
[4469]154        double estimatedValue = estimatedValues[row];
[6740]155        double targetValue = targetValues[row];
156        if (targetValue.IsAlmost((double)series.Tag)) {
[4417]157          double jitterValue = random.NextDouble() * 2.0 - 1.0;
158          DataPoint point = new DataPoint();
159          point.XValue = TestAxisValue + 0.01 * jitterValue * JitterTrackBar.Value * (TrainingTestBorder * 0.9);
160          point.YValues[0] = estimatedValue;
161          point.Tag = new KeyValuePair<double, double>(TestAxisValue, jitterValue);
162          series.Points.Add(point);
163        }
164      }
[4469]165
[4417]166      UpdateCursorInterval();
167    }
168
169    private void AddThresholds() {
170      chart.Annotations.Clear();
171      int classIndex = 1;
[8840]172      IClassificationProblemData problemData = Content.ProblemData;
173      var classValues = Content.Model.ClassValues.ToArray();
174      Axis y = chart.ChartAreas[0].AxisY;
175      Axis x = chart.ChartAreas[0].AxisX;
176      string name;
[5717]177      foreach (double threshold in Content.Model.Thresholds) {
[4417]178        if (!double.IsInfinity(threshold)) {
179          HorizontalLineAnnotation annotation = new HorizontalLineAnnotation();
180          annotation.AllowMoving = true;
181          annotation.AllowResizing = false;
182          annotation.LineWidth = 2;
183          annotation.LineColor = Color.Red;
184          annotation.IsInfinitive = true;
185          annotation.ClipToChartArea = chart.ChartAreas[0].Name;
186          annotation.Tag = classIndex;  //save classIndex as Tag to avoid moving the threshold accross class bounderies
187          annotation.AxisX = chart.ChartAreas[0].AxisX;
[8840]188          annotation.AxisY = y;
[4417]189          annotation.Y = threshold;
190
[8840]191          name = problemData.GetClassName(classValues[classIndex - 1]);
[8868]192          TextAnnotation beneathLeft = CreateTextAnnotation(name, classIndex, x, y, x.Minimum, threshold, ContentAlignment.TopLeft);
193          TextAnnotation beneathRight = CreateTextAnnotation(name, classIndex, x, y, x.Maximum, threshold, ContentAlignment.TopRight);
[8840]194
195          name = problemData.GetClassName(classValues[classIndex]);
[8868]196          TextAnnotation aboveLeft = CreateTextAnnotation(name, classIndex, x, y, x.Minimum, threshold, ContentAlignment.BottomLeft);
197          TextAnnotation aboveRight = CreateTextAnnotation(name, classIndex, x, y, x.Maximum, threshold, ContentAlignment.BottomRight);
[8840]198
[4417]199          chart.Annotations.Add(annotation);
[8840]200          chart.Annotations.Add(beneathLeft);
201          chart.Annotations.Add(aboveLeft);
[8868]202          chart.Annotations.Add(beneathRight);
[8840]203          chart.Annotations.Add(aboveRight);
[8868]204
205          beneathLeft.ResizeToContent();
206          beneathRight.ResizeToContent();
207          aboveLeft.ResizeToContent();
208          aboveRight.ResizeToContent();
209
210          beneathRight.Width = -beneathRight.Width;
211          aboveLeft.Height = -aboveLeft.Height;
212          aboveRight.Height = -aboveRight.Height;
213          aboveRight.Width = -aboveRight.Width;
214
[4417]215          classIndex++;
216        }
217      }
218    }
219
[8868]220    private TextAnnotation CreateTextAnnotation(string name, int classIndex, Axis axisX, Axis axisY, double x, double y, ContentAlignment alignment) {
[8840]221      TextAnnotation annotation = new TextAnnotation();
222      annotation.Text = name;
223      annotation.AllowMoving = true;
224      annotation.AllowResizing = false;
225      annotation.AllowSelecting = false;
[8868]226      annotation.IsSizeAlwaysRelative = true;
[8840]227      annotation.ClipToChartArea = chart.ChartAreas[0].Name;
228      annotation.Tag = classIndex;
[8868]229      annotation.AxisX = axisX;
230      annotation.AxisY = axisY;
231      annotation.Alignment = alignment;
232      annotation.X = x;
233      annotation.Y = y;
[8840]234      return annotation;
235    }
236
[4417]237    private void JitterTrackBar_ValueChanged(object sender, EventArgs e) {
238      foreach (Series series in chart.Series) {
239        foreach (DataPoint point in series.Points) {
240          double value = ((KeyValuePair<double, double>)point.Tag).Key;
241          double jitterValue = ((KeyValuePair<double, double>)point.Tag).Value; ;
242          point.XValue = value + 0.01 * jitterValue * JitterTrackBar.Value * (TrainingTestBorder * 0.9);
243        }
244      }
245    }
246
247    private void chart_CustomizeLegend(object sender, CustomizeLegendEventArgs e) {
248      foreach (LegendItem legendItem in e.LegendItems) {
249        var series = chart.Series[legendItem.SeriesName];
250        if (series != null) {
251          bool seriesIsInvisible = series.Points.Count == 0;
252          foreach (LegendCell cell in legendItem.Cells)
253            cell.ForeColor = seriesIsInvisible ? Color.Gray : Color.Black;
254        }
255      }
256    }
257
258    private void chart_MouseMove(object sender, MouseEventArgs e) {
259      HitTestResult result = chart.HitTest(e.X, e.Y);
260      if (result.ChartElementType == ChartElementType.LegendItem)
261        this.Cursor = Cursors.Hand;
262      else
263        this.Cursor = Cursors.Default;
264    }
265
266    private void ToggleSeries(Series series) {
267      if (series.Points.Count == 0)
268        FillSeriesWithDataPoints(series);
269      else
270        series.Points.Clear();
271    }
272
273    private void chart_MouseDown(object sender, MouseEventArgs e) {
274      HitTestResult result = chart.HitTest(e.X, e.Y);
275      if (result.ChartElementType == ChartElementType.LegendItem) {
276        if (result.Series != null) ToggleSeries(result.Series);
277      }
278    }
279
280    private void chart_AnnotationPositionChanging(object sender, AnnotationPositionChangingEventArgs e) {
281      int classIndex = (int)e.Annotation.Tag;
[5717]282      double[] thresholds = Content.Model.Thresholds.ToArray();
[4417]283      thresholds[classIndex] = e.NewLocationY;
[8550]284      Array.Sort(thresholds);
[5736]285      Content.Model.SetThresholdsAndClassValues(thresholds, Content.Model.ClassValues);
[4417]286    }
287
288    private void UpdateCursorInterval() {
289      Series series = chart.Series[0];
290      double[] xValues = (from point in series.Points
291                          where !point.IsEmpty
292                          select point.XValue)
293                    .DefaultIfEmpty(1.0)
294                    .ToArray();
295      double[] yValues = (from point in series.Points
296                          where !point.IsEmpty
297                          select point.YValues[0])
298                    .DefaultIfEmpty(1.0)
299                    .ToArray();
300
301      double xRange = xValues.Max() - xValues.Min();
302      double yRange = yValues.Max() - yValues.Min();
303      if (xRange.IsAlmost(0.0)) xRange = 1.0;
304      if (yRange.IsAlmost(0.0)) yRange = 1.0;
305      double xDigits = (int)Math.Log10(xRange) - 3;
306      double yDigits = (int)Math.Log10(yRange) - 3;
307      double xZoomInterval = Math.Pow(10, xDigits);
308      double yZoomInterval = Math.Pow(10, yDigits);
309      this.chart.ChartAreas[0].CursorX.Interval = xZoomInterval;
310      this.chart.ChartAreas[0].CursorY.Interval = yZoomInterval;
311    }
312  }
313}
Note: See TracBrowser for help on using the repository browser.