Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/14/11 13:59:25 (13 years ago)
Author:
epitzer
Message:

#1530 integrate changes from trunk

Location:
branches/PersistenceSpeedUp
Files:
41 edited
6 copied

Legend:

Unmodified
Added
Removed
  • branches/PersistenceSpeedUp

  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis

  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.3/HeuristicLabProblemsDataAnalysisPlugin.cs.frame

    r6099 r6760  
    2626
    2727namespace HeuristicLab.Problems.DataAnalysis {
    28   [Plugin("HeuristicLab.Problems.DataAnalysis", "3.3.4.$WCREV$")]
     28  [Plugin("HeuristicLab.Problems.DataAnalysis", "3.3.5.$WCREV$")]
    2929  [PluginFile("HeuristicLab.Problems.DataAnalysis-3.3.dll", PluginFileType.Assembly)]
    3030  [PluginDependency("HeuristicLab.ALGLIB", "3.1")]
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.3/Properties/AssemblyInfo.frame

    r6099 r6760  
    5353// by using the '*' as shown below:
    5454[assembly: AssemblyVersion("3.3.0.0")]
    55 [assembly: AssemblyFileVersion("3.3.4.$WCREV$")]
     55[assembly: AssemblyFileVersion("3.3.5.$WCREV$")]
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.3/Tests/Properties/AssemblyInfo.cs

    r5446 r6760  
    5252// by using the '*' as shown below:
    5353[assembly: AssemblyVersion("3.3.0.0")]
    54 [assembly: AssemblyFileVersion("3.3.3.0")]
     54[assembly: AssemblyFileVersion("3.3.5.0")]
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Dataset.cs

    r5847 r6760  
    2121
    2222using System;
     23using System.Collections;
    2324using System.Collections.Generic;
     25using System.Collections.ObjectModel;
    2426using System.Linq;
    2527using HeuristicLab.Common;
     
    3638    private Dataset(Dataset original, Cloner cloner)
    3739      : base(original, cloner) {
    38       variableNameToVariableIndexMapping = original.variableNameToVariableIndexMapping;
    39       data = original.data;
    40     }
    41     public override IDeepCloneable Clone(Cloner cloner) {
    42       return new Dataset(this, cloner);
    43     }
     40      variableValues = new Dictionary<string, IList>(original.variableValues);
     41      variableNames = new List<string>(original.variableNames);
     42      rows = original.rows;
     43    }
     44    public override IDeepCloneable Clone(Cloner cloner) { return new Dataset(this, cloner); }
    4445
    4546    public Dataset()
     
    4748      Name = "-";
    4849      VariableNames = Enumerable.Empty<string>();
    49       data = new double[0, 0];
    50     }
    51 
    52     public Dataset(IEnumerable<string> variableNames, double[,] data)
     50      variableValues = new Dictionary<string, IList>();
     51      rows = 0;
     52    }
     53
     54    public Dataset(IEnumerable<string> variableNames, IEnumerable<IList> variableValues)
    5355      : base() {
    5456      Name = "-";
    55       if (variableNames.Count() != data.GetLength(1)) {
    56         throw new ArgumentException("Number of variable names doesn't match the number of columns of data");
    57       }
    58       this.data = (double[,])data.Clone();
    59       VariableNames = variableNames;
    60     }
    61 
    62 
    63     private Dictionary<string, int> variableNameToVariableIndexMapping;
    64     private Dictionary<int, string> variableIndexToVariableNameMapping;
     57      if (!variableNames.Any()) {
     58        this.variableNames = Enumerable.Range(0, variableValues.Count()).Select(x => "Column " + x).ToList();
     59      } else if (variableNames.Count() != variableValues.Count()) {
     60        throw new ArgumentException("Number of variable names doesn't match the number of columns of variableValues");
     61      } else if (!variableValues.All(list => list.Count == variableValues.First().Count)) {
     62        throw new ArgumentException("The number of values must be equal for every variable");
     63      } else if (variableNames.Distinct().Count() != variableNames.Count()) {
     64        var duplicateVariableNames =
     65          variableNames.GroupBy(v => v).Where(g => g.Count() > 1).Select(g => g.Key).ToList();
     66        string message = "The dataset cannot contain duplicate variables names: " + Environment.NewLine;
     67        foreach (var duplicateVariableName in duplicateVariableNames)
     68          message += duplicateVariableName + Environment.NewLine;
     69        throw new ArgumentException(message);
     70      }
     71
     72      rows = variableValues.First().Count;
     73      this.variableNames = new List<string>(variableNames);
     74      this.variableValues = new Dictionary<string, IList>();
     75      for (int i = 0; i < this.variableNames.Count; i++) {
     76        var values = variableValues.ElementAt(i);
     77        IList clonedValues = null;
     78        if (values is List<double>)
     79          clonedValues = new List<double>(values.Cast<double>());
     80        else if (values is List<string>)
     81          clonedValues = new List<string>(values.Cast<string>());
     82        else if (values is List<DateTime>)
     83          clonedValues = new List<DateTime>(values.Cast<DateTime>());
     84        else {
     85          this.variableNames = new List<string>();
     86          this.variableValues = new Dictionary<string, IList>();
     87          throw new ArgumentException("The variable values must be of type List<double>, List<string> or List<DateTime>");
     88        }
     89        this.variableValues.Add(this.variableNames[i], clonedValues);
     90      }
     91    }
     92
     93    public Dataset(IEnumerable<string> variableNames, double[,] variableValues) {
     94      Name = "-";
     95      if (variableNames.Count() != variableValues.GetLength(1)) {
     96        throw new ArgumentException("Number of variable names doesn't match the number of columns of variableValues");
     97      }
     98      if (variableNames.Distinct().Count() != variableNames.Count()) {
     99        var duplicateVariableNames = variableNames.GroupBy(v => v).Where(g => g.Count() > 1).Select(g => g.Key).ToList();
     100        string message = "The dataset cannot contain duplicate variables names: " + Environment.NewLine;
     101        foreach (var duplicateVariableName in duplicateVariableNames)
     102          message += duplicateVariableName + Environment.NewLine;
     103        throw new ArgumentException(message);
     104      }
     105
     106      rows = variableValues.GetLength(0);
     107      this.variableNames = new List<string>(variableNames);
     108
     109      this.variableValues = new Dictionary<string, IList>();
     110      for (int col = 0; col < variableValues.GetLength(1); col++) {
     111        string columName = this.variableNames[col];
     112        var values = new List<double>();
     113        for (int row = 0; row < variableValues.GetLength(0); row++) {
     114          values.Add(variableValues[row, col]);
     115        }
     116        this.variableValues.Add(columName, values);
     117      }
     118    }
     119
     120    #region Backwards compatible code, remove with 3.5
     121    private double[,] storableData;
     122    //name alias used to suppport backwards compatibility
     123    [Storable(Name = "data", AllowOneWay = true)]
     124    private double[,] StorableData { set { storableData = value; } }
     125
     126    [StorableHook(HookType.AfterDeserialization)]
     127    private void AfterDeserialization() {
     128      if (variableValues == null) {
     129        rows = storableData.GetLength(0);
     130        variableValues = new Dictionary<string, IList>();
     131        for (int col = 0; col < storableData.GetLength(1); col++) {
     132          string columName = variableNames[col];
     133          var values = new List<double>();
     134          for (int row = 0; row < storableData.GetLength(0); row++) {
     135            values.Add(storableData[row, col]);
     136          }
     137          variableValues.Add(columName, values);
     138        }
     139        storableData = null;
     140      }
     141    }
     142    #endregion
     143
     144    [Storable(Name = "VariableValues")]
     145    private Dictionary<string, IList> variableValues;
     146
     147    private List<string> variableNames;
    65148    [Storable]
    66149    public IEnumerable<string> VariableNames {
    67       get {
    68         // convert KeyCollection to an array first for persistence
    69         return variableNameToVariableIndexMapping.Keys.ToArray();
    70       }
     150      get { return variableNames; }
    71151      private set {
    72         if (variableNameToVariableIndexMapping != null) throw new InvalidOperationException("VariableNames can only be set once.");
    73         this.variableNameToVariableIndexMapping = new Dictionary<string, int>();
    74         this.variableIndexToVariableNameMapping = new Dictionary<int, string>();
    75         int i = 0;
    76         foreach (string variableName in value) {
    77           this.variableNameToVariableIndexMapping.Add(variableName, i);
    78           this.variableIndexToVariableNameMapping.Add(i, variableName);
    79           i++;
    80         }
    81       }
    82     }
    83 
     152        if (variableNames != null) throw new InvalidOperationException();
     153        variableNames = new List<string>(value);
     154      }
     155    }
     156
     157    public IEnumerable<string> DoubleVariables {
     158      get { return variableValues.Where(p => p.Value is List<double>).Select(p => p.Key); }
     159    }
     160
     161    public IEnumerable<double> GetDoubleValues(string variableName) {
     162      IList list;
     163      if (!variableValues.TryGetValue(variableName, out list))
     164        throw new ArgumentException("The variable " + variableName + " does not exist in the dataset.");
     165      List<double> values = list as List<double>;
     166      if (values == null) throw new ArgumentException("The variable " + variableName + " is not a double variable.");
     167
     168      //mkommend yield return used to enable lazy evaluation
     169      foreach (double value in values)
     170        yield return value;
     171    }
     172    public ReadOnlyCollection<double> GetReadOnlyDoubleValues(string variableName) {
     173      IList list;
     174      if (!variableValues.TryGetValue(variableName, out list))
     175        throw new ArgumentException("The variable " + variableName + " does not exist in the dataset.");
     176      List<double> values = list as List<double>;
     177      if (values == null) throw new ArgumentException("The variable " + variableName + " is not a double variable.");
     178      return values.AsReadOnly();
     179    }
     180    public double GetDoubleValue(string variableName, int row) {
     181      IList list;
     182      if (!variableValues.TryGetValue(variableName, out list))
     183        throw new ArgumentException("The variable " + variableName + " does not exist in the dataset.");
     184      List<double> values = list as List<double>;
     185      if (values == null) throw new ArgumentException("The variable " + variableName + " is not a double variable.");
     186      return values[row];
     187    }
     188    public IEnumerable<double> GetDoubleValues(string variableName, IEnumerable<int> rows) {
     189      IList list;
     190      if (!variableValues.TryGetValue(variableName, out list))
     191        throw new ArgumentException("The variable " + variableName + " does not exist in the dataset.");
     192      List<double> values = list as List<double>;
     193      if (values == null) throw new ArgumentException("The varialbe " + variableName + " is not a double variable.");
     194
     195      foreach (int index in rows)
     196        yield return values[index];
     197    }
     198
     199    #region IStringConvertibleMatrix Members
    84200    [Storable]
    85     private double[,] data;
    86     private double[,] Data {
    87       get { return data; }
    88     }
    89 
    90     // elementwise access
    91     public double this[int rowIndex, int columnIndex] {
    92       get { return data[rowIndex, columnIndex]; }
    93     }
    94     public double this[string variableName, int rowIndex] {
    95       get {
    96         int columnIndex = GetVariableIndex(variableName);
    97         return data[rowIndex, columnIndex];
    98       }
    99     }
    100 
    101     public double[] GetVariableValues(int variableIndex) {
    102       return GetVariableValues(variableIndex, 0, Rows);
    103     }
    104     public double[] GetVariableValues(string variableName) {
    105       return GetVariableValues(GetVariableIndex(variableName), 0, Rows);
    106     }
    107     public double[] GetVariableValues(int variableIndex, int start, int end) {
    108       return GetEnumeratedVariableValues(variableIndex, start, end).ToArray();
    109     }
    110     public double[] GetVariableValues(string variableName, int start, int end) {
    111       return GetVariableValues(GetVariableIndex(variableName), start, end);
    112     }
    113 
    114     public IEnumerable<double> GetEnumeratedVariableValues(int variableIndex) {
    115       return GetEnumeratedVariableValues(variableIndex, 0, Rows);
    116     }
    117     public IEnumerable<double> GetEnumeratedVariableValues(int variableIndex, int start, int end) {
    118       if (start < 0 || !(start <= end))
    119         throw new ArgumentException("Start must be between 0 and end (" + end + ").");
    120       if (end > Rows || end < start)
    121         throw new ArgumentException("End must be between start (" + start + ") and dataset rows (" + Rows + ").");
    122 
    123       for (int i = start; i < end; i++)
    124         yield return data[i, variableIndex];
    125     }
    126     public IEnumerable<double> GetEnumeratedVariableValues(int variableIndex, IEnumerable<int> rows) {
    127       foreach (int row in rows)
    128         yield return data[row, variableIndex];
    129     }
    130 
    131     public IEnumerable<double> GetEnumeratedVariableValues(string variableName) {
    132       return GetEnumeratedVariableValues(GetVariableIndex(variableName), 0, Rows);
    133     }
    134     public IEnumerable<double> GetEnumeratedVariableValues(string variableName, int start, int end) {
    135       return GetEnumeratedVariableValues(GetVariableIndex(variableName), start, end);
    136     }
    137     public IEnumerable<double> GetEnumeratedVariableValues(string variableName, IEnumerable<int> rows) {
    138       return GetEnumeratedVariableValues(GetVariableIndex(variableName), rows);
    139     }
    140 
    141     public string GetVariableName(int variableIndex) {
    142       try {
    143         return variableIndexToVariableNameMapping[variableIndex];
    144       }
    145       catch (KeyNotFoundException ex) {
    146         throw new ArgumentException("The variable index " + variableIndex + " was not found.", ex);
    147       }
    148     }
    149     public int GetVariableIndex(string variableName) {
    150       try {
    151         return variableNameToVariableIndexMapping[variableName];
    152       }
    153       catch (KeyNotFoundException ex) {
    154         throw new ArgumentException("The variable name " + variableName + " was not found.", ex);
    155       }
    156     }
    157 
    158     #region IStringConvertibleMatrix Members
     201    private int rows;
    159202    public int Rows {
    160       get { return data.GetLength(0); }
     203      get { return rows; }
    161204      set { throw new NotSupportedException(); }
    162205    }
    163206    public int Columns {
    164       get { return data.GetLength(1); }
     207      get { return variableNames.Count; }
    165208      set { throw new NotSupportedException(); }
    166209    }
     
    184227
    185228    public string GetValue(int rowIndex, int columnIndex) {
    186       return data[rowIndex, columnIndex].ToString();
     229      return variableValues[variableNames[columnIndex]][rowIndex].ToString();
    187230    }
    188231    public bool SetValue(string value, int rowIndex, int columnIndex) {
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/HeuristicLab.Problems.DataAnalysis-3.4.csproj

    r6184 r6760  
    112112      <SubType>Code</SubType>
    113113    </Compile>
     114    <Compile Include="Implementation\Classification\ClassificationEnsembleSolution.cs" />
    114115    <Compile Include="Implementation\Classification\ClassificationProblemData.cs" />
    115116    <Compile Include="Implementation\Classification\ClassificationProblem.cs" />
    116117    <Compile Include="Implementation\Classification\ClassificationSolution.cs" />
     118    <Compile Include="Implementation\Classification\ClassificationEnsembleProblemData.cs" />
     119    <Compile Include="Implementation\Classification\ClassificationSolutionBase.cs" />
     120    <Compile Include="Implementation\Classification\DiscriminantFunctionClassificationSolutionBase.cs" />
    117121    <Compile Include="Implementation\Clustering\ClusteringProblem.cs" />
    118122    <Compile Include="Implementation\Clustering\ClusteringProblemData.cs" />
    119123    <Compile Include="Implementation\Clustering\ClusteringSolution.cs" />
     124    <Compile Include="Implementation\Regression\RegressionEnsembleProblemData.cs" />
    120125    <Compile Include="Implementation\Regression\RegressionEnsembleModel.cs">
    121126      <SubType>Code</SubType>
     
    125130      <SubType>Code</SubType>
    126131    </Compile>
    127     <Compile Include="Interfaces\Classification\IClassificationEnsembleSolution.cs" />
     132    <Compile Include="Interfaces\Classification\IClassificationEnsembleSolution.cs">
     133      <SubType>Code</SubType>
     134    </Compile>
    128135    <Compile Include="Interfaces\Classification\IDiscriminantFunctionThresholdCalculator.cs" />
    129136    <Compile Include="Interfaces\Regression\IRegressionEnsembleModel.cs">
     
    131138    </Compile>
    132139    <Compile Include="Interfaces\Regression\IRegressionEnsembleSolution.cs" />
     140    <Compile Include="Implementation\Regression\RegressionSolutionBase.cs" />
     141    <Compile Include="OnlineCalculators\OnlineMeanAbsoluteErrorCalculator.cs" />
    133142    <Compile Include="OnlineCalculators\OnlineLinearScalingParameterCalculator.cs" />
    134143    <Compile Include="Implementation\Classification\DiscriminantFunctionClassificationModel.cs" />
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/HeuristicLabProblemsDataAnalysisPlugin.cs.frame

    r5860 r6760  
    2626
    2727namespace HeuristicLab.Problems.DataAnalysis {
    28   [Plugin("HeuristicLab.Problems.DataAnalysis","Provides base classes for data analysis tasks.", "3.4.0.$WCREV$")]
     28  [Plugin("HeuristicLab.Problems.DataAnalysis","Provides base classes for data analysis tasks.", "3.4.1.$WCREV$")]
    2929  [PluginFile("HeuristicLab.Problems.DataAnalysis-3.4.dll", PluginFileType.Assembly)]
    3030  [PluginDependency("HeuristicLab.Collections", "3.3")]
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleModel.cs

    r5809 r6760  
    3939      get { return new List<IClassificationModel>(models); }
    4040    }
     41
    4142    [StorableConstructor]
    4243    protected ClassificationEnsembleModel(bool deserializing) : base(deserializing) { }
     
    4546      this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
    4647    }
     48
     49    public ClassificationEnsembleModel() : this(Enumerable.Empty<IClassificationModel>()) { }
    4750    public ClassificationEnsembleModel(IEnumerable<IClassificationModel> models)
    4851      : base() {
     
    5760
    5861    #region IClassificationEnsembleModel Members
     62    public void Add(IClassificationModel model) {
     63      models.Add(model);
     64    }
     65    public void Remove(IClassificationModel model) {
     66      models.Remove(model);
     67    }
    5968
    6069    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
     
    8594    }
    8695
     96    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
     97      return new ClassificationEnsembleSolution(models, problemData);
     98    }
    8799    #endregion
    88100  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r6184 r6760  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     25using HeuristicLab.Collections;
    2426using HeuristicLab.Common;
    2527using HeuristicLab.Core;
     28using HeuristicLab.Data;
    2629using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2730
     
    3235  [StorableClass]
    3336  [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")]
    34   // [Creatable("Data Analysis")]
    35   public class ClassificationEnsembleSolution : NamedItem, IClassificationEnsembleSolution {
     37  [Creatable("Data Analysis - Ensembles")]
     38  public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution {
     39    public new IClassificationEnsembleModel Model {
     40      get { return (IClassificationEnsembleModel)base.Model; }
     41    }
     42    public new ClassificationEnsembleProblemData ProblemData {
     43      get { return (ClassificationEnsembleProblemData)base.ProblemData; }
     44      set { base.ProblemData = value; }
     45    }
     46
     47    private readonly ItemCollection<IClassificationSolution> classificationSolutions;
     48    public IItemCollection<IClassificationSolution> ClassificationSolutions {
     49      get { return classificationSolutions; }
     50    }
    3651
    3752    [Storable]
    38     private List<IClassificationModel> models;
    39     public IEnumerable<IClassificationModel> Models {
    40       get { return new List<IClassificationModel>(models); }
    41     }
     53    private Dictionary<IClassificationModel, IntRange> trainingPartitions;
     54    [Storable]
     55    private Dictionary<IClassificationModel, IntRange> testPartitions;
     56
    4257    [StorableConstructor]
    43     protected ClassificationEnsembleSolution(bool deserializing) : base(deserializing) { }
    44     protected ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
     58    private ClassificationEnsembleSolution(bool deserializing)
     59      : base(deserializing) {
     60      classificationSolutions = new ItemCollection<IClassificationSolution>();
     61    }
     62    [StorableHook(HookType.AfterDeserialization)]
     63    private void AfterDeserialization() {
     64      foreach (var model in Model.Models) {
     65        IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone();
     66        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
     67        problemData.TrainingPartition.End = trainingPartitions[model].End;
     68        problemData.TestPartition.Start = testPartitions[model].Start;
     69        problemData.TestPartition.End = testPartitions[model].End;
     70
     71        classificationSolutions.Add(model.CreateClassificationSolution(problemData));
     72      }
     73      RegisterClassificationSolutionsEventHandler();
     74    }
     75
     76    private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
    4577      : base(original, cloner) {
    46       this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
    47     }
    48     public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models)
    49       : base() {
    50       this.name = ItemName;
    51       this.description = ItemDescription;
    52       this.models = new List<IClassificationModel>(models);
     78      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     79      testPartitions = new Dictionary<IClassificationModel, IntRange>();
     80      foreach (var pair in original.trainingPartitions) {
     81        trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     82      }
     83      foreach (var pair in original.testPartitions) {
     84        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     85      }
     86
     87      classificationSolutions = cloner.Clone(original.classificationSolutions);
     88      RegisterClassificationSolutionsEventHandler();
     89    }
     90
     91    public ClassificationEnsembleSolution()
     92      : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) {
     93      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     94      testPartitions = new Dictionary<IClassificationModel, IntRange>();
     95      classificationSolutions = new ItemCollection<IClassificationSolution>();
     96
     97      RegisterClassificationSolutionsEventHandler();
     98    }
     99
     100    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData)
     101      : this(models, problemData,
     102             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
     103             models.Select(m => (IntRange)problemData.TestPartition.Clone())
     104      ) { }
     105
     106    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
     107      : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) {
     108      this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     109      this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
     110      this.classificationSolutions = new ItemCollection<IClassificationSolution>();
     111
     112      List<IClassificationSolution> solutions = new List<IClassificationSolution>();
     113      var modelEnumerator = models.GetEnumerator();
     114      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
     115      var testPartitionEnumerator = testPartitions.GetEnumerator();
     116
     117      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
     118        var p = (IClassificationProblemData)problemData.Clone();
     119        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
     120        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
     121        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
     122        p.TestPartition.End = testPartitionEnumerator.Current.End;
     123
     124        solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p));
     125      }
     126      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
     127        throw new ArgumentException();
     128      }
     129
     130      RegisterClassificationSolutionsEventHandler();
     131      classificationSolutions.AddRange(solutions);
    53132    }
    54133
     
    56135      return new ClassificationEnsembleSolution(this, cloner);
    57136    }
    58 
    59     #region IClassificationEnsembleModel Members
     137    private void RegisterClassificationSolutionsEventHandler() {
     138      classificationSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsAdded);
     139      classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved);
     140      classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset);
     141    }
     142
     143    protected override void RecalculateResults() {
     144      CalculateResults();
     145    }
     146
     147    #region Evaluation
     148    public override IEnumerable<double> EstimatedTrainingClassValues {
     149      get {
     150        var rows = ProblemData.TrainingIndizes;
     151        var estimatedValuesEnumerators = (from model in Model.Models
     152                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     153                                         .ToList();
     154        var rowsEnumerator = rows.GetEnumerator();
     155        // aggregate to make sure that MoveNext is called for all enumerators
     156        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     157          int currentRow = rowsEnumerator.Current;
     158
     159          var selectedEnumerators = from pair in estimatedValuesEnumerators
     160                                    where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
     161                                    select pair.EstimatedValuesEnumerator;
     162          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     163        }
     164      }
     165    }
     166
     167    public override IEnumerable<double> EstimatedTestClassValues {
     168      get {
     169        var rows = ProblemData.TestIndizes;
     170        var estimatedValuesEnumerators = (from model in Model.Models
     171                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     172                                         .ToList();
     173        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
     174        // aggregate to make sure that MoveNext is called for all enumerators
     175        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     176          int currentRow = rowsEnumerator.Current;
     177
     178          var selectedEnumerators = from pair in estimatedValuesEnumerators
     179                                    where RowIsTestForModel(currentRow, pair.Model)
     180                                    select pair.EstimatedValuesEnumerator;
     181
     182          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     183        }
     184      }
     185    }
     186
     187    private bool RowIsTrainingForModel(int currentRow, IClassificationModel model) {
     188      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
     189              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
     190    }
     191
     192    private bool RowIsTestForModel(int currentRow, IClassificationModel model) {
     193      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     194              (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
     195    }
     196
     197    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     198      return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows)
     199             select AggregateEstimatedClassValues(xs);
     200    }
    60201
    61202    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
    62       var estimatedValuesEnumerators = (from model in models
     203      var estimatedValuesEnumerators = (from model in Model.Models
    63204                                        select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
    64205                                       .ToList();
     
    70211    }
    71212
     213    private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) {
     214      return estimatedClassValues
     215      .GroupBy(x => x)
     216      .OrderBy(g => -g.Count())
     217      .Select(g => g.Key)
     218      .DefaultIfEmpty(double.NaN)
     219      .First();
     220    }
    72221    #endregion
    73222
    74     #region IClassificationModel Members
    75 
    76     public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    77       foreach (var estimatedValuesVector in GetEstimatedClassValueVectors(dataset, rows)) {
    78         // return the class which is most often occuring
    79         yield return
    80           estimatedValuesVector
    81           .GroupBy(x => x)
    82           .OrderBy(g => -g.Count())
    83           .Select(g => g.Key)
    84           .First();
    85       }
    86     }
    87 
    88     #endregion
     223    protected override void OnProblemDataChanged() {
     224      IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset,
     225                                                                     ProblemData.AllowedInputVariables,
     226                                                                     ProblemData.TargetVariable);
     227      problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
     228      problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
     229      problemData.TestPartition.Start = ProblemData.TestPartition.Start;
     230      problemData.TestPartition.End = ProblemData.TestPartition.End;
     231
     232      foreach (var solution in ClassificationSolutions) {
     233        if (solution is ClassificationEnsembleSolution)
     234          solution.ProblemData = ProblemData;
     235        else
     236          solution.ProblemData = problemData;
     237      }
     238      foreach (var trainingPartition in trainingPartitions.Values) {
     239        trainingPartition.Start = ProblemData.TrainingPartition.Start;
     240        trainingPartition.End = ProblemData.TrainingPartition.End;
     241      }
     242      foreach (var testPartition in testPartitions.Values) {
     243        testPartition.Start = ProblemData.TestPartition.Start;
     244        testPartition.End = ProblemData.TestPartition.End;
     245      }
     246
     247      base.OnProblemDataChanged();
     248    }
     249
     250    public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
     251      classificationSolutions.AddRange(solutions);
     252    }
     253    public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
     254      classificationSolutions.RemoveRange(solutions);
     255    }
     256
     257    private void classificationSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     258      foreach (var solution in e.Items) AddClassificationSolution(solution);
     259      RecalculateResults();
     260    }
     261    private void classificationSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     262      foreach (var solution in e.Items) RemoveClassificationSolution(solution);
     263      RecalculateResults();
     264    }
     265    private void classificationSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     266      foreach (var solution in e.OldItems) RemoveClassificationSolution(solution);
     267      foreach (var solution in e.Items) AddClassificationSolution(solution);
     268      RecalculateResults();
     269    }
     270
     271    private void AddClassificationSolution(IClassificationSolution solution) {
     272      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
     273      Model.Add(solution.Model);
     274      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
     275      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
     276    }
     277
     278    private void RemoveClassificationSolution(IClassificationSolution solution) {
     279      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
     280      Model.Remove(solution.Model);
     281      trainingPartitions.Remove(solution.Model);
     282      testPartitions.Remove(solution.Model);
     283    }
    89284  }
    90285}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblemData.cs

    r6232 r6760  
    3434  [Item("ClassificationProblemData", "Represents an item containing all data defining a classification problem.")]
    3535  public class ClassificationProblemData : DataAnalysisProblemData, IClassificationProblemData {
    36     private const string TargetVariableParameterName = "TargetVariable";
    37     private const string ClassNamesParameterName = "ClassNames";
    38     private const string ClassificationPenaltiesParameterName = "ClassificationPenalties";
    39     private const int MaximumNumberOfClasses = 20;
    40     private const int InspectedRowsToDetermineTargets = 500;
     36    protected const string TargetVariableParameterName = "TargetVariable";
     37    protected const string ClassNamesParameterName = "ClassNames";
     38    protected const string ClassificationPenaltiesParameterName = "ClassificationPenalties";
     39    protected const int MaximumNumberOfClasses = 20;
     40    protected const int InspectedRowsToDetermineTargets = 500;
    4141
    4242    #region default data
     
    171171     {1176881,7,5,3,7,4,10,7,5,5,4        }
    172172};
    173     private static Dataset defaultDataset;
    174     private static IEnumerable<string> defaultAllowedInputVariables;
    175     private static string defaultTargetVariable;
     173    private static readonly Dataset defaultDataset;
     174    private static readonly IEnumerable<string> defaultAllowedInputVariables;
     175    private static readonly string defaultTargetVariable;
     176
     177    private static readonly ClassificationProblemData emptyProblemData;
     178    public static ClassificationProblemData EmptyProblemData {
     179      get { return EmptyProblemData; }
     180    }
     181
    176182    static ClassificationProblemData() {
    177183      defaultDataset = new Dataset(defaultVariableNames, defaultData);
     
    181187      defaultAllowedInputVariables = defaultVariableNames.Except(new List<string>() { "sample", "class" });
    182188      defaultTargetVariable = "class";
     189
     190      var problemData = new ClassificationProblemData();
     191      problemData.Parameters.Clear();
     192      problemData.Name = "Empty Classification ProblemData";
     193      problemData.Description = "This ProblemData acts as place holder before the correct problem data is loaded.";
     194      problemData.isEmpty = true;
     195
     196      problemData.Parameters.Add(new FixedValueParameter<Dataset>(DatasetParameterName, "", new Dataset()));
     197      problemData.Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemList<StringValue>>(InputVariablesParameterName, ""));
     198      problemData.Parameters.Add(new FixedValueParameter<IntRange>(TrainingPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly()));
     199      problemData.Parameters.Add(new FixedValueParameter<IntRange>(TestPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly()));
     200      problemData.Parameters.Add(new ConstrainedValueParameter<StringValue>(TargetVariableParameterName, new ItemSet<StringValue>()));
     201      problemData.Parameters.Add(new FixedValueParameter<StringMatrix>(ClassNamesParameterName, "", new StringMatrix(0, 0).AsReadOnly()));
     202      problemData.Parameters.Add(new FixedValueParameter<DoubleMatrix>(ClassificationPenaltiesParameterName, "", (DoubleMatrix)new DoubleMatrix(0, 0).AsReadOnly()));
     203      emptyProblemData = problemData;
    183204    }
    184205    #endregion
    185206
    186207    #region parameter properties
    187     public IValueParameter<StringValue> TargetVariableParameter {
    188       get { return (IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
     208    public ConstrainedValueParameter<StringValue> TargetVariableParameter {
     209      get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
    189210    }
    190211    public IFixedValueParameter<StringMatrix> ClassNamesParameter {
     
    205226      get {
    206227        if (classValues == null) {
    207           classValues = Dataset.GetEnumeratedVariableValues(TargetVariableParameter.Value.Value).Distinct().ToList();
     228          classValues = Dataset.GetDoubleValues(TargetVariableParameter.Value.Value).Distinct().ToList();
    208229          classValues.Sort();
    209230        }
     
    249270      RegisterParameterEvents();
    250271    }
    251     public override IDeepCloneable Clone(Cloner cloner) { return new ClassificationProblemData(this, cloner); }
     272    public override IDeepCloneable Clone(Cloner cloner) {
     273      if (this == emptyProblemData) return emptyProblemData;
     274      return new ClassificationProblemData(this, cloner);
     275    }
    252276
    253277    public ClassificationProblemData() : this(defaultDataset, defaultAllowedInputVariables, defaultTargetVariable) { }
     
    267291    private static IEnumerable<string> CheckVariablesForPossibleTargetVariables(Dataset dataset) {
    268292      int maxSamples = Math.Min(InspectedRowsToDetermineTargets, dataset.Rows);
    269       var validTargetVariables = from v in dataset.VariableNames
    270                                  let DistinctValues = dataset.GetVariableValues(v)
    271                                    .Take(maxSamples)
    272                                    .Distinct()
    273                                    .Count()
    274                                  where DistinctValues < MaximumNumberOfClasses
    275                                  select v;
     293      var validTargetVariables = (from v in dataset.DoubleVariables
     294                                  let distinctValues = dataset.GetDoubleValues(v)
     295                                    .Take(maxSamples)
     296                                    .Distinct()
     297                                    .Count()
     298                                  where distinctValues < MaximumNumberOfClasses
     299                                  select v).ToArray();
    276300
    277301      if (!validTargetVariables.Any())
     
    283307
    284308    private void ResetTargetVariableDependentMembers() {
    285       DergisterParameterEvents();
     309      DeregisterParameterEvents();
    286310
    287311      classNames = null;
     
    357381      ClassificationPenaltiesParameter.Value.ItemChanged += new EventHandler<EventArgs<int, int>>(MatrixParameter_ItemChanged);
    358382    }
    359     private void DergisterParameterEvents() {
     383    private void DeregisterParameterEvents() {
    360384      TargetVariableParameter.ValueChanged -= new EventHandler(TargetVariableParameter_ValueChanged);
    361385      ClassNamesParameter.Value.Reset -= new EventHandler(Parameter_ValueChanged);
     
    386410      dataset.Name = Path.GetFileName(fileName);
    387411
    388       ClassificationProblemData problemData = new ClassificationProblemData(dataset, dataset.VariableNames.Skip(1), dataset.VariableNames.First());
     412      ClassificationProblemData problemData = new ClassificationProblemData(dataset, dataset.DoubleVariables.Skip(1), dataset.DoubleVariables.First());
    389413      problemData.Name = "Data imported from " + Path.GetFileName(fileName);
    390414      return problemData;
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs

    r6184 r6760  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
    2524using HeuristicLab.Common;
    26 using HeuristicLab.Data;
    27 using HeuristicLab.Optimization;
    2825using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2926
     
    3330  /// </summary>
    3431  [StorableClass]
    35   public class ClassificationSolution : DataAnalysisSolution, IClassificationSolution {
    36     private const string TrainingAccuracyResultName = "Accuracy (training)";
    37     private const string TestAccuracyResultName = "Accuracy (test)";
    38 
    39     public new IClassificationModel Model {
    40       get { return (IClassificationModel)base.Model; }
    41       protected set { base.Model = value; }
    42     }
    43 
    44     public new IClassificationProblemData ProblemData {
    45       get { return (IClassificationProblemData)base.ProblemData; }
    46       protected set { base.ProblemData = value; }
    47     }
    48 
    49     public double TrainingAccuracy {
    50       get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; }
    51       private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; }
    52     }
    53 
    54     public double TestAccuracy {
    55       get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; }
    56       private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; }
    57     }
     32  public abstract class ClassificationSolution : ClassificationSolutionBase {
     33    protected readonly Dictionary<int, double> evaluationCache;
    5834
    5935    [StorableConstructor]
    60     protected ClassificationSolution(bool deserializing) : base(deserializing) { }
     36    protected ClassificationSolution(bool deserializing)
     37      : base(deserializing) {
     38      evaluationCache = new Dictionary<int, double>();
     39    }
    6140    protected ClassificationSolution(ClassificationSolution original, Cloner cloner)
    6241      : base(original, cloner) {
     42      evaluationCache = new Dictionary<int, double>(original.evaluationCache);
    6343    }
    6444    public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData)
    6545      : base(model, problemData) {
    66       Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));
    67       Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));
    68       RecalculateResults();
     46      evaluationCache = new Dictionary<int, double>();
    6947    }
    7048
    71     public override IDeepCloneable Clone(Cloner cloner) {
    72       return new ClassificationSolution(this, cloner);
     49    public override IEnumerable<double> EstimatedClassValues {
     50      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     51    }
     52    public override IEnumerable<double> EstimatedTrainingClassValues {
     53      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
     54    }
     55    public override IEnumerable<double> EstimatedTestClassValues {
     56      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
    7357    }
    7458
    75     protected override void OnProblemDataChanged(EventArgs e) {
    76       base.OnProblemDataChanged(e);
    77       RecalculateResults();
     59    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     60      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
     61      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     62      var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     63
     64      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     65        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     66      }
     67
     68      return rows.Select(row => evaluationCache[row]);
    7869    }
    7970
    80     protected override void OnModelChanged(EventArgs e) {
    81       base.OnModelChanged(e);
    82       RecalculateResults();
     71    protected override void OnProblemDataChanged() {
     72      evaluationCache.Clear();
     73      base.OnProblemDataChanged();
    8374    }
    8475
    85     protected void RecalculateResults() {
    86       double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values
    87       IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    88       double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values
    89       IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    90 
    91       OnlineCalculatorError errorState;
    92       double trainingAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues, out errorState);
    93       if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN;
    94       double testAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTestClassValues, originalTestClassValues, out errorState);
    95       if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN;
    96 
    97       TrainingAccuracy = trainingAccuracy;
    98       TestAccuracy = testAccuracy;
    99     }
    100 
    101     public virtual IEnumerable<double> EstimatedClassValues {
    102       get {
    103         return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
    104       }
    105     }
    106 
    107     public virtual IEnumerable<double> EstimatedTrainingClassValues {
    108       get {
    109         return GetEstimatedClassValues(ProblemData.TrainingIndizes);
    110       }
    111     }
    112 
    113     public virtual IEnumerable<double> EstimatedTestClassValues {
    114       get {
    115         return GetEstimatedClassValues(ProblemData.TestIndizes);
    116       }
    117     }
    118 
    119     public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    120       return Model.GetEstimatedClassValues(ProblemData.Dataset, rows);
     76    protected override void OnModelChanged() {
     77      evaluationCache.Clear();
     78      base.OnModelChanged();
    12179    }
    12280  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs

    r5809 r6760  
    3333  [StorableClass]
    3434  [Item("DiscriminantFunctionClassificationModel", "Represents a classification model that uses a discriminant function and classification thresholds.")]
    35   public class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel {
     35  public abstract class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel {
    3636    [Storable]
    3737    private IRegressionModel model;
     
    7070    }
    7171
    72     public override IDeepCloneable Clone(Cloner cloner) {
    73       return new DiscriminantFunctionClassificationModel(this, cloner);
    74     }
    75 
    7672    public void SetThresholdsAndClassValues(IEnumerable<double> thresholds, IEnumerable<double> classValues) {
    7773      var classValuesArr = classValues.ToArray();
     
    106102    }
    107103    #endregion
     104
     105    public abstract IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData);
     106    public abstract IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData);
    108107  }
    109108}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs

    r5942 r6760  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
     
    2625using HeuristicLab.Core;
    2726using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    28 using HeuristicLab.Data;
    29 using HeuristicLab.Optimization;
    3027
    3128namespace HeuristicLab.Problems.DataAnalysis {
     
    3532  [StorableClass]
    3633  [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")]
    37   public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution {
    38     private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
    39     private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
    40     private const string TrainingRSquaredResultName = "Pearson's R² (training)";
    41     private const string TestRSquaredResultName = "Pearson's R² (test)";
     34  public abstract class DiscriminantFunctionClassificationSolution : DiscriminantFunctionClassificationSolutionBase {
     35    protected readonly Dictionary<int, double> valueEvaluationCache;
     36    protected readonly Dictionary<int, double> classValueEvaluationCache;
    4237
    43     public new IDiscriminantFunctionClassificationModel Model {
    44       get { return (IDiscriminantFunctionClassificationModel)base.Model; }
    45       protected set {
    46         if (value != null && value != Model) {
    47           if (Model != null) {
    48             Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
    49           }
    50           value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
    51           base.Model = value;
    52         }
    53       }
     38    [StorableConstructor]
     39    protected DiscriminantFunctionClassificationSolution(bool deserializing)
     40      : base(deserializing) {
     41      valueEvaluationCache = new Dictionary<int, double>();
     42      classValueEvaluationCache = new Dictionary<int, double>();
     43    }
     44    protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
     45      : base(original, cloner) {
     46      valueEvaluationCache = new Dictionary<int, double>(original.valueEvaluationCache);
     47      classValueEvaluationCache = new Dictionary<int, double>(original.classValueEvaluationCache);
     48    }
     49    protected DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
     50      : base(model, problemData) {
     51      valueEvaluationCache = new Dictionary<int, double>();
     52      classValueEvaluationCache = new Dictionary<int, double>();
     53
     54      SetAccuracyMaximizingThresholds();
    5455    }
    5556
    56     public double TrainingMeanSquaredError {
    57       get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    58       private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
     57    public override IEnumerable<double> EstimatedClassValues {
     58      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     59    }
     60    public override IEnumerable<double> EstimatedTrainingClassValues {
     61      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
     62    }
     63    public override IEnumerable<double> EstimatedTestClassValues {
     64      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
    5965    }
    6066
    61     public double TestMeanSquaredError {
    62       get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    63       private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
     67    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     68      var rowsToEvaluate = rows.Except(classValueEvaluationCache.Keys);
     69      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     70      var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     71
     72      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     73        classValueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     74      }
     75
     76      return rows.Select(row => classValueEvaluationCache[row]);
    6477    }
    6578
    66     public double TrainingRSquared {
    67       get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }
    68       private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }
    69     }
    7079
    71     public double TestRSquared {
    72       get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; }
    73       private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; }
    74     }
    75 
    76     [StorableConstructor]
    77     protected DiscriminantFunctionClassificationSolution(bool deserializing) : base(deserializing) { }
    78     protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
    79       : base(original, cloner) {
    80       RegisterEventHandler();
    81     }
    82     public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)
    83       : this(new DiscriminantFunctionClassificationModel(model), problemData) {
    84     }
    85     public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
    86       : base(model, problemData) {
    87       Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
    88       Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
    89       Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
    90       Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
    91       RegisterEventHandler();
    92       SetAccuracyMaximizingThresholds();
    93       RecalculateResults();
    94     }
    95 
    96     [StorableHook(HookType.AfterDeserialization)]
    97     private void AfterDeserialization() {
    98       RegisterEventHandler();
    99     }
    100 
    101     protected new void RecalculateResults() {
    102       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    103       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    104       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    105       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    106 
    107       OnlineCalculatorError errorState;
    108       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    109       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    110       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    111       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    112 
    113       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    114       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    115       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    116       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    117     }
    118 
    119     private void RegisterEventHandler() {
    120       Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
    121     }
    122     private void Model_ThresholdsChanged(object sender, EventArgs e) {
    123       OnModelThresholdsChanged(e);
    124     }
    125 
    126     public void SetAccuracyMaximizingThresholds() {
    127       double[] classValues;
    128       double[] thresholds;
    129       var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    130       AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    131 
    132       Model.SetThresholdsAndClassValues(thresholds, classValues);
    133     }
    134 
    135     public void SetClassDistibutionCutPointThresholds() {
    136       double[] classValues;
    137       double[] thresholds;
    138       var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    139       NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    140 
    141       Model.SetThresholdsAndClassValues(thresholds, classValues);
    142     }
    143 
    144     protected override void OnModelChanged(EventArgs e) {
    145       base.OnModelChanged(e);
    146       SetAccuracyMaximizingThresholds();
    147       RecalculateResults();
    148     }
    149 
    150     protected override void OnProblemDataChanged(EventArgs e) {
    151       base.OnProblemDataChanged(e);
    152       SetAccuracyMaximizingThresholds();
    153       RecalculateResults();
    154     }
    155     protected virtual void OnModelThresholdsChanged(EventArgs e) {
    156       base.OnModelChanged(e);
    157       RecalculateResults();
    158     }
    159 
    160     public IEnumerable<double> EstimatedValues {
     80    public override IEnumerable<double> EstimatedValues {
    16181      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
    16282    }
    163 
    164     public IEnumerable<double> EstimatedTrainingValues {
     83    public override IEnumerable<double> EstimatedTrainingValues {
    16584      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
    16685    }
    167 
    168     public IEnumerable<double> EstimatedTestValues {
     86    public override IEnumerable<double> EstimatedTestValues {
    16987      get { return GetEstimatedValues(ProblemData.TestIndizes); }
    17088    }
    17189
    172     public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    173       return Model.GetEstimatedValues(ProblemData.Dataset, rows);
     90    public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
     91      var rowsToEvaluate = rows.Except(valueEvaluationCache.Keys);
     92      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     93      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     94
     95      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     96        valueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     97      }
     98
     99      return rows.Select(row => valueEvaluationCache[row]);
     100    }
     101
     102    protected override void OnModelChanged() {
     103      valueEvaluationCache.Clear();
     104      classValueEvaluationCache.Clear();
     105      base.OnModelChanged();
     106    }
     107    protected override void OnModelThresholdsChanged(System.EventArgs e) {
     108      classValueEvaluationCache.Clear();
     109      base.OnModelThresholdsChanged(e);
     110    }
     111    protected override void OnProblemDataChanged() {
     112      valueEvaluationCache.Clear();
     113      classValueEvaluationCache.Clear();
     114      base.OnProblemDataChanged();
    174115    }
    175116  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Clustering/ClusteringProblemData.cs

    r6228 r6760  
    9595      dataset.Name = Path.GetFileName(fileName);
    9696
    97       ClusteringProblemData problemData = new ClusteringProblemData(dataset, dataset.VariableNames);
     97      ClusteringProblemData problemData = new ClusteringProblemData(dataset, dataset.DoubleVariables);
    9898      problemData.Name = "Data imported from " + Path.GetFileName(fileName);
    9999      return problemData;
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Clustering/ClusteringSolution.cs

    r6184 r6760  
    4545    }
    4646
     47    protected override void RecalculateResults() {
     48    }
     49
    4750    #region IClusteringSolution Members
    4851
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisProblem.cs

    r5809 r6760  
    4848    public T ProblemData {
    4949      get { return ProblemDataParameter.Value; }
    50       protected set { ProblemDataParameter.Value = value; }
     50      protected set {
     51        ProblemDataParameter.Value = value;
     52      }
    5153    }
    5254    #endregion
    5355    protected DataAnalysisProblem(DataAnalysisProblem<T> original, Cloner cloner)
    5456      : base(original, cloner) {
     57      RegisterEventHandlers();
    5558    }
    5659    [StorableConstructor]
     
    5962      : base() {
    6063      Parameters.Add(new ValueParameter<T>(ProblemDataParameterName, ProblemDataParameterDescription));
     64      RegisterEventHandlers();
     65    }
     66
     67    [StorableHook(HookType.AfterDeserialization)]
     68    private void AfterDeserialization() {
     69      RegisterEventHandlers();
    6170    }
    6271
    6372    private void RegisterEventHandlers() {
    64       ProblemDataParameter.Value.Changed += new EventHandler(ProblemDataParameter_ValueChanged);
     73      ProblemDataParameter.ValueChanged += new EventHandler(ProblemDataParameter_ValueChanged);
     74      if (ProblemDataParameter.Value != null) ProblemDataParameter.Value.Changed += new EventHandler(ProblemData_Changed);
    6575    }
     76
    6677    private void ProblemDataParameter_ValueChanged(object sender, EventArgs e) {
     78      ProblemDataParameter.Value.Changed += new EventHandler(ProblemData_Changed);
    6779      OnProblemDataChanged();
     80      OnReset();
     81    }
     82
     83    private void ProblemData_Changed(object sender, EventArgs e) {
    6884      OnReset();
    6985    }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisProblemData.cs

    r5847 r6760  
    3333  [StorableClass]
    3434  public abstract class DataAnalysisProblemData : ParameterizedNamedItem, IDataAnalysisProblemData {
    35     private const string DatasetParameterName = "Dataset";
    36     private const string InputVariablesParameterName = "InputVariables";
    37     private const string TrainingPartitionParameterName = "TrainingPartition";
    38     private const string TestPartitionParameterName = "TestPartition";
     35    protected const string DatasetParameterName = "Dataset";
     36    protected const string InputVariablesParameterName = "InputVariables";
     37    protected const string TrainingPartitionParameterName = "TrainingPartition";
     38    protected const string TestPartitionParameterName = "TestPartition";
    3939
    4040    #region parameter properites
     
    5353    #endregion
    5454
    55     #region propeties
     55    #region properties
     56    protected bool isEmpty = false;
     57    public bool IsEmpty {
     58      get { return isEmpty; }
     59    }
    5660    public Dataset Dataset {
    5761      get { return DatasetParameter.Value; }
     
    7175    }
    7276
    73     public IEnumerable<int> TrainingIndizes {
     77    public virtual IEnumerable<int> TrainingIndizes {
    7478      get {
    7579        return Enumerable.Range(TrainingPartition.Start, TrainingPartition.End - TrainingPartition.Start)
    76                          .Where(i => i >= 0 && i < Dataset.Rows && (i < TestPartition.Start || TestPartition.End <= i));
     80                         .Where(IsTrainingSample);
    7781      }
    7882    }
    79     public IEnumerable<int> TestIndizes {
     83    public virtual IEnumerable<int> TestIndizes {
    8084      get {
    8185        return Enumerable.Range(TestPartition.Start, TestPartition.End - TestPartition.Start)
    82            .Where(i => i >= 0 && i < Dataset.Rows);
     86           .Where(IsTestSample);
    8387      }
     88    }
     89
     90    public virtual bool IsTrainingSample(int index) {
     91      return index >= 0 && index < Dataset.Rows &&
     92        TrainingPartition.Start <= index && index < TrainingPartition.End &&
     93        (index < TestPartition.Start || TestPartition.End <= index);
     94    }
     95
     96    public virtual bool IsTestSample(int index) {
     97      return index >= 0 && index < Dataset.Rows &&
     98             TestPartition.Start <= index && index < TestPartition.End;
    8499    }
    85100    #endregion
    86101
    87     protected DataAnalysisProblemData(DataAnalysisProblemData original, Cloner cloner) : base(original, cloner) { }
     102    protected DataAnalysisProblemData(DataAnalysisProblemData original, Cloner cloner)
     103      : base(original, cloner) {
     104      isEmpty = original.isEmpty;
     105      RegisterEventHandlers();
     106    }
    88107    [StorableConstructor]
    89108    protected DataAnalysisProblemData(bool deserializing) : base(deserializing) { }
     109    [StorableHook(HookType.AfterDeserialization)]
     110    private void AfterDeserialization() {
     111      RegisterEventHandlers();
     112    }
    90113
    91114    protected DataAnalysisProblemData(Dataset dataset, IEnumerable<string> allowedInputVariables) {
     
    93116      if (allowedInputVariables == null) throw new ArgumentNullException("The allowedInputVariables must not be null.");
    94117
    95       if (allowedInputVariables.Except(dataset.VariableNames).Any())
    96         throw new ArgumentException("All allowed input variables must be present in the dataset.");
     118      if (allowedInputVariables.Except(dataset.DoubleVariables).Any())
     119        throw new ArgumentException("All allowed input variables must be present in the dataset and of type double.");
    97120
    98       var inputVariables = new CheckedItemList<StringValue>(dataset.VariableNames.Select(x => new StringValue(x)));
     121      var inputVariables = new CheckedItemList<StringValue>(dataset.DoubleVariables.Select(x => new StringValue(x)));
    99122      foreach (StringValue x in inputVariables)
    100123        inputVariables.SetItemCheckedState(x, allowedInputVariables.Contains(x.Value));
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisSolution.cs

    r5914 r6760  
    4848          if (value != null) {
    4949            this[ModelResultName].Value = value;
    50             OnModelChanged(EventArgs.Empty);
     50            OnModelChanged();
    5151          }
    5252        }
     
    5656    public IDataAnalysisProblemData ProblemData {
    5757      get { return (IDataAnalysisProblemData)this[ProblemDataResultName].Value; }
    58       protected set {
     58      set {
    5959        if (this[ProblemDataResultName].Value != value) {
    6060          if (value != null) {
     
    6262            this[ProblemDataResultName].Value = value;
    6363            ProblemData.Changed += new EventHandler(ProblemData_Changed);
    64             OnProblemDataChanged(EventArgs.Empty);
     64            OnProblemDataChanged();
    6565          }
    6666        }
     
    8080      name = ItemName;
    8181      description = ItemDescription;
    82       Add(new Result(ModelResultName, "The symbolic data analysis model.", model));
    83       Add(new Result(ProblemDataResultName, "The symbolic data analysis problem data.", problemData));
     82      Add(new Result(ModelResultName, "The data analysis model.", model));
     83      Add(new Result(ProblemDataResultName, "The data analysis problem data.", problemData));
    8484
    8585      problemData.Changed += new EventHandler(ProblemData_Changed);
    8686    }
    8787
     88    protected abstract void RecalculateResults();
     89
    8890    private void ProblemData_Changed(object sender, EventArgs e) {
    89       OnProblemDataChanged(e);
     91      OnProblemDataChanged();
    9092    }
    9193
    9294    public event EventHandler ModelChanged;
    93     protected virtual void OnModelChanged(EventArgs e) {
     95    protected virtual void OnModelChanged() {
     96      RecalculateResults();
    9497      var listeners = ModelChanged;
    95       if (listeners != null) listeners(this, e);
     98      if (listeners != null) listeners(this, EventArgs.Empty);
    9699    }
    97100
    98101    public event EventHandler ProblemDataChanged;
    99     protected virtual void OnProblemDataChanged(EventArgs e) {
     102    protected virtual void OnProblemDataChanged() {
     103      RecalculateResults();
    100104      var listeners = ProblemDataChanged;
    101       if (listeners != null) listeners(this, e);
     105      if (listeners != null) listeners(this, EventArgs.Empty);
    102106    }
    103107
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r5809 r6760  
    3434  public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
    3535
    36     [Storable]
    3736    private List<IRegressionModel> models;
    3837    public IEnumerable<IRegressionModel> Models {
    3938      get { return new List<IRegressionModel>(models); }
    4039    }
     40
     41    [Storable(Name = "Models")]
     42    private IEnumerable<IRegressionModel> StorableModels {
     43      get { return models; }
     44      set { models = value.ToList(); }
     45    }
     46
     47    #region backwards compatiblity 3.3.5
     48    [Storable(Name = "models", AllowOneWay = true)]
     49    private List<IRegressionModel> OldStorableModels {
     50      set { models = value; }
     51    }
     52    #endregion
     53
    4154    [StorableConstructor]
    4255    protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
     
    4558      this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
    4659    }
     60
     61    public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { }
    4762    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models)
    4863      : base() {
     
    5772
    5873    #region IRegressionEnsembleModel Members
     74
     75    public void Add(IRegressionModel model) {
     76      models.Add(model);
     77    }
     78    public void Remove(IRegressionModel model) {
     79      models.Remove(model);
     80    }
    5981
    6082    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(Dataset dataset, IEnumerable<int> rows) {
     
    79101    }
    80102
     103    public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     104      return new RegressionEnsembleSolution(this.Models, problemData);
     105    }
     106    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
     107      return CreateRegressionSolution(problemData);
     108    }
     109
    81110    #endregion
    82111  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r6184 r6760  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     25using HeuristicLab.Collections;
    2426using HeuristicLab.Common;
    2527using HeuristicLab.Core;
     28using HeuristicLab.Data;
    2629using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    27 using System;
    28 using HeuristicLab.Data;
    2930
    3031namespace HeuristicLab.Problems.DataAnalysis {
     
    3435  [StorableClass]
    3536  [Item("Regression Ensemble Solution", "A regression solution that contains an ensemble of multiple regression models")]
    36   // [Creatable("Data Analysis")]
    37   public class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {
     37  [Creatable("Data Analysis - Ensembles")]
     38  public sealed class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {
    3839    public new IRegressionEnsembleModel Model {
    3940      get { return (IRegressionEnsembleModel)base.Model; }
     41    }
     42
     43    public new RegressionEnsembleProblemData ProblemData {
     44      get { return (RegressionEnsembleProblemData)base.ProblemData; }
     45      set { base.ProblemData = value; }
     46    }
     47
     48    private readonly ItemCollection<IRegressionSolution> regressionSolutions;
     49    public IItemCollection<IRegressionSolution> RegressionSolutions {
     50      get { return regressionSolutions; }
    4051    }
    4152
     
    4657
    4758    [StorableConstructor]
    48     protected RegressionEnsembleSolution(bool deserializing) : base(deserializing) { }
    49     protected RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner)
     59    private RegressionEnsembleSolution(bool deserializing)
     60      : base(deserializing) {
     61      regressionSolutions = new ItemCollection<IRegressionSolution>();
     62    }
     63    [StorableHook(HookType.AfterDeserialization)]
     64    private void AfterDeserialization() {
     65      foreach (var model in Model.Models) {
     66        IRegressionProblemData problemData = (IRegressionProblemData) ProblemData.Clone();
     67        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
     68        problemData.TrainingPartition.End = trainingPartitions[model].End;
     69        problemData.TestPartition.Start = testPartitions[model].Start;
     70        problemData.TestPartition.End = testPartitions[model].End;
     71
     72        regressionSolutions.Add(model.CreateRegressionSolution(problemData));
     73      }
     74      RegisterRegressionSolutionsEventHandler();
     75    }
     76
     77    private RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner)
    5078      : base(original, cloner) {
    51     }
    52     public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
    53       : base(new RegressionEnsembleModel(models), problemData) {
    5479      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    5580      testPartitions = new Dictionary<IRegressionModel, IntRange>();
    56       foreach (var model in models) {
    57         trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone();
    58         testPartitions[model] = (IntRange)problemData.TestPartition.Clone();
    59       }
    60       RecalculateResults();
    61     }
     81      foreach (var pair in original.trainingPartitions) {
     82        trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     83      }
     84      foreach (var pair in original.testPartitions) {
     85        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     86      }
     87
     88      regressionSolutions = cloner.Clone(original.regressionSolutions);
     89      RegisterRegressionSolutionsEventHandler();
     90    }
     91
     92    public RegressionEnsembleSolution()
     93      : base(new RegressionEnsembleModel(), RegressionEnsembleProblemData.EmptyProblemData) {
     94      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
     95      testPartitions = new Dictionary<IRegressionModel, IntRange>();
     96      regressionSolutions = new ItemCollection<IRegressionSolution>();
     97
     98      RegisterRegressionSolutionsEventHandler();
     99    }
     100
     101    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
     102      : this(models, problemData,
     103             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
     104             models.Select(m => (IntRange)problemData.TestPartition.Clone())
     105      ) { }
    62106
    63107    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
    64       : base(new RegressionEnsembleModel(models), problemData) {
     108      : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) {
    65109      this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    66110      this.testPartitions = new Dictionary<IRegressionModel, IntRange>();
     111      this.regressionSolutions = new ItemCollection<IRegressionSolution>();
     112
     113      List<IRegressionSolution> solutions = new List<IRegressionSolution>();
    67114      var modelEnumerator = models.GetEnumerator();
    68115      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
    69116      var testPartitionEnumerator = testPartitions.GetEnumerator();
     117
    70118      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
    71         this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone();
    72         this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone();
     119        var p = (IRegressionProblemData)problemData.Clone();
     120        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
     121        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
     122        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
     123        p.TestPartition.End = testPartitionEnumerator.Current.End;
     124
     125        solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p));
    73126      }
    74127      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
     
    76129      }
    77130
    78       RecalculateResults();
    79     }
    80 
    81     private void RecalculateResults() {
    82       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    83       var trainingIndizes = Enumerable.Range(ProblemData.TrainingPartition.Start,
    84         ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);
    85       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, trainingIndizes);
    86       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    87       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    88 
    89       OnlineCalculatorError errorState;
    90       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    91       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    92       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    93       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    94 
    95       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    96       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    97       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    98       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    99 
    100       double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    101       TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;
    102       double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    103       TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;
    104 
    105       double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    106       TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN;
    107       double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    108       TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN;
     131      RegisterRegressionSolutionsEventHandler();
     132      regressionSolutions.AddRange(solutions);
    109133    }
    110134
     
    112136      return new RegressionEnsembleSolution(this, cloner);
    113137    }
    114 
     138    private void RegisterRegressionSolutionsEventHandler() {
     139      regressionSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsAdded);
     140      regressionSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsRemoved);
     141      regressionSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_CollectionReset);
     142    }
     143
     144    protected override void RecalculateResults() {
     145      CalculateResults();
     146    }
     147
     148    #region Evaluation
    115149    public override IEnumerable<double> EstimatedTrainingValues {
    116150      get {
    117         var rows = Enumerable.Range(ProblemData.TrainingPartition.Start, ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);
     151        var rows = ProblemData.TrainingIndizes;
    118152        var estimatedValuesEnumerators = (from model in Model.Models
    119153                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    120154                                         .ToList();
    121155        var rowsEnumerator = rows.GetEnumerator();
     156        // aggregate to make sure that MoveNext is called for all enumerators
    122157        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    123158          int currentRow = rowsEnumerator.Current;
    124159
    125160          var selectedEnumerators = from pair in estimatedValuesEnumerators
    126                                     where trainingPartitions == null || !trainingPartitions.ContainsKey(pair.Model) ||
    127                                          (trainingPartitions[pair.Model].Start <= currentRow && currentRow < trainingPartitions[pair.Model].End)
     161                                    where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
    128162                                    select pair.EstimatedValuesEnumerator;
    129163          yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
     
    134168    public override IEnumerable<double> EstimatedTestValues {
    135169      get {
     170        var rows = ProblemData.TestIndizes;
    136171        var estimatedValuesEnumerators = (from model in Model.Models
    137                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, ProblemData.TestIndizes).GetEnumerator() })
     172                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    138173                                         .ToList();
    139174        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
     175        // aggregate to make sure that MoveNext is called for all enumerators
    140176        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    141177          int currentRow = rowsEnumerator.Current;
    142178
    143179          var selectedEnumerators = from pair in estimatedValuesEnumerators
    144                                     where testPartitions == null || !testPartitions.ContainsKey(pair.Model) ||
    145                                       (testPartitions[pair.Model].Start <= currentRow && currentRow < testPartitions[pair.Model].End)
     180                                    where RowIsTestForModel(currentRow, pair.Model)
    146181                                    select pair.EstimatedValuesEnumerator;
    147182
     
    149184        }
    150185      }
     186    }
     187
     188    private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) {
     189      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
     190              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
     191    }
     192
     193    private bool RowIsTestForModel(int currentRow, IRegressionModel model) {
     194      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     195              (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
    151196    }
    152197
     
    168213
    169214    private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) {
    170       return estimatedValues.Average();
    171     }
    172 
    173     //[Storable]
    174     //private string name;
    175     //public string Name {
    176     //  get {
    177     //    return name;
    178     //  }
    179     //  set {
    180     //    if (value != null && value != name) {
    181     //      var cancelEventArgs = new CancelEventArgs<string>(value);
    182     //      OnNameChanging(cancelEventArgs);
    183     //      if (cancelEventArgs.Cancel == false) {
    184     //        name = value;
    185     //        OnNamedChanged(EventArgs.Empty);
    186     //      }
    187     //    }
    188     //  }
    189     //}
    190 
    191     //public bool CanChangeName {
    192     //  get { return true; }
    193     //}
    194 
    195     //[Storable]
    196     //private string description;
    197     //public string Description {
    198     //  get {
    199     //    return description;
    200     //  }
    201     //  set {
    202     //    if (value != null && value != description) {
    203     //      description = value;
    204     //      OnDescriptionChanged(EventArgs.Empty);
    205     //    }
    206     //  }
    207     //}
    208 
    209     //public bool CanChangeDescription {
    210     //  get { return true; }
    211     //}
    212 
    213     //#region events
    214     //public event EventHandler<CancelEventArgs<string>> NameChanging;
    215     //private void OnNameChanging(CancelEventArgs<string> cancelEventArgs) {
    216     //  var listener = NameChanging;
    217     //  if (listener != null) listener(this, cancelEventArgs);
    218     //}
    219 
    220     //public event EventHandler NameChanged;
    221     //private void OnNamedChanged(EventArgs e) {
    222     //  var listener = NameChanged;
    223     //  if (listener != null) listener(this, e);
    224     //}
    225 
    226     //public event EventHandler DescriptionChanged;
    227     //private void OnDescriptionChanged(EventArgs e) {
    228     //  var listener = DescriptionChanged;
    229     //  if (listener != null) listener(this, e);
    230     //}
    231     // #endregion
     215      return estimatedValues.DefaultIfEmpty(double.NaN).Average();
     216    }
     217    #endregion
     218
     219    protected override void OnProblemDataChanged() {
     220      IRegressionProblemData problemData = new RegressionProblemData(ProblemData.Dataset,
     221                                                                     ProblemData.AllowedInputVariables,
     222                                                                     ProblemData.TargetVariable);
     223      problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
     224      problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
     225      problemData.TestPartition.Start = ProblemData.TestPartition.Start;
     226      problemData.TestPartition.End = ProblemData.TestPartition.End;
     227
     228      foreach (var solution in RegressionSolutions) {
     229        if (solution is RegressionEnsembleSolution)
     230          solution.ProblemData = ProblemData;
     231        else
     232          solution.ProblemData = problemData;
     233      }
     234      foreach (var trainingPartition in trainingPartitions.Values) {
     235        trainingPartition.Start = ProblemData.TrainingPartition.Start;
     236        trainingPartition.End = ProblemData.TrainingPartition.End;
     237      }
     238      foreach (var testPartition in testPartitions.Values) {
     239        testPartition.Start = ProblemData.TestPartition.Start;
     240        testPartition.End = ProblemData.TestPartition.End;
     241      }
     242
     243      base.OnProblemDataChanged();
     244    }
     245
     246    public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     247      regressionSolutions.AddRange(solutions);
     248    }
     249    public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     250      regressionSolutions.RemoveRange(solutions);
     251    }
     252
     253    private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
     254      foreach (var solution in e.Items) AddRegressionSolution(solution);
     255      RecalculateResults();
     256    }
     257    private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
     258      foreach (var solution in e.Items) RemoveRegressionSolution(solution);
     259      RecalculateResults();
     260    }
     261    private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
     262      foreach (var solution in e.OldItems) RemoveRegressionSolution(solution);
     263      foreach (var solution in e.Items) AddRegressionSolution(solution);
     264      RecalculateResults();
     265    }
     266
     267    private void AddRegressionSolution(IRegressionSolution solution) {
     268      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
     269      Model.Add(solution.Model);
     270      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
     271      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
     272    }
     273
     274    private void RemoveRegressionSolution(IRegressionSolution solution) {
     275      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
     276      Model.Remove(solution.Model);
     277      trainingPartitions.Remove(solution.Model);
     278      testPartitions.Remove(solution.Model);
     279    }
    232280  }
    233281}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs

    r5809 r6760  
    3333  [StorableClass]
    3434  [Item("RegressionProblemData", "Represents an item containing all data defining a regression problem.")]
    35   public sealed class RegressionProblemData : DataAnalysisProblemData, IRegressionProblemData {
    36     private const string TargetVariableParameterName = "TargetVariable";
     35  public class RegressionProblemData : DataAnalysisProblemData, IRegressionProblemData {
     36    protected const string TargetVariableParameterName = "TargetVariable";
    3737
    3838    #region default data
     
    6464          {0.83763905,  0.468046718}
    6565    };
    66     private static Dataset defaultDataset;
    67     private static IEnumerable<string> defaultAllowedInputVariables;
    68     private static string defaultTargetVariable;
     66    private static readonly Dataset defaultDataset;
     67    private static readonly IEnumerable<string> defaultAllowedInputVariables;
     68    private static readonly string defaultTargetVariable;
     69
     70    private static readonly RegressionProblemData emptyProblemData;
     71    public static RegressionProblemData EmptyProblemData {
     72      get { return emptyProblemData; }
     73    }
    6974
    7075    static RegressionProblemData() {
     
    7479      defaultAllowedInputVariables = new List<string>() { "x" };
    7580      defaultTargetVariable = "y";
     81
     82      var problemData = new RegressionProblemData();
     83      problemData.Parameters.Clear();
     84      problemData.Name = "Empty Regression ProblemData";
     85      problemData.Description = "This ProblemData acts as place holder before the correct problem data is loaded.";
     86      problemData.isEmpty = true;
     87
     88      problemData.Parameters.Add(new FixedValueParameter<Dataset>(DatasetParameterName, "", new Dataset()));
     89      problemData.Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemList<StringValue>>(InputVariablesParameterName, ""));
     90      problemData.Parameters.Add(new FixedValueParameter<IntRange>(TrainingPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly()));
     91      problemData.Parameters.Add(new FixedValueParameter<IntRange>(TestPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly()));
     92      problemData.Parameters.Add(new ConstrainedValueParameter<StringValue>(TargetVariableParameterName, new ItemSet<StringValue>()));
     93      emptyProblemData = problemData;
    7694    }
    7795    #endregion
    7896
    79     public IValueParameter<StringValue> TargetVariableParameter {
    80       get { return (IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
     97    public ConstrainedValueParameter<StringValue> TargetVariableParameter {
     98      get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
    8199    }
    82100    public string TargetVariable {
     
    85103
    86104    [StorableConstructor]
    87     private RegressionProblemData(bool deserializing) : base(deserializing) { }
     105    protected RegressionProblemData(bool deserializing) : base(deserializing) { }
    88106    [StorableHook(HookType.AfterDeserialization)]
    89107    private void AfterDeserialization() {
     
    91109    }
    92110
    93 
    94     private RegressionProblemData(RegressionProblemData original, Cloner cloner)
     111    protected RegressionProblemData(RegressionProblemData original, Cloner cloner)
    95112      : base(original, cloner) {
    96113      RegisterParameterEvents();
    97114    }
    98     public override IDeepCloneable Clone(Cloner cloner) { return new RegressionProblemData(this, cloner); }
     115    public override IDeepCloneable Clone(Cloner cloner) {
     116      if (this == emptyProblemData) return emptyProblemData;
     117      return new RegressionProblemData(this, cloner);
     118    }
    99119
    100120    public RegressionProblemData()
     
    124144      dataset.Name = Path.GetFileName(fileName);
    125145
    126       RegressionProblemData problemData = new RegressionProblemData(dataset, dataset.VariableNames.Skip(1), dataset.VariableNames.First());
     146      RegressionProblemData problemData = new RegressionProblemData(dataset, dataset.DoubleVariables.Skip(1), dataset.DoubleVariables.First());
    127147      problemData.Name = "Data imported from " + Path.GetFileName(fileName);
    128148      return problemData;
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs

    r6184 r6760  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
    2524using HeuristicLab.Common;
    26 using HeuristicLab.Data;
    27 using HeuristicLab.Optimization;
    2825using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2926
     
    3330  /// </summary>
    3431  [StorableClass]
    35   public class RegressionSolution : DataAnalysisSolution, IRegressionSolution {
    36     private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
    37     private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
    38     private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)";
    39     private const string TestSquaredCorrelationResultName = "Pearson's R² (test)";
    40     private const string TrainingRelativeErrorResultName = "Average relative error (training)";
    41     private const string TestRelativeErrorResultName = "Average relative error (test)";
    42     private const string TrainingNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (training)";
    43     private const string TestNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (test)";
     32  public abstract class RegressionSolution : RegressionSolutionBase {
     33    protected readonly Dictionary<int, double> evaluationCache;
    4434
    45     public new IRegressionModel Model {
    46       get { return (IRegressionModel)base.Model; }
    47       protected set { base.Model = value; }
     35    [StorableConstructor]
     36    protected RegressionSolution(bool deserializing)
     37      : base(deserializing) {
     38      evaluationCache = new Dictionary<int, double>();
     39    }
     40    protected RegressionSolution(RegressionSolution original, Cloner cloner)
     41      : base(original, cloner) {
     42      evaluationCache = new Dictionary<int, double>(original.evaluationCache);
     43    }
     44    protected RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
     45      : base(model, problemData) {
     46      evaluationCache = new Dictionary<int, double>();
    4847    }
    4948
    50     public new IRegressionProblemData ProblemData {
    51       get { return (IRegressionProblemData)base.ProblemData; }
    52       protected set { base.ProblemData = value; }
     49    protected override void RecalculateResults() {
     50      CalculateResults();
    5351    }
    5452
    55     public double TrainingMeanSquaredError {
    56       get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    57       protected set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
     53    public override IEnumerable<double> EstimatedValues {
     54      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     55    }
     56    public override IEnumerable<double> EstimatedTrainingValues {
     57      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
     58    }
     59    public override IEnumerable<double> EstimatedTestValues {
     60      get { return GetEstimatedValues(ProblemData.TestIndizes); }
    5861    }
    5962
    60     public double TestMeanSquaredError {
    61       get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    62       protected set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
     63    public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
     64      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
     65      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     66      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     67
     68      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     69        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     70      }
     71
     72      return rows.Select(row => evaluationCache[row]);
    6373    }
    6474
    65     public double TrainingRSquared {
    66       get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; }
    67       protected set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }
     75    protected override void OnProblemDataChanged() {
     76      evaluationCache.Clear();
     77      base.OnProblemDataChanged();
    6878    }
    6979
    70     public double TestRSquared {
    71       get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; }
    72       protected set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; }
    73     }
    74 
    75     public double TrainingRelativeError {
    76       get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; }
    77       protected set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; }
    78     }
    79 
    80     public double TestRelativeError {
    81       get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; }
    82       protected set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; }
    83     }
    84 
    85     public double TrainingNormalizedMeanSquaredError {
    86       get { return ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value; }
    87       protected set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    88     }
    89 
    90     public double TestNormalizedMeanSquaredError {
    91       get { return ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value; }
    92       protected set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    93     }
    94 
    95 
    96     [StorableConstructor]
    97     protected RegressionSolution(bool deserializing) : base(deserializing) { }
    98     protected RegressionSolution(RegressionSolution original, Cloner cloner)
    99       : base(original, cloner) {
    100     }
    101     public RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
    102       : base(model, problemData) {
    103       Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
    104       Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
    105       Add(new Result(TrainingSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
    106       Add(new Result(TestSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
    107       Add(new Result(TrainingRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the training partition", new PercentValue()));
    108       Add(new Result(TestRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the test partition", new PercentValue()));
    109       Add(new Result(TrainingNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the training partition", new DoubleValue()));
    110       Add(new Result(TestNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the test partition", new DoubleValue()));
    111 
    112       RecalculateResults();
    113     }
    114 
    115     public override IDeepCloneable Clone(Cloner cloner) {
    116       return new RegressionSolution(this, cloner);
    117     }
    118 
    119     protected override void OnProblemDataChanged(EventArgs e) {
    120       base.OnProblemDataChanged(e);
    121       RecalculateResults();
    122     }
    123     protected override void OnModelChanged(EventArgs e) {
    124       base.OnModelChanged(e);
    125       RecalculateResults();
    126     }
    127 
    128     private void RecalculateResults() {
    129       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    130       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    131       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    132       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    133 
    134       OnlineCalculatorError errorState;
    135       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    136       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    137       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    138       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    139 
    140       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    141       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    142       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    143       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    144 
    145       double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    146       TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;
    147       double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    148       TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;
    149 
    150       double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    151       TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN;
    152       double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    153       TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN;
    154     }
    155 
    156     public virtual IEnumerable<double> EstimatedValues {
    157       get {
    158         return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
    159       }
    160     }
    161 
    162     public virtual IEnumerable<double> EstimatedTrainingValues {
    163       get {
    164         return GetEstimatedValues(ProblemData.TrainingIndizes);
    165       }
    166     }
    167 
    168     public virtual IEnumerable<double> EstimatedTestValues {
    169       get {
    170         return GetEstimatedValues(ProblemData.TestIndizes);
    171       }
    172     }
    173 
    174     public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    175       return Model.GetEstimatedValues(ProblemData.Dataset, rows);
     80    protected override void OnModelChanged() {
     81      evaluationCache.Clear();
     82      base.OnModelChanged();
    17683    }
    17784  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleModel.cs

    r5809 r6760  
    2323namespace HeuristicLab.Problems.DataAnalysis {
    2424  public interface IClassificationEnsembleModel : IClassificationModel {
     25    void Add(IClassificationModel model);
     26    void Remove(IClassificationModel model);
    2527    IEnumerable<IClassificationModel> Models { get; }
    2628    IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows);
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleSolution.cs

    r6184 r6760  
    2121
    2222using System.Collections.Generic;
     23using HeuristicLab.Core;
    2324namespace HeuristicLab.Problems.DataAnalysis {
    2425  public interface IClassificationEnsembleSolution : IClassificationSolution {
    25     IEnumerable<IClassificationModel> Models { get; }
     26    new IClassificationEnsembleModel Model { get; }
     27    IItemCollection<IClassificationSolution> ClassificationSolutions { get; }
    2628    IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows);
    2729  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationModel.cs

    r5809 r6760  
    2424  public interface IClassificationModel : IDataAnalysisModel {
    2525    IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows);
     26    IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData);
    2627  }
    2728}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationSolution.cs

    r5809 r6760  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423namespace HeuristicLab.Problems.DataAnalysis {
    2524  public interface IClassificationSolution : IDataAnalysisSolution {
    2625    new IClassificationModel Model { get; }
    27     new IClassificationProblemData ProblemData { get; }
     26    new IClassificationProblemData ProblemData { get; set; }
    2827
    2928    IEnumerable<double> EstimatedClassValues { get; }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IDiscriminantFunctionClassificationModel.cs

    r5809 r6760  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    23 using System;
    2424namespace HeuristicLab.Problems.DataAnalysis {
    2525  public interface IDiscriminantFunctionClassificationModel : IClassificationModel {
    2626    IEnumerable<double> Thresholds { get; }
    2727    IEnumerable<double> ClassValues { get; }
    28     // class values and thresholds can only be assigned simultaniously
    29     void SetThresholdsAndClassValues(IEnumerable<double> thresholds, IEnumerable<double> classValues); 
     28    // class values and thresholds can only be assigned simultanously
     29    void SetThresholdsAndClassValues(IEnumerable<double> thresholds, IEnumerable<double> classValues);
    3030    IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows);
    3131
    3232    event EventHandler ThresholdsChanged;
     33
     34    IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData);
    3335  }
    3436}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Clustering/IClusteringSolution.cs

    r5809 r6760  
    2424  public interface IClusteringSolution : IDataAnalysisSolution {
    2525    new IClusteringModel Model { get; }
    26     new IClusteringProblemData ProblemData { get; }
     26    new IClusteringProblemData ProblemData { get; set; }
    2727
    2828    IEnumerable<int> ClusterValues { get; }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/IDataAnalysisProblemData.cs

    r5883 r6760  
    2727namespace HeuristicLab.Problems.DataAnalysis {
    2828  public interface IDataAnalysisProblemData : INamedItem {
     29    bool IsEmpty { get; }
     30
    2931    Dataset Dataset { get; }
    3032    ICheckedItemList<StringValue> InputVariables { get; }
     
    3739    IEnumerable<int> TestIndizes { get; }
    3840
     41    bool IsTrainingSample(int index);
     42    bool IsTestSample(int index);
     43
    3944    event EventHandler Changed;
    4045  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/IDataAnalysisSolution.cs

    r5909 r6760  
    2121
    2222using System;
     23using HeuristicLab.Common;
    2324using HeuristicLab.Core;
    24 using HeuristicLab.Common;
    2525
    2626namespace HeuristicLab.Problems.DataAnalysis {
    2727  public interface IDataAnalysisSolution : INamedItem, IStorableContent {
    2828    IDataAnalysisModel Model { get; }
    29     IDataAnalysisProblemData ProblemData { get; }
     29    IDataAnalysisProblemData ProblemData { get; set; }
    3030
    3131    event EventHandler ModelChanged;
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Regression/IRegressionEnsembleModel.cs

    r5809 r6760  
    2323namespace HeuristicLab.Problems.DataAnalysis {
    2424  public interface IRegressionEnsembleModel : IRegressionModel {
     25    void Add(IRegressionModel model);
     26    void Remove(IRegressionModel model);
    2527    IEnumerable<IRegressionModel> Models { get; }
    2628    IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(Dataset dataset, IEnumerable<int> rows);
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Regression/IRegressionEnsembleSolution.cs

    r6184 r6760  
    2424namespace HeuristicLab.Problems.DataAnalysis {
    2525  public interface IRegressionEnsembleSolution : IRegressionSolution {
    26     IRegressionEnsembleModel Model { get; }
     26    new IRegressionEnsembleModel Model { get; }
     27    new RegressionEnsembleProblemData ProblemData { get; set; }
     28    IItemCollection<IRegressionSolution> RegressionSolutions { get; }
    2729    IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(Dataset dataset, IEnumerable<int> rows);
    2830  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Regression/IRegressionModel.cs

    r5809 r6760  
    2424  public interface IRegressionModel : IDataAnalysisModel {
    2525    IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows);
     26    IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData);
    2627  }
    2728}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Regression/IRegressionSolution.cs

    r5829 r6760  
    2424  public interface IRegressionSolution : IDataAnalysisSolution {
    2525    new IRegressionModel Model { get; }
    26     new IRegressionProblemData ProblemData { get; }
     26    new IRegressionProblemData ProblemData { get; set; }
    2727
    2828    IEnumerable<double> EstimatedValues { get; }
     
    3333    double TrainingMeanSquaredError { get; }
    3434    double TestMeanSquaredError { get; }
     35    double TrainingMeanAbsoluteError { get; }
     36    double TestMeanAbsoluteError { get; }
    3537    double TrainingRSquared { get; }
    3638    double TestRSquared { get; }
    3739    double TrainingRelativeError { get; }
    3840    double TestRelativeError { get; }
     41    double TrainingNormalizedMeanSquaredError { get; }
     42    double TestNormalizedMeanSquaredError { get; }
    3943  }
    4044}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Properties/AssemblyInfo.frame

    r5860 r6760  
    5353// by using the '*' as shown below:
    5454[assembly: AssemblyVersion("3.4.0.0")]
    55 [assembly: AssemblyFileVersion("3.4.0.$WCREV$")]
     55[assembly: AssemblyFileVersion("3.4.1.$WCREV$")]
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/TableFileParser.cs

    r5809 r6760  
    2121
    2222using System;
     23using System.Collections;
    2324using System.Collections.Generic;
    2425using System.Globalization;
     
    3334    private readonly char[] POSSIBLE_SEPARATORS = new char[] { ',', ';', '\t' };
    3435    private Tokenizer tokenizer;
    35     private List<List<double>> rowValues;
     36    private List<List<object>> rowValues;
    3637
    3738    private int rows;
     
    4748    }
    4849
    49     private double[,] values;
    50     public double[,] Values {
     50    private List<IList> values;
     51    public List<IList> Values {
    5152      get {
    5253        return values;
     
    6970
    7071    public TableFileParser() {
    71       rowValues = new List<List<double>>();
     72      rowValues = new List<List<object>>();
    7273      variableNames = new List<string>();
    7374    }
     
    7576    public void Parse(string fileName) {
    7677      NumberFormatInfo numberFormat;
     78      DateTimeFormatInfo dateTimeFormatInfo;
    7779      char separator;
    78       DetermineFileFormat(fileName, out numberFormat, out separator);
     80      DetermineFileFormat(fileName, out numberFormat, out dateTimeFormatInfo, out separator);
    7981      using (StreamReader reader = new StreamReader(fileName)) {
    80         tokenizer = new Tokenizer(reader, numberFormat, separator);
     82        tokenizer = new Tokenizer(reader, numberFormat, dateTimeFormatInfo, separator);
    8183        // parse the file
    8284        Parse();
     
    8688      rows = rowValues.Count;
    8789      columns = rowValues[0].Count;
    88       values = new double[rows, columns];
    89 
    90       int rowIndex = 0;
    91       int columnIndex = 0;
    92       foreach (List<double> row in rowValues) {
    93         columnIndex = 0;
    94         foreach (double element in row) {
    95           values[rowIndex, columnIndex++] = element;
    96         }
    97         rowIndex++;
    98       }
    99     }
    100 
    101     private void DetermineFileFormat(string fileName, out NumberFormatInfo numberFormat, out char separator) {
     90      values = new List<IList>();
     91
     92      //create columns
     93      for (int col = 0; col < columns; col++) {
     94        var types = rowValues.Select(r => r[col]).Where(v => v != null && v as string != string.Empty).Take(10).Select(v => v.GetType());
     95        if (!types.Any()) {
     96          values.Add(new List<string>());
     97          continue;
     98        }
     99
     100        var columnType = types.GroupBy(v => v).OrderBy(v => v).Last().Key;
     101        if (columnType == typeof(double)) values.Add(new List<double>());
     102        else if (columnType == typeof(DateTime)) values.Add(new List<DateTime>());
     103        else if (columnType == typeof(string)) values.Add(new List<string>());
     104        else throw new InvalidOperationException();
     105      }
     106
     107
     108
     109      //fill with values
     110      foreach (List<object> row in rowValues) {
     111        int columnIndex = 0;
     112        foreach (object element in row) {
     113          //handle missing values with default values
     114          if (element as string == string.Empty) {
     115            if (values[columnIndex] is List<double>) values[columnIndex].Add(double.NaN);
     116            else if (values[columnIndex] is List<DateTime>) values[columnIndex].Add(DateTime.MinValue);
     117            else if (values[columnIndex] is List<string>) values[columnIndex].Add(string.Empty);
     118            else throw new InvalidOperationException();
     119          } else values[columnIndex].Add(element);
     120          columnIndex++;
     121        }
     122      }
     123    }
     124
     125    private void DetermineFileFormat(string fileName, out NumberFormatInfo numberFormat, out DateTimeFormatInfo dateTimeFormatInfo, out char separator) {
    102126      using (StreamReader reader = new StreamReader(fileName)) {
    103127        // skip first line
     
    123147        if (OccurrencesOf(charCounts, '.') > 10) {
    124148          numberFormat = NumberFormatInfo.InvariantInfo;
     149          dateTimeFormatInfo = DateTimeFormatInfo.InvariantInfo;
    125150          separator = POSSIBLE_SEPARATORS
    126151            .Where(c => OccurrencesOf(charCounts, c) > 10)
     
    139164            // English format (only integer values) with ',' as separator
    140165            numberFormat = NumberFormatInfo.InvariantInfo;
     166            dateTimeFormatInfo = DateTimeFormatInfo.InvariantInfo;
    141167            separator = ',';
    142168          } else {
     
    144170            // German format (real values)
    145171            numberFormat = NumberFormatInfo.GetInstance(new CultureInfo("de-DE"));
     172            dateTimeFormatInfo = DateTimeFormatInfo.GetInstance(new CultureInfo("de-DE"));
    146173            separator = POSSIBLE_SEPARATORS
    147174              .Except(disallowedSeparators)
     
    154181          // no points and no commas => English format
    155182          numberFormat = NumberFormatInfo.InvariantInfo;
     183          dateTimeFormatInfo = DateTimeFormatInfo.InvariantInfo;
    156184          separator = POSSIBLE_SEPARATORS
    157185            .Where(c => OccurrencesOf(charCounts, c) > 10)
     
    169197    #region tokenizer
    170198    internal enum TokenTypeEnum {
    171       NewLine, Separator, String, Double
     199      NewLine, Separator, String, Double, DateTime
    172200    }
    173201
     
    176204      public string stringValue;
    177205      public double doubleValue;
     206      public DateTime dateTimeValue;
    178207
    179208      public Token(TokenTypeEnum type, string value) {
    180209        this.type = type;
    181210        stringValue = value;
     211        dateTimeValue = DateTime.MinValue;
    182212        doubleValue = 0.0;
    183213      }
     
    193223      private List<Token> tokens;
    194224      private NumberFormatInfo numberFormatInfo;
     225      private DateTimeFormatInfo dateTimeFormatInfo;
    195226      private char separator;
    196227      private const string INTERNAL_SEPARATOR = "#";
     
    218249      }
    219250
    220       public Tokenizer(StreamReader reader, NumberFormatInfo numberFormatInfo, char separator) {
     251      public Tokenizer(StreamReader reader, NumberFormatInfo numberFormatInfo, DateTimeFormatInfo dateTimeFormatInfo, char separator) {
    221252        this.reader = reader;
    222253        this.numberFormatInfo = numberFormatInfo;
     254        this.dateTimeFormatInfo = dateTimeFormatInfo;
    223255        this.separator = separator;
    224256        separatorToken = new Token(TokenTypeEnum.Separator, INTERNAL_SEPARATOR);
     
    264296          token.type = TokenTypeEnum.Double;
    265297          return token;
    266         }
    267 
    268         // couldn't parse the token as an int or float number so return a string token
     298        } else if (DateTime.TryParse(strToken, out token.dateTimeValue)) {
     299          token.type = TokenTypeEnum.DateTime;
     300          return token;
     301        }
     302
     303        // couldn't parse the token as an int or float number  or datetime value so return a string token
    269304        return token;
    270305      }
     
    299334    private void ParseValues() {
    300335      while (tokenizer.HasNext()) {
    301         List<double> row = new List<double>();
    302         row.Add(NextValue(tokenizer));
    303         while (tokenizer.HasNext() && tokenizer.Peek() == tokenizer.SeparatorToken) {
    304           Expect(tokenizer.SeparatorToken);
    305           row.Add(NextValue(tokenizer));
    306         }
    307         Expect(tokenizer.NewlineToken);
    308         // all rows have to have the same number of values           
    309         // the first row defines how many samples are needed
    310         if (rowValues.Count > 0 && rowValues[0].Count != row.Count) {
    311           Error("The first row of the dataset has " + rowValues[0].Count + " columns." +
    312             "\nLine " + tokenizer.CurrentLineNumber + " has " + row.Count + " columns.", "", tokenizer.CurrentLineNumber);
    313         }
    314         // add the current row to the collection of rows and start a new row
    315         rowValues.Add(row);
    316         row = new List<double>();
    317       }
    318     }
    319 
    320     private double NextValue(Tokenizer tokenizer) {
    321       if (tokenizer.Peek() == tokenizer.SeparatorToken || tokenizer.Peek() == tokenizer.NewlineToken) return double.NaN;
     336        if (tokenizer.Peek() == tokenizer.NewlineToken) {
     337          tokenizer.Next();
     338        } else {
     339          List<object> row = new List<object>();
     340          object value = NextValue(tokenizer);
     341          row.Add(value);
     342          while (tokenizer.HasNext() && tokenizer.Peek() == tokenizer.SeparatorToken) {
     343            Expect(tokenizer.SeparatorToken);
     344            row.Add(NextValue(tokenizer));
     345          }
     346          Expect(tokenizer.NewlineToken);
     347          // all rows have to have the same number of values           
     348          // the first row defines how many samples are needed
     349          if (rowValues.Count > 0 && rowValues[0].Count != row.Count) {
     350            Error("The first row of the dataset has " + rowValues[0].Count + " columns." +
     351                  "\nLine " + tokenizer.CurrentLineNumber + " has " + row.Count + " columns.", "",
     352                  tokenizer.CurrentLineNumber);
     353          }
     354          rowValues.Add(row);
     355        }
     356      }
     357    }
     358
     359    private object NextValue(Tokenizer tokenizer) {
     360      if (tokenizer.Peek() == tokenizer.SeparatorToken || tokenizer.Peek() == tokenizer.NewlineToken) return string.Empty;
    322361      Token current = tokenizer.Next();
    323       if (current.type == TokenTypeEnum.Separator || current.type == TokenTypeEnum.String) {
     362      if (current.type == TokenTypeEnum.Separator) {
    324363        return double.NaN;
     364      } else if (current.type == TokenTypeEnum.String) {
     365        return current.stringValue;
    325366      } else if (current.type == TokenTypeEnum.Double) {
    326         // just take the value
    327367        return current.doubleValue;
     368      } else if (current.type == TokenTypeEnum.DateTime) {
     369        return current.dateTimeValue;
    328370      }
    329371      // found an unexpected token => throw error
     
    334376
    335377    private void ParseVariableNames() {
    336       // if the first line doesn't start with a double value then we assume that the
    337       // first line contains variable names
    338       if (tokenizer.HasNext() && tokenizer.Peek().type != TokenTypeEnum.Double) {
    339 
    340         List<Token> tokens = new List<Token>();
    341         Token valueToken;
     378      //if first token is double no variables names are given
     379      if (tokenizer.Peek().type == TokenTypeEnum.Double) return;
     380
     381      // the first line must contain variable names
     382      List<Token> tokens = new List<Token>();
     383      Token valueToken;
     384      valueToken = tokenizer.Next();
     385      tokens.Add(valueToken);
     386      while (tokenizer.HasNext() && tokenizer.Peek() == tokenizer.SeparatorToken) {
     387        Expect(tokenizer.SeparatorToken);
    342388        valueToken = tokenizer.Next();
    343         tokens.Add(valueToken);
    344         while (tokenizer.HasNext() && tokenizer.Peek() == tokenizer.SeparatorToken) {
    345           Expect(tokenizer.SeparatorToken);
    346           valueToken = tokenizer.Next();
    347           if (valueToken != tokenizer.NewlineToken) {
    348             tokens.Add(valueToken);
    349           }
    350         }
    351389        if (valueToken != tokenizer.NewlineToken) {
    352           Expect(tokenizer.NewlineToken);
    353         }
    354         variableNames = tokens.Select(x => x.stringValue.Trim()).ToList();
    355       }
     390          tokens.Add(valueToken);
     391        }
     392      }
     393      if (valueToken != tokenizer.NewlineToken) {
     394        Expect(tokenizer.NewlineToken);
     395      }
     396      variableNames = tokens.Select(x => x.stringValue.Trim()).ToList();
    356397    }
    357398
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Tests/OnlineCalculatorPerformanceTest.cs

    r5963 r6760  
    8080      watch.Start();
    8181      for (int i = 0; i < Repetitions; i++) {
    82         double value = calculateFunc(dataset.GetEnumeratedVariableValues(0), dataset.GetEnumeratedVariableValues(1), out errorState);
     82        double value = calculateFunc(dataset.GetDoubleValues("y"), dataset.GetDoubleValues("x0"), out errorState);
    8383      }
    8484      Assert.AreEqual(errorState, OnlineCalculatorError.None);
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Tests/Properties/AssemblyInfo.cs

    r5809 r6760  
    5252// by using the '*' as shown below:
    5353[assembly: AssemblyVersion("3.4.0.0")]
    54 [assembly: AssemblyFileVersion("3.4.0.0")]
     54[assembly: AssemblyFileVersion("3.4.1.0")]
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Tests/StatisticCalculatorsTest.cs

    r6184 r6760  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
     
    124123        IEnumerable<double> x = from rows in Enumerable.Range(0, n)
    125124                                select testData[rows, c1] * c1Scale;
    126         IEnumerable<double> y = (new List<double>() { 150494407424305.44 })
     125        IEnumerable<double> y = (new List<double>() { 150494407424305.47 })
    127126          .Concat(Enumerable.Repeat(150494407424305.47, n - 1));
    128127        double[] xs = x.ToArray();
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Tests/TableFileParserTest.cs

    r5809 r6760  
    2121
    2222using System;
    23 using System.Collections.Generic;
    24 using System.Linq;
    25 using Microsoft.VisualStudio.TestTools.UnitTesting;
    2623using System.IO;
    2724using HeuristicLab.Problems.DataAnalysis;
     25using Microsoft.VisualStudio.TestTools.UnitTesting;
    2826namespace HeuristicLab.Problems.DataAnalysis_3_4.Tests {
    2927
     
    4644        Assert.AreEqual(6, parser.Rows);
    4745        Assert.AreEqual(4, parser.Columns);
    48         Assert.AreEqual(parser.Values[0, 3], 3.14);
     46        Assert.AreEqual(parser.Values[3][0], 3.14);
    4947      }
    5048      finally {
     
    6866        Assert.AreEqual(6, parser.Rows);
    6967        Assert.AreEqual(4, parser.Columns);
    70         Assert.AreEqual(parser.Values[0, 3], 3.14);
     68        Assert.AreEqual(parser.Values[3][0], 3.14);
    7169      }
    7270      finally {
     
    9088        Assert.AreEqual(6, parser.Rows);
    9189        Assert.AreEqual(4, parser.Columns);
    92         Assert.AreEqual(parser.Values[0, 3], 3.14);
     90        Assert.AreEqual(parser.Values[3][0], 3.14);
    9391      }
    9492      finally {
     
    113111        Assert.AreEqual(6, parser.Rows);
    114112        Assert.AreEqual(4, parser.Columns);
    115         Assert.AreEqual(parser.Values[0, 3], 3.14);
     113        Assert.AreEqual(parser.Values[3][0], 3.14);
    116114      }
    117115      finally {
     
    135133        Assert.AreEqual(6, parser.Rows);
    136134        Assert.AreEqual(4, parser.Columns);
    137         Assert.AreEqual(parser.Values[0, 3], 3);
     135        Assert.AreEqual((double)parser.Values[3][0], 3);
    138136      }
    139137      finally {
     
    157155        Assert.AreEqual(6, parser.Rows);
    158156        Assert.AreEqual(4, parser.Columns);
    159         Assert.AreEqual(parser.Values[0, 3], 3);
     157        Assert.AreEqual((double)parser.Values[3][0], 3);
    160158      }
    161159      finally {
     
    179177        Assert.AreEqual(6, parser.Rows);
    180178        Assert.AreEqual(4, parser.Columns);
    181         Assert.AreEqual(parser.Values[0, 3], 3);
     179        Assert.AreEqual((double)parser.Values[3][0], 3);
    182180      }
    183181      finally {
     
    202200        Assert.AreEqual(6, parser.Rows);
    203201        Assert.AreEqual(4, parser.Columns);
    204         Assert.AreEqual(parser.Values[0, 3], 3);
     202        Assert.AreEqual((double)parser.Values[3][0], 3);
    205203      }
    206204      finally {
     
    225223        Assert.AreEqual(6, parser.Rows);
    226224        Assert.AreEqual(4, parser.Columns);
    227         Assert.AreEqual(parser.Values[0, 3], 3.14);
     225        Assert.AreEqual((double)parser.Values[3][0], 3.14);
    228226      }
    229227      finally {
     
    248246        Assert.AreEqual(6, parser.Rows);
    249247        Assert.AreEqual(4, parser.Columns);
    250         Assert.AreEqual(parser.Values[0, 3], 3.14);
     248        Assert.AreEqual((double)parser.Values[3][0], 3.14);
    251249      }
    252250      finally {
     
    270268        Assert.AreEqual(6, parser.Rows);
    271269        Assert.AreEqual(4, parser.Columns);
    272         Assert.AreEqual(parser.Values[0, 3], 3.14);
     270        Assert.AreEqual((double)parser.Values[3][0], 3.14);
    273271      }
    274272      finally {
     
    292290        Assert.AreEqual(6, parser.Rows);
    293291        Assert.AreEqual(4, parser.Columns);
    294         Assert.AreEqual(parser.Values[0, 3], 3.14);
     292        Assert.AreEqual((double)parser.Values[3][0], 3.14);
    295293      }
    296294      finally {
     
    314312        Assert.AreEqual(6, parser.Rows);
    315313        Assert.AreEqual(4, parser.Columns);
    316         Assert.AreEqual(parser.Values[0, 3], 3);
     314        Assert.AreEqual((double)parser.Values[3][0], 3);
    317315      }
    318316      finally {
     
    336334        Assert.AreEqual(6, parser.Rows);
    337335        Assert.AreEqual(4, parser.Columns);
    338         Assert.AreEqual(parser.Values[0, 3], 3);
     336        Assert.AreEqual((double)parser.Values[3][0], 3);
     337      }
     338      finally {
     339        File.Delete(tempFileName);
     340      }
     341    }
     342
     343    [TestMethod]
     344    public void ParseWithEmtpyLines() {
     345      string tempFileName = Path.GetTempFileName();
     346      WriteToFile(tempFileName,
     347"x01\t x02\t x03\t x04" + Environment.NewLine +
     348"0\t 0\t 0\t 3" + Environment.NewLine +
     349 Environment.NewLine +
     350"0\t 0\t 0\t 0" + Environment.NewLine +
     351" " + Environment.NewLine +
     352"0\t 0\t 0\t 0" + Environment.NewLine +
     353"0\t 0\t 0\t 0" + Environment.NewLine + Environment.NewLine);
     354      TableFileParser parser = new TableFileParser();
     355      try {
     356        parser.Parse(tempFileName);
     357        Assert.AreEqual(4, parser.Rows);
     358        Assert.AreEqual(4, parser.Columns);
    339359      }
    340360      finally {
     
    358378        Assert.AreEqual(6, parser.Rows);
    359379        Assert.AreEqual(4, parser.Columns);
    360         Assert.AreEqual(parser.Values[0, 3], 3.14);
     380        Assert.AreEqual((double)parser.Values[3][0], 3.14);
    361381      }
    362382      finally {
     
    380400        Assert.AreEqual(6, parser.Rows);
    381401        Assert.AreEqual(4, parser.Columns);
    382         Assert.AreEqual(parser.Values[0, 3], 3.14);
     402        Assert.AreEqual((double)parser.Values[3][0], 3.14);
    383403      }
    384404      finally {
     
    402422        Assert.AreEqual(6, parser.Rows);
    403423        Assert.AreEqual(4, parser.Columns);
    404         Assert.AreEqual(parser.Values[0, 3], 3.14);
     424        Assert.AreEqual((double)parser.Values[3][0], 3.14);
    405425      }
    406426      finally {
     
    424444        Assert.AreEqual(6, parser.Rows);
    425445        Assert.AreEqual(4, parser.Columns);
    426         Assert.AreEqual(parser.Values[0, 3], 3.14);
     446        Assert.AreEqual((double)parser.Values[3][0], 3.14);
    427447      }
    428448      finally {
     
    446466        Assert.AreEqual(6, parser.Rows);
    447467        Assert.AreEqual(4, parser.Columns);
    448         Assert.AreEqual(parser.Values[0, 3], 3);
     468        Assert.AreEqual((double)parser.Values[3][0], 3);
    449469      }
    450470      finally {
     
    468488        Assert.AreEqual(6, parser.Rows);
    469489        Assert.AreEqual(4, parser.Columns);
    470         Assert.AreEqual(parser.Values[0, 3], 3);
     490        Assert.AreEqual((double)parser.Values[3][0], 3);
    471491      }
    472492      finally {
Note: See TracChangeset for help on using the changeset viewer.