Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/DiscriminantFunctionClassificationSolutionView.cs @ 5717

Last change on this file since 5717 was 5717, checked in by gkronber, 13 years ago

#1418 Implemented interactive simplifier views for symbolic classification and regression.

File size: 11.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Windows.Forms;
27using System.Windows.Forms.DataVisualization.Charting;
28using HeuristicLab.Common;
29using HeuristicLab.MainForm;
30using HeuristicLab.MainForm.WindowsForms;
31
32namespace HeuristicLab.Problems.DataAnalysis.Classification.Views {
33  [View("Discriminant function classification solution view")]
34  [Content(typeof(IDiscriminantFunctionClassificationSolution), true)]
35  public sealed partial class DiscriminantFunctionClassificationSolutionView : AsynchronousContentView {
36    private const double TrainingAxisValue = 0.0;
37    private const double TestAxisValue = 10.0;
38    private const double TrainingTestBorder = (TestAxisValue - TrainingAxisValue) / 2;
39    private const string TrainingLabelText = "Training Samples";
40    private const string TestLabelText = "Test Samples";
41
42    public new IDiscriminantFunctionClassificationSolution Content {
43      get { return (IDiscriminantFunctionClassificationSolution)base.Content; }
44      set { base.Content = value; }
45    }
46
47    private Dictionary<double, Series> classValueSeriesMapping;
48    private Random random;
49    private bool updateInProgress;
50
51    public DiscriminantFunctionClassificationSolutionView()
52      : base() {
53      InitializeComponent();
54
55      classValueSeriesMapping = new Dictionary<double, Series>();
56      random = new Random();
57      updateInProgress = false;
58
59      this.chart.CustomizeAllChartAreas();
60      this.chart.ChartAreas[0].CursorX.IsUserSelectionEnabled = true;
61      this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true;
62      this.chart.ChartAreas[0].AxisX.Minimum = TrainingAxisValue - TrainingTestBorder;
63      this.chart.ChartAreas[0].AxisX.Maximum = TestAxisValue + TrainingTestBorder;
64      AddCustomLabelToAxis(this.chart.ChartAreas[0].AxisX);
65
66      this.chart.ChartAreas[0].AxisY.Title = "Estimated Values";
67      this.chart.ChartAreas[0].CursorY.IsUserSelectionEnabled = true;
68      this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true;
69    }
70
71    private void AddCustomLabelToAxis(Axis axis) {
72      CustomLabel trainingLabel = new CustomLabel();
73      trainingLabel.Text = TrainingLabelText;
74      trainingLabel.FromPosition = TrainingAxisValue - TrainingTestBorder;
75      trainingLabel.ToPosition = TrainingAxisValue + TrainingTestBorder;
76      axis.CustomLabels.Add(trainingLabel);
77
78      CustomLabel testLabel = new CustomLabel();
79      testLabel.Text = TestLabelText;
80      testLabel.FromPosition = TestAxisValue - TrainingTestBorder;
81      testLabel.ToPosition = TestAxisValue + TrainingTestBorder;
82      axis.CustomLabels.Add(testLabel);
83    }
84
85    protected override void RegisterContentEvents() {
86      base.RegisterContentEvents();
87      Content.ModelChanged += new EventHandler(Content_ModelChanged);
88      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
89    }
90    protected override void DeregisterContentEvents() {
91      base.DeregisterContentEvents();
92      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
93      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
94    }
95
96    private void Content_ProblemDataChanged(object sender, EventArgs e) {
97      UpdateChart();
98    }
99    private void Content_ModelChanged(object sender, EventArgs e) {
100      UpdateChart();
101    }
102    private void Content_ThresholdsChanged(object sender, EventArgs e) {
103      AddThresholds();
104    }
105    protected override void OnContentChanged() {
106      base.OnContentChanged();
107      UpdateChart();
108    }
109
110    private void UpdateChart() {
111      if (InvokeRequired) Invoke((Action)UpdateChart);
112      else if (!updateInProgress) {
113        updateInProgress = true;
114        chart.Series.Clear();
115        classValueSeriesMapping.Clear();
116        if (Content != null) {
117          IEnumerator<string> classNameEnumerator = Content.ProblemData.ClassNames.GetEnumerator();
118          IEnumerator<double> classValueEnumerator = Content.ProblemData.ClassValues.OrderBy(x => x).GetEnumerator();
119          while (classNameEnumerator.MoveNext() && classValueEnumerator.MoveNext()) {
120            Series series = new Series(classNameEnumerator.Current);
121            series.ChartType = SeriesChartType.FastPoint;
122            series.Tag = classValueEnumerator.Current;
123            chart.Series.Add(series);
124            classValueSeriesMapping.Add(classValueEnumerator.Current, series);
125            FillSeriesWithDataPoints(series);
126          }
127          AddThresholds();
128        }
129        chart.ChartAreas[0].RecalculateAxesScale();
130        updateInProgress = false;
131      }
132    }
133
134    private void FillSeriesWithDataPoints(Series series) {
135      List<double> estimatedValues = Content.EstimatedValues.ToList();
136      foreach (int row in Content.ProblemData.TrainingIndizes) {
137        double estimatedValue = estimatedValues[row];
138        double targetValue = Content.ProblemData.Dataset[Content.ProblemData.TargetVariable, row];
139        if (targetValue.IsAlmost((double)series.Tag)) {
140          double jitterValue = random.NextDouble() * 2.0 - 1.0;
141          DataPoint point = new DataPoint();
142          point.XValue = TrainingAxisValue + 0.01 * jitterValue * JitterTrackBar.Value * (TrainingTestBorder * 0.9);
143          point.YValues[0] = estimatedValue;
144          point.Tag = new KeyValuePair<double, double>(TrainingAxisValue, jitterValue);
145          series.Points.Add(point);
146        }
147      }
148
149      foreach (int row in Content.ProblemData.TestIndizes) {
150        double estimatedValue = estimatedValues[row];
151        double targetValue = Content.ProblemData.Dataset[Content.ProblemData.TargetVariable, row];
152        if (targetValue == (double)series.Tag) {
153          double jitterValue = random.NextDouble() * 2.0 - 1.0;
154          DataPoint point = new DataPoint();
155          point.XValue = TestAxisValue + 0.01 * jitterValue * JitterTrackBar.Value * (TrainingTestBorder * 0.9);
156          point.YValues[0] = estimatedValue;
157          point.Tag = new KeyValuePair<double, double>(TestAxisValue, jitterValue);
158          series.Points.Add(point);
159        }
160      }
161
162      UpdateCursorInterval();
163    }
164
165    private void AddThresholds() {
166      chart.Annotations.Clear();
167      int classIndex = 1;
168      foreach (double threshold in Content.Model.Thresholds) {
169        if (!double.IsInfinity(threshold)) {
170          HorizontalLineAnnotation annotation = new HorizontalLineAnnotation();
171          annotation.AllowMoving = true;
172          annotation.AllowResizing = false;
173          annotation.LineWidth = 2;
174          annotation.LineColor = Color.Red;
175
176          annotation.IsInfinitive = true;
177          annotation.ClipToChartArea = chart.ChartAreas[0].Name;
178          annotation.Tag = classIndex;  //save classIndex as Tag to avoid moving the threshold accross class bounderies
179
180          annotation.AxisX = chart.ChartAreas[0].AxisX;
181          annotation.AxisY = chart.ChartAreas[0].AxisY;
182          annotation.Y = threshold;
183
184          chart.Annotations.Add(annotation);
185          classIndex++;
186        }
187      }
188    }
189
190    private void JitterTrackBar_ValueChanged(object sender, EventArgs e) {
191      foreach (Series series in chart.Series) {
192        foreach (DataPoint point in series.Points) {
193          double value = ((KeyValuePair<double, double>)point.Tag).Key;
194          double jitterValue = ((KeyValuePair<double, double>)point.Tag).Value; ;
195          point.XValue = value + 0.01 * jitterValue * JitterTrackBar.Value * (TrainingTestBorder * 0.9);
196        }
197      }
198    }
199
200    private void chart_CustomizeLegend(object sender, CustomizeLegendEventArgs e) {
201      foreach (LegendItem legendItem in e.LegendItems) {
202        var series = chart.Series[legendItem.SeriesName];
203        if (series != null) {
204          bool seriesIsInvisible = series.Points.Count == 0;
205          foreach (LegendCell cell in legendItem.Cells)
206            cell.ForeColor = seriesIsInvisible ? Color.Gray : Color.Black;
207        }
208      }
209    }
210
211    private void chart_MouseMove(object sender, MouseEventArgs e) {
212      HitTestResult result = chart.HitTest(e.X, e.Y);
213      if (result.ChartElementType == ChartElementType.LegendItem)
214        this.Cursor = Cursors.Hand;
215      else
216        this.Cursor = Cursors.Default;
217    }
218
219    private void ToggleSeries(Series series) {
220      if (series.Points.Count == 0)
221        FillSeriesWithDataPoints(series);
222      else
223        series.Points.Clear();
224    }
225
226    private void chart_MouseDown(object sender, MouseEventArgs e) {
227      HitTestResult result = chart.HitTest(e.X, e.Y);
228      if (result.ChartElementType == ChartElementType.LegendItem) {
229        if (result.Series != null) ToggleSeries(result.Series);
230      }
231    }
232
233    private void chart_AnnotationPositionChanging(object sender, AnnotationPositionChangingEventArgs e) {
234      int classIndex = (int)e.Annotation.Tag;
235      double[] thresholds = Content.Model.Thresholds.ToArray();
236      double max = thresholds[classIndex + 1];
237      double min = thresholds[classIndex - 1];
238
239      if (e.NewLocationY >= max)
240        e.NewLocationY = max;
241
242      if (e.NewLocationY <= min)
243        e.NewLocationY = min;
244
245      thresholds[classIndex] = e.NewLocationY;
246      Content.Model.Thresholds = thresholds;
247    }
248
249    private void UpdateCursorInterval() {
250      Series series = chart.Series[0];
251      double[] xValues = (from point in series.Points
252                          where !point.IsEmpty
253                          select point.XValue)
254                    .DefaultIfEmpty(1.0)
255                    .ToArray();
256      double[] yValues = (from point in series.Points
257                          where !point.IsEmpty
258                          select point.YValues[0])
259                    .DefaultIfEmpty(1.0)
260                    .ToArray();
261
262      double xRange = xValues.Max() - xValues.Min();
263      double yRange = yValues.Max() - yValues.Min();
264      if (xRange.IsAlmost(0.0)) xRange = 1.0;
265      if (yRange.IsAlmost(0.0)) yRange = 1.0;
266      double xDigits = (int)Math.Log10(xRange) - 3;
267      double yDigits = (int)Math.Log10(yRange) - 3;
268      double xZoomInterval = Math.Pow(10, xDigits);
269      double yZoomInterval = Math.Pow(10, yDigits);
270      this.chart.ChartAreas[0].CursorX.Interval = xZoomInterval;
271      this.chart.ChartAreas[0].CursorY.Interval = yZoomInterval;
272    }
273  }
274}
Note: See TracBrowser for help on using the repository browser.