Changeset 15614


Ignore:
Timestamp:
01/15/18 08:21:48 (18 months ago)
Author:
bwerth
Message:

#2847 made changes to M5 according to review comments

Location:
branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis
Files:
11 added
12 deleted
32 edited
1 copied

Legend:

Unmodified
Added
Removed
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis

  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4

    • Property svn:mergeinfo set to (toggle deleted branches)
      /stable/HeuristicLab.Algorithms.DataAnalysis/3.4mergedeligible
      /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4mergedeligible
      /branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.410321-10322
      /branches/Async/HeuristicLab.Algorithms.DataAnalysis/3.413329-15286
      /branches/Benchmarking/sources/HeuristicLab.Algorithms.DataAnalysis/3.46917-7005
      /branches/ClassificationModelComparison/HeuristicLab.Algorithms.DataAnalysis/3.49070-13099
      /branches/CloningRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.44656-4721
      /branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.45471-5808
      /branches/DataAnalysis SolutionEnsembles/HeuristicLab.Algorithms.DataAnalysis/3.45815-6180
      /branches/DataAnalysis/HeuristicLab.Algorithms.DataAnalysis/3.44458-4459,​4462,​4464
      /branches/DataPreprocessing/HeuristicLab.Algorithms.DataAnalysis/3.410085-11101
      /branches/GP.Grammar.Editor/HeuristicLab.Algorithms.DataAnalysis/3.46284-6795
      /branches/GP.Symbols (TimeLag, Diff, Integral)/HeuristicLab.Algorithms.DataAnalysis/3.45060
      /branches/HeuristicLab.DatasetRefactor/sources/HeuristicLab.Algorithms.DataAnalysis/3.411570-12508
      /branches/HeuristicLab.Problems.Orienteering/HeuristicLab.Algorithms.DataAnalysis/3.411130-12721
      /branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.413819-14091
      /branches/HeuristicLab.TimeSeries/HeuristicLab.Algorithms.DataAnalysis/3.48116-8789
      /branches/LogResidualEvaluator/HeuristicLab.Algorithms.DataAnalysis/3.410202-10483
      /branches/NET40/sources/HeuristicLab.Algorithms.DataAnalysis/3.45138-5162
      /branches/ParallelEngine/HeuristicLab.Algorithms.DataAnalysis/3.45175-5192
      /branches/ProblemInstancesRegressionAndClassification/HeuristicLab.Algorithms.DataAnalysis/3.47773-7810
      /branches/QAPAlgorithms/HeuristicLab.Algorithms.DataAnalysis/3.46350-6627
      /branches/Restructure trunk solution/HeuristicLab.Algorithms.DataAnalysis/3.46828
      /branches/SpectralKernelForGaussianProcesses/HeuristicLab.Algorithms.DataAnalysis/3.410204-10479
      /branches/SuccessProgressAnalysis/HeuristicLab.Algorithms.DataAnalysis/3.45370-5682
      /branches/Trunk/HeuristicLab.Algorithms.DataAnalysis/3.46829-6865
      /branches/VNS/HeuristicLab.Algorithms.DataAnalysis/3.45594-5752
      /branches/Weighted TSNE/3.415451-15531
      /branches/histogram/HeuristicLab.Algorithms.DataAnalysis/3.45959-6341
      /branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.414232-14825
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessRegression.cs

    r15430 r15614  
    3737  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 160)]
    3838  [StorableClass]
    39   public sealed class GaussianProcessRegression : GaussianProcessBase, IStorableContent {
     39  public sealed class GaussianProcessRegression : GaussianProcessBase, IStorableContent, IDataAnalysisAlgorithm<IRegressionProblem> {
    4040    public string Filename { get; set; }
    4141
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r15470 r15614  
    143143      <SpecificVersion>False</SpecificVersion>
    144144      <HintPath>..\..\..\..\trunk\sources\bin\HeuristicLab.Data-3.3.dll</HintPath>
     145    </Reference>
     146    <Reference Include="HeuristicLab.Encodings.PermutationEncoding-3.3, Version=3.3.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
     147      <SpecificVersion>False</SpecificVersion>
     148      <HintPath>..\..\..\..\trunk\sources\bin\HeuristicLab.Encodings.PermutationEncoding-3.3.dll</HintPath>
    145149    </Reference>
    146150    <Reference Include="HeuristicLab.Encodings.RealVectorEncoding-3.3, Version=3.3.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
     
    361365    <Compile Include="Linear\MultinomialLogitModel.cs" />
    362366    <Compile Include="Linear\Scaling.cs" />
    363     <Compile Include="M5Regression\Interfaces\ISplitType.cs" />
    364     <Compile Include="M5Regression\Interfaces\IM5MetaModel.cs" />
    365     <Compile Include="M5Regression\Interfaces\ILeafType.cs" />
    366     <Compile Include="M5Regression\Interfaces\IPruningType.cs" />
     367    <Compile Include="M5Regression\Interfaces\ISpliter.cs" />
     368    <Compile Include="M5Regression\Interfaces\IM5Model.cs" />
     369    <Compile Include="M5Regression\Interfaces\ILeafModel.cs" />
     370    <Compile Include="M5Regression\Interfaces\IPruning.cs" />
    367371    <Compile Include="M5Regression\LeafTypes\ComplexLeaf.cs" />
    368372    <Compile Include="M5Regression\LeafTypes\ComponentReductionLinearLeaf.cs" />
     
    374378    <Compile Include="M5Regression\M5Utilities\M5StaticUtilities.cs" />
    375379    <Compile Include="M5Regression\M5Utilities\M5Analyzer.cs" />
    376     <Compile Include="M5Regression\M5Utilities\M5CreationParameters.cs" />
    377     <Compile Include="M5Regression\M5Utilities\M5UpdateParameters.cs" />
     380    <Compile Include="M5Regression\M5Utilities\M5Parameters.cs" />
    378381    <Compile Include="M5Regression\MetaModels\ComponentReducedLinearModel.cs" />
    379382    <Compile Include="M5Regression\MetaModels\M5NodeModel.cs" />
     
    383386    <Compile Include="M5Regression\MetaModels\DampenedLinearModel.cs" />
    384387    <Compile Include="M5Regression\MetaModels\PreconstructedLinearModel.cs" />
    385     <Compile Include="M5Regression\Pruning\HoldoutLinearPruning.cs" />
    386     <Compile Include="M5Regression\Pruning\HoldoutLeafPruning.cs" />
    387     <Compile Include="M5Regression\Pruning\M5LinearPruning.cs" />
    388     <Compile Include="M5Regression\Pruning\PruningBase.cs" />
     388    <Compile Include="M5Regression\Pruning\M5LinearBottomUpPruning.cs" />
     389    <Compile Include="M5Regression\Pruning\BottomUpPruningBase.cs" />
    389390    <Compile Include="M5Regression\Pruning\NoPruning.cs" />
    390     <Compile Include="M5Regression\Pruning\M5LeafPruning.cs" />
     391    <Compile Include="M5Regression\Pruning\M5LeafBottomUpPruning.cs" />
    391392    <Compile Include="M5Regression\Spliting\OrderImpurityCalculator.cs" />
    392     <Compile Include="M5Regression\Spliting\OrderSplitType.cs" />
     393    <Compile Include="M5Regression\Spliting\OptimumSearchingSpliter.cs" />
     394    <Compile Include="M5Regression\Spliting\M5Spliter.cs" />
    393395    <Compile Include="Nca\Initialization\INcaInitializer.cs" />
    394396    <Compile Include="Nca\Initialization\LdaInitializer.cs" />
     
    448450    <Compile Include="TSNE\Distances\IndexedItemDistance.cs" />
    449451    <Compile Include="TSNE\Distances\ManhattanDistance.cs" />
     452    <Compile Include="TSNE\Distances\WeightedEuclideanDistance.cs" />
    450453    <Compile Include="TSNE\Distances\IDistance.cs" />
    451454    <Compile Include="TSNE\PriorityQueue.cs" />
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/LeafTypes/ComplexLeaf.cs

    r15430 r15614  
    3232  [StorableClass]
    3333  [Item("ComplexLeaf", "A leaf type that uses an arbitriary RegressionAlgorithm to create leaf models")]
    34   public class ComplexLeaf : ParameterizedNamedItem, ILeafType<IRegressionModel> {
     34  public class ComplexLeaf : ParameterizedNamedItem, ILeafModel {
    3535    public const string RegressionParameterName = "Regression";
    3636    public IValueParameter<IDataAnalysisAlgorithm<IRegressionProblem>> RegressionParameter {
     
    5555
    5656    #region IModelType
    57     public IRegressionModel BuildModel(IRegressionProblemData pd, IRandom random, CancellationToken cancellation, out int noParameters) {
     57    public bool ProvidesConfidence {
     58      get { return false; }
     59    }
     60    public IRegressionModel Build(IRegressionProblemData pd, IRandom random, CancellationToken cancellationToken, out int noParameters) {
    5861      if (pd.Dataset.Rows < MinLeafSize(pd)) throw new ArgumentException("The number of training instances is too small to create a linear model");
    5962      noParameters = pd.Dataset.Rows + 1;
    6063      Regression.Problem = new RegressionProblem {ProblemData = pd};
    61       var res = M5StaticUtilities.RunSubAlgorithm(Regression, random.Next(), cancellation);
     64      var res = M5StaticUtilities.RunSubAlgorithm(Regression, random.Next(), cancellationToken);
    6265      var t = res.Select(x => x.Value).OfType<IRegressionSolution>().FirstOrDefault();
    6366      if (t == null) throw new ArgumentException("No RegressionSolution was provided by the algorithm");
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/LeafTypes/ComponentReductionLinearLeaf.cs

    r15470 r15614  
    3434  [StorableClass]
    3535  [Item("ComponentReductionLinearLeaf", "A leaf type that uses principle component analysis to create smaller linear models as leaf models")]
    36   public class ComponentReductionLinearLeaf : ParameterizedNamedItem, ILeafType<IConfidenceRegressionModel> {
     36  public class ComponentReductionLinearLeaf : ParameterizedNamedItem, ILeafModel {
    3737    public const string NoComponentsParameterName = "NoComponents";
    3838    public IFixedValueParameter<IntValue> NoComponentsParameter {
     
    5656
    5757    #region IModelType
    58     public IConfidenceRegressionModel BuildModel(IRegressionProblemData pd, IRandom random,
    59       CancellationToken cancellation, out int noParameters) {
     58    public bool ProvidesConfidence {
     59      get { return true; }
     60    }
     61    public IRegressionModel Build(IRegressionProblemData pd, IRandom random,
     62      CancellationToken cancellationToken, out int noParameters) {
    6063      var pca = PrincipleComponentTransformation.CreateProjection(pd.Dataset, pd.TrainingIndices, pd.AllowedInputVariables, true);
    6164      var pcdata = pca.TransformProblemData(pd);
     
    6467      noParameters = 1;
    6568      for (var i = 1; i <= Math.Min(NoComponents, pd.AllowedInputVariables.Count()); i++) {
    66         var pd2 = (IRegressionProblemData) pcdata.Clone();
     69        var pd2 = (IRegressionProblemData)pcdata.Clone();
    6770        var inputs = new HashSet<string>(pca.ComponentNames.Take(i));
    6871        foreach (var v in pd2.InputVariables.CheckedItems.ToArray())
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/LeafTypes/ConstantLeaf.cs

    r15430 r15614  
    3131  [StorableClass]
    3232  [Item("ConstantLeaf", "A leaf type that uses constant models as leaf models")]
    33   public class ConstantLeaf : ParameterizedNamedItem, ILeafType<IRegressionModel> {
     33  public class ConstantLeaf : ParameterizedNamedItem, ILeafModel {
    3434    #region Constructors & Cloning
    3535    [StorableConstructor]
     
    4343
    4444    #region IModelType
    45     public IRegressionModel BuildModel(IRegressionProblemData pd, IRandom random, CancellationToken cancellation, out int noParameters) {
     45    public bool ProvidesConfidence {
     46      get { return false; }
     47    }
     48    public IRegressionModel Build(IRegressionProblemData pd, IRandom random, CancellationToken cancellationToken, out int noParameters) {
    4649      if (pd.Dataset.Rows < MinLeafSize(pd)) throw new ArgumentException("The number of training instances is too small to create a linear model");
    4750      noParameters = 1;
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/LeafTypes/GaussianProcessLeaf.cs

    r15430 r15614  
    3333  [StorableClass]
    3434  [Item("GaussianProcessLeaf", "A leaf type that uses gaussian process models as leaf models.")]
    35   public class GaussianProcessLeaf : ParameterizedNamedItem, ILeafType<IGaussianProcessModel> {
     35  public class GaussianProcessLeaf : ParameterizedNamedItem, ILeafModel {
    3636    #region ParameterNames
    3737    public const string TriesParameterName = "Tries";
     
    7575
    7676    #region IModelType
    77     public IGaussianProcessModel BuildModel(IRegressionProblemData pd, IRandom random, CancellationToken cancellation, out int noParameters) {
    78       if (pd.Dataset.Rows < MinLeafSize(pd)) throw new ArgumentException("The number of training instances is too small to create a linear model");
     77    public bool ProvidesConfidence {
     78      get { return true; }
     79    }
     80    public IRegressionModel Build(IRegressionProblemData pd, IRandom random, CancellationToken cancellationToken, out int noParameters) {
     81      if (pd.Dataset.Rows < MinLeafSize(pd)) throw new ArgumentException("The number of training instances is too small to create a gaussian process model");
    7982      Regression.Problem = new RegressionProblem {ProblemData = pd};
    8083      var cvscore = double.MaxValue;
     
    8285
    8386      for (var i = 0; i < Tries; i++) {
    84         var res = M5StaticUtilities.RunSubAlgorithm(Regression, random.Next(), cancellation);
     87        var res = M5StaticUtilities.RunSubAlgorithm(Regression, random.Next(), cancellationToken);
    8588        var t = res.Select(x => x.Value).OfType<GaussianProcessRegressionSolution>().FirstOrDefault();
    86         var score = ((DoubleValue) res["Negative log pseudo-likelihood (LOO-CV)"].Value).Value;
     89        var score = ((DoubleValue)res["Negative log pseudo-likelihood (LOO-CV)"].Value).Value;
    8790        if (score >= cvscore || t == null || double.IsNaN(t.TrainingRSquared)) continue;
    8891        cvscore = score;
    8992        sol = t;
    9093      }
    91 
     94      Regression.Runs.Clear();
    9295      if (sol == null) throw new ArgumentException("Could not create Gaussian Process model");
    9396
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/LeafTypes/LinearLeaf.cs

    r15430 r15614  
    3131  [StorableClass]
    3232  [Item("LinearLeaf", "A leaf type that uses linear models as leaf models. This is the standard for M5' regression")]
    33   public class LinearLeaf : ParameterizedNamedItem, ILeafType<IConfidenceRegressionModel> {
     33  public class LinearLeaf : ParameterizedNamedItem, ILeafModel {
    3434    #region Constructors & Cloning
    3535    [StorableConstructor]
     
    4343
    4444    #region IModelType
    45     public IConfidenceRegressionModel BuildModel(IRegressionProblemData pd, IRandom random, CancellationToken cancellation, out int noParameters) {
     45    public bool ProvidesConfidence {
     46      get { return true; }
     47    }
     48    public IRegressionModel Build(IRegressionProblemData pd, IRandom random, CancellationToken cancellationToken, out int noParameters) {
    4649      if (pd.Dataset.Rows < MinLeafSize(pd)) throw new ArgumentException("The number of training instances is too small to create a linear model");
    4750      double rmse, cvRmse;
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/LeafTypes/LogisticLeaf.cs

    r15430 r15614  
    3333  [StorableClass]
    3434  [Item("LogisticLeaf", "A leaf type that uses linear models with a logistic dampening as leaf models. Dampening reduces prediction values far outside the observed target values.")]
    35   public class LogisticLeaf : ParameterizedNamedItem, ILeafType<IConfidenceRegressionModel> {
     35  public class LogisticLeaf : ParameterizedNamedItem, ILeafModel {
    3636    private const string DampeningParameterName = "Dampening";
    3737    public IFixedValueParameter<DoubleValue> DampeningParameter {
     
    5555
    5656    #region IModelType
    57     public IConfidenceRegressionModel BuildModel(IRegressionProblemData pd, IRandom random, CancellationToken cancellation, out int noParameters) {
     57    public bool ProvidesConfidence {
     58      get { return true; }
     59    }
     60    public IRegressionModel Build(IRegressionProblemData pd, IRandom random, CancellationToken cancellationToken, out int noParameters) {
    5861      if (pd.Dataset.Rows < MinLeafSize(pd)) throw new ArgumentException("The number of training instances is too small to create a linear model");
    5962      double rmse, cvRmse;
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/M5Regression.cs

    r15470 r15614  
    66using HeuristicLab.Core;
    77using HeuristicLab.Data;
     8using HeuristicLab.Encodings.PermutationEncoding;
    89using HeuristicLab.Optimization;
    910using HeuristicLab.Parameters;
     
    1617  [StorableClass]
    1718  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 95)]
    18   [Item("M5RegressionTree", "A M5 regression tree / rule set classifier")]
     19  [Item("M5RegressionTree", "A M5 regression tree / rule set")]
    1920  public sealed class M5Regression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
    2021    #region Parametername
    2122    private const string GenerateRulesParameterName = "GenerateRules";
    22     private const string ImpurityParameterName = "Split";
     23    private const string HoldoutSizeParameterName = "HoldoutSize";
     24    private const string SpliterParameterName = "Spliter";
    2325    private const string MinimalNodeSizeParameterName = "MinimalNodeSize";
    24     private const string ModelTypeParameterName = "ModelType";
     26    private const string LeafModelParameterName = "LeafModel";
    2527    private const string PruningTypeParameterName = "PruningType";
    2628    private const string SeedParameterName = "Seed";
    2729    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     30    private const string UseHoldoutParameterName = "UseHoldout";
    2831    #endregion
    2932
    3033    #region Parameter properties
    3134    public IFixedValueParameter<BoolValue> GenerateRulesParameter {
    32       get { return Parameters[GenerateRulesParameterName] as IFixedValueParameter<BoolValue>; }
    33     }
    34     public IConstrainedValueParameter<ISplitType> ImpurityParameter {
    35       get { return Parameters[ImpurityParameterName] as IConstrainedValueParameter<ISplitType>; }
     35      get { return (IFixedValueParameter<BoolValue>)Parameters[GenerateRulesParameterName]; }
     36    }
     37    public IFixedValueParameter<PercentValue> HoldoutSizeParameter {
     38      get { return (IFixedValueParameter<PercentValue>)Parameters[HoldoutSizeParameterName]; }
     39    }
     40    public IConstrainedValueParameter<ISpliter> ImpurityParameter {
     41      get { return (IConstrainedValueParameter<ISpliter>)Parameters[SpliterParameterName]; }
    3642    }
    3743    public IFixedValueParameter<IntValue> MinimalNodeSizeParameter {
    38       get { return (IFixedValueParameter<IntValue>) Parameters[MinimalNodeSizeParameterName]; }
    39     }
    40     public IConstrainedValueParameter<ILeafType<IRegressionModel>> ModelTypeParameter {
    41       get { return Parameters[ModelTypeParameterName] as IConstrainedValueParameter<ILeafType<IRegressionModel>>; }
    42     }
    43     public IConstrainedValueParameter<IPruningType> PruningTypeParameter {
    44       get { return Parameters[PruningTypeParameterName] as IConstrainedValueParameter<IPruningType>; }
     44      get { return (IFixedValueParameter<IntValue>)Parameters[MinimalNodeSizeParameterName]; }
     45    }
     46    public IConstrainedValueParameter<ILeafModel> LeafModelParameter {
     47      get { return (IConstrainedValueParameter<ILeafModel>)Parameters[LeafModelParameterName]; }
     48    }
     49    public IConstrainedValueParameter<IPruning> PruningTypeParameter {
     50      get { return (IConstrainedValueParameter<IPruning>)Parameters[PruningTypeParameterName]; }
    4551    }
    4652    public IFixedValueParameter<IntValue> SeedParameter {
    47       get { return Parameters[SeedParameterName] as IFixedValueParameter<IntValue>; }
     53      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
    4854    }
    4955    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
    50       get { return Parameters[SetSeedRandomlyParameterName] as IFixedValueParameter<BoolValue>; }
     56      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
     57    }
     58    public IFixedValueParameter<BoolValue> UseHoldoutParameter {
     59      get { return (IFixedValueParameter<BoolValue>)Parameters[UseHoldoutParameterName]; }
    5160    }
    5261    #endregion
     
    5665      get { return GenerateRulesParameter.Value.Value; }
    5766    }
    58     public ISplitType Split {
     67    public double HoldoutSize {
     68      get { return HoldoutSizeParameter.Value.Value; }
     69    }
     70    public ISpliter Split {
    5971      get { return ImpurityParameter.Value; }
    6072    }
     
    6274      get { return MinimalNodeSizeParameter.Value.Value; }
    6375    }
    64     public ILeafType<IRegressionModel> LeafType {
    65       get { return ModelTypeParameter.Value; }
    66     }
    67     public IPruningType PruningType {
     76    public ILeafModel LeafModel {
     77      get { return LeafModelParameter.Value; }
     78    }
     79    public IPruning Pruning {
    6880      get { return PruningTypeParameter.Value; }
    6981    }
     
    7385    public bool SetSeedRandomly {
    7486      get { return SetSeedRandomlyParameter.Value.Value; }
     87    }
     88    public bool UseHoldout {
     89      get { return UseHoldoutParameter.Value.Value; }
    7590    }
    7691    #endregion
     
    8196    private M5Regression(M5Regression original, Cloner cloner) : base(original, cloner) { }
    8297    public M5Regression() {
    83       var modelSet = new ItemSet<ILeafType<IRegressionModel>>(ApplicationManager.Manager.GetInstances<ILeafType<IRegressionModel>>());
    84       var pruningSet = new ItemSet<IPruningType>(ApplicationManager.Manager.GetInstances<IPruningType>());
    85       var impuritySet = new ItemSet<ISplitType>(ApplicationManager.Manager.GetInstances<ISplitType>());
    86       Parameters.Add(new FixedValueParameter<BoolValue>(GenerateRulesParameterName, "Whether a set of rules or a decision tree shall be created", new BoolValue(true)));
    87       Parameters.Add(new ConstrainedValueParameter<ISplitType>(ImpurityParameterName, "The type of split function used to create node splits", impuritySet, impuritySet.OfType<OrderSplitType>().First()));
     98      var modelSet = new ItemSet<ILeafModel>(ApplicationManager.Manager.GetInstances<ILeafModel>());
     99      var pruningSet = new ItemSet<IPruning>(ApplicationManager.Manager.GetInstances<IPruning>());
     100      var impuritySet = new ItemSet<ISpliter>(ApplicationManager.Manager.GetInstances<ISpliter>());
     101      Parameters.Add(new FixedValueParameter<BoolValue>(GenerateRulesParameterName, "Whether a set of rules or a decision tree shall be created", new BoolValue(false)));
     102      Parameters.Add(new FixedValueParameter<PercentValue>(HoldoutSizeParameterName, "How much of the training set shall be reserved for pruning", new PercentValue(0.2)));
     103      Parameters.Add(new ConstrainedValueParameter<ISpliter>(SpliterParameterName, "The type of split function used to create node splits", impuritySet, impuritySet.OfType<M5Spliter>().First()));
    88104      Parameters.Add(new FixedValueParameter<IntValue>(MinimalNodeSizeParameterName, "The minimal number of samples in a leaf node", new IntValue(1)));
    89       Parameters.Add(new ConstrainedValueParameter<ILeafType<IRegressionModel>>(ModelTypeParameterName, "The type of model used for the nodes", modelSet, modelSet.OfType<LinearLeaf>().First()));
    90       Parameters.Add(new ConstrainedValueParameter<IPruningType>(PruningTypeParameterName, "The type of pruning used", pruningSet, pruningSet.OfType<M5LeafPruning>().First()));
     105      Parameters.Add(new ConstrainedValueParameter<ILeafModel>(LeafModelParameterName, "The type of model used for the nodes", modelSet, modelSet.OfType<LinearLeaf>().First()));
     106      Parameters.Add(new ConstrainedValueParameter<IPruning>(PruningTypeParameterName, "The type of pruning used", pruningSet, pruningSet.OfType<M5LinearBottomUpPruning>().First()));
    91107      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    92108      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     109      Parameters.Add(new FixedValueParameter<BoolValue>(UseHoldoutParameterName, "True if a holdout set should be generated, false if splitting and pruning shall be performed on the same data ", new BoolValue(false)));
    93110      Problem = new RegressionProblem();
    94111    }
     
    102119      if (SetSeedRandomly) SeedParameter.Value.Value = new System.Random().Next();
    103120      random.Reset(Seed);
    104       var solution = CreateM5RegressionSolution(Problem.ProblemData, random, LeafType, Split, PruningType, cancellationToken, MinimalNodeSize, GenerateRules, Results);
     121      var solution = CreateM5RegressionSolution(Problem.ProblemData, random, LeafModel, Split, Pruning, UseHoldout, HoldoutSize, MinimalNodeSize, GenerateRules, Results, cancellationToken);
    105122      AnalyzeSolution(solution);
    106123    }
     
    108125    #region Static Interface
    109126    public static IRegressionSolution CreateM5RegressionSolution(IRegressionProblemData problemData, IRandom random,
    110       ILeafType<IRegressionModel> leafType = null, ISplitType splitType = null, IPruningType pruningType = null,
    111       CancellationToken? cancellationToken = null, int minNumInstances = 4, bool generateRules = false, ResultCollection results = null) {
     127      ILeafModel leafModel = null, ISpliter spliter = null, IPruning pruning = null,
     128      bool useHoldout = false, double holdoutSize = 0.2, int minNumInstances = 4, bool generateRules = false, ResultCollection results = null, CancellationToken? cancellationToken = null) {
    112129      //set default values
    113       if (leafType == null) leafType = new LinearLeaf();
    114       if (splitType == null) splitType = new OrderSplitType();
     130      if (leafModel == null) leafModel = new LinearLeaf();
     131      if (spliter == null) spliter = new M5Spliter();
    115132      if (cancellationToken == null) cancellationToken = CancellationToken.None;
    116       if (pruningType == null) pruningType = new M5LeafPruning();
    117 
     133      if (pruning == null) pruning = new M5LeafBottomUpPruning();
    118134
    119135      var doubleVars = new HashSet<string>(problemData.Dataset.DoubleVariables);
    120136      var vars = problemData.AllowedInputVariables.Concat(new[] {problemData.TargetVariable}).ToArray();
    121       if (vars.Any(v => !doubleVars.Contains(v))) throw new NotSupportedException("M5 regression does not support non-double valued input or output features.");
     137      if (vars.Any(v => !doubleVars.Contains(v))) throw new NotSupportedException("M5 regression supports only double valued input or output features.");
    122138
    123139      var values = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray();
    124140      if (values.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x))))
    125141        throw new NotSupportedException("M5 regression does not support NaN or infinity values in the input dataset.");
     142
    126143      var trainingData = new Dataset(vars, values);
    127144      var pd = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable);
     
    130147
    131148      //create & build Model
    132       var m5Params = new M5CreationParameters(pruningType, minNumInstances, leafType, pd, random, splitType, results);
    133 
    134       IReadOnlyList<int> t, h;
    135       pruningType.GenerateHoldOutSet(problemData.TrainingIndices.ToArray(), random, out t, out h);
    136 
    137       if (generateRules) {
    138         IM5MetaModel model = M5RuleSetModel.CreateRuleModel(problemData.TargetVariable, m5Params);
    139         model.BuildClassifier(t, h, m5Params, cancellationToken.Value);
    140         return model.CreateRegressionSolution(problemData);
     149      var m5Params = new M5Parameters(pruning, minNumInstances, leafModel, pd, random, spliter, results);
     150
     151      IReadOnlyList<int> trainingRows, pruningRows;
     152      GeneratePruningSet(problemData.TrainingIndices.ToArray(), random, useHoldout, holdoutSize, out trainingRows, out pruningRows);
     153
     154      IM5Model model;
     155      if (generateRules)
     156        model = M5RuleSetModel.CreateRuleModel(problemData.TargetVariable, m5Params);
     157      else
     158        model = M5TreeModel.CreateTreeModel(problemData.TargetVariable, m5Params);
     159
     160      model.Build(trainingRows, pruningRows, m5Params, cancellationToken.Value);
     161      return model.CreateRegressionSolution(problemData);
     162    }
     163
     164    public static void UpdateM5Model(IRegressionModel model, IRegressionProblemData problemData, IRandom random,
     165      ILeafModel leafModel, CancellationToken? cancellationToken = null) {
     166      var m5Model = model as IM5Model;
     167      if (m5Model == null) throw new ArgumentException("This type of model can not be updated");
     168      UpdateM5Model(m5Model, problemData, random, leafModel, cancellationToken);
     169    }
     170
     171    private static void UpdateM5Model(IM5Model model, IRegressionProblemData problemData, IRandom random,
     172      ILeafModel leafModel = null, CancellationToken? cancellationToken = null) {
     173      if (cancellationToken == null) cancellationToken = CancellationToken.None;
     174      var m5Params = new M5Parameters(leafModel, problemData, random);
     175      model.Update(problemData.TrainingIndices.ToList(), m5Params, cancellationToken.Value);
     176    }
     177    #endregion
     178
     179    #region Helpers
     180    private static void GeneratePruningSet(IReadOnlyList<int> allrows, IRandom random, bool useHoldout, double holdoutSize, out IReadOnlyList<int> training, out IReadOnlyList<int> pruning) {
     181      if (!useHoldout) {
     182        training = allrows;
     183        pruning = allrows;
     184        return;
     185      }
     186      var perm = new Permutation(PermutationTypes.Absolute, allrows.Count, random);
     187      var cut = (int)(holdoutSize * allrows.Count);
     188      pruning = perm.Take(cut).Select(i => allrows[i]).ToArray();
     189      training = perm.Take(cut).Select(i => allrows[i]).ToArray();
     190    }
     191
     192    private void AnalyzeSolution(IRegressionSolution solution) {
     193      Results.Add(new Result("RegressionSolution", (IItem)solution.Clone()));
     194
     195      Dictionary<string, int> frequencies;
     196      if (!GenerateRules) {
     197        Results.Add(M5Analyzer.CreateLeafDepthHistogram((M5TreeModel)solution.Model));
     198        frequencies = M5Analyzer.GetTreeVariableFrequences((M5TreeModel)solution.Model);
    141199      }
    142200      else {
    143         IM5MetaModel model = M5TreeModel.CreateTreeModel(problemData.TargetVariable, m5Params);
    144         model.BuildClassifier(t, h, m5Params, cancellationToken.Value);
    145         return model.CreateRegressionSolution(problemData);
    146       }
    147     }
    148 
    149     public static void UpdateM5Model(M5TreeModel model, IRegressionProblemData problemData, IRandom random,
    150       ILeafType<IRegressionModel> leafType = null, CancellationToken? cancellationToken = null) {
    151       UpdateM5Model(model as IM5MetaModel, problemData, random, leafType, cancellationToken);
    152     }
    153 
    154     public static void UpdateM5Model(M5RuleSetModel model, IRegressionProblemData problemData, IRandom random,
    155       ILeafType<IRegressionModel> leafType = null, CancellationToken? cancellationToken = null) {
    156       UpdateM5Model(model as IM5MetaModel, problemData, random, leafType, cancellationToken);
    157     }
    158 
    159     private static void UpdateM5Model(IM5MetaModel model, IRegressionProblemData problemData, IRandom random,
    160       ILeafType<IRegressionModel> leafType = null, CancellationToken? cancellationToken = null) {
    161       if (cancellationToken == null) cancellationToken = CancellationToken.None;
    162       var m5Params = new M5UpdateParameters(leafType, problemData, random);
    163       model.UpdateModel(problemData.TrainingIndices.ToList(), m5Params, cancellationToken.Value);
    164     }
    165     #endregion
    166 
    167     #region Helpers
    168     private void AnalyzeSolution(IRegressionSolution solution) {
    169       Results.Add(new Result("RegressionSolution", (IItem) solution.Clone()));
    170 
    171       Dictionary<string, int> frequencies;
    172       if (!GenerateRules) {
    173         Results.Add(M5Analyzer.CreateLeafDepthHistogram((M5TreeModel) solution.Model));
    174         frequencies = M5Analyzer.GetTreeVariableFrequences((M5TreeModel) solution.Model);
    175       }
    176       else {
    177         Results.Add(M5Analyzer.CreateRulesResult((M5RuleSetModel) solution.Model, Problem.ProblemData, "M5TreeResult", true));
    178         frequencies = M5Analyzer.GetRuleVariableFrequences((M5RuleSetModel) solution.Model);
    179         Results.Add(M5Analyzer.CreateCoverageDiagram((M5RuleSetModel) solution.Model, Problem.ProblemData));
     201        Results.Add(M5Analyzer.CreateRulesResult((M5RuleSetModel)solution.Model, Problem.ProblemData, "M5TreeResult", true));
     202        frequencies = M5Analyzer.GetRuleVariableFrequences((M5RuleSetModel)solution.Model);
     203        Results.Add(M5Analyzer.CreateCoverageDiagram((M5RuleSetModel)solution.Model, Problem.ProblemData));
    180204      }
    181205
     
    183207      var sum = frequencies.Values.Sum();
    184208      sum = sum == 0 ? 1 : sum;
    185       var impactArray = new DoubleArray(frequencies.Select(i => (double) i.Value / sum).ToArray()) {
     209      var impactArray = new DoubleArray(frequencies.Select(i => (double)i.Value / sum).ToArray()) {
    186210        ElementNames = frequencies.Select(i => i.Key)
    187211      };
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/M5Utilities/M5Analyzer.cs

    r15470 r15614  
    3737      var res = ruleSetModel.VariablesUsedForPrediction.ToDictionary(x => x, x => 0);
    3838      foreach (var rule in ruleSetModel.Rules)
    39       foreach (var att in rule.SplitAtts)
     39      foreach (var att in rule.SplitAttributes)
    4040        res[att]++;
    4141      return res;
     
    4646      var root = treeModel.Root;
    4747      foreach (var cur in root.EnumerateNodes().Where(x => !x.IsLeaf))
    48         res[cur.SplitAttr]++;
     48        res[cur.SplitAttribute]++;
    4949      return res;
    5050    }
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/M5Utilities/M5StaticUtilities.cs

    r15549 r15614  
    2121
    2222using System;
     23using System.Collections.Generic;
     24using System.Linq;
    2325using System.Threading;
     26using HeuristicLab.Common;
    2427using HeuristicLab.Core;
    2528using HeuristicLab.Data;
    2629using HeuristicLab.Optimization;
     30using HeuristicLab.Problems.DataAnalysis;
    2731
    2832namespace HeuristicLab.Algorithms.DataAnalysis {
    2933  internal static class M5StaticUtilities {
    30     public static ResultCollection RunSubAlgorithm(IAlgorithm alg, int random, CancellationToken cancellation) {
     34    public static ResultCollection RunSubAlgorithm(IAlgorithm alg, int random, CancellationToken cancellationToken) {
    3135      if (alg.Parameters.ContainsKey("SetSeedRandomly") && alg.Parameters.ContainsKey("Seed")) {
    3236        var seed = alg.Parameters["Seed"].ActualValue as IntValue;
     
    3842      }
    3943      if (alg.ExecutionState != ExecutionState.Paused) alg.Prepare();
    40       alg.Start(cancellation);
     44      alg.Start(cancellationToken);
    4145      return alg.Results;
     46    }
     47
     48    public static void SplitRows(IReadOnlyList<int> rows, IDataset data, string splitAttr, double splitValue, out IReadOnlyList<int> leftRows, out IReadOnlyList<int> rightRows) {
     49      //TODO check and revert points at borders are now used multipe times
     50      var assignment = data.GetDoubleValues(splitAttr, rows).Select(x => x.IsAlmost(splitValue) ? 2 : x < splitValue ? 0 : 1).ToArray();
     51      leftRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => x.b == 0 || x.b == 2).Select(x => x.i).ToList();
     52      rightRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => x.b > 0).Select(x => x.i).ToList();
     53    }
     54
     55    public static IRegressionModel BuildModel(IReadOnlyList<int> rows, M5Parameters parameters, ILeafModel leafModel, CancellationToken cancellation, out int numParams) {
     56      var reducedData = ReduceDataset(parameters.Data, rows, parameters.AllowedInputVariables.ToArray(), parameters.TargetVariable);
     57      var pd = new RegressionProblemData(reducedData, parameters.AllowedInputVariables.ToArray(), parameters.TargetVariable);
     58      pd.TrainingPartition.Start = 0;
     59      pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = reducedData.Rows;
     60
     61      int numP;
     62      var model = leafModel.Build(pd, parameters.Random, cancellation, out numP);
     63      numParams = numP;
     64      cancellation.ThrowIfCancellationRequested();
     65      return model;
     66    }
     67
     68    public static IDataset ReduceDataset(IDataset data, IReadOnlyList<int> rows, IReadOnlyList<string> inputVariables, string target) {
     69      return new Dataset(inputVariables.Concat(new[] {target}), inputVariables.Concat(new[] {target}).Select(x => data.GetDoubleValues(x, rows).ToList()));
    4270    }
    4371  }
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/MetaModels/M5NodeModel.cs

    r15470 r15614  
    3737    internal bool IsLeaf { get; private set; }
    3838    [Storable]
    39     internal IRegressionModel NodeModel { get; private set; }
     39    internal IRegressionModel Model { get; set; }
    4040    [Storable]
    41     internal string SplitAttr { get; private set; }
     41    internal string SplitAttribute { get; private set; }
    4242    [Storable]
    4343    internal double SplitValue { get; private set; }
     
    4747    internal M5NodeModel Right { get; private set; }
    4848    [Storable]
    49     internal M5NodeModel Parent { get; set; }
     49    internal M5NodeModel Parent { get; private set; }
    5050    [Storable]
    5151    internal int NumSamples { get; private set; }
    5252    [Storable]
    53     internal int NumParam { get; set; }
    54     [Storable]
    55     internal int NodeModelParams { get; set; }
    56     [Storable]
    57     private IReadOnlyList<string> Variables { get; set; }
     53    private IReadOnlyList<string> variables;
    5854    #endregion
    5955
     
    6359    protected M5NodeModel(M5NodeModel original, Cloner cloner) : base(original, cloner) {
    6460      IsLeaf = original.IsLeaf;
    65       NodeModel = cloner.Clone(original.NodeModel);
     61      Model = cloner.Clone(original.Model);
    6662      SplitValue = original.SplitValue;
    67       SplitAttr = original.SplitAttr;
     63      SplitAttribute = original.SplitAttribute;
    6864      Left = cloner.Clone(original.Left);
    6965      Right = cloner.Clone(original.Right);
    7066      Parent = cloner.Clone(original.Parent);
    71       NumParam = original.NumParam;
    7267      NumSamples = original.NumSamples;
    73       Variables = original.Variables != null ? original.Variables.ToList() : null;
     68      variables = original.variables != null ? original.variables.ToList() : null;
    7469    }
    75     protected M5NodeModel(string targetAttr) : base(targetAttr) { }
    76     protected M5NodeModel(M5NodeModel parent) : base(parent.TargetVariable) {
     70    private M5NodeModel(string targetAttr) : base(targetAttr) { }
     71    private M5NodeModel(M5NodeModel parent) : this(parent.TargetVariable) {
    7772      Parent = parent;
    7873    }
     
    8075      return new M5NodeModel(this, cloner);
    8176    }
    82     public static M5NodeModel CreateNode(string targetAttr, M5CreationParameters m5CreationParams) {
    83       return m5CreationParams.LeafType is ILeafType<IConfidenceRegressionModel> ? new ConfidenceM5NodeModel(targetAttr) : new M5NodeModel(targetAttr);
     77    public static M5NodeModel CreateNode(string targetAttr, M5Parameters m5Params) {
     78      return m5Params.LeafModel.ProvidesConfidence ? new ConfidenceM5NodeModel(targetAttr) : new M5NodeModel(targetAttr);
    8479    }
    85     private static M5NodeModel CreateNode(M5NodeModel parent, M5CreationParameters m5CreationParams) {
    86       return m5CreationParams.LeafType is ILeafType<IConfidenceRegressionModel> ? new ConfidenceM5NodeModel(parent) : new M5NodeModel(parent);
     80    private static M5NodeModel CreateNode(M5NodeModel parent, M5Parameters m5Params) {
     81      return m5Params.LeafModel.ProvidesConfidence ? new ConfidenceM5NodeModel(parent) : new M5NodeModel(parent);
    8782    }
    8883    #endregion
     
    9085    #region RegressionModel
    9186    public override IEnumerable<string> VariablesUsedForPrediction {
    92       get { return Variables; }
     87      get { return variables; }
    9388    }
    9489    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    9590      if (!IsLeaf) return rows.Select(row => GetEstimatedValue(dataset, row));
    96       if (NodeModel == null) throw new NotSupportedException("M5P has not been built correctly");
    97       return NodeModel.GetEstimatedValues(dataset, rows);
     91      if (Model == null) throw new NotSupportedException("The model has not been built correctly");
     92      return Model.GetEstimatedValues(dataset, rows);
    9893    }
    9994    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     
    10297    #endregion
    10398
    104     internal void Split(IReadOnlyList<int> rows, M5CreationParameters m5CreationParams, double globalStdDev) {
    105       Variables = m5CreationParams.AllowedInputVariables.ToArray();
     99    internal void Split(IReadOnlyList<int> rows, M5Parameters m5Params) {
     100      variables = m5Params.AllowedInputVariables.ToArray();
    106101      NumSamples = rows.Count;
    107102      Right = null;
    108103      Left = null;
    109       NodeModel = null;
    110       SplitAttr = null;
     104      Model = null;
     105      SplitAttribute = null;
    111106      SplitValue = double.NaN;
    112107      string attr;
    113108      double splitValue;
    114       //IsLeaf = m5CreationParams.Data.GetDoubleValues(TargetVariable, rows).StandardDeviation() < globalStdDev * DevFraction;
    115       //if (IsLeaf) return;
    116       IsLeaf = !m5CreationParams.Split.Split(new RegressionProblemData(ReduceDataset(m5CreationParams.Data, rows), Variables, TargetVariable), m5CreationParams.MinLeafSize, out attr, out splitValue);
     109      IsLeaf = !m5Params.Spliter.Split(new RegressionProblemData(M5StaticUtilities.ReduceDataset(m5Params.Data, rows, variables, TargetVariable), variables, TargetVariable), m5Params.MinLeafSize, out attr, out splitValue);
    117110      if (IsLeaf) return;
    118111
    119112      //split Dataset
    120113      IReadOnlyList<int> leftRows, rightRows;
    121       SplitRows(rows, m5CreationParams.Data, attr, splitValue, out leftRows, out rightRows);
     114      M5StaticUtilities.SplitRows(rows, m5Params.Data, attr, splitValue, out leftRows, out rightRows);
    122115
    123       if (leftRows.Count < m5CreationParams.MinLeafSize || rightRows.Count < m5CreationParams.MinLeafSize) {
     116      if (leftRows.Count < m5Params.MinLeafSize || rightRows.Count < m5Params.MinLeafSize) {
    124117        IsLeaf = true;
    125118        return;
    126119      }
    127       SplitAttr = attr;
     120      SplitAttribute = attr;
    128121      SplitValue = splitValue;
    129122
    130123      //create subtrees
    131       Left = CreateNode(this, m5CreationParams);
    132       Left.Split(leftRows, m5CreationParams, globalStdDev);
    133       Right = CreateNode(this, m5CreationParams);
    134       Right.Split(rightRows, m5CreationParams, globalStdDev);
     124      Left = CreateNode(this, m5Params);
     125      Left.Split(leftRows, m5Params);
     126      Right = CreateNode(this, m5Params);
     127      Right.Split(rightRows, m5Params);
    135128    }
    136129
    137     internal bool Prune(IReadOnlyList<int> trainingRows, IReadOnlyList<int> testRows, M5CreationParameters m5CreationParams, CancellationToken cancellation, double globalStdDev) {
    138       if (IsLeaf) {
    139         BuildModel(trainingRows, m5CreationParams.Data, m5CreationParams.Random, m5CreationParams.PruningLeaf, cancellation);
    140         NumParam = NodeModelParams;
    141         return true;
    142       }
    143       //split training & holdout data
    144       IReadOnlyList<int> leftTest, rightTest;
    145       SplitRows(testRows, m5CreationParams.Data, SplitAttr, SplitValue, out leftTest, out rightTest);
    146       IReadOnlyList<int> leftTraining, rightTraining;
    147       SplitRows(trainingRows, m5CreationParams.Data, SplitAttr, SplitValue, out leftTraining, out rightTraining);
    148 
    149       //prune children frist
    150       var lpruned = Left.Prune(leftTraining, leftTest, m5CreationParams, cancellation, globalStdDev);
    151       var rpruned = Right.Prune(rightTraining, rightTest, m5CreationParams, cancellation, globalStdDev);
    152       NumParam = Left.NumParam + Right.NumParam + 1;
    153 
    154       //TODO check if this reduces quality. It reduces training effort (consideraby for some pruningTypes)
    155       if (!lpruned && !rpruned) return false;
    156 
    157       BuildModel(trainingRows, m5CreationParams.Data, m5CreationParams.Random, m5CreationParams.PruningLeaf, cancellation);
    158 
    159       //check if children will be pruned
    160       if (!((PruningBase) m5CreationParams.Pruningtype).Prune(this, m5CreationParams, testRows, globalStdDev)) return false;
    161 
    162       //convert to leafNode
    163       ((IntValue) m5CreationParams.Results[M5RuleModel.NoCurrentLeafesResultName].Value).Value -= EnumerateNodes().Count(x => x.IsLeaf) - 1;
     130    internal void ToLeaf() {
    164131      IsLeaf = true;
    165132      Right = null;
    166133      Left = null;
    167       NumParam = NodeModelParams;
    168       return true;
    169134    }
    170135
    171     internal void InstallModels(IReadOnlyList<int> rows, IRandom random, IDataset data, ILeafType<IRegressionModel> leafType, CancellationToken cancellation) {
     136    internal void BuildLeafModels(IReadOnlyList<int> rows, M5Parameters parameters, CancellationToken cancellationToken) {
    172137      if (!IsLeaf) {
    173138        IReadOnlyList<int> leftRows, rightRows;
    174         SplitRows(rows, data, SplitAttr, SplitValue, out leftRows, out rightRows);
    175         Left.InstallModels(leftRows, random, data, leafType, cancellation);
    176         Right.InstallModels(rightRows, random, data, leafType, cancellation);
     139        M5StaticUtilities.SplitRows(rows, parameters.Data, SplitAttribute, SplitValue, out leftRows, out rightRows);
     140        Left.BuildLeafModels(leftRows, parameters, cancellationToken);
     141        Right.BuildLeafModels(rightRows, parameters, cancellationToken);
    177142        return;
    178143      }
    179       BuildModel(rows, data, random, leafType, cancellation);
     144      int numP;
     145      Model = M5StaticUtilities.BuildModel(rows, parameters, parameters.LeafModel, cancellationToken, out numP);
    180146    }
    181147
     
    192158    }
    193159
    194     internal void ToRuleNode() {
    195       Parent = null;
    196     }
    197 
    198160    #region Helpers
    199161    private double GetEstimatedValue(IDataset dataset, int row) {
    200       if (!IsLeaf) return (dataset.GetDoubleValue(SplitAttr, row) <= SplitValue ? Left : Right).GetEstimatedValue(dataset, row);
    201       if (NodeModel == null) throw new NotSupportedException("M5P has not been built correctly");
    202       return NodeModel.GetEstimatedValues(dataset, new[] {row}).First();
    203     }
    204 
    205     private void BuildModel(IReadOnlyList<int> rows, IDataset data, IRandom random, ILeafType<IRegressionModel> leafType, CancellationToken cancellation) {
    206       var reducedData = ReduceDataset(data, rows);
    207       var pd = new RegressionProblemData(reducedData, VariablesUsedForPrediction, TargetVariable);
    208       pd.TrainingPartition.Start = 0;
    209       pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = reducedData.Rows;
    210 
    211       int noparams;
    212       NodeModel = leafType.BuildModel(pd, random, cancellation, out noparams);
    213       NodeModelParams = noparams;
    214       cancellation.ThrowIfCancellationRequested();
    215     }
    216 
    217     private IDataset ReduceDataset(IDataset data, IReadOnlyList<int> rows) {
    218       return new Dataset(VariablesUsedForPrediction.Concat(new[] {TargetVariable}), VariablesUsedForPrediction.Concat(new[] {TargetVariable}).Select(x => data.GetDoubleValues(x, rows).ToList()));
    219     }
    220 
    221     private static void SplitRows(IReadOnlyList<int> rows, IDataset data, string splitAttr, double splitValue, out IReadOnlyList<int> leftRows, out IReadOnlyList<int> rightRows) {
    222       var assignment = data.GetDoubleValues(splitAttr, rows).Select(x => x <= splitValue).ToArray();
    223       leftRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => x.b).Select(x => x.i).ToList();
    224       rightRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => !x.b).Select(x => x.i).ToList();
     162      if (!IsLeaf) return (dataset.GetDoubleValue(SplitAttribute, row) <= SplitValue ? Left : Right).GetEstimatedValue(dataset, row);
     163      if (Model == null) throw new NotSupportedException("The model has not been built correctly");
     164      return Model.GetEstimatedValues(dataset, new[] {row}).First();
    225165    }
    226166    #endregion
     
    240180
    241181      public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
    242         return IsLeaf ? ((IConfidenceRegressionModel) NodeModel).GetEstimatedVariances(dataset, rows) : rows.Select(row => GetEstimatedVariance(dataset, row));
     182        return IsLeaf ? ((IConfidenceRegressionModel)Model).GetEstimatedVariances(dataset, rows) : rows.Select(row => GetEstimatedVariance(dataset, row));
    243183      }
    244184
    245185      private double GetEstimatedVariance(IDataset dataset, int row) {
    246186        if (!IsLeaf)
    247           return ((IConfidenceRegressionModel) (dataset.GetDoubleValue(SplitAttr, row) <= SplitValue ? Left : Right)).GetEstimatedVariances(dataset, row.ToEnumerable()).Single();
    248         return ((IConfidenceRegressionModel) NodeModel).GetEstimatedVariances(dataset, new[] {row}).First();
     187          return ((IConfidenceRegressionModel)(dataset.GetDoubleValue(SplitAttribute, row) <= SplitValue ? Left : Right)).GetEstimatedVariances(dataset, row.ToEnumerable()).Single();
     188        return ((IConfidenceRegressionModel)Model).GetEstimatedVariances(dataset, new[] {row}).First();
    249189      }
    250190
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/MetaModels/M5RuleModel.cs

    r15430 r15614  
    3232namespace HeuristicLab.Algorithms.DataAnalysis {
    3333  [StorableClass]
    34   internal class M5RuleModel : RegressionModel, IM5MetaModel {
    35     internal const string NoCurrentLeafesResultName = "Number of current Leafs";
    36 
     34  internal class M5RuleModel : RegressionModel {
    3735    #region Properties
    3836    [Storable]
    39     internal string[] SplitAtts { get; private set; }
     37    internal string[] SplitAttributes { get; private set; }
    4038    [Storable]
    41     private double[] SplitVals { get; set; }
     39    private double[] splitValues;
    4240    [Storable]
    43     private RelOp[] RelOps { get; set; }
     41    private Comparison[] comparisons;
    4442    [Storable]
    4543    protected IRegressionModel RuleModel { get; set; }
    4644    [Storable]
    47     private IReadOnlyList<string> Variables { get; set; }
     45    private IReadOnlyList<string> variables;
    4846    #endregion
    4947
     
    5250    protected M5RuleModel(bool deserializing) : base(deserializing) { }
    5351    protected M5RuleModel(M5RuleModel original, Cloner cloner) : base(original, cloner) {
    54       if (original.SplitAtts != null) SplitAtts = original.SplitAtts.ToArray();
    55       if (original.SplitVals != null) SplitVals = original.SplitVals.ToArray();
    56       if (original.RelOps != null) RelOps = original.RelOps.ToArray();
     52      if (original.SplitAttributes != null) SplitAttributes = original.SplitAttributes.ToArray();
     53      if (original.splitValues != null) splitValues = original.splitValues.ToArray();
     54      if (original.comparisons != null) comparisons = original.comparisons.ToArray();
    5755      RuleModel = cloner.Clone(original.RuleModel);
    58       if (original.Variables != null) Variables = original.Variables.ToList();
     56      if (original.variables != null) variables = original.variables.ToList();
    5957    }
    6058    private M5RuleModel(string target) : base(target) { }
     
    6462    #endregion
    6563
    66     internal static M5RuleModel CreateRuleModel(string target, M5CreationParameters m5CreationParams) {
    67       return m5CreationParams.LeafType is ILeafType<IConfidenceRegressionModel> ? new ConfidenceM5RuleModel(target) : new M5RuleModel(target);
     64    internal static M5RuleModel CreateRuleModel(string target, M5Parameters m5Params) {
     65      return m5Params.LeafModel.ProvidesConfidence ? new ConfidenceM5RuleModel(target) : new M5RuleModel(target);
    6866    }
    6967
    7068    #region IRegressionModel
    7169    public override IEnumerable<string> VariablesUsedForPrediction {
    72       get { return Variables; }
     70      get { return variables; }
    7371    }
    7472
    7573    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    76       if (RuleModel == null) throw new NotSupportedException("M5P has not been built correctly");
     74      if (RuleModel == null) throw new NotSupportedException("The model has not been built correctly");
    7775      return RuleModel.GetEstimatedValues(dataset, rows);
    7876    }
     
    8381    #endregion
    8482
    85     #region IM5Component
    86     public void BuildClassifier(IReadOnlyList<int> trainingRows, IReadOnlyList<int> holdoutRows, M5CreationParameters m5CreationParams, CancellationToken cancellation) {
    87       Variables = m5CreationParams.AllowedInputVariables.ToList();
    88       var tree = M5TreeModel.CreateTreeModel(m5CreationParams.TargetVariable, m5CreationParams);
    89       ((IM5MetaModel) tree).BuildClassifier(trainingRows, holdoutRows, m5CreationParams, cancellation);
     83
     84    public void Build(IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, M5Parameters m5Params, CancellationToken cancellationToken) {
     85      variables = m5Params.AllowedInputVariables.ToList();
     86      var tree = M5TreeModel.CreateTreeModel(m5Params.TargetVariable, m5Params);
     87      tree.Build(trainingRows, pruningRows, m5Params, cancellationToken);
    9088      var nodeModel = tree.Root.EnumerateNodes().Where(x => x.IsLeaf).MaxItems(x => x.NumSamples).First();
    9189
    9290      var satts = new List<string>();
    9391      var svals = new List<double>();
    94       var reops = new List<RelOp>();
     92      var reops = new List<Comparison>();
    9593
    9694      //extract Splits
    9795      for (var temp = nodeModel; temp.Parent != null; temp = temp.Parent) {
    98         satts.Add(temp.Parent.SplitAttr);
     96        satts.Add(temp.Parent.SplitAttribute);
    9997        svals.Add(temp.Parent.SplitValue);
    100         reops.Add(temp.Parent.Left == temp ? RelOp.Lessequal : RelOp.Greater);
     98        reops.Add(temp.Parent.Left == temp ? Comparison.LessEqual : Comparison.Greater);
    10199      }
    102       nodeModel.ToRuleNode();
    103       RuleModel = nodeModel.NodeModel;
    104       RelOps = reops.ToArray();
    105       SplitAtts = satts.ToArray();
    106       SplitVals = svals.ToArray();
     100      RuleModel = nodeModel.Model;
     101      comparisons = reops.ToArray();
     102      SplitAttributes = satts.ToArray();
     103      splitValues = svals.ToArray();
    107104    }
    108105
    109     public void UpdateModel(IReadOnlyList<int> rows, M5UpdateParameters m5UpdateParameters, CancellationToken cancellation) {
    110       BuildModel(rows, m5UpdateParameters.Random, m5UpdateParameters.Data, m5UpdateParameters.LeafType, cancellation);
     106    public void Update(IReadOnlyList<int> rows, M5Parameters m5Parameters, CancellationToken cancellationToken) {
     107      BuildModel(rows, m5Parameters.Random, m5Parameters.Data, m5Parameters.LeafModel, cancellationToken);
    111108    }
    112     #endregion
    113109
    114110    public bool Covers(IDataset dataset, int row) {
    115       return !SplitAtts.Where((t, i) => !RelOps[i].Compare(dataset.GetDoubleValue(t, row), SplitVals[i])).Any();
     111      return !SplitAttributes.Where((t, i) => !comparisons[i].Compare(dataset.GetDoubleValue(t, row), splitValues[i])).Any();
    116112    }
    117113
     
    119115      var mins = new Dictionary<string, double>();
    120116      var maxs = new Dictionary<string, double>();
    121       for (var i = 0; i < SplitAtts.Length; i++) {
    122         var n = SplitAtts[i];
    123         var v = SplitVals[i];
     117      for (var i = 0; i < SplitAttributes.Length; i++) {
     118        var n = SplitAttributes[i];
     119        var v = splitValues[i];
    124120        if (!mins.ContainsKey(n)) mins.Add(n, double.NegativeInfinity);
    125121        if (!maxs.ContainsKey(n)) maxs.Add(n, double.PositiveInfinity);
    126         if (RelOps[i] == RelOp.Lessequal) maxs[n] = Math.Min(maxs[n], v);
     122        if (comparisons[i] == Comparison.LessEqual) maxs[n] = Math.Min(maxs[n], v);
    127123        else mins[n] = Math.Max(mins[n], v);
    128124      }
     
    136132
    137133    #region Helpers
    138     private void BuildModel(IReadOnlyList<int> rows, IRandom random, IDataset data, ILeafType<IRegressionModel> leafType, CancellationToken cancellation) {
     134    private void BuildModel(IReadOnlyList<int> rows, IRandom random, IDataset data, ILeafModel leafModel, CancellationToken cancellationToken) {
    139135      var reducedData = new Dataset(VariablesUsedForPrediction.Concat(new[] {TargetVariable}), VariablesUsedForPrediction.Concat(new[] {TargetVariable}).Select(x => data.GetDoubleValues(x, rows).ToList()));
    140136      var pd = new RegressionProblemData(reducedData, VariablesUsedForPrediction, TargetVariable);
     
    143139
    144140      int noparams;
    145       RuleModel = leafType.BuildModel(pd, random, cancellation, out noparams);
    146       cancellation.ThrowIfCancellationRequested();
     141      RuleModel = leafModel.Build(pd, random, cancellationToken, out noparams);
     142      cancellationToken.ThrowIfCancellationRequested();
    147143    }
    148144    #endregion
     
    161157
    162158      public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
    163         return ((IConfidenceRegressionModel) RuleModel).GetEstimatedVariances(dataset, rows);
     159        return ((IConfidenceRegressionModel)RuleModel).GetEstimatedVariances(dataset, rows);
    164160      }
    165161
     
    170166  }
    171167
    172   internal enum RelOp {
    173     Lessequal,
     168  internal enum Comparison {
     169    LessEqual,
    174170    Greater
    175171  }
    176172
    177   internal static class RelOpExtentions {
    178     public static bool Compare(this RelOp op, double x, double y) {
     173  internal static class ComparisonExtentions {
     174    public static bool Compare(this Comparison op, double x, double y) {
    179175      switch (op) {
    180         case RelOp.Greater:
     176        case Comparison.Greater:
    181177          return x > y;
    182         case RelOp.Lessequal:
     178        case Comparison.LessEqual:
    183179          return x <= y;
    184180        default:
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/MetaModels/M5RuleSetModel.cs

    r15430 r15614  
    3232namespace HeuristicLab.Algorithms.DataAnalysis {
    3333  [StorableClass]
    34   public class M5RuleSetModel : RegressionModel, IM5MetaModel {
    35     private const string NoRulesResultName = "Number of Rules";
    36     private const string CoveredInstancesResultName = "Covered Instances";
     34  internal class M5RuleSetModel : RegressionModel, IM5Model {
     35    private const string NumRulesResultName = "Number of rules";
     36    private const string CoveredInstancesResultName = "Covered instances";
    3737
    3838    #region Properties
     
    5353    #endregion
    5454
    55     internal static M5RuleSetModel CreateRuleModel(string targetAttr, M5CreationParameters m5CreationParams) {
    56       return m5CreationParams.LeafType is ILeafType<IConfidenceRegressionModel> ? new ConfidenceM5RuleSetModel(targetAttr) : new M5RuleSetModel(targetAttr);
     55    internal static M5RuleSetModel CreateRuleModel(string targetAttr, M5Parameters m5Params) {
     56      return m5Params.LeafModel.ProvidesConfidence ? new ConfidenceM5RuleSetModel(targetAttr) : new M5RuleSetModel(targetAttr);
    5757    }
    5858
     
    6565    }
    6666    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    67       if (Rules == null) throw new NotSupportedException("The classifier has not been built yet");
     67      if (Rules == null) throw new NotSupportedException("The model has not been built yet");
    6868      return rows.Select(row => GetEstimatedValue(dataset, row));
    6969    }
     
    7373    #endregion
    7474
    75     #region IM5Component
    76     void IM5MetaModel.BuildClassifier(IReadOnlyList<int> trainingRows, IReadOnlyList<int> holdoutRows, M5CreationParameters m5CreationParams, CancellationToken cancellation) {
     75    #region IM5Model
     76    public void Build(IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, M5Parameters m5Params, CancellationToken cancellationToken) {
    7777      Rules = new List<M5RuleModel>();
    7878      var tempTraining = trainingRows;
    79       var tempHoldout = holdoutRows;
     79      var tempPruning = pruningRows;
    8080      do {
    81         var tempRule = M5RuleModel.CreateRuleModel(m5CreationParams.TargetVariable, m5CreationParams);
    82         cancellation.ThrowIfCancellationRequested();
     81        var tempRule = M5RuleModel.CreateRuleModel(m5Params.TargetVariable, m5Params);
     82        cancellationToken.ThrowIfCancellationRequested();
    8383
    84         if (!m5CreationParams.Results.ContainsKey(NoRulesResultName)) m5CreationParams.Results.Add(new Result(NoRulesResultName, new IntValue(0)));
    85         if (!m5CreationParams.Results.ContainsKey(CoveredInstancesResultName)) m5CreationParams.Results.Add(new Result(CoveredInstancesResultName, new IntValue(0)));
     84        if (!m5Params.Results.ContainsKey(NumRulesResultName)) m5Params.Results.Add(new Result(NumRulesResultName, new IntValue(0)));
     85        if (!m5Params.Results.ContainsKey(CoveredInstancesResultName)) m5Params.Results.Add(new Result(CoveredInstancesResultName, new IntValue(0)));
    8686
    8787        var t1 = tempTraining.Count;
    88         tempRule.BuildClassifier(tempTraining, tempHoldout, m5CreationParams, cancellation);
    89         tempTraining = tempTraining.Where(i => !tempRule.Covers(m5CreationParams.Data, i)).ToArray();
    90         tempHoldout = tempHoldout.Where(i => !tempRule.Covers(m5CreationParams.Data, i)).ToArray();
     88        tempRule.Build(tempTraining, tempPruning, m5Params, cancellationToken);
     89        tempTraining = tempTraining.Where(i => !tempRule.Covers(m5Params.Data, i)).ToArray();
     90        tempPruning = tempPruning.Where(i => !tempRule.Covers(m5Params.Data, i)).ToArray();
    9191        Rules.Add(tempRule);
    92         ((IntValue) m5CreationParams.Results[NoRulesResultName].Value).Value++;
    93         ((IntValue) m5CreationParams.Results[CoveredInstancesResultName].Value).Value += t1 - tempTraining.Count;
     92        ((IntValue)m5Params.Results[NumRulesResultName].Value).Value++;
     93        ((IntValue)m5Params.Results[CoveredInstancesResultName].Value).Value += t1 - tempTraining.Count;
    9494      }
    9595      while (tempTraining.Count > 0);
    9696    }
    9797
    98     void IM5MetaModel.UpdateModel(IReadOnlyList<int> rows, M5UpdateParameters m5UpdateParameters, CancellationToken cancellation) {
    99       foreach (var rule in Rules) rule.UpdateModel(rows, m5UpdateParameters, cancellation);
     98    public void Update(IReadOnlyList<int> rows, M5Parameters m5Parameters, CancellationToken cancellationToken) {
     99      foreach (var rule in Rules) rule.Update(rows, m5Parameters, cancellationToken);
    100100    }
    101101    #endregion
     
    104104    private double GetEstimatedValue(IDataset dataset, int row) {
    105105      foreach (var rule in Rules) {
    106         var prediction = rule.GetEstimatedValues(dataset, row.ToEnumerable()).Single();
    107         if (rule.Covers(dataset, row)) return prediction;
     106        if (rule.Covers(dataset, row))
     107          return rule.GetEstimatedValues(dataset, row.ToEnumerable()).Single();
    108108      }
    109109      throw new ArgumentException("Instance is not covered by any rule");
     
    125125      #region IConfidenceRegressionModel
    126126      public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
    127         if (Rules == null) throw new NotSupportedException("The classifier has not been built yet");
     127        if (Rules == null) throw new NotSupportedException("The model has not been built yet");
    128128        return rows.Select(row => GetEstimatedVariance(dataset, row));
    129129      }
     
    133133      private double GetEstimatedVariance(IDataset dataset, int row) {
    134134        foreach (var rule in Rules) {
    135           var prediction = ((IConfidenceRegressionModel) rule).GetEstimatedVariances(dataset, row.ToEnumerable()).Single();
    136           if (rule.Covers(dataset, row)) return prediction;
     135          if (rule.Covers(dataset, row)) return ((IConfidenceRegressionModel)rule).GetEstimatedVariances(dataset, row.ToEnumerable()).Single();
    137136        }
    138137        throw new ArgumentException("Instance is not covered by any rule");
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/MetaModels/M5TreeModel.cs

    r15430 r15614  
    3232namespace HeuristicLab.Algorithms.DataAnalysis {
    3333  [StorableClass]
    34   public class M5TreeModel : RegressionModel, IM5MetaModel {
    35     private const string NoCurrentLeafesResultName = "Number of current Leafs";
     34  internal class M5TreeModel : RegressionModel, IM5Model {
     35    public const string NumCurrentLeafsResultName = "Number of current leafs";
    3636    #region Properties
    3737    [Storable]
    3838    internal M5NodeModel Root { get; private set; }
    39     //[Storable]
    40     //private M5Parameters M5Params { get; set; }
    4139    #endregion
    4240
     
    5351    #endregion
    5452
    55     internal static M5TreeModel CreateTreeModel(string targetAttr, M5CreationParameters m5CreationParams) {
    56       return m5CreationParams.LeafType is ILeafType<IConfidenceRegressionModel> ? new ConfidenceM5TreeModel(targetAttr) : new M5TreeModel(targetAttr);
     53    internal static M5TreeModel CreateTreeModel(string targetAttr, M5Parameters m5Params) {
     54      return m5Params.LeafModel.ProvidesConfidence ? new ConfidenceM5TreeModel(targetAttr) : new M5TreeModel(targetAttr);
    5755    }
    5856
     
    6260    }
    6361    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    64       if (Root == null) throw new NotSupportedException("The classifier has not been built yet");
     62      if (Root == null) throw new NotSupportedException("The model has not been built yet");
    6563      return Root.GetEstimatedValues(dataset, rows);
    6664    }
     
    7068    #endregion
    7169
    72     #region IM5Component
    73     void IM5MetaModel.BuildClassifier(IReadOnlyList<int> trainingRows, IReadOnlyList<int> holdoutRows, M5CreationParameters m5CreationParams, CancellationToken cancellation) {
    74       Root = null;
    75       var globalStdDev = m5CreationParams.Data.GetDoubleValues(m5CreationParams.TargetVariable, trainingRows).StandardDeviationPop();
    76       Root = M5NodeModel.CreateNode(m5CreationParams.TargetVariable, m5CreationParams);
    77       Root.Split(trainingRows, m5CreationParams, globalStdDev);
    78       InitializeLeafCounter(m5CreationParams);
    79       if (!(m5CreationParams.Pruningtype is NoPruning)) Root.Prune(trainingRows, holdoutRows, m5CreationParams, cancellation, globalStdDev);
    80       Root.InstallModels(trainingRows.Union(holdoutRows).ToArray(), m5CreationParams.Random, m5CreationParams.Data, m5CreationParams.LeafType, cancellation);
     70    #region IM5Model
     71    public void Build(IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, M5Parameters m5Params, CancellationToken cancellationToken) {
     72      Root = M5NodeModel.CreateNode(m5Params.TargetVariable, m5Params);
     73      Root.Split(trainingRows, m5Params);
     74
     75      InitializeLeafCounter(m5Params);
     76
     77      var buPruner = m5Params.Pruning as BottomUpPruningBase;
     78      if (buPruner != null) buPruner.Prune(this, trainingRows, pruningRows, m5Params, cancellationToken);
     79
     80      Root.BuildLeafModels(trainingRows.Union(pruningRows).ToArray(), m5Params, cancellationToken);
    8181    }
    8282
    83     void IM5MetaModel.UpdateModel(IReadOnlyList<int> rows, M5UpdateParameters m5UpdateParameters, CancellationToken cancellation) {
    84       Root.InstallModels(rows, m5UpdateParameters.Random, m5UpdateParameters.Data, m5UpdateParameters.LeafType, cancellation);
     83    public void Update(IReadOnlyList<int> rows, M5Parameters m5Parameters, CancellationToken cancellationToken) {
     84      Root.BuildLeafModels(rows, m5Parameters, cancellationToken);
    8585    }
    8686    #endregion
    8787
    8888    #region Helpers
    89     private void InitializeLeafCounter(M5CreationParameters m5CreationParams) {
    90       if (!m5CreationParams.Results.ContainsKey(NoCurrentLeafesResultName))
    91         m5CreationParams.Results.Add(new Result(NoCurrentLeafesResultName, new IntValue(Root.EnumerateNodes().Count(x => x.IsLeaf))));
    92       else ((IntValue) m5CreationParams.Results[NoCurrentLeafesResultName].Value).Value = Root.EnumerateNodes().Count(x => x.IsLeaf);
     89    private void InitializeLeafCounter(M5Parameters m5Params) {
     90      if (!m5Params.Results.ContainsKey(NumCurrentLeafsResultName))
     91        m5Params.Results.Add(new Result(NumCurrentLeafsResultName, new IntValue(Root.EnumerateNodes().Count(x => x.IsLeaf))));
     92      else ((IntValue)m5Params.Results[NumCurrentLeafsResultName].Value).Value = Root.EnumerateNodes().Count(x => x.IsLeaf);
    9393    }
    9494    #endregion
     
    107107
    108108      public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
    109         if (Root == null) throw new NotSupportedException("The classifier has not been built yet");
    110         return ((IConfidenceRegressionModel) Root).GetEstimatedVariances(dataset, rows);
     109        if (Root == null) throw new NotSupportedException("The model has not been built yet");
     110        return ((IConfidenceRegressionModel)Root).GetEstimatedVariances(dataset, rows);
    111111      }
    112112      public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/Pruning/NoPruning.cs

    r15470 r15614  
    2929  [StorableClass]
    3030  [Item("NoPruning", "No pruning")]
    31   public class NoPruning : PruningBase {
     31  public class NoPruning : ParameterizedNamedItem, IPruning {
    3232    #region Constructors & Cloning
    3333    [StorableConstructor]
    3434    private NoPruning(bool deserializing) : base(deserializing) { }
    3535    private NoPruning(NoPruning original, Cloner cloner) : base(original, cloner) { }
    36     public NoPruning() {
    37       PruningStrengthParameter.Hidden = true;
    38     }
     36    public NoPruning() { }
    3937    public override IDeepCloneable Clone(Cloner cloner) {
    4038      return new NoPruning(this, cloner);
    4139    }
    42     #endregion
    43 
    44     #region IPruningType
    45     public override ILeafType<IRegressionModel> ModelType(ILeafType<IRegressionModel> leafType) {
    46       return null;
    47     }
    48 
    49     public override void GenerateHoldOutSet(IReadOnlyList<int> allrows, IRandom random, out IReadOnlyList<int> training, out IReadOnlyList<int> holdout) {
    50       training = allrows;
    51       holdout = allrows;
    52     }
    53     internal override bool Prune(M5NodeModel node, M5CreationParameters m5CreationParams, IReadOnlyList<int> testRows, double globalStdDev) {
    54       return false;
     40    public int MinLeafSize(IRegressionProblemData pd, ILeafModel leafModel) {
     41      return 0;
    5542    }
    5643    #endregion
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/Spliting/OrderImpurityCalculator.cs

    r15470 r15614  
    2828  /// <summary>
    2929  /// Helper class for incremental split calculation.
    30   /// Used while moving a potential Split along the ordered training Instances
     30  /// Used while moving a potential Spliter along the ordered training Instances
    3131  /// </summary>
    3232  internal class OrderImpurityCalculator {
     
    105105      VarRight = NoRight <= 0 ? 0 : Math.Abs(NoRight * SqSumRight - SumRight * SumRight) / (NoRight * NoRight);
    106106
    107       if (Order <= 0) throw new ArgumentException("Split order must be larger than 0");
     107      if (Order <= 0) throw new ArgumentException("Spliter order must be larger than 0");
    108108      if (Order.IsAlmost(1)) {
    109109        y = VarTotal;
     
    117117      }
    118118      var t = NoRight + NoLeft;
    119       if (NoLeft <= 0.0 || NoRight <= 0.0) Impurity = double.MinValue; //Split = 0;
    120       else Impurity = y - NoLeft / t * yl - NoRight / t * yr; //  Split = y - NoLeft / NoRight * yl - NoRight / NoLeft * yr
     119      if (NoLeft <= 0.0 || NoRight <= 0.0) Impurity = double.MinValue; //Spliter = 0;
     120      else Impurity = y - NoLeft / t * yl - NoRight / t * yr; //  Spliter = y - NoLeft / NoRight * yl - NoRight / NoLeft * yr
    121121    }
    122122    #endregion
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/NonlinearRegression/NonlinearRegression.cs

    r14826 r15614  
    5151    private const string SeedParameterName = "Seed";
    5252    private const string InitParamsRandomlyParameterName = "InitializeParametersRandomly";
     53    private const string ApplyLinearScalingParameterName = "Apply linear scaling";
    5354
    5455    public IFixedValueParameter<StringValue> ModelStructureParameter {
     
    7374    public IFixedValueParameter<BoolValue> InitParametersRandomlyParameter {
    7475      get { return (IFixedValueParameter<BoolValue>)Parameters[InitParamsRandomlyParameterName]; }
     76    }
     77
     78    public IFixedValueParameter<BoolValue> ApplyLinearScalingParameter {
     79      get { return (IFixedValueParameter<BoolValue>)Parameters[ApplyLinearScalingParameterName]; }
    7580    }
    7681
     
    103108      get { return InitParametersRandomlyParameter.Value.Value; }
    104109      set { InitParametersRandomlyParameter.Value.Value = value; }
     110    }
     111
     112    public bool ApplyLinearScaling {
     113      get { return ApplyLinearScalingParameter.Value.Value; }
     114      set { ApplyLinearScalingParameter.Value.Value = value; }
    105115    }
    106116
     
    119129      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "Switch to determine if the random number seed should be initialized randomly.", new BoolValue(true)));
    120130      Parameters.Add(new FixedValueParameter<BoolValue>(InitParamsRandomlyParameterName, "Switch to determine if the real-valued model parameters should be initialized randomly in each restart.", new BoolValue(false)));
     131      Parameters.Add(new FixedValueParameter<BoolValue>(ApplyLinearScalingParameterName, "Switch to determine if linear scaling terms should be added to the model", new BoolValue(true)));
    121132
    122133      SetParameterHiddenState();
     
    146157      if (!Parameters.ContainsKey(InitParamsRandomlyParameterName))
    147158        Parameters.Add(new FixedValueParameter<BoolValue>(InitParamsRandomlyParameterName, "Switch to determine if the numeric parameters of the model should be initialized randomly.", new BoolValue(false)));
     159      if (!Parameters.ContainsKey(ApplyLinearScalingParameterName))
     160        Parameters.Add(new FixedValueParameter<BoolValue>(ApplyLinearScalingParameterName, "Switch to determine if linear scaling terms should be added to the model", new BoolValue(true)));
     161
    148162
    149163      SetParameterHiddenState();
     
    174188        if (SetSeedRandomly) Seed = (new System.Random()).Next();
    175189        var rand = new MersenneTwister((uint)Seed);
    176         bestSolution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, rand);
     190        bestSolution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, ApplyLinearScaling, rand);
    177191        trainRMSERow.Values.Add(bestSolution.TrainingRootMeanSquaredError);
    178192        testRMSERow.Values.Add(bestSolution.TestRootMeanSquaredError);
    179193        for (int r = 0; r < Restarts; r++) {
    180           var solution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, rand);
     194          var solution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, ApplyLinearScaling, rand);
    181195          trainRMSERow.Values.Add(solution.TrainingRootMeanSquaredError);
    182196          testRMSERow.Values.Add(solution.TestRootMeanSquaredError);
     
    186200        }
    187201      } else {
    188         bestSolution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations);
     202        bestSolution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, ApplyLinearScaling);
    189203      }
    190204
     
    206220    /// <param name="random">Optional random number generator for random initialization of numeric constants.</param>
    207221    /// <returns></returns>
    208     public static ISymbolicRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, string modelStructure, int maxIterations, IRandom rand = null) {
     222    public static ISymbolicRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, string modelStructure, int maxIterations, bool applyLinearScaling, IRandom rand = null) {
    209223      var parser = new InfixExpressionParser();
    210224      var tree = parser.Parse(modelStructure);
     
    262276
    263277      SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, problemData.TrainingIndices,
    264         applyLinearScaling: false, maxIterations: maxIterations,
     278        applyLinearScaling: applyLinearScaling, maxIterations: maxIterations,
    265279        updateVariableWeights: false, updateConstantsInTree: true);
    266280
    267       var scaledModel = new SymbolicRegressionModel(problemData.TargetVariable, tree, (ISymbolicDataAnalysisExpressionTreeInterpreter)interpreter.Clone());
    268       scaledModel.Scale(problemData);
    269       SymbolicRegressionSolution solution = new SymbolicRegressionSolution(scaledModel, (IRegressionProblemData)problemData.Clone());
     281      var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, (ISymbolicDataAnalysisExpressionTreeInterpreter)interpreter.Clone());
     282      if (applyLinearScaling)
     283        model.Scale(problemData);
     284
     285      SymbolicRegressionSolution solution = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
    270286      solution.Model.Name = "Regression Model";
    271287      solution.Name = "Regression Solution";
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/Plugin.cs.frame

    r14195 r15614  
    3737  [PluginDependency("HeuristicLab.Core", "3.3")]
    3838  [PluginDependency("HeuristicLab.Data", "3.3")]
     39  [PluginDependency("HeuristicLab.Encodings.PermutationEncoding", "3.3")]
    3940  [PluginDependency("HeuristicLab.Encodings.RealVectorEncoding", "3.3")]
    4041  [PluginDependency("HeuristicLab.Encodings.SymbolicExpressionTreeEncoding", "3.4")]
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r14523 r15614  
    152152    public static RandomForestClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    153153      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    154       var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     154      var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed,
     155        out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    155156      return new RandomForestClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
    156157    }
     
    158159    public static RandomForestModel CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    159160      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    160       return RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     161      return RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed,
     162       rmsError: out rmsError, relClassificationError: out relClassificationError, outOfBagRmsError: out outOfBagRmsError, outOfBagRelClassificationError: out outOfBagRelClassificationError);
    161163    }
    162164    #endregion
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r14843 r15614  
    288288    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    289289      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
    290       return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError);
     290      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed,
     291       rmsError: out rmsError, outOfBagRmsError: out outOfBagRmsError, avgRelError: out avgRelError, outOfBagAvgRelError: out outOfBagAvgRelError);
    291292    }
    292293
     
    300301
    301302      rmsError = rep.rmserror;
     303      outOfBagRmsError = rep.oobrmserror;
    302304      avgRelError = rep.avgrelerror;
    303305      outOfBagAvgRelError = rep.oobavgrelerror;
    304       outOfBagRmsError = rep.oobrmserror;
    305306
    306307      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
     
    309310    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    310311      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
    311       return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
     312      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed,
     313        out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
    312314    }
    313315
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r14523 r15614  
    160160      double r, double m, int seed,
    161161      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
    162       return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     162      return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed,
     163        rmsError: out rmsError, avgRelError: out avgRelError, outOfBagRmsError: out outOfBagRmsError, outOfBagAvgRelError: out outOfBagAvgRelError);
    163164    }
    164165
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/Distances/CosineDistance.cs

    r15234 r15614  
    2222using System;
    2323using System.Collections.Generic;
    24 using System.Linq;
    2524using HeuristicLab.Common;
    2625using HeuristicLab.Core;
     
    2827
    2928namespace HeuristicLab.Algorithms.DataAnalysis {
    30 
    3129  /// <summary>
    3230  /// The angular distance as defined as a normalized distance measure dependent on the angle between two vectors.
     
    3533  [Item("CosineDistance", "The angular distance as defined as a normalized distance measure dependent on the angle between two vectors.")]
    3634  public class CosineDistance : DistanceBase<IEnumerable<double>> {
    37 
    3835    #region HLConstructors & Cloning
    3936    [StorableConstructor]
     
    4845
    4946    #region statics
    50     public static double GetDistance(IReadOnlyList<double> point1, IReadOnlyList<double> point2) {
    51       if (point1.Count != point2.Count) throw new ArgumentException("Cosine distance not defined on vectors of different length");
    52       var innerprod = 0.0;
    53       var length1 = 0.0;
    54       var length2 = 0.0;
    55 
    56       for (var i = 0; i < point1.Count; i++) {
    57         double d1 = point1[i], d2 = point2[i];
    58         innerprod += d1 * d2;
    59         length1 += d1 * d1;
    60         length2 += d2 * d2;
     47    public static double GetDistance(IEnumerable<double> point1, IEnumerable<double> point2) {
     48      using (IEnumerator<double> p1Enum = point1.GetEnumerator(), p2Enum = point2.GetEnumerator()) {
     49        var innerprod = 0.0;
     50        var length1 = 0.0;
     51        var length2 = 0.0;
     52        while (p1Enum.MoveNext() & p2Enum.MoveNext()) {
     53          double d1 = p1Enum.Current, d2 = p2Enum.Current;
     54          innerprod += d1 * d2;
     55          length1 += d1 * d1;
     56          length2 += d2 * d2;
     57        }
     58        var divisor = Math.Sqrt(length1 * length2);
     59        if (divisor.IsAlmost(0)) throw new ArgumentException("Cosine distance is not defined on vectors of length 0");
     60        if (p1Enum.MoveNext() || p2Enum.MoveNext()) throw new ArgumentException("Cosine distance not defined on vectors of different length");
     61        return 1 - innerprod / divisor;
    6162      }
    62       var l = Math.Sqrt(length1 * length2);
    63       if (l.IsAlmost(0)) throw new ArgumentException("Cosine distance is not defined on vectors of length 0");
    64       return 1 - innerprod / l;
    6563    }
    6664    #endregion
    6765    public override double Get(IEnumerable<double> a, IEnumerable<double> b) {
    68       return GetDistance(a.ToArray(), b.ToArray());
     66      return GetDistance(a, b);
    6967    }
    7068  }
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/Distances/DistanceBase.cs

    r15207 r15614  
    2929  [StorableClass]
    3030  public abstract class DistanceBase<T> : Item, IDistance<T> {
    31 
    3231    #region HLConstructors & Cloning
    3332    [StorableConstructor]
     
    4443
    4544    public double Get(object x, object y) {
    46       return Get((T)x, (T)y);
     45      return Get((T) x, (T) y);
    4746    }
    4847
    4948    public IComparer GetDistanceComparer(object item) {
    50       return new DistanceComparer((T)item, this);
     49      return new DistanceComparer((T) item, this);
    5150    }
    5251
    53     private class DistanceComparer : IComparer<T>, IComparer {
     52    internal class DistanceComparer : IComparer<T>, IComparer {
    5453      private readonly T item;
    5554      private readonly IDistance<T> dist;
     
    6564
    6665      public int Compare(object x, object y) {
    67         return Compare((T)x, (T)y);
     66        return Compare((T) x, (T) y);
    6867      }
    6968    }
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/Distances/EuclideanDistance.cs

    r15207 r15614  
    3131  [Item("EuclideanDistance", "A norm function that uses Euclidean distance")]
    3232  public class EuclideanDistance : DistanceBase<IEnumerable<double>> {
    33 
    3433    #region HLConstructors & Cloning
    3534    [StorableConstructor]
    3635    protected EuclideanDistance(bool deserializing) : base(deserializing) { }
    3736    protected EuclideanDistance(EuclideanDistance original, Cloner cloner) : base(original, cloner) { }
    38     public override IDeepCloneable Clone(Cloner cloner) { return new EuclideanDistance(this, cloner); }
     37    public override IDeepCloneable Clone(Cloner cloner) {
     38      return new EuclideanDistance(this, cloner);
     39    }
    3940    public EuclideanDistance() { }
    4041    #endregion
    4142
    42     public static double GetDistance(IReadOnlyList<double> point1, IReadOnlyList<double> point2) {
    43       if (point1.Count != point2.Count) throw new ArgumentException("Euclidean distance not defined on vectors of different length");
    44       var sum = 0.0;
    45       for (var i = 0; i < point1.Count; i++) {
    46         var d = point1[i] - point2[i];
    47         sum += d * d;
     43    public static double GetDistance(IEnumerable<double> point1, IEnumerable<double> point2) {
     44      using (IEnumerator<double> p1Enum = point1.GetEnumerator(), p2Enum = point2.GetEnumerator()) {
     45        var sum = 0.0;
     46        while (p1Enum.MoveNext() & p2Enum.MoveNext()) {
     47          var d = p1Enum.Current - p2Enum.Current;
     48          sum += d * d;
     49        }
     50        if (p1Enum.MoveNext() || p2Enum.MoveNext()) throw new ArgumentException("Euclidean distance not defined on vectors of different length");
     51        return Math.Sqrt(sum);
    4852      }
    49 
    50       return Math.Sqrt(sum);
    5153    }
    5254
    5355    public override double Get(IEnumerable<double> a, IEnumerable<double> b) {
    54       return GetDistance(a.ToArray(), b.ToArray());
     56      return GetDistance(a, b);
    5557    }
    5658  }
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/Distances/ManhattanDistance.cs

    r15207 r15614  
    3131  [Item("ManhattanDistance", "A distance function that uses block distance")]
    3232  public class ManhattanDistance : DistanceBase<IEnumerable<double>> {
    33 
    3433    #region HLConstructors & Cloning
    3534    [StorableConstructor]
     
    4544    #endregion
    4645
    47     public static double GetDistance(double[] point1, double[] point2) {
    48       if (point1.Length != point2.Length) throw new ArgumentException("Manhattan distance not defined on vectors of different length");
    49       var sum = 0.0;
    50       for (var i = 0; i < point1.Length; i++)
    51         sum += Math.Abs(point1[i] + point2[i]);
    52       return sum;
     46    public static double GetDistance(IEnumerable<double> point1, IEnumerable<double> point2) {
     47      using (IEnumerator<double> p1Enum = point1.GetEnumerator(), p2Enum = point2.GetEnumerator()) {
     48        var sum = 0.0;
     49        while (p1Enum.MoveNext() & p2Enum.MoveNext())
     50          sum += Math.Abs(p1Enum.Current - p2Enum.Current);
     51        if (p1Enum.MoveNext() || p2Enum.MoveNext()) throw new ArgumentException("Manhattan distance not defined on vectors of different length");
     52        return sum;
     53      }
    5354    }
    5455
    5556    public override double Get(IEnumerable<double> a, IEnumerable<double> b) {
    56       return GetDistance(a.ToArray(), b.ToArray());
     57      return GetDistance(a, b);
    5758    }
    5859  }
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs

    r15428 r15614  
    3838namespace HeuristicLab.Algorithms.DataAnalysis {
    3939  /// <summary>
    40   /// t-distributed stochastic neighbourhood embedding (tSNE) projects the data in a low dimensional
     40  /// t-Distributed Stochastic Neighbor Embedding (tSNE) projects the data in a low dimensional
    4141  /// space to allow visual cluster identification.
    4242  /// </summary>
    43   [Item("tSNE", "t-distributed stochastic neighbourhood embedding projects the data in a low " +
    44                 "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")]
     43  [Item("t-Distributed Stochastic Neighbor Embedding (tSNE)", "t-Distributed Stochastic Neighbor Embedding projects the data in a low " +
     44                                                              "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")]
    4545  [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 100)]
    4646  [StorableClass]
     
    5757    }
    5858
    59     #region parameter names
     59    #region Parameter names
    6060    private const string DistanceFunctionParameterName = "DistanceFunction";
    6161    private const string PerplexityParameterName = "Perplexity";
     
    7272    private const string ClassesNameParameterName = "ClassesName";
    7373    private const string NormalizationParameterName = "Normalization";
     74    private const string RandomInitializationParameterName = "RandomInitialization";
    7475    private const string UpdateIntervalParameterName = "UpdateInterval";
    7576    #endregion
    7677
    77     #region result names
     78    #region Result names
    7879    private const string IterationResultName = "Iteration";
    7980    private const string ErrorResultName = "Error";
     
    8384    #endregion
    8485
    85     #region parameter properties
     86    #region Parameter properties
    8687    public IFixedValueParameter<DoubleValue> PerplexityParameter {
    87       get { return Parameters[PerplexityParameterName] as IFixedValueParameter<DoubleValue>; }
     88      get { return (IFixedValueParameter<DoubleValue>)Parameters[PerplexityParameterName]; }
    8889    }
    8990    public IFixedValueParameter<PercentValue> ThetaParameter {
    90       get { return Parameters[ThetaParameterName] as IFixedValueParameter<PercentValue>; }
     91      get { return (IFixedValueParameter<PercentValue>)Parameters[ThetaParameterName]; }
    9192    }
    9293    public IFixedValueParameter<IntValue> NewDimensionsParameter {
    93       get { return Parameters[NewDimensionsParameterName] as IFixedValueParameter<IntValue>; }
     94      get { return (IFixedValueParameter<IntValue>)Parameters[NewDimensionsParameterName]; }
    9495    }
    9596    public IConstrainedValueParameter<IDistance<double[]>> DistanceFunctionParameter {
    96       get { return Parameters[DistanceFunctionParameterName] as IConstrainedValueParameter<IDistance<double[]>>; }
     97      get { return (IConstrainedValueParameter<IDistance<double[]>>)Parameters[DistanceFunctionParameterName]; }
    9798    }
    9899    public IFixedValueParameter<IntValue> MaxIterationsParameter {
    99       get { return Parameters[MaxIterationsParameterName] as IFixedValueParameter<IntValue>; }
     100      get { return (IFixedValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }
    100101    }
    101102    public IFixedValueParameter<IntValue> StopLyingIterationParameter {
    102       get { return Parameters[StopLyingIterationParameterName] as IFixedValueParameter<IntValue>; }
     103      get { return (IFixedValueParameter<IntValue>)Parameters[StopLyingIterationParameterName]; }
    103104    }
    104105    public IFixedValueParameter<IntValue> MomentumSwitchIterationParameter {
    105       get { return Parameters[MomentumSwitchIterationParameterName] as IFixedValueParameter<IntValue>; }
     106      get { return (IFixedValueParameter<IntValue>)Parameters[MomentumSwitchIterationParameterName]; }
    106107    }
    107108    public IFixedValueParameter<DoubleValue> InitialMomentumParameter {
    108       get { return Parameters[InitialMomentumParameterName] as IFixedValueParameter<DoubleValue>; }
     109      get { return (IFixedValueParameter<DoubleValue>)Parameters[InitialMomentumParameterName]; }
    109110    }
    110111    public IFixedValueParameter<DoubleValue> FinalMomentumParameter {
    111       get { return Parameters[FinalMomentumParameterName] as IFixedValueParameter<DoubleValue>; }
     112      get { return (IFixedValueParameter<DoubleValue>)Parameters[FinalMomentumParameterName]; }
    112113    }
    113114    public IFixedValueParameter<DoubleValue> EtaParameter {
    114       get { return Parameters[EtaParameterName] as IFixedValueParameter<DoubleValue>; }
     115      get { return (IFixedValueParameter<DoubleValue>)Parameters[EtaParameterName]; }
    115116    }
    116117    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
    117       get { return Parameters[SetSeedRandomlyParameterName] as IFixedValueParameter<BoolValue>; }
     118      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    118119    }
    119120    public IFixedValueParameter<IntValue> SeedParameter {
    120       get { return Parameters[SeedParameterName] as IFixedValueParameter<IntValue>; }
     121      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
    121122    }
    122123    public IConstrainedValueParameter<StringValue> ClassesNameParameter {
    123       get { return Parameters[ClassesNameParameterName] as IConstrainedValueParameter<StringValue>; }
     124      get { return (IConstrainedValueParameter<StringValue>)Parameters[ClassesNameParameterName]; }
    124125    }
    125126    public IFixedValueParameter<BoolValue> NormalizationParameter {
    126       get { return Parameters[NormalizationParameterName] as IFixedValueParameter<BoolValue>; }
     127      get { return (IFixedValueParameter<BoolValue>)Parameters[NormalizationParameterName]; }
     128    }
     129    public IFixedValueParameter<BoolValue> RandomInitializationParameter {
     130      get { return (IFixedValueParameter<BoolValue>)Parameters[RandomInitializationParameterName]; }
    127131    }
    128132    public IFixedValueParameter<IntValue> UpdateIntervalParameter {
    129       get { return Parameters[UpdateIntervalParameterName] as IFixedValueParameter<IntValue>; }
     133      get { return (IFixedValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }
    130134    }
    131135    #endregion
     
    187191      set { NormalizationParameter.Value.Value = value; }
    188192    }
    189 
     193    public bool RandomInitialization {
     194      get { return RandomInitializationParameter.Value.Value; }
     195      set { RandomInitializationParameter.Value.Value = value; }
     196    }
    190197    public int UpdateInterval {
    191198      get { return UpdateIntervalParameter.Value.Value; }
     
    194201    #endregion
    195202
     203    #region Storable poperties
     204    [Storable]
     205    private Dictionary<string, IList<int>> dataRowIndices;
     206    [Storable]
     207    private TSNEStatic<double[]>.TSNEState state;
     208    #endregion
     209
    196210    #region Constructors & Cloning
    197211    [StorableConstructor]
    198212    private TSNEAlgorithm(bool deserializing) : base(deserializing) { }
    199213
     214    [StorableHook(HookType.AfterDeserialization)]
     215    private void AfterDeserialization() {
     216      if (!Parameters.ContainsKey(RandomInitializationParameterName))
     217        Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true)));
     218      RegisterParameterEvents();
     219    }
    200220    private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) {
    201       if (original.dataRowNames != null)
    202         this.dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames);
    203       if (original.dataRows != null)
    204         this.dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value));
     221      if (original.dataRowIndices != null)
     222        dataRowIndices = new Dictionary<string, IList<int>>(original.dataRowIndices);
    205223      if (original.state != null)
    206         this.state = cloner.Clone(original.state);
    207       this.iter = original.iter;
    208     }
    209     public override IDeepCloneable Clone(Cloner cloner) { return new TSNEAlgorithm(this, cloner); }
     224        state = cloner.Clone(original.state);
     225      RegisterParameterEvents();
     226    }
     227    public override IDeepCloneable Clone(Cloner cloner) {
     228      return new TSNEAlgorithm(this, cloner);
     229    }
    210230    public TSNEAlgorithm() {
    211231      var distances = new ItemSet<IDistance<double[]>>(ApplicationManager.Manager.GetInstances<IDistance<double[]>>());
     
    213233      Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity-parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower", new DoubleValue(25)));
    214234      Parameters.Add(new FixedValueParameter<PercentValue>(ThetaParameterName, "Value describing how much appoximated " +
    215                                                                               "gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. " +
    216                                                                               "Appropriate values for theta are between 0.1 and 0.7 (default = 0.5). CAUTION: exact calculation of " +
    217                                                                               "forces requires building a non-sparse N*N matrix where N is the number of data points. This may " +
    218                                                                               "exceed memory limitations. The function is designed to run on large (N > 5000) data sets. It may give" +
    219                                                                               " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new PercentValue(0)));
     235                                                                               "gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. " +
     236                                                                               "Appropriate values for theta are between 0.1 and 0.7 (default = 0.5). CAUTION: exact calculation of " +
     237                                                                               "forces requires building a non-sparse N*N matrix where N is the number of data points. This may " +
     238                                                                               "exceed memory limitations. The function is designed to run on large (N > 5000) data sets. It may give" +
     239                                                                               " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new PercentValue(0)));
    220240      Parameters.Add(new FixedValueParameter<IntValue>(NewDimensionsParameterName, "Dimensionality of projected space (usually 2 for easy visual analysis)", new IntValue(2)));
    221241      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent.", new IntValue(1000)));
     
    230250      Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored.", new BoolValue(true)));
    231251      Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "The interval after which the results will be updated.", new IntValue(50)));
    232       Parameters[UpdateIntervalParameterName].Hidden = true;
    233 
     252      Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true)));
     253
     254      UpdateIntervalParameter.Hidden = true;
    234255      MomentumSwitchIterationParameter.Hidden = true;
    235256      InitialMomentumParameter.Hidden = true;
     
    238259      EtaParameter.Hidden = false;
    239260      Problem = new RegressionProblem();
    240     }
    241     #endregion
    242 
    243     [Storable]
    244     private Dictionary<string, List<int>> dataRowNames;
    245     [Storable]
    246     private Dictionary<string, ScatterPlotDataRow> dataRows;
    247     [Storable]
    248     private TSNEStatic<double[]>.TSNEState state;
    249     [Storable]
    250     private int iter;
     261      RegisterParameterEvents();
     262    }
     263    #endregion
    251264
    252265    public override void Prepare() {
    253266      base.Prepare();
    254       dataRowNames = null;
    255       dataRows = null;
     267      dataRowIndices = null;
    256268      state = null;
    257269    }
     
    259271    protected override void Run(CancellationToken cancellationToken) {
    260272      var problemData = Problem.ProblemData;
    261       // set up and initialized everything if necessary
     273      // set up and initialize everything if necessary
     274      var wdist = DistanceFunction as WeightedEuclideanDistance;
     275      if (wdist != null) wdist.Initialize(problemData);
    262276      if (state == null) {
    263277        if (SetSeedRandomly) Seed = new System.Random().Next();
     
    265279        var dataset = problemData.Dataset;
    266280        var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
    267         var data = new double[dataset.Rows][];
    268         for (var row = 0; row < dataset.Rows; row++)
    269           data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray();
    270 
    271         if (Normalization) data = NormalizeData(data);
    272 
    273         state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta,
    274           StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta);
    275 
    276         SetUpResults(data);
    277         iter = 0;
    278       }
    279       for (; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) {
    280         if (iter % UpdateInterval == 0)
    281           Analyze(state);
     281        var allindices = Problem.ProblemData.AllIndices.ToArray();
     282
     283        // jagged array is required to meet the static method declarations of TSNEStatic<T>
     284        var data = Enumerable.Range(0, dataset.Rows).Select(x => new double[allowedInputVariables.Length]).ToArray();
     285        var col = 0;
     286        foreach (var s in allowedInputVariables) {
     287          var row = 0;
     288          foreach (var d in dataset.GetDoubleValues(s)) {
     289            data[row][col] = d;
     290            row++;
     291          }
     292          col++;
     293        }
     294        if (Normalization) data = NormalizeInputData(data);
     295        state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, RandomInitialization);
     296        SetUpResults(allindices);
     297      }
     298      while (state.iter < MaxIterations && !cancellationToken.IsCancellationRequested) {
     299        if (state.iter % UpdateInterval == 0) Analyze(state);
    282300        TSNEStatic<double[]>.Iterate(state);
    283301      }
     
    294312    protected override void RegisterProblemEvents() {
    295313      base.RegisterProblemEvents();
     314      if (Problem == null) return;
    296315      Problem.ProblemDataChanged += OnProblemDataChanged;
    297     }
     316      if (Problem.ProblemData == null) return;
     317      Problem.ProblemData.Changed += OnPerplexityChanged;
     318      Problem.ProblemData.Changed += OnColumnsChanged;
     319      if (Problem.ProblemData.Dataset == null) return;
     320      Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged;
     321      Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged;
     322    }
     323
    298324    protected override void DeregisterProblemEvents() {
    299325      base.DeregisterProblemEvents();
     326      if (Problem == null) return;
    300327      Problem.ProblemDataChanged -= OnProblemDataChanged;
     328      if (Problem.ProblemData == null) return;
     329      Problem.ProblemData.Changed -= OnPerplexityChanged;
     330      Problem.ProblemData.Changed -= OnColumnsChanged;
     331      if (Problem.ProblemData.Dataset == null) return;
     332      Problem.ProblemData.Dataset.RowsChanged -= OnPerplexityChanged;
     333      Problem.ProblemData.Dataset.ColumnsChanged -= OnColumnsChanged;
     334    }
     335
     336    protected override void OnStopped() {
     337      base.OnStopped();
     338      //bwerth: state objects can be very large; avoid state serialization
     339      state = null;
     340      dataRowIndices = null;
    301341    }
    302342
    303343    private void OnProblemDataChanged(object sender, EventArgs args) {
    304344      if (Problem == null || Problem.ProblemData == null) return;
     345      OnPerplexityChanged(this, null);
     346      OnColumnsChanged(this, null);
     347      Problem.ProblemData.Changed += OnPerplexityChanged;
     348      Problem.ProblemData.Changed += OnColumnsChanged;
     349      if (Problem.ProblemData.Dataset == null) return;
     350      Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged;
     351      Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged;
    305352      if (!Parameters.ContainsKey(ClassesNameParameterName)) return;
    306353      ClassesNameParameter.ValidValues.Clear();
     
    308355    }
    309356
     357    private void OnColumnsChanged(object sender, EventArgs e) {
     358      if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(DistanceFunctionParameterName)) return;
     359      DistanceFunctionParameter.ValidValues.OfType<WeightedEuclideanDistance>().Single().AdaptToProblemData(Problem.ProblemData);
     360    }
     361
     362    private void RegisterParameterEvents() {
     363      PerplexityParameter.Value.ValueChanged += OnPerplexityChanged;
     364    }
     365
     366    private void OnPerplexityChanged(object sender, EventArgs e) {
     367      if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(PerplexityParameterName)) return;
     368      PerplexityParameter.Value.Value = Math.Max(1, Math.Min((Problem.ProblemData.Dataset.Rows - 1) / 3.0, Perplexity));
     369    }
    310370    #endregion
    311371
    312372    #region Helpers
    313     private void SetUpResults(IReadOnlyCollection<double[]> data) {
     373    private void SetUpResults(IReadOnlyList<int> allIndices) {
    314374      if (Results == null) return;
    315375      var results = Results;
    316       dataRowNames = new Dictionary<string, List<int>>();
    317       dataRows = new Dictionary<string, ScatterPlotDataRow>();
     376      dataRowIndices = new Dictionary<string, IList<int>>();
    318377      var problemData = Problem.ProblemData;
    319378
    320       //color datapoints acording to classes variable (be it double or string)
    321       if (problemData.Dataset.VariableNames.Contains(ClassesName)) {
    322         if ((problemData.Dataset as Dataset).VariableHasType<string>(ClassesName)) {
    323           var classes = problemData.Dataset.GetStringValues(ClassesName).ToArray();
    324           for (var i = 0; i < classes.Length; i++) {
    325             if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
    326             dataRowNames[classes[i]].Add(i);
     379      if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
     380      if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
     381      if (!results.ContainsKey(ScatterPlotResultName)) results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
     382      if (!results.ContainsKey(DataResultName)) results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
     383      if (!results.ContainsKey(ErrorPlotResultName)) {
     384        var errortable = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent") {
     385          VisualProperties = {
     386            XAxisTitle = "UpdateIntervall",
     387            YAxisTitle = "Error",
     388            YAxisLogScale = true
    327389          }
    328         } else if ((problemData.Dataset as Dataset).VariableHasType<double>(ClassesName)) {
    329           var classValues = problemData.Dataset.GetDoubleValues(ClassesName).ToArray();
    330           var max = classValues.Max() + 0.1;
    331           var min = classValues.Min() - 0.1;
    332           const int contours = 8;
    333           for (var i = 0; i < contours; i++) {
    334             var contourname = GetContourName(i, min, max, contours);
    335             dataRowNames.Add(contourname, new List<int>());
    336             dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
    337             dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
    338             dataRows[contourname].VisualProperties.PointSize = i + 3;
    339           }
    340           for (var i = 0; i < classValues.Length; i++) {
    341             dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i);
    342           }
    343         }
     390        };
     391        errortable.Rows.Add(new DataRow("Errors"));
     392        errortable.Rows["Errors"].VisualProperties.StartIndexZero = true;
     393        results.Add(new Result(ErrorPlotResultName, errortable));
     394      }
     395
     396      //color datapoints acording to classes variable (be it double, datetime or string)
     397      if (!problemData.Dataset.VariableNames.Contains(ClassesName)) {
     398        dataRowIndices.Add("Training", problemData.TrainingIndices.ToList());
     399        dataRowIndices.Add("Test", problemData.TestIndices.ToList());
     400        return;
     401      }
     402
     403      var classificationData = problemData as ClassificationProblemData;
     404      if (classificationData != null && classificationData.TargetVariable.Equals(ClassesName)) {
     405        var classNames = classificationData.ClassValues.Zip(classificationData.ClassNames, (v, n) => new {v, n}).ToDictionary(x => x.v, x => x.n);
     406        var classes = classificationData.Dataset.GetDoubleValues(classificationData.TargetVariable, allIndices).Select(v => classNames[v]).ToArray();
     407        for (var i = 0; i < classes.Length; i++) {
     408          if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>());
     409          dataRowIndices[classes[i]].Add(i);
     410        }
     411      } else if (((Dataset)problemData.Dataset).VariableHasType<string>(ClassesName)) {
     412        var classes = problemData.Dataset.GetStringValues(ClassesName, allIndices).ToArray();
     413        for (var i = 0; i < classes.Length; i++) {
     414          if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>());
     415          dataRowIndices[classes[i]].Add(i);
     416        }
     417      } else if (((Dataset)problemData.Dataset).VariableHasType<double>(ClassesName)) {
     418        var clusterdata = new Dataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
     419        const int contours = 8;
     420        Dictionary<int, string> contourMap;
     421        IClusteringModel clusterModel;
     422        double[][] borders;
     423        CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders);
     424        var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
     425        for (var i = 0; i < contours; i++) {
     426          var c = contourorder[i];
     427          var contourname = contourMap[c];
     428          dataRowIndices.Add(contourname, new List<int>());
     429          var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}};
     430          ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row);
     431        }
     432        var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
     433        for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i);
     434      } else if (((Dataset)problemData.Dataset).VariableHasType<DateTime>(ClassesName)) {
     435        var clusterdata = new Dataset(problemData.Dataset.DateTimeVariables, problemData.Dataset.DateTimeVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
     436        const int contours = 8;
     437        Dictionary<int, string> contourMap;
     438        IClusteringModel clusterModel;
     439        double[][] borders;
     440        CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders);
     441        var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
     442        for (var i = 0; i < contours; i++) {
     443          var c = contourorder[i];
     444          var contourname = contourMap[c];
     445          dataRowIndices.Add(contourname, new List<int>());
     446          var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}};
     447          row.VisualProperties.PointSize = 8;
     448          ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row);
     449        }
     450        var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
     451        for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i);
    344452      } else {
    345         dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
    346         dataRowNames.Add("Test", problemData.TestIndices.ToList());
    347       }
    348 
    349       if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
    350       else ((IntValue)results[IterationResultName].Value).Value = 0;
    351 
    352       if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
    353       else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
    354 
    355       if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent")));
    356       else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent");
    357 
    358       var plot = results[ErrorPlotResultName].Value as DataTable;
    359       if (plot == null) throw new ArgumentException("could not create/access error data table in results collection");
    360 
    361       if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors"));
    362       plot.Rows["errors"].Values.Clear();
    363       plot.Rows["errors"].VisualProperties.StartIndexZero = true;
    364 
    365       results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
    366       results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
     453        dataRowIndices.Add("Training", problemData.TrainingIndices.ToList());
     454        dataRowIndices.Add("Test", problemData.TestIndices.ToList());
     455      }
    367456    }
    368457
     
    372461      var plot = results[ErrorPlotResultName].Value as DataTable;
    373462      if (plot == null) throw new ArgumentException("Could not create/access error data table in results collection.");
    374       var errors = plot.Rows["errors"].Values;
     463      var errors = plot.Rows["Errors"].Values;
    375464      var c = tsneState.EvaluateError();
    376465      errors.Add(c);
     
    378467      ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last();
    379468
    380       var ndata = Normalize(tsneState.newData);
     469      var ndata = NormalizeProjectedData(tsneState.newData);
    381470      results[DataResultName].Value = new DoubleMatrix(ndata);
    382471      var splot = results[ScatterPlotResultName].Value as ScatterPlot;
    383       FillScatterPlot(ndata, splot);
    384     }
    385 
    386     private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
    387       foreach (var rowName in dataRowNames.Keys) {
    388         if (!plot.Rows.ContainsKey(rowName))
    389           plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
    390         plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
    391       }
    392     }
    393 
    394     private static double[,] Normalize(double[,] data) {
     472      FillScatterPlot(ndata, splot, dataRowIndices);
     473    }
     474
     475    private static void FillScatterPlot(double[,] lowDimData, ScatterPlot plot, Dictionary<string, IList<int>> dataRowIndices) {
     476      foreach (var rowName in dataRowIndices.Keys) {
     477        if (!plot.Rows.ContainsKey(rowName)) {
     478          plot.Rows.Add(new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
     479          plot.Rows[rowName].VisualProperties.PointSize = 8;
     480        }
     481        plot.Rows[rowName].Points.Replace(dataRowIndices[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
     482      }
     483    }
     484
     485    private static double[,] NormalizeProjectedData(double[,] data) {
    395486      var max = new double[data.GetLength(1)];
    396487      var min = new double[data.GetLength(1)];
     
    398489      for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
    399490      for (var i = 0; i < data.GetLength(0); i++)
    400         for (var j = 0; j < data.GetLength(1); j++) {
    401           var v = data[i, j];
    402           max[j] = Math.Max(max[j], v);
    403           min[j] = Math.Min(min[j], v);
    404         }
     491      for (var j = 0; j < data.GetLength(1); j++) {
     492        var v = data[i, j];
     493        max[j] = Math.Max(max[j], v);
     494        min[j] = Math.Min(min[j], v);
     495      }
    405496      for (var i = 0; i < data.GetLength(0); i++) {
    406497        for (var j = 0; j < data.GetLength(1); j++) {
    407498          var d = max[j] - min[j];
    408           var s = data[i, j] - (max[j] + min[j]) / 2;  //shift data
    409           if (d.IsAlmost(0)) res[i, j] = data[i, j];   //no scaling possible
    410           else res[i, j] = s / d;  //scale data
     499          var s = data[i, j] - (max[j] + min[j]) / 2; //shift data
     500          if (d.IsAlmost(0)) res[i, j] = data[i, j]; //no scaling possible
     501          else res[i, j] = s / d; //scale data
    411502        }
    412503      }
     
    414505    }
    415506
    416     private static double[][] NormalizeData(IReadOnlyList<double[]> data) {
     507    private static double[][] NormalizeInputData(IReadOnlyList<IReadOnlyList<double>> data) {
    417508      // as in tSNE implementation by van der Maaten
    418       var n = data[0].Length;
     509      var n = data[0].Count;
    419510      var mean = new double[n];
    420511      var max = new double[n];
     
    426517      for (var i = 0; i < data.Count; i++) {
    427518        nData[i] = new double[n];
    428         for (var j = 0; j < n; j++) nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j];
     519        for (var j = 0; j < n; j++)
     520          nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j];
    429521      }
    430522      return nData;
     
    432524
    433525    private static Color GetHeatMapColor(int contourNr, int noContours) {
    434       var q = (double)contourNr / noContours;  // q in [0,1]
    435       var c = q < 0.5 ? Color.FromArgb((int)(q * 2 * 255), 255, 0) : Color.FromArgb(255, (int)((1 - q) * 2 * 255), 0);
    436       return c;
    437     }
    438 
    439     private static string GetContourName(double value, double min, double max, int noContours) {
    440       var size = (max - min) / noContours;
    441       var contourNr = (int)((value - min) / size);
    442       return GetContourName(contourNr, min, max, noContours);
    443     }
    444 
    445     private static string GetContourName(int i, double min, double max, int noContours) {
    446       var size = (max - min) / noContours;
    447       return "[" + (min + i * size) + ";" + (min + (i + 1) * size) + ")";
     526      return ConvertTotalToRgb(0, noContours, contourNr);
     527    }
     528
     529    private static void CreateClusters(IDataset data, string target, int contours, out IClusteringModel contourCluster, out Dictionary<int, string> contourNames, out double[][] borders) {
     530      var cpd = new ClusteringProblemData((Dataset)data, new[] {target});
     531      contourCluster = KMeansClustering.CreateKMeansSolution(cpd, contours, 3).Model;
     532
     533      borders = Enumerable.Range(0, contours).Select(x => new[] {double.MaxValue, double.MinValue}).ToArray();
     534      var clusters = contourCluster.GetClusterValues(cpd.Dataset, cpd.AllIndices).ToArray();
     535      var targetvalues = cpd.Dataset.GetDoubleValues(target).ToArray();
     536      foreach (var i in cpd.AllIndices) {
     537        var cl = clusters[i] - 1;
     538        var clv = targetvalues[i];
     539        if (borders[cl][0] > clv) borders[cl][0] = clv;
     540        if (borders[cl][1] < clv) borders[cl][1] = clv;
     541      }
     542
     543      contourNames = new Dictionary<int, string>();
     544      for (var i = 0; i < contours; i++)
     545        contourNames.Add(i, "[" + borders[i][0] + ";" + borders[i][1] + "]");
     546    }
     547
     548    private static Color ConvertTotalToRgb(double low, double high, double cell) {
     549      var colorGradient = ColorGradient.Colors;
     550      var range = high - low;
     551      var h = Math.Min(cell / range * colorGradient.Count, colorGradient.Count - 1);
     552      return colorGradient[(int)h];
    448553    }
    449554    #endregion
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEStatic.cs

    r15207 r15614  
    6565  [StorableClass]
    6666  public class TSNEStatic<T> {
    67 
    6867    [StorableClass]
    6968    public sealed class TSNEState : DeepCloneable {
     
    170169      [StorableConstructor]
    171170      public TSNEState(bool deserializing) { }
    172       public TSNEState(T[] data, IDistance<T> distance, IRandom random, int newDimensions, double perplexity, double theta, int stopLyingIter, int momSwitchIter, double momentum, double finalMomentum, double eta) {
     171
     172      public TSNEState(IReadOnlyList<T> data, IDistance<T> distance, IRandom random, int newDimensions, double perplexity,
     173        double theta, int stopLyingIter, int momSwitchIter, double momentum, double finalMomentum, double eta, bool randomInit) {
    173174        this.distance = distance;
    174175        this.random = random;
     
    183184
    184185        // initialize
    185         noDatapoints = data.Length;
     186        noDatapoints = data.Count;
    186187        if (noDatapoints - 1 < 3 * perplexity)
    187188          throw new ArgumentException("Perplexity too large for the number of data points!");
     
    193194        gains = new double[noDatapoints, newDimensions];
    194195        for (var i = 0; i < noDatapoints; i++)
    195           for (var j = 0; j < newDimensions; j++)
    196             gains[i, j] = 1.0;
     196        for (var j = 0; j < newDimensions; j++)
     197          gains[i, j] = 1.0;
    197198
    198199        p = null;
     
    212213        var rand = new NormalDistributedRandom(random, 0, 1);
    213214        for (var i = 0; i < noDatapoints; i++)
    214           for (var j = 0; j < newDimensions; j++)
    215             newData[i, j] = rand.NextDouble() * .0001;
     215        for (var j = 0; j < newDimensions; j++)
     216          newData[i, j] = rand.NextDouble() * .0001;
     217
     218        if (!(data[0] is IReadOnlyList<double>) || randomInit) return;
     219        for (var i = 0; i < noDatapoints; i++)
     220        for (var j = 0; j < newDimensions; j++) {
     221          var row = (IReadOnlyList<double>) data[i];
     222          newData[i, j] = row[j % row.Count];
     223        }
    216224      }
    217225      #endregion
    218226
    219227      public double EvaluateError() {
    220         return exact ?
    221           EvaluateErrorExact(p, newData, noDatapoints, newDimensions) :
    222           EvaluateErrorApproximate(rowP, colP, valP, newData, theta);
     228        return exact ? EvaluateErrorExact(p, newData, noDatapoints, newDimensions) : EvaluateErrorApproximate(rowP, colP, valP, newData, theta);
    223229      }
    224230
    225231      #region Helpers
    226       private static void CalculateApproximateSimilarities(T[] data, IDistance<T> distance, double perplexity, out int[] rowP, out int[] colP, out double[] valP) {
     232      private static void CalculateApproximateSimilarities(IReadOnlyList<T> data, IDistance<T> distance, double perplexity, out int[] rowP, out int[] colP, out double[] valP) {
    227233        // Compute asymmetric pairwise input similarities
    228         ComputeGaussianPerplexity(data, distance, out rowP, out colP, out valP, perplexity, (int)(3 * perplexity));
     234        ComputeGaussianPerplexity(data, distance, out rowP, out colP, out valP, perplexity, (int) (3 * perplexity));
    229235        // Symmetrize input similarities
    230236        int[] sRowP, symColP;
     
    235241        valP = sValP;
    236242        var sumP = .0;
    237         for (var i = 0; i < rowP[data.Length]; i++) sumP += valP[i];
    238         for (var i = 0; i < rowP[data.Length]; i++) valP[i] /= sumP;
    239       }
    240 
    241       private static double[,] CalculateExactSimilarites(T[] data, IDistance<T> distance, double perplexity) {
     243        for (var i = 0; i < rowP[data.Count]; i++) sumP += valP[i];
     244        for (var i = 0; i < rowP[data.Count]; i++) valP[i] /= sumP;
     245      }
     246      private static double[,] CalculateExactSimilarites(IReadOnlyList<T> data, IDistance<T> distance, double perplexity) {
    242247        // Compute similarities
    243         var p = new double[data.Length, data.Length];
     248        var p = new double[data.Count, data.Count];
    244249        ComputeGaussianPerplexity(data, distance, p, perplexity);
    245250        // Symmetrize input similarities
    246         for (var n = 0; n < data.Length; n++) {
    247           for (var m = n + 1; m < data.Length; m++) {
     251        for (var n = 0; n < data.Count; n++) {
     252          for (var m = n + 1; m < data.Count; m++) {
    248253            p[n, m] += p[m, n];
    249254            p[m, n] = p[n, m];
     
    251256        }
    252257        var sumP = .0;
    253         for (var i = 0; i < data.Length; i++) for (var j = 0; j < data.Length; j++) sumP += p[i, j];
    254         for (var i = 0; i < data.Length; i++) for (var j = 0; j < data.Length; j++) p[i, j] /= sumP;
     258        for (var i = 0; i < data.Count; i++) {
     259          for (var j = 0; j < data.Count; j++) {
     260            sumP += p[i, j];
     261          }
     262        }
     263        for (var i = 0; i < data.Count; i++) {
     264          for (var j = 0; j < data.Count; j++) {
     265            p[i, j] /= sumP;
     266          }
     267        }
    255268        return p;
    256269      }
    257 
    258270      private static void ComputeGaussianPerplexity(IReadOnlyList<T> x, IDistance<T> distance, out int[] rowP, out int[] colP, out double[] valP, double perplexity, int k) {
    259271        if (perplexity > k) throw new ArgumentException("Perplexity should be lower than k!");
     
    290302
    291303          // Iterate until we found a good perplexity
    292           var iter = 0; double sumP = 0;
     304          var iter = 0;
     305          double sumP = 0;
    293306          while (!found && iter < 200) {
    294 
    295307            // Compute Gaussian kernel row
    296308            for (var m = 0; m < k; m++) curP[m] = Math.Exp(-beta * distances[m + 1]);
     
    307319            if (hdiff < tol && -hdiff < tol) {
    308320              found = true;
    309             } else {
     321            }
     322            else {
    310323              if (hdiff > 0) {
    311324                minBeta = beta;
     
    314327                else
    315328                  beta = (beta + maxBeta) / 2.0;
    316               } else {
     329              }
     330              else {
    317331                maxBeta = beta;
    318332                if (minBeta.IsAlmost(double.MinValue) || minBeta.IsAlmost(double.MaxValue))
     
    335349        }
    336350      }
    337       private static void ComputeGaussianPerplexity(T[] x, IDistance<T> distance, double[,] p, double perplexity) {
     351      private static void ComputeGaussianPerplexity(IReadOnlyList<T> x, IDistance<T> distance, double[,] p, double perplexity) {
    338352        // Compute the distance matrix
    339353        var dd = ComputeDistances(x, distance);
    340354
    341         var n = x.Length;
     355        var n = x.Count;
    342356        // Compute the Gaussian kernel row by row
    343357        for (var i = 0; i < n; i++) {
     
    352366          // Iterate until we found a good perplexity
    353367          var iter = 0;
    354           while (!found && iter < 200) {      // 200 iterations as in tSNE implementation by van der Maarten
     368          while (!found && iter < 200) { // 200 iterations as in tSNE implementation by van der Maarten
    355369
    356370            // Compute Gaussian kernel row
     
    369383            if (hdiff < tol && -hdiff < tol) {
    370384              found = true;
    371             } else {
     385            }
     386            else {
    372387              if (hdiff > 0) {
    373388                minBeta = beta;
     
    376391                else
    377392                  beta = (beta + maxBeta) / 2.0;
    378               } else {
     393              }
     394              else {
    379395                maxBeta = beta;
    380396                if (minBeta.IsAlmost(double.MinValue) || minBeta.IsAlmost(double.MaxValue))
     
    393409        }
    394410      }
    395 
    396       private static double[][] ComputeDistances(T[] x, IDistance<T> distance) {
    397         var res = new double[x.Length][];
    398         for (var r = 0; r < x.Length; r++) {
    399           var rowV = new double[x.Length];
     411      private static double[][] ComputeDistances(IReadOnlyList<T> x, IDistance<T> distance) {
     412        var res = new double[x.Count][];
     413        for (var r = 0; r < x.Count; r++) {
     414          var rowV = new double[x.Count];
    400415          // all distances must be symmetric
    401416          for (var c = 0; c < r; c++) {
     
    403418          }
    404419          rowV[r] = 0.0; // distance to self is zero for all distances
    405           for (var c = r + 1; c < x.Length; c++) {
     420          for (var c = r + 1; c < x.Count; c++) {
    406421            rowV[c] = distance.Get(x[r], x[c]);
    407422          }
     
    411426        // return x.Select(m => x.Select(n => distance.Get(m, n)).ToArray()).ToArray();
    412427      }
    413 
    414428      private static double EvaluateErrorExact(double[,] p, double[,] y, int n, int d) {
    415429        // Compute the squared Euclidean distance matrix
     
    425439              q[n1, m] = 1 / (1 + dd[n1, m]);
    426440              sumQ += q[n1, m];
    427             } else q[n1, m] = double.Epsilon;
     441            }
     442            else q[n1, m] = double.Epsilon;
    428443          }
    429444        }
     
    433448        var c = .0;
    434449        for (var i = 0; i < n; i++)
    435           for (var j = 0; j < n; j++) {
    436             c += p[i, j] * Math.Log((p[i, j] + float.Epsilon) / (q[i, j] + float.Epsilon));
    437           }
     450        for (var j = 0; j < n; j++) {
     451          c += p[i, j] * Math.Log((p[i, j] + float.Epsilon) / (q[i, j] + float.Epsilon));
     452        }
    438453        return c;
    439454      }
    440 
    441455      private static double EvaluateErrorApproximate(IReadOnlyList<int> rowP, IReadOnlyList<int> colP, IReadOnlyList<double> valP, double[,] y, double theta) {
    442456        // Get estimate of normalization term
     
    463477      }
    464478      private static void SymmetrizeMatrix(IReadOnlyList<int> rowP, IReadOnlyList<int> colP, IReadOnlyList<double> valP, out int[] symRowP, out int[] symColP, out double[] symValP) {
    465 
    466479        // Count number of elements and row counts of symmetric matrix
    467480        var n = rowP.Count - 1;
     
    469482        for (var j = 0; j < n; j++) {
    470483          for (var i = rowP[j]; i < rowP[j + 1]; i++) {
    471 
    472484            // Check whether element (col_P[i], n) is present
    473485            var present = false;
     
    497509        var offset = new int[n];
    498510        for (var j = 0; j < n; j++) {
    499           for (var i = rowP[j]; i < rowP[j + 1]; i++) {                                  // considering element(n, colP[i])
     511          for (var i = rowP[j]; i < rowP[j + 1]; i++) { // considering element(n, colP[i])
    500512
    501513            // Check whether element (col_P[i], n) is present
     
    549561    public static double[,] Run(T[] data, IDistance<T> distance, IRandom random,
    550562      int newDimensions = 2, double perplexity = 25, int iterations = 1000,
    551       double theta = 0,
    552       int stopLyingIter = 0, int momSwitchIter = 0, double momentum = .5,
     563      double theta = 0, int stopLyingIter = 0, int momSwitchIter = 0, double momentum = .5,
    553564      double finalMomentum = .8, double eta = 10.0
    554       ) {
     565    ) {
    555566      var state = CreateState(data, distance, random, newDimensions, perplexity,
    556567        theta, stopLyingIter, momSwitchIter, momentum, finalMomentum, eta);
     
    565576      int newDimensions = 2, double perplexity = 25, double theta = 0,
    566577      int stopLyingIter = 0, int momSwitchIter = 0, double momentum = .5,
    567       double finalMomentum = .8, double eta = 10.0
    568       ) {
    569       return new TSNEState(data, distance, random, newDimensions, perplexity, theta, stopLyingIter, momSwitchIter, momentum, finalMomentum, eta);
     578      double finalMomentum = .8, double eta = 10.0, bool randomInit = true
     579    ) {
     580      return new TSNEState(data, distance, random, newDimensions, perplexity, theta, stopLyingIter, momSwitchIter, momentum, finalMomentum, eta, randomInit);
    570581    }
    571582
     
    580591        for (var j = 0; j < state.newDimensions; j++) {
    581592          state.gains[i, j] = Math.Sign(state.dY[i, j]) != Math.Sign(state.uY[i, j])
    582             ? state.gains[i, j] + .2  // +0.2 nd *0.8 are used in two separate implementations of tSNE -> seems to be correct
     593            ? state.gains[i, j] + .2 // +0.2 nd *0.8 are used in two separate implementations of tSNE -> seems to be correct
    583594            : state.gains[i, j] * .8;
    584 
    585595          if (state.gains[i, j] < .01) state.gains[i, j] = .01;
    586596        }
    587597      }
    588 
    589598
    590599      // Perform gradient update (with momentum and gains)
    591600      for (var i = 0; i < state.noDatapoints; i++)
    592         for (var j = 0; j < state.newDimensions; j++)
    593           state.uY[i, j] = state.currentMomentum * state.uY[i, j] - state.eta * state.gains[i, j] * state.dY[i, j];
     601      for (var j = 0; j < state.newDimensions; j++)
     602        state.uY[i, j] = state.currentMomentum * state.uY[i, j] - state.eta * state.gains[i, j] * state.dY[i, j];
    594603
    595604      for (var i = 0; i < state.noDatapoints; i++)
    596         for (var j = 0; j < state.newDimensions; j++)
    597           state.newData[i, j] = state.newData[i, j] + state.uY[i, j];
     605      for (var j = 0; j < state.newDimensions; j++)
     606        state.newData[i, j] = state.newData[i, j] + state.uY[i, j];
    598607
    599608      // Make solution zero-mean
     
    604613        if (state.exact)
    605614          for (var i = 0; i < state.noDatapoints; i++)
    606             for (var j = 0; j < state.noDatapoints; j++)
    607               state.p[i, j] /= 12.0;
     615          for (var j = 0; j < state.noDatapoints; j++)
     616            state.p[i, j] /= 12.0;
    608617        else
    609618          for (var i = 0; i < state.rowP[state.noDatapoints]; i++)
     
    634643      // Compute final t-SNE gradient
    635644      for (var i = 0; i < n; i++)
    636         for (var j = 0; j < d; j++) {
    637           dC[i, j] = posF[i, j] - negF[i, j] / sumQ;
    638         }
     645      for (var j = 0; j < d; j++) {
     646        dC[i, j] = posF[i, j] - negF[i, j] / sumQ;
     647      }
    639648    }
    640649
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEUtils.cs

    r14414 r15614  
    3535    }
    3636
    37     internal static IList<T> Swap<T>(this IList<T> list, int indexA, int indexB) {
     37    internal static void Swap<T>(this IList<T> list, int indexA, int indexB) {
    3838      var tmp = list[indexA];
    3939      list[indexA] = list[indexB];
    4040      list[indexB] = tmp;
    41       return list;
    4241    }
    4342
    44     internal static int Partition<T>(this IList<T> list, int left, int right, int pivotindex, IComparer<T> comparer) {
     43    private static int Partition<T>(this IList<T> list, int left, int right, int pivotindex, IComparer<T> comparer) {
    4544      var pivotValue = list[pivotindex];
    4645      list.Swap(pivotindex, right);
     
    6766    /// <param name="comparer">comparer for list elemnts </param>
    6867    /// <returns></returns>
    69     internal static T NthElement<T>(this IList<T> list, int left, int right, int n, IComparer<T> comparer) {
     68    internal static void PartialSort<T>(this IList<T> list, int left, int right, int n, IComparer<T> comparer) {
    7069      while (true) {
    71         if (left == right) return list[left];
    72         var pivotindex = left + (int)Math.Floor(new System.Random().Next() % (right - (double)left + 1));
     70        if (left == right) return;
     71        var pivotindex = left + (int) Math.Floor(new System.Random().Next() % (right - (double) left + 1));
    7372        pivotindex = list.Partition(left, right, pivotindex, comparer);
    74         if (n == pivotindex) return list[n];
     73        if (n == pivotindex) return;
    7574        if (n < pivotindex) right = pivotindex - 1;
    7675        else left = pivotindex + 1;
  • branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/VantagePointTree.cs

    r15207 r15614  
    139139      // Partition around the median distance
    140140      var median = (upper + lower) / 2;
    141       items.NthElement(lower + 1, upper - 1, median, distance.GetDistanceComparer(items[lower]));
     141      items.PartialSort(lower + 1, upper - 1, median, distance.GetDistanceComparer(items[lower]));
    142142
    143143      // Threshold of the new node will be the distance to the median
Note: See TracChangeset for help on using the changeset viewer.