Changeset 14859


Ignore:
Timestamp:
04/13/17 13:59:42 (5 months ago)
Author:
gkronber
Message:

#2700 update only every x iterations and fixed output paths

Location:
branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r14836 r14859  
    5454    <DebugType>pdbonly</DebugType>
    5555    <Optimize>true</Optimize>
    56     <OutputPath>$(SolutionDir)\bin\</OutputPath>
     56    <OutputPath>..\..\..\..\trunk\sources\bin\</OutputPath>
    5757    <DefineConstants>TRACE</DefineConstants>
    5858    <ErrorReport>prompt</ErrorReport>
     
    6565  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|x86' ">
    6666    <DebugSymbols>true</DebugSymbols>
    67     <OutputPath>$(SolutionDir)\bin\</OutputPath>
     67    <OutputPath>..\..\..\..\trunk\sources\bin\</OutputPath>
    6868    <DefineConstants>DEBUG;TRACE</DefineConstants>
    6969    <DebugType>full</DebugType>
     
    7474  </PropertyGroup>
    7575  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|x86' ">
    76     <OutputPath>$(SolutionDir)\bin\</OutputPath>
     76    <OutputPath>..\..\..\..\trunk\sources\bin\</OutputPath>
    7777    <DefineConstants>TRACE</DefineConstants>
    7878    <DocumentationFile>
     
    8787  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|x64' ">
    8888    <DebugSymbols>true</DebugSymbols>
    89     <OutputPath>$(SolutionDir)\bin\</OutputPath>
     89    <OutputPath>..\..\..\..\trunk\sources\bin\</OutputPath>
    9090    <DefineConstants>DEBUG;TRACE</DefineConstants>
    9191    <DebugType>full</DebugType>
     
    9696  </PropertyGroup>
    9797  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|x64' ">
    98     <OutputPath>$(SolutionDir)\bin\</OutputPath>
     98    <OutputPath>..\..\..\..\trunk\sources\bin\</OutputPath>
    9999    <DefineConstants>TRACE</DefineConstants>
    100100    <DocumentationFile>
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs

    r14855 r14859  
    7171    private const string ClassesParameterName = "ClassNames";
    7272    private const string NormalizationParameterName = "Normalization";
     73    private const string UpdateIntervalParameterName = "UpdateInterval";
    7374    #endregion
    7475
     
    124125      get { return Parameters[NormalizationParameterName] as IFixedValueParameter<BoolValue>; }
    125126    }
     127    public IFixedValueParameter<IntValue> UpdateIntervalParameter {
     128      get { return Parameters[UpdateIntervalParameterName] as IFixedValueParameter<IntValue>; }
     129    }
    126130    #endregion
    127131
     
    182186      set { NormalizationParameter.Value.Value = value; }
    183187    }
     188
     189    public int UpdateInterval {
     190      get { return UpdateIntervalParameter.Value.Value; }
     191      set { UpdateIntervalParameter.Value.Value = value; }
     192    }
    184193    #endregion
    185194
     
    191200      this.dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames);
    192201      this.dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value));
    193       if(original.state != null)
     202      if (original.state != null)
    194203        this.state = cloner.Clone(original.state);
    195204      this.iter = original.iter;
     
    217226      Parameters.Add(new FixedValueParameter<StringValue>(ClassesParameterName, "name of the column specifying the class lables of each data point. If the label column can not be found training/test is used as labels.", new StringValue("none")));
    218227      Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored.", new BoolValue(true)));
     228      Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "", new IntValue(50)));
     229      Parameters[UpdateIntervalParameterName].Hidden = true;
    219230
    220231      MomentumSwitchIterationParameter.Hidden = true;
     
    245256      var problemData = Problem.ProblemData;
    246257      // set up and initialized everything if necessary
    247       if(state == null) {
    248         if(SetSeedRandomly) Seed = new System.Random().Next();
     258      if (state == null) {
     259        if (SetSeedRandomly) Seed = new System.Random().Next();
    249260        var random = new MersenneTwister((uint)Seed);
    250261        var dataset = problemData.Dataset;
    251262        var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
    252263        var data = new double[dataset.Rows][];
    253         for(var row = 0; row < dataset.Rows; row++)
     264        for (var row = 0; row < dataset.Rows; row++)
    254265          data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray();
    255266
    256         if(Normalization) data = NormalizeData(data);
     267        if (Normalization) data = NormalizeData(data);
    257268
    258269        state = TSNEStatic<double[]>.CreateState(data, Distance, random, NewDimensions, Perplexity, Theta,
     
    262273        iter = 0;
    263274      }
    264       for(; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) {
     275      for (; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) {
     276        if (iter % UpdateInterval == 0)
     277          Analyze(state);
    265278        TSNEStatic<double[]>.Iterate(state);
    266         Analyze(state);
    267       }
     279      }
     280      Analyze(state);
    268281    }
    269282
    270283    private void SetUpResults(IReadOnlyCollection<double[]> data) {
    271       if(Results == null) return;
     284      if (Results == null) return;
    272285      var results = Results;
    273286      dataRowNames = new Dictionary<string, List<int>>();
     
    276289
    277290      //color datapoints acording to classes variable (be it double or string)
    278       if(problemData.Dataset.VariableNames.Contains(Classes)) {
    279         if((problemData.Dataset as Dataset).VariableHasType<string>(Classes)) {
     291      if (problemData.Dataset.VariableNames.Contains(Classes)) {
     292        if ((problemData.Dataset as Dataset).VariableHasType<string>(Classes)) {
    280293          var classes = problemData.Dataset.GetStringValues(Classes).ToArray();
    281           for(var i = 0; i < classes.Length; i++) {
    282             if(!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
     294          for (var i = 0; i < classes.Length; i++) {
     295            if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
    283296            dataRowNames[classes[i]].Add(i);
    284297          }
    285         } else if((problemData.Dataset as Dataset).VariableHasType<double>(Classes)) {
     298        } else if ((problemData.Dataset as Dataset).VariableHasType<double>(Classes)) {
    286299          var classValues = problemData.Dataset.GetDoubleValues(Classes).ToArray();
    287           var max = classValues.Max() + 0.1;   
     300          var max = classValues.Max() + 0.1;
    288301          var min = classValues.Min() - 0.1;
    289302          const int contours = 8;
    290           for(var i = 0; i < contours; i++) {
     303          for (var i = 0; i < contours; i++) {
    291304            var contourname = GetContourName(i, min, max, contours);
    292305            dataRowNames.Add(contourname, new List<int>());
     
    295308            dataRows[contourname].VisualProperties.PointSize = i + 3;
    296309          }
    297           for(var i = 0; i < classValues.Length; i++) {
     310          for (var i = 0; i < classValues.Length; i++) {
    298311            dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i);
    299312          }
     
    304317      }
    305318
    306       if(!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
     319      if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
    307320      else ((IntValue)results[IterationResultName].Value).Value = 0;
    308321
    309       if(!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
     322      if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
    310323      else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
    311324
    312       if(!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent")));
     325      if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent")));
    313326      else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent");
    314327
    315328      var plot = results[ErrorPlotResultName].Value as DataTable;
    316       if(plot == null) throw new ArgumentException("could not create/access error data table in results collection");
    317 
    318       if(!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors"));
     329      if (plot == null) throw new ArgumentException("could not create/access error data table in results collection");
     330
     331      if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors"));
    319332      plot.Rows["errors"].Values.Clear();
     333      plot.Rows["errors"].VisualProperties.StartIndexZero = true;
    320334
    321335      results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
     
    324338
    325339    private void Analyze(TSNEStatic<double[]>.TSNEState tsneState) {
    326       if(Results == null) return;
     340      if (Results == null) return;
    327341      var results = Results;
    328342      var plot = results[ErrorPlotResultName].Value as DataTable;
    329       if(plot == null) throw new ArgumentException("Could not create/access error data table in results collection.");
     343      if (plot == null) throw new ArgumentException("Could not create/access error data table in results collection.");
    330344      var errors = plot.Rows["errors"].Values;
    331345      var c = tsneState.EvaluateError();
     
    341355
    342356    private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
    343       foreach(var rowName in dataRowNames.Keys) {
    344         if(!plot.Rows.ContainsKey(rowName))
     357      foreach (var rowName in dataRowNames.Keys) {
     358        if (!plot.Rows.ContainsKey(rowName))
    345359          plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
    346360        plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
     
    352366      var min = new double[data.GetLength(1)];
    353367      var res = new double[data.GetLength(0), data.GetLength(1)];
    354       for(var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
    355       for(var i = 0; i < data.GetLength(0); i++)
    356         for(var j = 0; j < data.GetLength(1); j++) {
     368      for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
     369      for (var i = 0; i < data.GetLength(0); i++)
     370        for (var j = 0; j < data.GetLength(1); j++) {
    357371          var v = data[i, j];
    358372          max[j] = Math.Max(max[j], v);
    359373          min[j] = Math.Min(min[j], v);
    360374        }
    361       for(var i = 0; i < data.GetLength(0); i++) {
    362         for(var j = 0; j < data.GetLength(1); j++) {
     375      for (var i = 0; i < data.GetLength(0); i++) {
     376        for (var j = 0; j < data.GetLength(1); j++) {
    363377          res[i, j] = (data[i, j] - (max[j] + min[j]) / 2) / (max[j] - min[j]);
    364378        }
     
    368382
    369383    private static double[][] NormalizeData(IReadOnlyList<double[]> data) {
     384      // as in tSNE implementation by van der Maaten
    370385      var n = data[0].Length;
    371386      var mean = new double[n];
    372       var sd = new double[n];
     387      var max = new double[n];
    373388      var nData = new double[data.Count][];
    374       for(var i = 0; i < n; i++) {
    375         var i1 = i;
    376         sd[i] = Enumerable.Range(0, data.Count).Select(x => data[x][i1]).StandardDeviation();
    377         mean[i] = Enumerable.Range(0, data.Count).Select(x => data[x][i1]).Average();
    378       }
    379       for(var i = 0; i < data.Count; i++) {
     389      for (var i = 0; i < n; i++) {
     390        mean[i] = Enumerable.Range(0, data.Count).Select(x => data[x][i]).Average();
     391        max[i] = Enumerable.Range(0, data.Count).Max(x => Math.Abs(data[x][i]));
     392      }
     393      for (var i = 0; i < data.Count; i++) {
    380394        nData[i] = new double[n];
    381         for(var j = 0; j < n; j++) nData[i][j] = (data[i][j] - mean[j]) / sd[j];
     395        for (var j = 0; j < n; j++) nData[i][j] = (data[i][j] - mean[j]) / max[j];
    382396      }
    383397      return nData;
Note: See TracChangeset for help on using the changeset viewer.