Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/DiscriminantFunctionClassificationSolutionThresholdView.cs @ 8840

Last change on this file since 8840 was 8840, checked in by sforsten, 11 years ago

#1949: class names are displayed on the left and right side of the diagram, above and beneath every threshold

File size: 14.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Windows.Forms;
27using System.Windows.Forms.DataVisualization.Charting;
28using HeuristicLab.Common;
29using HeuristicLab.MainForm;
30using HeuristicLab.MainForm.WindowsForms;
31
32namespace HeuristicLab.Problems.DataAnalysis.Views {
33  [View("Classification Threshold")]
34  [Content(typeof(IDiscriminantFunctionClassificationSolution), false)]
35  public sealed partial class DiscriminantFunctionClassificationSolutionThresholdView : DataAnalysisSolutionEvaluationView {
36    private const double TrainingAxisValue = 0.0;
37    private const double TestAxisValue = 10.0;
38    private const double TrainingTestBorder = (TestAxisValue - TrainingAxisValue) / 2;
39    private const string TrainingLabelText = "Training Samples";
40    private const string TestLabelText = "Test Samples";
41
42    private const double ClassNameLeftOffset = 10.0;
43    private const double ClassNameBottomOffset = 5.0;
44
45    public new IDiscriminantFunctionClassificationSolution Content {
46      get { return (IDiscriminantFunctionClassificationSolution)base.Content; }
47      set { base.Content = value; }
48    }
49
50    private Dictionary<double, Series> classValueSeriesMapping;
51    private Random random;
52    private bool updateInProgress, updateThresholds;
53
54    public DiscriminantFunctionClassificationSolutionThresholdView()
55      : base() {
56      InitializeComponent();
57
58      classValueSeriesMapping = new Dictionary<double, Series>();
59      random = new Random();
60      updateInProgress = false;
61
62      this.chart.CustomizeAllChartAreas();
63      this.chart.ChartAreas[0].CursorX.IsUserSelectionEnabled = true;
64      this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true;
65      this.chart.ChartAreas[0].AxisX.Minimum = TrainingAxisValue - TrainingTestBorder;
66      this.chart.ChartAreas[0].AxisX.Maximum = TestAxisValue + TrainingTestBorder;
67      AddCustomLabelToAxis(this.chart.ChartAreas[0].AxisX);
68
69      this.chart.ChartAreas[0].AxisY.Title = "Estimated Values";
70      this.chart.ChartAreas[0].CursorY.IsUserSelectionEnabled = true;
71      this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true;
72    }
73
74    private void AddCustomLabelToAxis(Axis axis) {
75      CustomLabel trainingLabel = new CustomLabel();
76      trainingLabel.Text = TrainingLabelText;
77      trainingLabel.FromPosition = TrainingAxisValue - TrainingTestBorder;
78      trainingLabel.ToPosition = TrainingAxisValue + TrainingTestBorder;
79      axis.CustomLabels.Add(trainingLabel);
80
81      CustomLabel testLabel = new CustomLabel();
82      testLabel.Text = TestLabelText;
83      testLabel.FromPosition = TestAxisValue - TrainingTestBorder;
84      testLabel.ToPosition = TestAxisValue + TrainingTestBorder;
85      axis.CustomLabels.Add(testLabel);
86    }
87
88    protected override void RegisterContentEvents() {
89      base.RegisterContentEvents();
90      Content.ModelChanged += new EventHandler(Content_ModelChanged);
91      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
92    }
93    protected override void DeregisterContentEvents() {
94      base.DeregisterContentEvents();
95      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
96      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
97    }
98
99    private void Content_ProblemDataChanged(object sender, EventArgs e) {
100      UpdateChart();
101    }
102    private void Content_ModelChanged(object sender, EventArgs e) {
103      Content.Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
104      UpdateChart();
105    }
106    private void Model_ThresholdsChanged(object sender, EventArgs e) {
107      AddThresholds();
108    }
109    protected override void OnContentChanged() {
110      base.OnContentChanged();
111      UpdateChart();
112    }
113
114    private void UpdateChart() {
115      if (InvokeRequired) Invoke((Action)UpdateChart);
116      else if (!updateInProgress) {
117        updateInProgress = true;
118        chart.Series.Clear();
119        classValueSeriesMapping.Clear();
120        if (Content != null) {
121          IEnumerator<string> classNameEnumerator = Content.ProblemData.ClassNames.GetEnumerator();
122          IEnumerator<double> classValueEnumerator = Content.ProblemData.ClassValues.OrderBy(x => x).GetEnumerator();
123          while (classNameEnumerator.MoveNext() && classValueEnumerator.MoveNext()) {
124            Series series = new Series(classNameEnumerator.Current);
125            series.ChartType = SeriesChartType.FastPoint;
126            series.Tag = classValueEnumerator.Current;
127            chart.Series.Add(series);
128            classValueSeriesMapping.Add(classValueEnumerator.Current, series);
129            FillSeriesWithDataPoints(series);
130          }
131          updateThresholds = true;
132        }
133        chart.ChartAreas[0].RecalculateAxesScale();
134        updateInProgress = false;
135      }
136    }
137
138    private void FillSeriesWithDataPoints(Series series) {
139      List<double> estimatedValues = Content.EstimatedValues.ToList();
140      var targetValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToList();
141
142      foreach (int row in Content.ProblemData.TrainingIndices) {
143        double estimatedValue = estimatedValues[row];
144        double targetValue = targetValues[row];
145        if (targetValue.IsAlmost((double)series.Tag)) {
146          double jitterValue = random.NextDouble() * 2.0 - 1.0;
147          DataPoint point = new DataPoint();
148          point.XValue = TrainingAxisValue + 0.01 * jitterValue * JitterTrackBar.Value * (TrainingTestBorder * 0.9);
149          point.YValues[0] = estimatedValue;
150          point.Tag = new KeyValuePair<double, double>(TrainingAxisValue, jitterValue);
151          series.Points.Add(point);
152        }
153      }
154
155      foreach (int row in Content.ProblemData.TestIndices) {
156        double estimatedValue = estimatedValues[row];
157        double targetValue = targetValues[row];
158        if (targetValue.IsAlmost((double)series.Tag)) {
159          double jitterValue = random.NextDouble() * 2.0 - 1.0;
160          DataPoint point = new DataPoint();
161          point.XValue = TestAxisValue + 0.01 * jitterValue * JitterTrackBar.Value * (TrainingTestBorder * 0.9);
162          point.YValues[0] = estimatedValue;
163          point.Tag = new KeyValuePair<double, double>(TestAxisValue, jitterValue);
164          series.Points.Add(point);
165        }
166      }
167
168      UpdateCursorInterval();
169    }
170
171    private void chart_PostPaint(object sender, ChartPaintEventArgs e) {
172      if (updateThresholds) {
173        AddThresholds();
174        updateThresholds = false;
175      }
176    }
177
178    private void AddThresholds() {
179      chart.Annotations.Clear();
180      int classIndex = 1;
181      SizeF textSizeInPixel;
182      IClassificationProblemData problemData = Content.ProblemData;
183      var classValues = Content.Model.ClassValues.ToArray();
184      Graphics g = chart.CreateGraphics();
185      Axis y = chart.ChartAreas[0].AxisY;
186      Axis x = chart.ChartAreas[0].AxisX;
187      string name;
188      foreach (double threshold in Content.Model.Thresholds) {
189        if (!double.IsInfinity(threshold)) {
190          HorizontalLineAnnotation annotation = new HorizontalLineAnnotation();
191          annotation.AllowMoving = true;
192          annotation.AllowResizing = false;
193          annotation.LineWidth = 2;
194          annotation.LineColor = Color.Red;
195          annotation.IsInfinitive = true;
196          annotation.ClipToChartArea = chart.ChartAreas[0].Name;
197          annotation.Tag = classIndex;  //save classIndex as Tag to avoid moving the threshold accross class bounderies
198          annotation.AxisX = chart.ChartAreas[0].AxisX;
199          annotation.AxisY = y;
200          annotation.Y = threshold;
201
202          name = problemData.GetClassName(classValues[classIndex - 1]);
203          TextAnnotation beneathLeft = CreateTextAnnotation(name, classIndex, x, y);
204          beneathLeft.Y = threshold;
205          beneathLeft.X = x.Minimum;
206          TextAnnotation beneathRigth = CreateTextAnnotation(name, classIndex, x, y);
207          beneathRigth.Y = threshold;
208          textSizeInPixel = g.MeasureString(beneathRigth.Text, beneathRigth.Font);
209          double textWidthPixelPos = x.ValueToPixelPosition(x.Maximum) - textSizeInPixel.Width - ClassNameLeftOffset;
210          //check if position is within the position boundary
211          beneathRigth.X = textWidthPixelPos < 0 || textWidthPixelPos > chart.Width ? double.NaN : x.PixelPositionToValue(textWidthPixelPos);
212
213          name = problemData.GetClassName(classValues[classIndex]);
214          TextAnnotation aboveLeft = CreateTextAnnotation(name, classIndex, x, y);
215          textSizeInPixel = g.MeasureString(aboveLeft.Text, aboveLeft.Font);
216          double textHeightPixelPos = y.ValueToPixelPosition(threshold) - textSizeInPixel.Height - ClassNameBottomOffset;
217          //check if position is within the position boundary
218          aboveLeft.Y = textHeightPixelPos < 0 || textHeightPixelPos > chart.Height ? double.NaN : y.PixelPositionToValue(textHeightPixelPos);
219          aboveLeft.X = x.Minimum;
220          TextAnnotation aboveRight = CreateTextAnnotation(name, classIndex, x, y);
221          aboveRight.Y = aboveLeft.Y;
222          textWidthPixelPos = x.ValueToPixelPosition(x.Maximum) - textSizeInPixel.Width - ClassNameLeftOffset;
223          //check if position is within the position boundary
224          aboveRight.X = textWidthPixelPos < 0 || textWidthPixelPos > chart.Width ? double.NaN : x.PixelPositionToValue(textWidthPixelPos);
225
226          chart.Annotations.Add(annotation);
227          chart.Annotations.Add(beneathLeft);
228          chart.Annotations.Add(aboveLeft);
229          chart.Annotations.Add(beneathRigth);
230          chart.Annotations.Add(aboveRight);
231          classIndex++;
232        }
233      }
234    }
235
236    private TextAnnotation CreateTextAnnotation(string name, int classIndex, Axis x, Axis y) {
237      TextAnnotation annotation = new TextAnnotation();
238      annotation.Text = name;
239      annotation.AllowMoving = true;
240      annotation.AllowResizing = false;
241      annotation.AllowSelecting = false;
242      annotation.ClipToChartArea = chart.ChartAreas[0].Name;
243      annotation.Tag = classIndex;
244      annotation.AxisX = chart.ChartAreas[0].AxisX;
245      annotation.AxisY = y;
246      return annotation;
247    }
248
249    private void JitterTrackBar_ValueChanged(object sender, EventArgs e) {
250      foreach (Series series in chart.Series) {
251        foreach (DataPoint point in series.Points) {
252          double value = ((KeyValuePair<double, double>)point.Tag).Key;
253          double jitterValue = ((KeyValuePair<double, double>)point.Tag).Value; ;
254          point.XValue = value + 0.01 * jitterValue * JitterTrackBar.Value * (TrainingTestBorder * 0.9);
255        }
256      }
257    }
258
259    private void chart_CustomizeLegend(object sender, CustomizeLegendEventArgs e) {
260      foreach (LegendItem legendItem in e.LegendItems) {
261        var series = chart.Series[legendItem.SeriesName];
262        if (series != null) {
263          bool seriesIsInvisible = series.Points.Count == 0;
264          foreach (LegendCell cell in legendItem.Cells)
265            cell.ForeColor = seriesIsInvisible ? Color.Gray : Color.Black;
266        }
267      }
268    }
269
270    private void chart_MouseMove(object sender, MouseEventArgs e) {
271      HitTestResult result = chart.HitTest(e.X, e.Y);
272      if (result.ChartElementType == ChartElementType.LegendItem)
273        this.Cursor = Cursors.Hand;
274      else
275        this.Cursor = Cursors.Default;
276    }
277
278    private void ToggleSeries(Series series) {
279      if (series.Points.Count == 0)
280        FillSeriesWithDataPoints(series);
281      else
282        series.Points.Clear();
283    }
284
285    private void chart_MouseDown(object sender, MouseEventArgs e) {
286      HitTestResult result = chart.HitTest(e.X, e.Y);
287      if (result.ChartElementType == ChartElementType.LegendItem) {
288        if (result.Series != null) ToggleSeries(result.Series);
289      }
290    }
291
292    private void chart_AnnotationPositionChanging(object sender, AnnotationPositionChangingEventArgs e) {
293      int classIndex = (int)e.Annotation.Tag;
294      double[] thresholds = Content.Model.Thresholds.ToArray();
295      thresholds[classIndex] = e.NewLocationY;
296      Array.Sort(thresholds);
297      Content.Model.SetThresholdsAndClassValues(thresholds, Content.Model.ClassValues);
298    }
299
300    private void UpdateCursorInterval() {
301      Series series = chart.Series[0];
302      double[] xValues = (from point in series.Points
303                          where !point.IsEmpty
304                          select point.XValue)
305                    .DefaultIfEmpty(1.0)
306                    .ToArray();
307      double[] yValues = (from point in series.Points
308                          where !point.IsEmpty
309                          select point.YValues[0])
310                    .DefaultIfEmpty(1.0)
311                    .ToArray();
312
313      double xRange = xValues.Max() - xValues.Min();
314      double yRange = yValues.Max() - yValues.Min();
315      if (xRange.IsAlmost(0.0)) xRange = 1.0;
316      if (yRange.IsAlmost(0.0)) yRange = 1.0;
317      double xDigits = (int)Math.Log10(xRange) - 3;
318      double yDigits = (int)Math.Log10(yRange) - 3;
319      double xZoomInterval = Math.Pow(10, xDigits);
320      double yZoomInterval = Math.Pow(10, yDigits);
321      this.chart.ChartAreas[0].CursorX.Interval = xZoomInterval;
322      this.chart.ChartAreas[0].CursorY.Interval = yZoomInterval;
323    }
324  }
325}
Note: See TracBrowser for help on using the repository browser.