#region License Information /* HeuristicLab * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Drawing; using System.Linq; using System.Windows.Forms; using System.Windows.Forms.DataVisualization.Charting; using HeuristicLab.Common; using HeuristicLab.MainForm; using HeuristicLab.MainForm.WindowsForms; namespace HeuristicLab.Problems.DataAnalysis.Views { [View("Classification Threshold")] [Content(typeof(IDiscriminantFunctionClassificationSolution), false)] public sealed partial class DiscriminantFunctionClassificationSolutionThresholdView : DataAnalysisSolutionEvaluationView { private const double TrainingAxisValue = 0.0; private const double TestAxisValue = 10.0; private const double TrainingTestBorder = (TestAxisValue - TrainingAxisValue) / 2; private const string TrainingLabelText = "Training Samples"; private const string TestLabelText = "Test Samples"; public new IDiscriminantFunctionClassificationSolution Content { get { return (IDiscriminantFunctionClassificationSolution)base.Content; } set { base.Content = value; } } private Dictionary classValueSeriesMapping; private Random random; private bool updateInProgress; public DiscriminantFunctionClassificationSolutionThresholdView() : base() { InitializeComponent(); classValueSeriesMapping = new Dictionary(); random = new Random(); updateInProgress = false; this.chart.CustomizeAllChartAreas(); this.chart.ChartAreas[0].CursorX.IsUserSelectionEnabled = true; this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true; this.chart.ChartAreas[0].AxisX.Minimum = TrainingAxisValue - TrainingTestBorder; this.chart.ChartAreas[0].AxisX.Maximum = TestAxisValue + TrainingTestBorder; AddCustomLabelToAxis(this.chart.ChartAreas[0].AxisX); this.chart.ChartAreas[0].AxisY.Title = "Estimated Values"; this.chart.ChartAreas[0].AxisY.IsStartedFromZero = false; this.chart.ChartAreas[0].CursorY.IsUserSelectionEnabled = true; this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true; } private void AddCustomLabelToAxis(Axis axis) { CustomLabel trainingLabel = new CustomLabel(); trainingLabel.Text = TrainingLabelText; trainingLabel.FromPosition = TrainingAxisValue - TrainingTestBorder; trainingLabel.ToPosition = TrainingAxisValue + TrainingTestBorder; axis.CustomLabels.Add(trainingLabel); CustomLabel testLabel = new CustomLabel(); testLabel.Text = TestLabelText; testLabel.FromPosition = TestAxisValue - TrainingTestBorder; testLabel.ToPosition = TestAxisValue + TrainingTestBorder; axis.CustomLabels.Add(testLabel); } protected override void RegisterContentEvents() { base.RegisterContentEvents(); Content.ModelChanged += new EventHandler(Content_ModelChanged); Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged); } protected override void DeregisterContentEvents() { base.DeregisterContentEvents(); Content.ModelChanged -= new EventHandler(Content_ModelChanged); Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged); } private void Content_ProblemDataChanged(object sender, EventArgs e) { UpdateChart(); } private void Content_ModelChanged(object sender, EventArgs e) { Content.Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); UpdateChart(); } private void Model_ThresholdsChanged(object sender, EventArgs e) { AddThresholds(); } protected override void OnContentChanged() { base.OnContentChanged(); UpdateChart(); } private void UpdateChart() { if (InvokeRequired) Invoke((Action)UpdateChart); else if (!updateInProgress) { updateInProgress = true; chart.Series.Clear(); classValueSeriesMapping.Clear(); if (Content != null) { IEnumerator classNameEnumerator = Content.ProblemData.ClassNames.GetEnumerator(); IEnumerator classValueEnumerator = Content.ProblemData.ClassValues.OrderBy(x => x).GetEnumerator(); while (classNameEnumerator.MoveNext() && classValueEnumerator.MoveNext()) { Series series = new Series(classNameEnumerator.Current); series.ChartType = SeriesChartType.FastPoint; series.Tag = classValueEnumerator.Current; chart.Series.Add(series); classValueSeriesMapping.Add(classValueEnumerator.Current, series); FillSeriesWithDataPoints(series); } AddThresholds(); } chart.ChartAreas[0].RecalculateAxesScale(); updateInProgress = false; } } private void FillSeriesWithDataPoints(Series series) { List estimatedValues = Content.EstimatedValues.ToList(); var targetValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToList(); foreach (int row in Content.ProblemData.TrainingIndices) { double estimatedValue = estimatedValues[row]; double targetValue = targetValues[row]; if (targetValue.IsAlmost((double)series.Tag)) { double jitterValue = random.NextDouble() * 2.0 - 1.0; DataPoint point = new DataPoint(); point.XValue = TrainingAxisValue + 0.01 * jitterValue * JitterTrackBar.Value * (TrainingTestBorder * 0.9); point.YValues[0] = estimatedValue; point.Tag = new KeyValuePair(TrainingAxisValue, jitterValue); series.Points.Add(point); } } foreach (int row in Content.ProblemData.TestIndices) { double estimatedValue = estimatedValues[row]; double targetValue = targetValues[row]; if (targetValue.IsAlmost((double)series.Tag)) { double jitterValue = random.NextDouble() * 2.0 - 1.0; DataPoint point = new DataPoint(); point.XValue = TestAxisValue + 0.01 * jitterValue * JitterTrackBar.Value * (TrainingTestBorder * 0.9); point.YValues[0] = estimatedValue; point.Tag = new KeyValuePair(TestAxisValue, jitterValue); series.Points.Add(point); } } UpdateCursorInterval(); } private void AddThresholds() { chart.Annotations.Clear(); int classIndex = 1; IClassificationProblemData problemData = Content.ProblemData; var classValues = Content.Model.ClassValues.ToArray(); Axis y = chart.ChartAreas[0].AxisY; Axis x = chart.ChartAreas[0].AxisX; string name; foreach (double threshold in Content.Model.Thresholds) { if (!double.IsInfinity(threshold)) { HorizontalLineAnnotation annotation = new HorizontalLineAnnotation(); annotation.AllowMoving = true; annotation.AllowResizing = false; annotation.LineWidth = 2; annotation.LineColor = Color.Red; annotation.IsInfinitive = true; annotation.ClipToChartArea = chart.ChartAreas[0].Name; annotation.Tag = classIndex; //save classIndex as Tag to avoid moving the threshold accross class bounderies annotation.AxisX = chart.ChartAreas[0].AxisX; annotation.AxisY = y; annotation.Y = threshold; name = problemData.GetClassName(classValues[classIndex - 1]); TextAnnotation beneathLeft = CreateTextAnnotation(name, classIndex, x, y, x.Minimum, threshold, ContentAlignment.TopLeft); TextAnnotation beneathRight = CreateTextAnnotation(name, classIndex, x, y, x.Maximum, threshold, ContentAlignment.TopRight); name = problemData.GetClassName(classValues[classIndex]); TextAnnotation aboveLeft = CreateTextAnnotation(name, classIndex, x, y, x.Minimum, threshold, ContentAlignment.BottomLeft); TextAnnotation aboveRight = CreateTextAnnotation(name, classIndex, x, y, x.Maximum, threshold, ContentAlignment.BottomRight); chart.Annotations.Add(annotation); chart.Annotations.Add(beneathLeft); chart.Annotations.Add(aboveLeft); chart.Annotations.Add(beneathRight); chart.Annotations.Add(aboveRight); beneathLeft.ResizeToContent(); beneathRight.ResizeToContent(); aboveLeft.ResizeToContent(); aboveRight.ResizeToContent(); beneathRight.Width = -beneathRight.Width; aboveLeft.Height = -aboveLeft.Height; aboveRight.Height = -aboveRight.Height; aboveRight.Width = -aboveRight.Width; classIndex++; } } } private TextAnnotation CreateTextAnnotation(string name, int classIndex, Axis axisX, Axis axisY, double x, double y, ContentAlignment alignment) { TextAnnotation annotation = new TextAnnotation(); annotation.Text = name; annotation.AllowMoving = true; annotation.AllowResizing = false; annotation.AllowSelecting = false; annotation.IsSizeAlwaysRelative = true; annotation.ClipToChartArea = chart.ChartAreas[0].Name; annotation.Tag = classIndex; annotation.AxisX = axisX; annotation.AxisY = axisY; annotation.Alignment = alignment; annotation.X = x; annotation.Y = y; return annotation; } private void JitterTrackBar_ValueChanged(object sender, EventArgs e) { foreach (Series series in chart.Series) { foreach (DataPoint point in series.Points) { double value = ((KeyValuePair)point.Tag).Key; double jitterValue = ((KeyValuePair)point.Tag).Value; ; point.XValue = value + 0.01 * jitterValue * JitterTrackBar.Value * (TrainingTestBorder * 0.9); } } } private void chart_CustomizeLegend(object sender, CustomizeLegendEventArgs e) { foreach (LegendItem legendItem in e.LegendItems) { var series = chart.Series[legendItem.SeriesName]; if (series != null) { bool seriesIsInvisible = series.Points.Count == 0; foreach (LegendCell cell in legendItem.Cells) cell.ForeColor = seriesIsInvisible ? Color.Gray : Color.Black; } } } private void chart_MouseMove(object sender, MouseEventArgs e) { HitTestResult result = chart.HitTest(e.X, e.Y); if (result.ChartElementType == ChartElementType.LegendItem) this.Cursor = Cursors.Hand; else this.Cursor = Cursors.Default; } private void ToggleSeries(Series series) { if (series.Points.Count == 0) FillSeriesWithDataPoints(series); else series.Points.Clear(); } private void chart_MouseDown(object sender, MouseEventArgs e) { HitTestResult result = chart.HitTest(e.X, e.Y); if (result.ChartElementType == ChartElementType.LegendItem) { if (result.Series != null) ToggleSeries(result.Series); } } private void chart_AnnotationPositionChanging(object sender, AnnotationPositionChangingEventArgs e) { int classIndex = (int)e.Annotation.Tag; double[] thresholds = Content.Model.Thresholds.ToArray(); thresholds[classIndex] = e.NewLocationY; Array.Sort(thresholds); Content.Model.SetThresholdsAndClassValues(thresholds, Content.Model.ClassValues); } private void UpdateCursorInterval() { Series series = chart.Series[0]; double[] xValues = (from point in series.Points where !point.IsEmpty select point.XValue) .DefaultIfEmpty(1.0) .ToArray(); double[] yValues = (from point in series.Points where !point.IsEmpty select point.YValues[0]) .DefaultIfEmpty(1.0) .ToArray(); double xRange = xValues.Max() - xValues.Min(); double yRange = yValues.Max() - yValues.Min(); if (xRange.IsAlmost(0.0)) xRange = 1.0; if (yRange.IsAlmost(0.0)) yRange = 1.0; double xDigits = (int)Math.Log10(xRange) - 3; double yDigits = (int)Math.Log10(yRange) - 3; double xZoomInterval = Math.Pow(10, xDigits); double yZoomInterval = Math.Pow(10, yDigits); this.chart.ChartAreas[0].CursorX.Interval = xZoomInterval; this.chart.ChartAreas[0].CursorY.Interval = yZoomInterval; } } }