Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataPreprocessing Cleanup/HeuristicLab.DataPreprocessing/3.4/Data/PreprocessingData.cs @ 15291

Last change on this file since 15291 was 15291, checked in by pfleck, 7 years ago

#2809: Added (Double/String/DateTime)PreprocessingDataColumn. (experimental state)

File size: 36.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections;
24using System.Collections.Generic;
25using System.Globalization;
26using System.Linq;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32
33namespace HeuristicLab.DataPreprocessing {
34  [Item("PreprocessingData", "Represents data used for preprocessing.")]
35  [StorableClass]
36  public class PreprocessingData : NamedItem, IPreprocessingData {
37
38    [Storable]
39    protected List<PreprocessingDataColumn> dataColumns;
40
41    #region Constructor, Cloning & Persistence
42    public PreprocessingData(IDataAnalysisProblemData problemData)
43      : base() {
44      Name = "Preprocessing Data";
45
46      dataColumns = new List<PreprocessingDataColumn>();
47      Transformations = new List<ITransformation>();
48      selection = new Dictionary<int, IList<int>>();
49
50      Import(problemData);
51
52      RegisterEventHandler();
53    }
54
55    protected PreprocessingData(PreprocessingData original, Cloner cloner)
56      : base(original, cloner) {
57      dataColumns = new List<PreprocessingDataColumn>(original.dataColumns.Select(cloner.Clone));
58      TrainingPartition = cloner.Clone(original.TrainingPartition);
59      TestPartition = cloner.Clone(original.TestPartition);
60      Transformations = new List<ITransformation>(original.Transformations.Select(cloner.Clone));
61
62      InputVariables = new List<string>(original.InputVariables);
63      TargetVariable = original.TargetVariable;
64
65      RegisterEventHandler();
66    }
67    public override IDeepCloneable Clone(Cloner cloner) {
68      return new PreprocessingData(this, cloner);
69    }
70
71    [StorableConstructor]
72    protected PreprocessingData(bool deserializing)
73      : base(deserializing) { }
74    [StorableHook(HookType.AfterDeserialization)]
75    private void AfterDeserialization() {
76      RegisterEventHandler();
77    }
78
79    private void RegisterEventHandler() {
80      Changed += (s, e) => {
81        switch (e.Type) {
82          case DataPreprocessingChangedEventType.DeleteRow:
83          case DataPreprocessingChangedEventType.Any:
84          case DataPreprocessingChangedEventType.Transformation:
85            int maxRowIndex = Math.Max(0, Rows);
86            TrainingPartition.Start = Math.Min(TrainingPartition.Start, maxRowIndex);
87            TrainingPartition.End = Math.Min(TrainingPartition.End, maxRowIndex);
88            TestPartition.Start = Math.Min(TestPartition.Start, maxRowIndex);
89            TestPartition.End = Math.Min(TestPartition.End, maxRowIndex);
90            break;
91        }
92      };
93    }
94    #endregion
95
96    #region Cells
97    public bool IsCellEmpty(int columnIndex, int rowIndex) {
98      return !dataColumns[columnIndex].IsValidValue(rowIndex);
99    }
100
101    private void ColumnTypeSwitchAction<T>(int columnIndex, T value, Action<DoublePreprocessingDataColumn, double?> doubleAction,
102      Action<StringPreprocessingDataColumn, string> stringAction = null, Action<DateTimePreprocessingDataColumn, DateTime?> dateTimeAction = null) {
103      ColumnTypeSwitchAction(dataColumns[columnIndex], value, doubleAction, stringAction, dateTimeAction);
104    }
105    private void ColumnTypeSwitchAction<T>(PreprocessingDataColumn column, T value, Action<DoublePreprocessingDataColumn, double?> doubleAction,
106      Action<StringPreprocessingDataColumn, string> stringAction = null, Action<DateTimePreprocessingDataColumn, DateTime?> dateTimeAction = null) {
107      var doubleColumn = column as DoublePreprocessingDataColumn;
108      if (doubleColumn != null && doubleAction != null) doubleAction(doubleColumn, Convert<double?>(value));
109      var stringColumn = column as StringPreprocessingDataColumn;
110      if (stringColumn != null && stringAction != null) stringAction(stringColumn, Convert<string>(value));
111      var dateTimeColumn = column as DateTimePreprocessingDataColumn;
112      if (dateTimeColumn != null && dateTimeAction != null) dateTimeAction(dateTimeColumn, Convert<DateTime?>(value));
113    }
114
115    private void ColumnTypeSwitchAction(int columnIndex, Action<DoublePreprocessingDataColumn> doubleAction,
116      Action<StringPreprocessingDataColumn> stringAction = null, Action<DateTimePreprocessingDataColumn> dateTimeAction = null) {
117      ColumnTypeSwitchAction(dataColumns[columnIndex], doubleAction, stringAction, dateTimeAction);
118    }
119    private void ColumnTypeSwitchAction(PreprocessingDataColumn column, Action<DoublePreprocessingDataColumn> doubleAction,
120      Action<StringPreprocessingDataColumn> stringAction = null, Action<DateTimePreprocessingDataColumn> dateTimeAction = null) {
121      var doubleColumn = column as DoublePreprocessingDataColumn;
122      if (doubleColumn != null && doubleAction != null) doubleAction(doubleColumn);
123      var stringColumn = column as StringPreprocessingDataColumn;
124      if (stringColumn != null && stringAction != null) stringAction(stringColumn);
125      var dateTimeColumn = column as DateTimePreprocessingDataColumn;
126      if (dateTimeColumn != null && dateTimeAction != null) dateTimeAction(dateTimeColumn);
127    }
128
129
130    private T ColumnTypeSwitchFunc<T>(int columnIndex, Func<DoublePreprocessingDataColumn, double?> doubleFunc,
131      Func<StringPreprocessingDataColumn, string> stringFunc = null, Func<DateTimePreprocessingDataColumn, DateTime?> dateTimeFunc = null) {
132      var doubleColumn = dataColumns[columnIndex] as DoublePreprocessingDataColumn;
133      if (doubleColumn != null && doubleFunc != null) return Convert<T>(doubleFunc(doubleColumn));
134      var stringColumn = dataColumns[columnIndex] as StringPreprocessingDataColumn;
135      if (stringColumn != null && stringFunc != null) return Convert<T>(stringFunc(stringColumn));
136      var dateTimeColumn = dataColumns[columnIndex] as DateTimePreprocessingDataColumn;
137      if (dateTimeColumn != null && dateTimeFunc != null) return Convert<T>(dateTimeFunc(dateTimeColumn));
138      throw new InvalidOperationException("Invalid data column type.");
139    }
140
141    private T ColumnTypeSwitchFuncResult<T>(int columnIndex, Func<DoublePreprocessingDataColumn, T> doubleFunc,
142      Func<StringPreprocessingDataColumn, T> stringFunc = null, Func<DateTimePreprocessingDataColumn, T> dateTimeFunc = null) {
143      var doubleColumn = dataColumns[columnIndex] as DoublePreprocessingDataColumn;
144      if (doubleColumn != null && doubleFunc != null) return doubleFunc(doubleColumn);
145      var stringColumn = dataColumns[columnIndex] as StringPreprocessingDataColumn;
146      if (stringColumn != null && stringFunc != null) return stringFunc(stringColumn);
147      var dateTimeColumn = dataColumns[columnIndex] as DateTimePreprocessingDataColumn;
148      if (dateTimeColumn != null && dateTimeFunc != null) return dateTimeFunc(dateTimeColumn);
149      throw new InvalidOperationException("Invalid data column type.");
150    }
151    private TOut ColumnTypeSwitchFuncResult<TIn, TOut>(int columnIndex, TIn value, Func<DoublePreprocessingDataColumn, double?, TOut> doubleFunc,
152     Func<StringPreprocessingDataColumn, string, TOut> stringFunc = null, Func<DateTimePreprocessingDataColumn, DateTime?, TOut> dateTimeFunc = null) {
153      var doubleColumn = dataColumns[columnIndex] as DoublePreprocessingDataColumn;
154      if (doubleColumn != null && doubleFunc != null) return doubleFunc(doubleColumn, Convert<double?>(value));
155      var stringColumn = dataColumns[columnIndex] as StringPreprocessingDataColumn;
156      if (stringColumn != null && stringFunc != null) return stringFunc(stringColumn, Convert<string>(value));
157      var dateTimeColumn = dataColumns[columnIndex] as DateTimePreprocessingDataColumn;
158      if (dateTimeColumn != null && dateTimeFunc != null) return dateTimeFunc(dateTimeColumn, Convert<DateTime?>(value));
159      throw new InvalidOperationException("Invalid data column type.");
160    }
161
162    private IList<T> ColumnTypeSwitchFuncList<T>(int columnIndex, Func<DoublePreprocessingDataColumn, IList<double>> doubleFunc,
163      Func<StringPreprocessingDataColumn, IList<string>> stringFunc = null, Func<DateTimePreprocessingDataColumn, IList<DateTime>> dateTimeFunc = null) {
164      var doubleColumn = dataColumns[columnIndex] as DoublePreprocessingDataColumn;
165      if (doubleColumn != null && doubleFunc != null) return Convert<IList<T>>(doubleFunc(doubleColumn));
166      var stringColumn = dataColumns[columnIndex] as StringPreprocessingDataColumn;
167      if (stringColumn != null && stringFunc != null) return Convert<IList<T>>(stringFunc(stringColumn));
168      var dateTimeColumn = dataColumns[columnIndex] as DateTimePreprocessingDataColumn;
169      if (dateTimeColumn != null && dateTimeFunc != null) return Convert<IList<T>>(dateTimeFunc(dateTimeColumn));
170      throw new InvalidOperationException("Invalid data column type.");
171    }
172    private static T Convert<T>(object obj) { return (T)obj; }
173
174
175    public T GetCell<T>(int columnIndex, int rowIndex) {
176      return ColumnTypeSwitchFunc<T>(columnIndex,
177        c => c[rowIndex],
178        c => c[rowIndex],
179        c => c[rowIndex]);
180    }
181
182    public void SetCell<T>(int columnIndex, int rowIndex, T value) {
183      SaveSnapshot(DataPreprocessingChangedEventType.ChangeItem, columnIndex, rowIndex);
184
185      for (int i = Rows; i <= rowIndex; i++)
186        InsertRow(i);
187      for (int i = Columns; i <= columnIndex; i++)
188        InsertColumn<T>(i.ToString(), i);
189
190      ColumnTypeSwitchAction<T>(columnIndex, value,
191        (c, v) => c[rowIndex] = v,
192        (c, v) => c[rowIndex] = v,
193        (c, v) => c[rowIndex] = v);
194
195      if (!IsInTransaction)
196        OnChanged(DataPreprocessingChangedEventType.ChangeItem, columnIndex, rowIndex);
197    }
198
199    public string GetCellAsString(int columnIndex, int rowIndex) {
200      return dataColumns[columnIndex].GetValue(rowIndex);
201    }
202
203    public IList<T> GetValues<T>(int columnIndex, bool considerSelection) {
204      if (considerSelection) {
205        var list = new List<T>();
206        foreach (var rowIdx in selection[columnIndex]) {
207          list.Add(GetCell<T>(columnIndex, rowIdx));
208          //list.Add((T)dataColumns[columnIndex][rowIdx]);
209        }
210        return list;
211      } else {
212        return ColumnTypeSwitchFuncList<T>(columnIndex,
213          c => c.Values.Select(x => x ?? double.NaN).ToList(),
214          c => c.Values,
215          c => c.Values.Select(x => x ?? DateTime.MinValue).ToList());
216        //(IList<T>)dataColumns[columnIndex];
217      }
218    }
219
220    public void SetValues<T>(int columnIndex, IList<T> values) {
221      SaveSnapshot(DataPreprocessingChangedEventType.ChangeColumn, columnIndex, -1);
222      if (VariableHasType<T>(columnIndex)) {
223        var name = dataColumns[columnIndex].Name;
224        if (dataColumns[columnIndex].IsType<double>()) {
225          dataColumns[columnIndex] = new DoublePreprocessingDataColumn(name, (IEnumerable<double>)values);
226        } else if (dataColumns[columnIndex].IsType<string>()) {
227          dataColumns[columnIndex] = new StringPreprocessingDataColumn(name, (IEnumerable<string>)values);
228        } else if (dataColumns[columnIndex].IsType<DateTime>()) {
229          dataColumns[columnIndex] = new DateTimePreprocessingDataColumn(name, (IEnumerable<DateTime>)values);
230        } else {
231          throw new ArgumentException("Unknown column type");
232        }
233      } else {
234        throw new ArgumentException("The datatype of column " + columnIndex + " must be of type " + dataColumns[columnIndex].GetType().Name + " but was " + typeof(T).Name);
235      }
236      if (!IsInTransaction)
237        OnChanged(DataPreprocessingChangedEventType.ChangeColumn, columnIndex, -1);
238    }
239
240    public bool SetValue(string value, int columnIndex, int rowIndex) {
241      bool valid = false;
242      if (VariableHasType<double>(columnIndex)) {
243        double val;
244        if (string.IsNullOrWhiteSpace(value)) {
245          val = double.NaN;
246          valid = true;
247        } else {
248          valid = double.TryParse(value, out val);
249        }
250        if (valid)
251          SetCell(columnIndex, rowIndex, val);
252      } else if (VariableHasType<string>(columnIndex)) {
253        valid = value != null;
254        if (valid)
255          SetCell(columnIndex, rowIndex, value);
256      } else if (VariableHasType<DateTime>(columnIndex)) {
257        DateTime date;
258        valid = DateTime.TryParse(value, out date);
259        if (valid)
260          SetCell(columnIndex, rowIndex, date);
261      } else {
262        throw new ArgumentException("column " + columnIndex + " contains a non supported type.");
263      }
264
265      if (!IsInTransaction)
266        OnChanged(DataPreprocessingChangedEventType.ChangeColumn, columnIndex, -1);
267
268      return valid;
269    }
270
271    public int Columns {
272      get { return dataColumns.Count; }
273    }
274
275    public int Rows {
276      get { return dataColumns.Count > 0 ? dataColumns[0].Length : 0; }
277    }
278    #endregion
279
280    #region Rows
281    public void InsertRow(int rowIndex) {
282      SaveSnapshot(DataPreprocessingChangedEventType.DeleteRow, -1, rowIndex);
283      foreach (var column in dataColumns) {
284        ColumnTypeSwitchAction(column,
285          c => c.Values.Insert(rowIndex, null),
286          c => c.Values.Insert(rowIndex, null),
287          c => c.Values.Insert(rowIndex, null));
288        //var valueType = column.GetValueType();
289        //column.Insert(rowIndex, valueType.IsValueType ? Activator.CreateInstance(valueType) : null);
290      }
291      if (TrainingPartition.Start <= rowIndex && rowIndex <= TrainingPartition.End) {
292        TrainingPartition.End++;
293        if (TrainingPartition.End <= TestPartition.Start) {
294          TestPartition.Start++;
295          TestPartition.End++;
296        }
297      } else if (TestPartition.Start <= rowIndex && rowIndex <= TestPartition.End) {
298        TestPartition.End++;
299        if (TestPartition.End <= TrainingPartition.Start) {
300          TestPartition.Start++;
301          TestPartition.End++;
302        }
303      }
304      if (!IsInTransaction)
305        OnChanged(DataPreprocessingChangedEventType.AddRow, -1, rowIndex);
306    }
307    public void DeleteRow(int rowIndex) {
308      SaveSnapshot(DataPreprocessingChangedEventType.AddRow, -1, rowIndex);
309      foreach (var column in dataColumns) {
310        ColumnTypeSwitchAction(column,
311          c => c.Values.RemoveAt(rowIndex),
312          c => c.Values.RemoveAt(rowIndex),
313          c => c.Values.RemoveAt(rowIndex));
314        //column.RemoveAt(rowIndex);
315      }
316      if (TrainingPartition.Start <= rowIndex && rowIndex <= TrainingPartition.End) {
317        TrainingPartition.End--;
318        if (TrainingPartition.End <= TestPartition.Start) {
319          TestPartition.Start--;
320          TestPartition.End--;
321        }
322      } else if (TestPartition.Start <= rowIndex && rowIndex <= TestPartition.End) {
323        TestPartition.End--;
324        if (TestPartition.End <= TrainingPartition.Start) {
325          TestPartition.Start--;
326          TestPartition.End--;
327        }
328      }
329      if (!IsInTransaction)
330        OnChanged(DataPreprocessingChangedEventType.DeleteRow, -1, rowIndex);
331    }
332    public void DeleteRowsWithIndices(IEnumerable<int> rows) {
333      SaveSnapshot(DataPreprocessingChangedEventType.AddRow, -1, -1);
334      foreach (int rowIndex in rows.OrderByDescending(x => x)) {
335        foreach (var column in dataColumns) {
336          ColumnTypeSwitchAction(column,
337            c => c.Values.RemoveAt(rowIndex),
338            c => c.Values.RemoveAt(rowIndex),
339            c => c.Values.RemoveAt(rowIndex));
340          //column.RemoveAt(rowIndex);
341        }
342        if (TrainingPartition.Start <= rowIndex && rowIndex <= TrainingPartition.End) {
343          TrainingPartition.End--;
344          if (TrainingPartition.End <= TestPartition.Start) {
345            TestPartition.Start--;
346            TestPartition.End--;
347          }
348        } else if (TestPartition.Start <= rowIndex && rowIndex <= TestPartition.End) {
349          TestPartition.End--;
350          if (TestPartition.End <= TrainingPartition.Start) {
351            TestPartition.Start--;
352            TestPartition.End--;
353          }
354        }
355      }
356      if (!IsInTransaction)
357        OnChanged(DataPreprocessingChangedEventType.DeleteRow, -1, -1);
358    }
359
360    public void InsertColumn<T>(string variableName, int columnIndex) {
361      SaveSnapshot(DataPreprocessingChangedEventType.DeleteColumn, columnIndex, -1);
362
363      if (typeof(T) == typeof(double)) {
364        dataColumns.Insert(columnIndex, new DoublePreprocessingDataColumn(variableName, Enumerable.Repeat<double?>(null, Rows)));
365      } else if (typeof(T) == typeof(string)) {
366        dataColumns.Add(new StringPreprocessingDataColumn(variableName, Enumerable.Repeat<string>(null, Rows)));
367      } else if (typeof(T) == typeof(DateTime)) {
368        dataColumns.Add(new DateTimePreprocessingDataColumn(variableName, Enumerable.Repeat<DateTime?>(null, Rows)));
369      } else {
370        throw new ArgumentException("The datatype of column " + variableName + " must be of type double, string or DateTime");
371      }
372
373      //dataColumns.Insert(columnIndex, new List<T>(Enumerable.Repeat(default(T), Rows)));
374      //variableNames.Insert(columnIndex, variableName);
375      if (!IsInTransaction)
376        OnChanged(DataPreprocessingChangedEventType.AddColumn, columnIndex, -1);
377    }
378
379    public void DeleteColumn(int columnIndex) {
380      SaveSnapshot(DataPreprocessingChangedEventType.AddColumn, columnIndex, -1);
381      dataColumns.RemoveAt(columnIndex);
382      //variableNames.RemoveAt(columnIndex);
383      if (!IsInTransaction)
384        OnChanged(DataPreprocessingChangedEventType.DeleteColumn, columnIndex, -1);
385    }
386
387    public void RenameColumn(int columnIndex, string name) {
388      SaveSnapshot(DataPreprocessingChangedEventType.ChangeColumn, columnIndex, -1);
389      if (columnIndex < 0 || columnIndex > dataColumns.Count)
390        throw new ArgumentOutOfRangeException("columnIndex");
391      dataColumns[columnIndex].Name = name;
392
393      if (!IsInTransaction)
394        OnChanged(DataPreprocessingChangedEventType.ChangeColumn, -1, -1);
395    }
396
397    public void RenameColumns(IList<string> names) {
398      if (names == null) throw new ArgumentNullException("names");
399      if (names.Count != dataColumns.Count) throw new ArgumentException("number of names must match the number of columns.", "names");
400
401      SaveSnapshot(DataPreprocessingChangedEventType.ChangeColumn, -1, -1);
402      for (int i = 0; i < names.Count; i++)
403        dataColumns[i].Name = names[i];
404
405      if (!IsInTransaction)
406        OnChanged(DataPreprocessingChangedEventType.ChangeColumn, -1, -1);
407    }
408
409    public bool AreAllStringColumns(IEnumerable<int> columnIndices) {
410      return columnIndices.All(x => VariableHasType<string>(x));
411    }
412    #endregion
413
414    #region Variables
415    public IEnumerable<string> VariableNames {
416      get { return dataColumns.Select(c => c.Name); }
417    }
418
419    public IEnumerable<string> GetDoubleVariableNames() {
420      return dataColumns.OfType<DoublePreprocessingDataColumn>().Select(c => c.Name);
421    }
422
423    public string GetVariableName(int columnIndex) {
424      return dataColumns[columnIndex].Name;
425    }
426
427    public int GetColumnIndex(string variableName) {
428      return dataColumns.FindIndex(c => c.Name == variableName);
429    }
430
431    public bool VariableHasType<T>(int columnIndex) {
432      return dataColumns[columnIndex].IsType<T>();
433    }
434
435    public Type GetVariableType(int columnIndex) {
436      return dataColumns[columnIndex].GetValueType();
437    }
438
439    public IList<string> InputVariables { get; private set; }
440    public string TargetVariable { get; private set; } // optional
441    #endregion
442
443    #region Partitions
444    [Storable]
445    public IntRange TrainingPartition { get; set; }
446    [Storable]
447    public IntRange TestPartition { get; set; }
448    #endregion
449
450    #region Transformations
451    [Storable]
452    public IList<ITransformation> Transformations { get; protected set; }
453    #endregion
454
455    #region Validation
456    public bool Validate(string value, out string errorMessage, int columnIndex) {
457      if (columnIndex < 0 || columnIndex > VariableNames.Count()) {
458        throw new ArgumentOutOfRangeException("column index is out of range");
459      }
460
461      bool valid = false;
462      errorMessage = string.Empty;
463      if (VariableHasType<double>(columnIndex)) {
464        if (string.IsNullOrWhiteSpace(value)) {
465          valid = true;
466        } else {
467          double val;
468          valid = double.TryParse(value, out val);
469          if (!valid) {
470            errorMessage = "Invalid Value (Valid Value Format: \"" + FormatPatterns.GetDoubleFormatPattern() + "\")";
471          }
472        }
473      } else if (VariableHasType<string>(columnIndex)) {
474        valid = value != null;
475        if (!valid) {
476          errorMessage = "Invalid Value (string must not be null)";
477        }
478      } else if (VariableHasType<DateTime>(columnIndex)) {
479        DateTime date;
480        valid = DateTime.TryParse(value, out date);
481        if (!valid) {
482          errorMessage = "Invalid Value (Valid Value Format: \"" + CultureInfo.CurrentCulture.DateTimeFormat + "\"";
483        }
484      } else {
485        throw new ArgumentException("column " + columnIndex + " contains a non supported type.");
486      }
487
488      return valid;
489    }
490    #endregion
491
492    #region Import & Export
493    public void Import(IDataAnalysisProblemData problemData) {
494      var dataset = problemData.Dataset;
495      InputVariables = new List<string>(problemData.AllowedInputVariables);
496      TargetVariable = problemData is IRegressionProblemData ? ((IRegressionProblemData)problemData).TargetVariable
497        : problemData is IClassificationProblemData ? ((IClassificationProblemData)problemData).TargetVariable
498        : null;
499
500      dataColumns.Clear();
501      foreach (var variableName in problemData.Dataset.VariableNames) {
502        if (dataset.VariableHasType<double>(variableName)) {
503          dataColumns.Add(new DoublePreprocessingDataColumn(variableName, dataset.GetDoubleValues(variableName)));
504        } else if (dataset.VariableHasType<string>(variableName)) {
505          dataColumns.Add(new StringPreprocessingDataColumn(variableName, dataset.GetStringValues(variableName)));
506        } else if (dataset.VariableHasType<DateTime>(variableName)) {
507          dataColumns.Add(new DateTimePreprocessingDataColumn(variableName, dataset.GetDateTimeValues(variableName)));
508        } else {
509          throw new ArgumentException("The datatype of column " + variableName + " must be of type double, string or DateTime");
510        }
511      }
512
513      TrainingPartition = new IntRange(problemData.TrainingPartition.Start, problemData.TrainingPartition.End);
514      TestPartition = new IntRange(problemData.TestPartition.Start, problemData.TestPartition.End);
515    }
516
517    public Dataset ExportToDataset() {
518      IList<IList> values = new List<IList>();
519
520      for (int i = 0; i < Columns; i++) {
521        var doubleColumn = dataColumns[i] as DoublePreprocessingDataColumn;
522        var stringColumn = dataColumns[i] as StringPreprocessingDataColumn;
523        var dateTimeColumn = dataColumns[i] as DateTimePreprocessingDataColumn;
524        if (doubleColumn != null) values.Add(new List<double>(doubleColumn.Values.Select(x => x ?? double.NaN)));
525        else if (stringColumn != null) values.Add(new List<string>(stringColumn.Values));
526        else if (dateTimeColumn != null) values.Add(new List<DateTime>(dateTimeColumn.Values.Select(x => x ?? DateTime.MinValue)));
527        else throw new InvalidOperationException("Column type not supported for export");
528      }
529
530      return new Dataset(VariableNames, values);
531    }
532    #endregion
533
534    #region Selection
535    [Storable]
536    protected IDictionary<int, IList<int>> selection;
537    public IDictionary<int, IList<int>> Selection {
538      get { return selection; }
539      set {
540        selection = value;
541        OnSelectionChanged();
542      }
543    }
544    public void ClearSelection() {
545      Selection = new Dictionary<int, IList<int>>();
546    }
547
548    public event EventHandler SelectionChanged;
549    protected void OnSelectionChanged() {
550      var listeners = SelectionChanged;
551      if (listeners != null) listeners(this, EventArgs.Empty);
552    }
553    #endregion
554
555    #region Transactions
556    // Snapshot/History are not storable/cloneable on purpose
557    private class Snapshot {
558      public List<PreprocessingDataColumn> DataColumns { get; set; }
559
560      public IntRange TrainingPartition { get; set; }
561      public IntRange TestPartition { get; set; }
562      public IList<ITransformation> Transformations { get; set; }
563      public DataPreprocessingChangedEventType ChangedType { get; set; }
564
565      public int ChangedColumn { get; set; }
566      public int ChangedRow { get; set; }
567    }
568
569    public event DataPreprocessingChangedEventHandler Changed;
570    protected virtual void OnChanged(DataPreprocessingChangedEventType type, int column, int row) {
571      var listeners = Changed;
572      if (listeners != null) listeners(this, new DataPreprocessingChangedEventArgs(type, column, row));
573    }
574
575    private const int MaxUndoDepth = 5;
576
577    private readonly IList<Snapshot> undoHistory = new List<Snapshot>();
578    private readonly Stack<DataPreprocessingChangedEventType> eventStack = new Stack<DataPreprocessingChangedEventType>();
579
580    public bool IsInTransaction { get { return eventStack.Count > 0; } }
581
582    private void SaveSnapshot(DataPreprocessingChangedEventType changedType, int column, int row) {
583      if (IsInTransaction) return;
584
585      var cloner = new Cloner();
586      var currentSnapshot = new Snapshot {
587        DataColumns = new List<PreprocessingDataColumn>(dataColumns.Select(cloner.Clone)),
588        TrainingPartition = new IntRange(TrainingPartition.Start, TrainingPartition.End),
589        TestPartition = new IntRange(TestPartition.Start, TestPartition.End),
590        Transformations = new List<ITransformation>(Transformations),
591        ChangedType = changedType,
592        ChangedColumn = column,
593        ChangedRow = row
594      };
595
596      if (undoHistory.Count >= MaxUndoDepth)
597        undoHistory.RemoveAt(0);
598
599      undoHistory.Add(currentSnapshot);
600    }
601
602    public bool IsUndoAvailable {
603      get { return undoHistory.Count > 0; }
604    }
605
606    public void Undo() {
607      if (IsUndoAvailable) {
608        Snapshot previousSnapshot = undoHistory[undoHistory.Count - 1];
609        dataColumns = previousSnapshot.DataColumns;
610        TrainingPartition = previousSnapshot.TrainingPartition;
611        TestPartition = previousSnapshot.TestPartition;
612        Transformations = previousSnapshot.Transformations;
613        undoHistory.Remove(previousSnapshot);
614        OnChanged(previousSnapshot.ChangedType,
615          previousSnapshot.ChangedColumn,
616          previousSnapshot.ChangedRow);
617      }
618    }
619
620    public void InTransaction(Action action, DataPreprocessingChangedEventType type = DataPreprocessingChangedEventType.Any) {
621      BeginTransaction(type);
622      action();
623      EndTransaction();
624    }
625
626    public void BeginTransaction(DataPreprocessingChangedEventType type) {
627      SaveSnapshot(type, -1, -1);
628      eventStack.Push(type);
629    }
630
631    public void EndTransaction() {
632      if (eventStack.Count == 0)
633        throw new InvalidOperationException("There is no open transaction that can be ended.");
634
635      var @event = eventStack.Pop();
636      OnChanged(@event, -1, -1);
637    }
638    #endregion
639
640    #region Statistics
641    public T GetMin<T>(int columnIndex, bool considerSelection = false, T emptyValue = default(T)) {
642      var values = GetValuesWithoutMissingValues<T>(columnIndex, considerSelection);
643      return values.Any() ? values.Min() : emptyValue;
644    }
645
646    public T GetMax<T>(int columnIndex, bool considerSelection = false, T emptyValue = default(T)) {
647      var values = GetValuesWithoutMissingValues<T>(columnIndex, considerSelection);
648      return values.Any() ? values.Max() : emptyValue;
649    }
650
651    public T GetMean<T>(int columnIndex, bool considerSelection = false, T emptyValue = default(T)) {
652      if (typeof(T) == typeof(double)) {
653        var values = GetValuesWithoutMissingValues<double>(columnIndex, considerSelection);
654        return values.Any() ? Convert<T>(values.Average()) : emptyValue;
655      }
656      if (typeof(T) == typeof(string)) {
657        return Convert<T>(string.Empty);
658      }
659      if (typeof(T) == typeof(DateTime)) {
660        var values = GetValuesWithoutMissingValues<DateTime>(columnIndex, considerSelection);
661        return values.Any() ? Convert<T>(AggregateAsDouble(values, Enumerable.Average)) : emptyValue;
662      }
663
664      throw new InvalidOperationException(typeof(T) + " not supported");
665    }
666
667    public T GetMedian<T>(int columnIndex, bool considerSelection = false, T emptyValue = default(T)) where T : IComparable<T> {
668      if (typeof(T) == typeof(double)) {// IEnumerable<double> is faster 
669        var doubleValues = GetValuesWithoutMissingValues<double>(columnIndex, considerSelection);
670        return doubleValues.Any() ? Convert<T>(doubleValues.Median()) : emptyValue;
671      }
672      var values = GetValuesWithoutMissingValues<T>(columnIndex, considerSelection);
673      return values.Any() ? values.Quantile(0.5) : emptyValue;
674    }
675
676    public T GetMode<T>(int columnIndex, bool considerSelection = false, T emptyValue = default(T)) where T : IEquatable<T> {
677      var values = GetValuesWithoutMissingValues<T>(columnIndex, considerSelection);
678      return values.Any() ? values.GroupBy(x => x).OrderByDescending(g => g.Count()).Select(g => g.Key).First() : emptyValue;
679    }
680
681    public T GetStandardDeviation<T>(int columnIndex, bool considerSelection = false, T emptyValue = default(T)) {
682      if (typeof(T) == typeof(double)) {
683        var values = GetValuesWithoutMissingValues<double>(columnIndex, considerSelection);
684        return values.Any() ? Convert<T>(values.StandardDeviation()) : emptyValue;
685      }
686      // For DateTime, std.dev / variance would have to be TimeSpan
687      //if (typeof(T) == typeof(DateTime)) {
688      //  var values = GetValuesWithoutMissingValues<DateTime>(columnIndex, considerSelection);
689      //  return values.Any() ? Convert<T>(AggregateAsDouble(values, EnumerableStatisticExtensions.StandardDeviation)) : emptyValue;
690      //}
691      return default(T);
692    }
693
694    public T GetVariance<T>(int columnIndex, bool considerSelection = false, T emptyValue = default(T)) {
695      if (typeof(T) == typeof(double)) {
696        var values = GetValuesWithoutMissingValues<double>(columnIndex, considerSelection);
697        return values.Any() ? Convert<T>(values.Variance()) : emptyValue;
698      }
699      // DateTime variance often overflows long, thus the corresponding DateTime is invalid
700      //if (typeof(T) == typeof(DateTime)) {
701      //  var values = GetValuesWithoutMissingValues<DateTime>(columnIndex, considerSelection);
702      //  return values.Any() ? Convert<T>(AggregateAsDouble(values, EnumerableStatisticExtensions.Variance)) : emptyValue;
703      //}
704      return default(T);
705    }
706
707    public T GetQuantile<T>(double alpha, int columnIndex, bool considerSelection = false, T emptyValue = default(T)) where T : IComparable<T> {
708      if (typeof(T) == typeof(double)) {// IEnumerable<double> is faster 
709        var doubleValues = GetValuesWithoutMissingValues<double>(columnIndex, considerSelection);
710        return doubleValues.Any() ? Convert<T>(doubleValues.Quantile(alpha)) : emptyValue;
711      }
712      var values = GetValuesWithoutMissingValues<T>(columnIndex, considerSelection);
713      return values.Any() ? values.Quantile(alpha) : emptyValue;
714    }
715
716    public int GetDistinctValues<T>(int columnIndex, bool considerSelection = false) {
717      var values = GetValuesWithoutMissingValues<T>(columnIndex, considerSelection);
718      return values.GroupBy(x => x).Count();
719    }
720
721    private IEnumerable<T> GetValuesWithoutMissingValues<T>(int columnIndex, bool considerSelection) {
722      //var doubleColumn = dataColumns[columnIndex] as DoublePreprocessingDataColumn;
723      //var stringColumn = dataColumns[columnIndex] as StringPreprocessingDataColumn;
724      //var dateTimeColumn = dataColumns[columnIndex] as DateTimePreprocessingDataColumn;
725      //return GetValues<T>(columnIndex, considerSelection).Where(x =>
726      //  doubleColumn != null ? doubleColumn.IsValidValue(Convert<double>(x))
727      //  : stringColumn != null ? stringColumn.IsValidValue(Convert<string>(x))
728      //  : dateTimeColumn != null ? dateTimeColumn.IsValidValue(Convert<DateTime>(x))
729      //  : false);
730      //!IsMissingValue(x));
731
732      return GetValues<T>(columnIndex, considerSelection).Where(x =>
733        ColumnTypeSwitchFuncResult<T, bool>(columnIndex, x,
734          (c, v) => v.HasValue && c.IsValidValue(v.Value),
735          (c, v) => c.IsValidValue(v),
736          (c, v) => v.HasValue && c.IsValidValue(v.Value)
737      ));
738    }
739
740    private static DateTime AggregateAsDouble(IEnumerable<DateTime> values, Func<IEnumerable<double>, double> func) {
741      return new DateTime((long)(func(values.Select(x => (double)x.Ticks / TimeSpan.TicksPerSecond)) * TimeSpan.TicksPerSecond));
742    }
743
744    public int GetMissingValueCount() {
745      int count = 0;
746      for (int i = 0; i < Columns; ++i) {
747        count += GetMissingValueCount(i);
748      }
749      return count;
750    }
751    public int GetMissingValueCount(int columnIndex) {
752      int sum = 0;
753      for (int i = 0; i < Rows; i++) {
754        if (IsCellEmpty(columnIndex, i))
755          sum++;
756      }
757      return sum;
758    }
759    public int GetRowMissingValueCount(int rowIndex) {
760      int sum = 0;
761      for (int i = 0; i < Columns; i++) {
762        if (IsCellEmpty(i, rowIndex))
763          sum++;
764      }
765      return sum;
766    }
767    #endregion
768
769    #region Helpers
770    private static IList<IList> CopyVariableValues(IList<IList> original) {
771      var copy = new List<IList>(original);
772      for (int i = 0; i < original.Count; ++i) {
773        copy[i] = (IList)Activator.CreateInstance(original[i].GetType(), original[i]);
774      }
775      return copy;
776    }
777    #endregion
778  }
779
780  // Adapted from HeuristicLab.Common.EnumerableStatisticExtensions
781  internal static class EnumerableExtensions {
782    public static T Quantile<T>(this IEnumerable<T> values, double alpha) where T : IComparable<T> {
783      T[] valuesArr = values.ToArray();
784      int n = valuesArr.Length;
785      if (n == 0) throw new InvalidOperationException("Enumeration contains no elements.");
786
787      var pos = n * alpha;
788
789      return Select((int)Math.Ceiling(pos) - 1, valuesArr);
790
791    }
792
793    private static T Select<T>(int k, T[] arr) where T : IComparable<T> {
794      int i, ir, j, l, mid, n = arr.Length;
795      T a;
796      l = 0;
797      ir = n - 1;
798      for (;;) {
799        if (ir <= l + 1) {
800          // Active partition contains 1 or 2 elements.
801          if (ir == l + 1 && arr[ir].CompareTo(arr[l]) < 0) {
802            // Case of 2 elements.
803            Swap(arr, l, ir);
804          }
805          return arr[k];
806        } else {
807          mid = (l + ir) >> 1; // Choose median of left, center, and right elements
808          Swap(arr, mid, l + 1); // as partitioning element a. Also
809
810          if (arr[l].CompareTo(arr[ir]) > 0) {  // rearrange so that arr[l] arr[ir] <= arr[l+1],
811            Swap(arr, l, ir); // . arr[ir] >= arr[l+1]
812          }
813
814          if (arr[l + 1].CompareTo(arr[ir]) > 0) {
815            Swap(arr, l + 1, ir);
816          }
817          if (arr[l].CompareTo(arr[l + 1]) > 0) {
818            Swap(arr, l, l + 1);
819          }
820          i = l + 1; // Initialize pointers for partitioning.
821          j = ir;
822          a = arr[l + 1]; // Partitioning element.
823          for (;;) { // Beginning of innermost loop.
824            do i++; while (arr[i].CompareTo(a) < 0); // Scan up to find element > a.
825            do j--; while (arr[j].CompareTo(a) > 0); // Scan down to find element < a.
826            if (j < i) break; // Pointers crossed. Partitioning complete.
827            Swap(arr, i, j);
828          } // End of innermost loop.
829          arr[l + 1] = arr[j]; // Insert partitioning element.
830          arr[j] = a;
831          if (j >= k) ir = j - 1; // Keep active the partition that contains the
832          if (j <= k) l = i; // kth element.
833        }
834      }
835    }
836
837    private static void Swap<T>(T[] arr, int i, int j) {
838      T temp = arr[i];
839      arr[i] = arr[j];
840      arr[j] = temp;
841    }
842  }
843}
Note: See TracBrowser for help on using the repository browser.