Changeset 15451


Ignore:
Timestamp:
11/06/17 15:57:55 (22 months ago)
Author:
bwerth
Message:

#2850 created branch & added WeightedEuclideanDistance

Location:
branches/Weighted TSNE
Files:
5 added
7 edited
1 copied

Legend:

Unmodified
Added
Removed
  • branches/Weighted TSNE/3.4/FixedDataAnalysisAlgorithm.cs

    r15287 r15451  
    3030namespace HeuristicLab.Algorithms.DataAnalysis {
    3131  [StorableClass]
    32   public abstract class FixedDataAnalysisAlgorithm<T> : BasicAlgorithm where T : class, IDataAnalysisProblem {
     32  public abstract class FixedDataAnalysisAlgorithm<T> : BasicAlgorithm, IDataAnalysisAlgorithm<T> where T : class, IDataAnalysisProblem {
    3333    #region Properties
    3434    public override Type ProblemType {
  • branches/Weighted TSNE/3.4/GaussianProcess/GaussianProcessRegression.cs

    r14185 r15451  
    3939  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 160)]
    4040  [StorableClass]
    41   public sealed class GaussianProcessRegression : GaussianProcessBase, IStorableContent {
     41  public sealed class GaussianProcessRegression : GaussianProcessBase, IStorableContent, IDataAnalysisAlgorithm<IRegressionProblem> {
    4242    public string Filename { get; set; }
    4343
  • branches/Weighted TSNE/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r15209 r15451  
    4343    <DebugType>full</DebugType>
    4444    <Optimize>false</Optimize>
    45     <OutputPath>$(SolutionDir)\bin\</OutputPath>
     45    <OutputPath>..\..\..\trunk\sources\bin\</OutputPath>
    4646    <DefineConstants>DEBUG;TRACE</DefineConstants>
    4747    <ErrorReport>prompt</ErrorReport>
     
    4949    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    5050    <Prefer32Bit>false</Prefer32Bit>
     51    <LangVersion>5</LangVersion>
    5152  </PropertyGroup>
    5253  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
     
    107108  </PropertyGroup>
    108109  <ItemGroup>
    109     <Reference Include="ALGLIB-3.7.0, Version=3.7.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
    110       <HintPath>..\..\bin\ALGLIB-3.7.0.dll</HintPath>
    111       <Private>False</Private>
    112     </Reference>
    113     <Reference Include="LibSVM-3.12, Version=3.12.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
    114       <HintPath>..\..\bin\LibSVM-3.12.dll</HintPath>
    115       <Private>False</Private>
     110    <Reference Include="ALGLIB-3.7.0">
     111      <HintPath>..\..\..\trunk\sources\bin\ALGLIB-3.7.0.dll</HintPath>
     112    </Reference>
     113    <Reference Include="HeuristicLab.Algorithms.GradientDescent-3.3">
     114      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Algorithms.GradientDescent-3.3.dll</HintPath>
     115    </Reference>
     116    <Reference Include="HeuristicLab.Algorithms.OffspringSelectionGeneticAlgorithm-3.3">
     117      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Algorithms.OffspringSelectionGeneticAlgorithm-3.3.dll</HintPath>
     118    </Reference>
     119    <Reference Include="HeuristicLab.Analysis-3.3">
     120      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Analysis-3.3.dll</HintPath>
     121    </Reference>
     122    <Reference Include="HeuristicLab.Collections-3.3">
     123      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Collections-3.3.dll</HintPath>
     124    </Reference>
     125    <Reference Include="HeuristicLab.Common-3.3">
     126      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Common-3.3.dll</HintPath>
     127    </Reference>
     128    <Reference Include="HeuristicLab.Common.Resources-3.3">
     129      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Common.Resources-3.3.dll</HintPath>
     130    </Reference>
     131    <Reference Include="HeuristicLab.Core-3.3">
     132      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Core-3.3.dll</HintPath>
     133    </Reference>
     134    <Reference Include="HeuristicLab.Data-3.3">
     135      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Data-3.3.dll</HintPath>
     136    </Reference>
     137    <Reference Include="HeuristicLab.Encodings.RealVectorEncoding-3.3">
     138      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Encodings.RealVectorEncoding-3.3.dll</HintPath>
     139    </Reference>
     140    <Reference Include="HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4">
     141      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4.dll</HintPath>
     142    </Reference>
     143    <Reference Include="HeuristicLab.LibSVM-3.12">
     144      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.LibSVM-3.12.dll</HintPath>
     145    </Reference>
     146    <Reference Include="HeuristicLab.Operators-3.3">
     147      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Operators-3.3.dll</HintPath>
     148    </Reference>
     149    <Reference Include="HeuristicLab.Optimization-3.3">
     150      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Optimization-3.3.dll</HintPath>
     151    </Reference>
     152    <Reference Include="HeuristicLab.Parameters-3.3">
     153      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Parameters-3.3.dll</HintPath>
     154    </Reference>
     155    <Reference Include="HeuristicLab.Persistence-3.3">
     156      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Persistence-3.3.dll</HintPath>
     157    </Reference>
     158    <Reference Include="HeuristicLab.PluginInfrastructure-3.3">
     159      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.PluginInfrastructure-3.3.dll</HintPath>
     160    </Reference>
     161    <Reference Include="HeuristicLab.Problems.DataAnalysis-3.4">
     162      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Problems.DataAnalysis-3.4.dll</HintPath>
     163    </Reference>
     164    <Reference Include="HeuristicLab.Problems.DataAnalysis.Symbolic-3.4">
     165      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Problems.DataAnalysis.Symbolic-3.4.dll</HintPath>
     166    </Reference>
     167    <Reference Include="HeuristicLab.Problems.DataAnalysis.Symbolic.Classification-3.4">
     168      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Problems.DataAnalysis.Symbolic.Classification-3.4.dll</HintPath>
     169    </Reference>
     170    <Reference Include="HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4">
     171      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4.dll</HintPath>
     172    </Reference>
     173    <Reference Include="HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis-3.4">
     174      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis-3.4.dll</HintPath>
     175    </Reference>
     176    <Reference Include="HeuristicLab.Problems.Instances-3.3">
     177      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Problems.Instances-3.3.dll</HintPath>
     178    </Reference>
     179    <Reference Include="HeuristicLab.Random-3.3">
     180      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Random-3.3.dll</HintPath>
     181    </Reference>
     182    <Reference Include="HeuristicLab.Selection-3.3">
     183      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Selection-3.3.dll</HintPath>
     184    </Reference>
     185    <Reference Include="LibSVM-3.12">
     186      <HintPath>..\..\..\trunk\sources\bin\LibSVM-3.12.dll</HintPath>
    116187    </Reference>
    117188    <Reference Include="System" />
     
    317388    <Compile Include="TSNE\Distances\CosineDistance.cs" />
    318389    <Compile Include="TSNE\Distances\DistanceBase.cs" />
     390    <Compile Include="TSNE\Distances\WeightedEuclideanDistance.cs" />
    319391    <Compile Include="TSNE\Distances\EuclideanDistance.cs" />
    320392    <Compile Include="TSNE\Distances\IndexedItemDistance.cs" />
     
    327399    <Compile Include="TSNE\TSNEUtils.cs" />
    328400    <Compile Include="TSNE\VantagePointTree.cs" />
    329   </ItemGroup>
    330   <ItemGroup>
    331     <ProjectReference Include="..\..\HeuristicLab.Algorithms.GradientDescent\3.3\HeuristicLab.Algorithms.GradientDescent-3.3.csproj">
    332       <Project>{1256B945-EEA9-4BE4-9880-76B5B113F089}</Project>
    333       <Name>HeuristicLab.Algorithms.GradientDescent-3.3</Name>
    334       <Private>False</Private>
    335     </ProjectReference>
    336     <ProjectReference Include="..\..\HeuristicLab.Algorithms.OffspringSelectionGeneticAlgorithm\3.3\HeuristicLab.Algorithms.OffspringSelectionGeneticAlgorithm-3.3.csproj">
    337       <Project>{F409DD9E-1E9C-4EB1-AA3A-9F6E987C6E58}</Project>
    338       <Name>HeuristicLab.Algorithms.OffspringSelectionGeneticAlgorithm-3.3</Name>
    339     </ProjectReference>
    340     <ProjectReference Include="..\..\HeuristicLab.Analysis\3.3\HeuristicLab.Analysis-3.3.csproj">
    341       <Project>{887425B4-4348-49ED-A457-B7D2C26DDBF9}</Project>
    342       <Name>HeuristicLab.Analysis-3.3</Name>
    343       <Private>False</Private>
    344     </ProjectReference>
    345     <ProjectReference Include="..\..\HeuristicLab.Collections\3.3\HeuristicLab.Collections-3.3.csproj">
    346       <Project>{958B43BC-CC5C-4FA2-8628-2B3B01D890B6}</Project>
    347       <Name>HeuristicLab.Collections-3.3</Name>
    348       <Private>False</Private>
    349     </ProjectReference>
    350     <ProjectReference Include="..\..\HeuristicLab.Common.Resources\3.3\HeuristicLab.Common.Resources-3.3.csproj">
    351       <Project>{0E27A536-1C4A-4624-A65E-DC4F4F23E3E1}</Project>
    352       <Name>HeuristicLab.Common.Resources-3.3</Name>
    353       <Private>False</Private>
    354     </ProjectReference>
    355     <ProjectReference Include="..\..\HeuristicLab.Common\3.3\HeuristicLab.Common-3.3.csproj">
    356       <Project>{A9AD58B9-3EF9-4CC1-97E5-8D909039FF5C}</Project>
    357       <Name>HeuristicLab.Common-3.3</Name>
    358       <Private>False</Private>
    359     </ProjectReference>
    360     <ProjectReference Include="..\..\HeuristicLab.Core\3.3\HeuristicLab.Core-3.3.csproj">
    361       <Project>{C36BD924-A541-4A00-AFA8-41701378DDC5}</Project>
    362       <Name>HeuristicLab.Core-3.3</Name>
    363       <Private>False</Private>
    364     </ProjectReference>
    365     <ProjectReference Include="..\..\HeuristicLab.Data\3.3\HeuristicLab.Data-3.3.csproj">
    366       <Project>{BBAB9DF5-5EF3-4BA8-ADE9-B36E82114937}</Project>
    367       <Name>HeuristicLab.Data-3.3</Name>
    368       <Private>False</Private>
    369     </ProjectReference>
    370     <ProjectReference Include="..\..\HeuristicLab.Encodings.RealVectorEncoding\3.3\HeuristicLab.Encodings.RealVectorEncoding-3.3.csproj">
    371       <Project>{BB6D334A-4BB6-4674-9883-31A6EBB32CAB}</Project>
    372       <Name>HeuristicLab.Encodings.RealVectorEncoding-3.3</Name>
    373       <Private>False</Private>
    374     </ProjectReference>
    375     <ProjectReference Include="..\..\HeuristicLab.Encodings.SymbolicExpressionTreeEncoding\3.4\HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4.csproj">
    376       <Project>{06D4A186-9319-48A0-BADE-A2058D462EEA}</Project>
    377       <Name>HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4</Name>
    378       <Private>False</Private>
    379     </ProjectReference>
    380     <ProjectReference Include="..\..\HeuristicLab.Operators\3.3\HeuristicLab.Operators-3.3.csproj">
    381       <Project>{23DA7FF4-D5B8-41B6-AA96-F0561D24F3EE}</Project>
    382       <Name>HeuristicLab.Operators-3.3</Name>
    383       <Private>False</Private>
    384     </ProjectReference>
    385     <ProjectReference Include="..\..\HeuristicLab.Optimization\3.3\HeuristicLab.Optimization-3.3.csproj">
    386       <Project>{14AB8D24-25BC-400C-A846-4627AA945192}</Project>
    387       <Name>HeuristicLab.Optimization-3.3</Name>
    388       <Private>False</Private>
    389     </ProjectReference>
    390     <ProjectReference Include="..\..\HeuristicLab.Parameters\3.3\HeuristicLab.Parameters-3.3.csproj">
    391       <Project>{56F9106A-079F-4C61-92F6-86A84C2D84B7}</Project>
    392       <Name>HeuristicLab.Parameters-3.3</Name>
    393       <Private>False</Private>
    394     </ProjectReference>
    395     <ProjectReference Include="..\..\HeuristicLab.Persistence\3.3\HeuristicLab.Persistence-3.3.csproj">
    396       <Project>{102BC7D3-0EF9-439C-8F6D-96FF0FDB8E1B}</Project>
    397       <Name>HeuristicLab.Persistence-3.3</Name>
    398       <Private>False</Private>
    399     </ProjectReference>
    400     <ProjectReference Include="..\..\HeuristicLab.PluginInfrastructure\3.3\HeuristicLab.PluginInfrastructure-3.3.csproj">
    401       <Project>{94186A6A-5176-4402-AE83-886557B53CCA}</Project>
    402       <Name>HeuristicLab.PluginInfrastructure-3.3</Name>
    403       <Private>False</Private>
    404     </ProjectReference>
    405     <ProjectReference Include="..\..\HeuristicLab.Problems.DataAnalysis.Symbolic.Classification\3.4\HeuristicLab.Problems.DataAnalysis.Symbolic.Classification-3.4.csproj">
    406       <Project>{05BAE4E1-A9FA-4644-AA77-42558720159E}</Project>
    407       <Name>HeuristicLab.Problems.DataAnalysis.Symbolic.Classification-3.4</Name>
    408       <Private>False</Private>
    409     </ProjectReference>
    410     <ProjectReference Include="..\..\HeuristicLab.Problems.DataAnalysis.Symbolic.Regression\3.4\HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4.csproj">
    411       <Project>{5AC82412-911B-4FA2-A013-EDC5E3F3FCC2}</Project>
    412       <Name>HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4</Name>
    413       <Private>False</Private>
    414     </ProjectReference>
    415     <ProjectReference Include="..\..\HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis\3.4\HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis-3.4.csproj">
    416       <Project>{07486E68-1517-4B9D-A58D-A38E99AE71AB}</Project>
    417       <Name>HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis-3.4</Name>
    418     </ProjectReference>
    419     <ProjectReference Include="..\..\HeuristicLab.Problems.DataAnalysis.Symbolic\3.4\HeuristicLab.Problems.DataAnalysis.Symbolic-3.4.csproj">
    420       <Project>{3D28463F-EC96-4D82-AFEE-38BE91A0CA00}</Project>
    421       <Name>HeuristicLab.Problems.DataAnalysis.Symbolic-3.4</Name>
    422       <Private>False</Private>
    423     </ProjectReference>
    424     <ProjectReference Include="..\..\HeuristicLab.Problems.DataAnalysis\3.4\HeuristicLab.Problems.DataAnalysis-3.4.csproj">
    425       <Project>{DF87C13E-A889-46FF-8153-66DCAA8C5674}</Project>
    426       <Name>HeuristicLab.Problems.DataAnalysis-3.4</Name>
    427       <Private>False</Private>
    428     </ProjectReference>
    429     <ProjectReference Include="..\..\HeuristicLab.Problems.Instances\3.3\HeuristicLab.Problems.Instances-3.3.csproj">
    430       <Project>{3540E29E-4793-49E7-8EE2-FEA7F61C3994}</Project>
    431       <Name>HeuristicLab.Problems.Instances-3.3</Name>
    432       <Private>False</Private>
    433     </ProjectReference>
    434     <ProjectReference Include="..\..\HeuristicLab.Random\3.3\HeuristicLab.Random-3.3.csproj">
    435       <Project>{F4539FB6-4708-40C9-BE64-0A1390AEA197}</Project>
    436       <Name>HeuristicLab.Random-3.3</Name>
    437       <Private>False</Private>
    438     </ProjectReference>
    439     <ProjectReference Include="..\..\HeuristicLab.Selection\3.3\HeuristicLab.Selection-3.3.csproj">
    440       <Project>{2C36CD4F-E5F5-43A4-801A-201EA895FE17}</Project>
    441       <Name>HeuristicLab.Selection-3.3</Name>
    442     </ProjectReference>
    443401  </ItemGroup>
    444402  <ItemGroup>
  • branches/Weighted TSNE/3.4/KernelRidgeRegression/KernelRidgeRegression.cs

    r15248 r15451  
    3636  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 100)]
    3737  [StorableClass]
    38   public sealed class KernelRidgeRegression : BasicAlgorithm {
     38  public sealed class KernelRidgeRegression : BasicAlgorithm, IDataAnalysisAlgorithm<IRegressionProblem> {
    3939    private const string SolutionResultName = "Kernel ridge regression solution";
    4040
  • branches/Weighted TSNE/3.4/TSNE/Distances/DistanceBase.cs

    r15207 r15451  
    2828namespace HeuristicLab.Algorithms.DataAnalysis {
    2929  [StorableClass]
    30   public abstract class DistanceBase<T> : Item, IDistance<T> {
     30  public abstract class DistanceBase<T> : ParameterizedNamedItem, IDistance<T> {
    3131
    3232    #region HLConstructors & Cloning
  • branches/Weighted TSNE/3.4/TSNE/TSNEAlgorithm.cs

    r15234 r15451  
    5353    }
    5454    public new IDataAnalysisProblem Problem {
    55       get { return (IDataAnalysisProblem)base.Problem; }
     55      get { return (IDataAnalysisProblem) base.Problem; }
    5656      set { base.Problem = value; }
    5757    }
     
    7272    private const string ClassesNameParameterName = "ClassesName";
    7373    private const string NormalizationParameterName = "Normalization";
     74    private const string RandomInitializationParameterName = "RandomInitialization";
    7475    private const string UpdateIntervalParameterName = "UpdateInterval";
    7576    #endregion
     
    126127      get { return Parameters[NormalizationParameterName] as IFixedValueParameter<BoolValue>; }
    127128    }
     129    public IFixedValueParameter<BoolValue> RandomInitializationParameter {
     130      get { return Parameters[RandomInitializationParameterName] as IFixedValueParameter<BoolValue>; }
     131    }
    128132    public IFixedValueParameter<IntValue> UpdateIntervalParameter {
    129133      get { return Parameters[UpdateIntervalParameterName] as IFixedValueParameter<IntValue>; }
     
    187191      set { NormalizationParameter.Value.Value = value; }
    188192    }
     193    public bool RandomInitialization {
     194      get { return RandomInitializationParameter.Value.Value; }
     195      set { RandomInitializationParameter.Value.Value = value; }
     196    }
    189197
    190198    public int UpdateInterval {
     
    200208    private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) {
    201209      if (original.dataRowNames != null)
    202         this.dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames);
     210        dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames);
    203211      if (original.dataRows != null)
    204         this.dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value));
     212        dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value));
    205213      if (original.state != null)
    206         this.state = cloner.Clone(original.state);
    207       this.iter = original.iter;
    208     }
    209     public override IDeepCloneable Clone(Cloner cloner) { return new TSNEAlgorithm(this, cloner); }
     214        state = cloner.Clone(original.state);
     215      iter = original.iter;
     216    }
     217    public override IDeepCloneable Clone(Cloner cloner) {
     218      return new TSNEAlgorithm(this, cloner);
     219    }
    210220    public TSNEAlgorithm() {
    211221      var distances = new ItemSet<IDistance<double[]>>(ApplicationManager.Manager.GetInstances<IDistance<double[]>>());
     
    213223      Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity-parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower", new DoubleValue(25)));
    214224      Parameters.Add(new FixedValueParameter<PercentValue>(ThetaParameterName, "Value describing how much appoximated " +
    215                                                                               "gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. " +
    216                                                                               "Appropriate values for theta are between 0.1 and 0.7 (default = 0.5). CAUTION: exact calculation of " +
    217                                                                               "forces requires building a non-sparse N*N matrix where N is the number of data points. This may " +
    218                                                                               "exceed memory limitations. The function is designed to run on large (N > 5000) data sets. It may give" +
    219                                                                               " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new PercentValue(0)));
     225                                                                               "gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. " +
     226                                                                               "Appropriate values for theta are between 0.1 and 0.7 (default = 0.5). CAUTION: exact calculation of " +
     227                                                                               "forces requires building a non-sparse N*N matrix where N is the number of data points. This may " +
     228                                                                               "exceed memory limitations. The function is designed to run on large (N > 5000) data sets. It may give" +
     229                                                                               " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new PercentValue(0)));
    220230      Parameters.Add(new FixedValueParameter<IntValue>(NewDimensionsParameterName, "Dimensionality of projected space (usually 2 for easy visual analysis)", new IntValue(2)));
    221231      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent.", new IntValue(1000)));
     
    230240      Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored.", new BoolValue(true)));
    231241      Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "The interval after which the results will be updated.", new IntValue(50)));
     242      Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true)));
     243
    232244      Parameters[UpdateIntervalParameterName].Hidden = true;
    233245
     
    259271    protected override void Run(CancellationToken cancellationToken) {
    260272      var problemData = Problem.ProblemData;
    261       // set up and initialized everything if necessary
     273      // set up and initialize everything if necessary
    262274      if (state == null) {
    263275        if (SetSeedRandomly) Seed = new System.Random().Next();
    264         var random = new MersenneTwister((uint)Seed);
     276        var random = new MersenneTwister((uint) Seed);
    265277        var dataset = problemData.Dataset;
    266278        var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
    267         var data = new double[dataset.Rows][];
    268         for (var row = 0; row < dataset.Rows; row++)
    269           data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray();
    270 
    271         if (Normalization) data = NormalizeData(data);
    272 
    273         state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta,
    274           StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta);
    275 
    276         SetUpResults(data);
     279        var allindices = Problem.ProblemData.AllIndices.ToArray();
     280        var data = allindices.Select(row => allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray()).ToArray();
     281        if (Normalization) data = NormalizeInputData(data);
     282        state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, RandomInitialization);
     283        SetUpResults(allindices);
    277284        iter = 0;
    278285      }
     
    283290      }
    284291      Analyze(state);
     292      dataRowNames = null;
     293      dataRows = null;
     294      state = null;
    285295    }
    286296
     
    307317      foreach (var input in Problem.ProblemData.InputVariables) ClassesNameParameter.ValidValues.Add(input);
    308318    }
    309 
    310319    #endregion
    311320
    312321    #region Helpers
    313     private void SetUpResults(IReadOnlyCollection<double[]> data) {
     322    private void SetUpResults(IReadOnlyList<int> allIndices) {
    314323      if (Results == null) return;
    315324      var results = Results;
     
    320329      //color datapoints acording to classes variable (be it double or string)
    321330      if (problemData.Dataset.VariableNames.Contains(ClassesName)) {
    322         if ((problemData.Dataset as Dataset).VariableHasType<string>(ClassesName)) {
    323           var classes = problemData.Dataset.GetStringValues(ClassesName).ToArray();
     331        var classificationData = problemData as ClassificationProblemData;
     332        if (classificationData != null && classificationData.TargetVariable.Equals(ClassesName)) {
     333          var classNames = classificationData.ClassValues.Zip(classificationData.ClassNames, (v, n) => new {v, n}).ToDictionary(x => x.v, x => x.n);
     334          var classes = classificationData.Dataset.GetDoubleValues(classificationData.TargetVariable, allIndices).Select(v => classNames[v]).ToArray();
    324335          for (var i = 0; i < classes.Length; i++) {
    325336            if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
    326337            dataRowNames[classes[i]].Add(i);
    327338          }
    328         } else if ((problemData.Dataset as Dataset).VariableHasType<double>(ClassesName)) {
    329           var classValues = problemData.Dataset.GetDoubleValues(ClassesName).ToArray();
    330           var max = classValues.Max() + 0.1;
    331           var min = classValues.Min() - 0.1;
     339        }
     340        else if (((Dataset) problemData.Dataset).VariableHasType<string>(ClassesName)) {
     341          var classes = problemData.Dataset.GetStringValues(ClassesName, allIndices).ToArray();
     342          for (var i = 0; i < classes.Length; i++) {
     343            if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
     344            dataRowNames[classes[i]].Add(i);
     345          }
     346        }
     347        else if (((Dataset) problemData.Dataset).VariableHasType<double>(ClassesName)) {
     348          var clusterdata = new Dataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
    332349          const int contours = 8;
     350          Dictionary<int, string> contourMap;
     351          IClusteringModel clusterModel;
     352          double[][] borders;
     353          CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders);
     354          var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
    333355          for (var i = 0; i < contours; i++) {
    334             var contourname = GetContourName(i, min, max, contours);
     356            var c = contourorder[i];
     357            var contourname = contourMap[c];
    335358            dataRowNames.Add(contourname, new List<int>());
    336359            dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
     
    338361            dataRows[contourname].VisualProperties.PointSize = i + 3;
    339362          }
    340           for (var i = 0; i < classValues.Length; i++) {
    341             dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i);
     363          var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
     364          for (var i = 0; i < clusterdata.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i);
     365        }
     366        else if (((Dataset) problemData.Dataset).VariableHasType<DateTime>(ClassesName)) {
     367          var clusterdata = new Dataset(problemData.Dataset.DateTimeVariables, problemData.Dataset.DateTimeVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
     368          const int contours = 8;
     369          Dictionary<int, string> contourMap;
     370          IClusteringModel clusterModel;
     371          double[][] borders;
     372          CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders);
     373          var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
     374          for (var i = 0; i < contours; i++) {
     375            var c = contourorder[i];
     376            var contourname = contourMap[c];
     377            dataRowNames.Add(contourname, new List<int>());
     378            dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
     379            dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
     380            dataRows[contourname].VisualProperties.PointSize = i + 3;
    342381          }
    343         }
    344       } else {
    345         dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
    346         dataRowNames.Add("Test", problemData.TestIndices.ToList());
    347       }
    348 
    349       if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
    350       else ((IntValue)results[IterationResultName].Value).Value = 0;
    351 
    352       if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
    353       else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
    354 
    355       if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent")));
    356       else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent");
    357 
    358       var plot = results[ErrorPlotResultName].Value as DataTable;
    359       if (plot == null) throw new ArgumentException("could not create/access error data table in results collection");
    360 
    361       if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors"));
    362       plot.Rows["errors"].Values.Clear();
    363       plot.Rows["errors"].VisualProperties.StartIndexZero = true;
    364 
    365       results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
    366       results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
     382          var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
     383          for (var i = 0; i < clusterdata.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i);
     384        }
     385        else {
     386          dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
     387          dataRowNames.Add("Test", problemData.TestIndices.ToList());
     388        }
     389
     390        if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
     391        else ((IntValue) results[IterationResultName].Value).Value = 0;
     392
     393        if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
     394        else ((DoubleValue) results[ErrorResultName].Value).Value = 0;
     395
     396        if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent")));
     397        else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent");
     398
     399        var plot = results[ErrorPlotResultName].Value as DataTable;
     400        if (plot == null) throw new ArgumentException("could not create/access error data table in results collection");
     401
     402        if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors"));
     403        plot.Rows["errors"].Values.Clear();
     404        plot.Rows["errors"].VisualProperties.StartIndexZero = true;
     405
     406        results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
     407        results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
     408      }
    367409    }
    368410
     
    375417      var c = tsneState.EvaluateError();
    376418      errors.Add(c);
    377       ((IntValue)results[IterationResultName].Value).Value = tsneState.iter;
    378       ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last();
    379 
    380       var ndata = Normalize(tsneState.newData);
     419      ((IntValue) results[IterationResultName].Value).Value = tsneState.iter;
     420      ((DoubleValue) results[ErrorResultName].Value).Value = errors.Last();
     421
     422      var ndata = NormalizeProjectedData(tsneState.newData);
    381423      results[DataResultName].Value = new DoubleMatrix(ndata);
    382424      var splot = results[ScatterPlotResultName].Value as ScatterPlot;
     
    386428    private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
    387429      foreach (var rowName in dataRowNames.Keys) {
    388         if (!plot.Rows.ContainsKey(rowName))
     430        if (!plot.Rows.ContainsKey(rowName)) {
    389431          plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
     432          plot.Rows[rowName].VisualProperties.PointSize = 6;
     433        }
    390434        plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
    391435      }
    392436    }
    393437
    394     private static double[,] Normalize(double[,] data) {
     438    private static double[,] NormalizeProjectedData(double[,] data) {
    395439      var max = new double[data.GetLength(1)];
    396440      var min = new double[data.GetLength(1)];
     
    398442      for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
    399443      for (var i = 0; i < data.GetLength(0); i++)
    400         for (var j = 0; j < data.GetLength(1); j++) {
    401           var v = data[i, j];
    402           max[j] = Math.Max(max[j], v);
    403           min[j] = Math.Min(min[j], v);
    404         }
     444      for (var j = 0; j < data.GetLength(1); j++) {
     445        var v = data[i, j];
     446        max[j] = Math.Max(max[j], v);
     447        min[j] = Math.Min(min[j], v);
     448      }
    405449      for (var i = 0; i < data.GetLength(0); i++) {
    406450        for (var j = 0; j < data.GetLength(1); j++) {
    407451          var d = max[j] - min[j];
    408           var s = data[i, j] - (max[j] + min[j]) / 2;  //shift data
    409           if (d.IsAlmost(0)) res[i, j] = data[i, j];   //no scaling possible
    410           else res[i, j] = s / d;  //scale data
     452          var s = data[i, j] - (max[j] + min[j]) / 2; //shift data
     453          if (d.IsAlmost(0)) res[i, j] = data[i, j]; //no scaling possible
     454          else res[i, j] = s / d; //scale data
    411455        }
    412456      }
     
    414458    }
    415459
    416     private static double[][] NormalizeData(IReadOnlyList<double[]> data) {
     460    private static double[][] NormalizeInputData(IReadOnlyList<IReadOnlyList<double>> data) {
    417461      // as in tSNE implementation by van der Maaten
    418       var n = data[0].Length;
     462      var n = data[0].Count;
    419463      var mean = new double[n];
    420464      var max = new double[n];
     
    432476
    433477    private static Color GetHeatMapColor(int contourNr, int noContours) {
    434       var q = (double)contourNr / noContours;  // q in [0,1]
    435       var c = q < 0.5 ? Color.FromArgb((int)(q * 2 * 255), 255, 0) : Color.FromArgb(255, (int)((1 - q) * 2 * 255), 0);
    436       return c;
    437     }
    438 
    439     private static string GetContourName(double value, double min, double max, int noContours) {
    440       var size = (max - min) / noContours;
    441       var contourNr = (int)((value - min) / size);
    442       return GetContourName(contourNr, min, max, noContours);
    443     }
    444 
    445     private static string GetContourName(int i, double min, double max, int noContours) {
    446       var size = (max - min) / noContours;
    447       return "[" + (min + i * size) + ";" + (min + (i + 1) * size) + ")";
     478      return ConvertTotalToRgb(0, noContours, contourNr);
     479    }
     480
     481    private static void CreateClusters(IDataset data, string target, int contours, out IClusteringModel contourCluster, out Dictionary<int, string> contourNames, out double[][] borders) {
     482      var cpd = new ClusteringProblemData((Dataset) data, new[] {target});
     483      contourCluster = KMeansClustering.CreateKMeansSolution(cpd, contours, 3).Model;
     484
     485      borders = Enumerable.Range(0, contours).Select(x => new[] {double.MaxValue, double.MinValue}).ToArray();
     486      var clusters = contourCluster.GetClusterValues(cpd.Dataset, cpd.AllIndices).ToArray();
     487      var targetvalues = cpd.Dataset.GetDoubleValues(target).ToArray();
     488      foreach (var i in cpd.AllIndices) {
     489        var cl = clusters[i] - 1;
     490        var clv = targetvalues[i];
     491        if (borders[cl][0] > clv) borders[cl][0] = clv;
     492        if (borders[cl][1] < clv) borders[cl][1] = clv;
     493      }
     494
     495      contourNames = new Dictionary<int, string>();
     496      for (var i = 0; i < contours; i++)
     497        contourNames.Add(i, "[" + borders[i][0] + ";" + borders[i][1] + "]");
     498    }
     499
     500    private static Color ConvertTotalToRgb(double low, double high, double cell) {
     501      var range = high - low;
     502      var h = cell / range;
     503      return HsVtoRgb(h * 0.5, 1.0f, 1.0f);
     504    }
     505
     506    private static Color HsVtoRgb(double hue, double saturation, double value) {
     507      while (hue > 1f) { hue -= 1f; }
     508      while (hue < 0f) { hue += 1f; }
     509      while (saturation > 1f) { saturation -= 1f; }
     510      while (saturation < 0f) { saturation += 1f; }
     511      while (value > 1f) { value -= 1f; }
     512      while (value < 0f) { value += 1f; }
     513      if (hue > 0.999f) { hue = 0.999f; }
     514      if (hue < 0.001f) { hue = 0.001f; }
     515      if (saturation > 0.999f) { saturation = 0.999f; }
     516      if (saturation < 0.001f) { return Color.FromArgb((int) (value * 255f), (int) (value * 255f), (int) (value * 255f)); }
     517      if (value > 0.999f) { value = 0.999f; }
     518      if (value < 0.001f) { value = 0.001f; }
     519
     520      var h6 = hue * 6f;
     521      if (h6.IsAlmost(6f)) { h6 = 0f; }
     522      var ihue = (int) h6;
     523      var p = value * (1f - saturation);
     524      var q = value * (1f - saturation * (h6 - ihue));
     525      var t = value * (1f - saturation * (1f - (h6 - ihue)));
     526      switch (ihue) {
     527        case 0:
     528          return Color.FromArgb((int) (value * 255), (int) (t * 255), (int) (p * 255));
     529        case 1:
     530          return Color.FromArgb((int) (q * 255), (int) (value * 255), (int) (p * 255));
     531        case 2:
     532          return Color.FromArgb((int) (p * 255), (int) (value * 255), (int) (t * 255));
     533        case 3:
     534          return Color.FromArgb((int) (p * 255), (int) (q * 255), (int) (value * 255));
     535        case 4:
     536          return Color.FromArgb((int) (t * 255), (int) (p * 255), (int) (value * 255));
     537        default:
     538          return Color.FromArgb((int) (value * 255), (int) (p * 255), (int) (q * 255));
     539      }
    448540    }
    449541    #endregion
  • branches/Weighted TSNE/3.4/TSNE/TSNEStatic.cs

    r15207 r15451  
    6565  [StorableClass]
    6666  public class TSNEStatic<T> {
    67 
    6867    [StorableClass]
    6968    public sealed class TSNEState : DeepCloneable {
     
    170169      [StorableConstructor]
    171170      public TSNEState(bool deserializing) { }
    172       public TSNEState(T[] data, IDistance<T> distance, IRandom random, int newDimensions, double perplexity, double theta, int stopLyingIter, int momSwitchIter, double momentum, double finalMomentum, double eta) {
     171
     172      public TSNEState(T[] data, IDistance<T> distance, IRandom random, int newDimensions, double perplexity,
     173        double theta, int stopLyingIter, int momSwitchIter, double momentum, double finalMomentum, double eta, bool randomInit) {
    173174        this.distance = distance;
    174175        this.random = random;
     
    193194        gains = new double[noDatapoints, newDimensions];
    194195        for (var i = 0; i < noDatapoints; i++)
    195           for (var j = 0; j < newDimensions; j++)
    196             gains[i, j] = 1.0;
     196        for (var j = 0; j < newDimensions; j++)
     197          gains[i, j] = 1.0;
    197198
    198199        p = null;
     
    212213        var rand = new NormalDistributedRandom(random, 0, 1);
    213214        for (var i = 0; i < noDatapoints; i++)
    214           for (var j = 0; j < newDimensions; j++)
    215             newData[i, j] = rand.NextDouble() * .0001;
     215        for (var j = 0; j < newDimensions; j++)
     216          newData[i, j] = rand.NextDouble() * .0001;
     217
     218        if (data[0] is IReadOnlyList<double> && !randomInit) {
     219          for (var i = 0; i < noDatapoints; i++)
     220          for (var j = 0; j < newDimensions; j++) {
     221            var row = (IReadOnlyList<double>) data[i];
     222            newData[i, j] = row[j % row.Count];
     223          }
     224        }
    216225      }
    217226      #endregion
    218227
    219228      public double EvaluateError() {
    220         return exact ?
    221           EvaluateErrorExact(p, newData, noDatapoints, newDimensions) :
    222           EvaluateErrorApproximate(rowP, colP, valP, newData, theta);
     229        return exact ? EvaluateErrorExact(p, newData, noDatapoints, newDimensions) : EvaluateErrorApproximate(rowP, colP, valP, newData, theta);
    223230      }
    224231
     
    226233      private static void CalculateApproximateSimilarities(T[] data, IDistance<T> distance, double perplexity, out int[] rowP, out int[] colP, out double[] valP) {
    227234        // Compute asymmetric pairwise input similarities
    228         ComputeGaussianPerplexity(data, distance, out rowP, out colP, out valP, perplexity, (int)(3 * perplexity));
     235        ComputeGaussianPerplexity(data, distance, out rowP, out colP, out valP, perplexity, (int) (3 * perplexity));
    229236        // Symmetrize input similarities
    230237        int[] sRowP, symColP;
     
    290297
    291298          // Iterate until we found a good perplexity
    292           var iter = 0; double sumP = 0;
     299          var iter = 0;
     300          double sumP = 0;
    293301          while (!found && iter < 200) {
    294 
    295302            // Compute Gaussian kernel row
    296303            for (var m = 0; m < k; m++) curP[m] = Math.Exp(-beta * distances[m + 1]);
     
    307314            if (hdiff < tol && -hdiff < tol) {
    308315              found = true;
    309             } else {
     316            }
     317            else {
    310318              if (hdiff > 0) {
    311319                minBeta = beta;
     
    314322                else
    315323                  beta = (beta + maxBeta) / 2.0;
    316               } else {
     324              }
     325              else {
    317326                maxBeta = beta;
    318327                if (minBeta.IsAlmost(double.MinValue) || minBeta.IsAlmost(double.MaxValue))
     
    352361          // Iterate until we found a good perplexity
    353362          var iter = 0;
    354           while (!found && iter < 200) {      // 200 iterations as in tSNE implementation by van der Maarten
     363          while (!found && iter < 200) { // 200 iterations as in tSNE implementation by van der Maarten
    355364
    356365            // Compute Gaussian kernel row
     
    369378            if (hdiff < tol && -hdiff < tol) {
    370379              found = true;
    371             } else {
     380            }
     381            else {
    372382              if (hdiff > 0) {
    373383                minBeta = beta;
     
    376386                else
    377387                  beta = (beta + maxBeta) / 2.0;
    378               } else {
     388              }
     389              else {
    379390                maxBeta = beta;
    380391                if (minBeta.IsAlmost(double.MinValue) || minBeta.IsAlmost(double.MaxValue))
     
    425436              q[n1, m] = 1 / (1 + dd[n1, m]);
    426437              sumQ += q[n1, m];
    427             } else q[n1, m] = double.Epsilon;
     438            }
     439            else q[n1, m] = double.Epsilon;
    428440          }
    429441        }
     
    433445        var c = .0;
    434446        for (var i = 0; i < n; i++)
    435           for (var j = 0; j < n; j++) {
    436             c += p[i, j] * Math.Log((p[i, j] + float.Epsilon) / (q[i, j] + float.Epsilon));
    437           }
     447        for (var j = 0; j < n; j++) {
     448          c += p[i, j] * Math.Log((p[i, j] + float.Epsilon) / (q[i, j] + float.Epsilon));
     449        }
    438450        return c;
    439451      }
     
    463475      }
    464476      private static void SymmetrizeMatrix(IReadOnlyList<int> rowP, IReadOnlyList<int> colP, IReadOnlyList<double> valP, out int[] symRowP, out int[] symColP, out double[] symValP) {
    465 
    466477        // Count number of elements and row counts of symmetric matrix
    467478        var n = rowP.Count - 1;
     
    469480        for (var j = 0; j < n; j++) {
    470481          for (var i = rowP[j]; i < rowP[j + 1]; i++) {
    471 
    472482            // Check whether element (col_P[i], n) is present
    473483            var present = false;
     
    497507        var offset = new int[n];
    498508        for (var j = 0; j < n; j++) {
    499           for (var i = rowP[j]; i < rowP[j + 1]; i++) {                                  // considering element(n, colP[i])
     509          for (var i = rowP[j]; i < rowP[j + 1]; i++) { // considering element(n, colP[i])
    500510
    501511            // Check whether element (col_P[i], n) is present
     
    552562      int stopLyingIter = 0, int momSwitchIter = 0, double momentum = .5,
    553563      double finalMomentum = .8, double eta = 10.0
    554       ) {
     564    ) {
    555565      var state = CreateState(data, distance, random, newDimensions, perplexity,
    556566        theta, stopLyingIter, momSwitchIter, momentum, finalMomentum, eta);
     
    565575      int newDimensions = 2, double perplexity = 25, double theta = 0,
    566576      int stopLyingIter = 0, int momSwitchIter = 0, double momentum = .5,
    567       double finalMomentum = .8, double eta = 10.0
    568       ) {
    569       return new TSNEState(data, distance, random, newDimensions, perplexity, theta, stopLyingIter, momSwitchIter, momentum, finalMomentum, eta);
     577      double finalMomentum = .8, double eta = 10.0, bool randomInit = true
     578    ) {
     579      return new TSNEState(data, distance, random, newDimensions, perplexity, theta, stopLyingIter, momSwitchIter, momentum, finalMomentum, eta, randomInit);
    570580    }
    571581
     
    580590        for (var j = 0; j < state.newDimensions; j++) {
    581591          state.gains[i, j] = Math.Sign(state.dY[i, j]) != Math.Sign(state.uY[i, j])
    582             ? state.gains[i, j] + .2  // +0.2 nd *0.8 are used in two separate implementations of tSNE -> seems to be correct
     592            ? state.gains[i, j] + .2 // +0.2 nd *0.8 are used in two separate implementations of tSNE -> seems to be correct
    583593            : state.gains[i, j] * .8;
    584594
     
    590600      // Perform gradient update (with momentum and gains)
    591601      for (var i = 0; i < state.noDatapoints; i++)
    592         for (var j = 0; j < state.newDimensions; j++)
    593           state.uY[i, j] = state.currentMomentum * state.uY[i, j] - state.eta * state.gains[i, j] * state.dY[i, j];
     602      for (var j = 0; j < state.newDimensions; j++)
     603        state.uY[i, j] = state.currentMomentum * state.uY[i, j] - state.eta * state.gains[i, j] * state.dY[i, j];
    594604
    595605      for (var i = 0; i < state.noDatapoints; i++)
    596         for (var j = 0; j < state.newDimensions; j++)
    597           state.newData[i, j] = state.newData[i, j] + state.uY[i, j];
     606      for (var j = 0; j < state.newDimensions; j++)
     607        state.newData[i, j] = state.newData[i, j] + state.uY[i, j];
    598608
    599609      // Make solution zero-mean
     
    604614        if (state.exact)
    605615          for (var i = 0; i < state.noDatapoints; i++)
    606             for (var j = 0; j < state.noDatapoints; j++)
    607               state.p[i, j] /= 12.0;
     616          for (var j = 0; j < state.noDatapoints; j++)
     617            state.p[i, j] /= 12.0;
    608618        else
    609619          for (var i = 0; i < state.rowP[state.noDatapoints]; i++)
     
    634644      // Compute final t-SNE gradient
    635645      for (var i = 0; i < n; i++)
    636         for (var j = 0; j < d; j++) {
    637           dC[i, j] = posF[i, j] - negF[i, j] / sumQ;
    638         }
     646      for (var j = 0; j < d; j++) {
     647        dC[i, j] = posF[i, j] - negF[i, j] / sumQ;
     648      }
    639649    }
    640650
Note: See TracChangeset for help on using the changeset viewer.