Free cookie consent management tool by TermsFeed Policy Generator

Changeset 15455


Ignore:
Timestamp:
11/07/17 13:15:55 (7 years ago)
Author:
bwerth
Message:

#2847 added WeightedEuclideanDistance && fixed minor bug in scatterPlot colloring

Location:
branches/Weighted TSNE/3.4/TSNE
Files:
1 added
3 edited

Legend:

Unmodified
Added
Removed
  • branches/Weighted TSNE/3.4/TSNE/TSNEAlgorithm.cs

    r15451 r15455  
    2929using HeuristicLab.Core;
    3030using HeuristicLab.Data;
     31using HeuristicLab.Encodings.RealVectorEncoding;
    3132using HeuristicLab.Optimization;
    3233using HeuristicLab.Parameters;
     
    5758    }
    5859
    59     #region parameter names
     60    #region Parameter names
    6061    private const string DistanceFunctionParameterName = "DistanceFunction";
    6162    private const string PerplexityParameterName = "Perplexity";
     
    7677    #endregion
    7778
    78     #region result names
     79    #region Result names
    7980    private const string IterationResultName = "Iteration";
    8081    private const string ErrorResultName = "Error";
     
    8485    #endregion
    8586
    86     #region parameter properties
     87    #region Parameter properties
    8788    public IFixedValueParameter<DoubleValue> PerplexityParameter {
    8889      get { return Parameters[PerplexityParameterName] as IFixedValueParameter<DoubleValue>; }
     
    202203    #endregion
    203204
     205    #region Storable poperties
     206    [Storable]
     207    private Dictionary<string, List<int>> dataRowNames;
     208    [Storable]
     209    private Dictionary<string, ScatterPlotDataRow> dataRows;
     210    [Storable]
     211    private TSNEStatic<double[]>.TSNEState state;
     212    [Storable]
     213    private int iter;
     214    #endregion
     215
    204216    #region Constructors & Cloning
    205217    [StorableConstructor]
    206218    private TSNEAlgorithm(bool deserializing) : base(deserializing) { }
    207219
     220    [StorableHook(HookType.AfterDeserialization)]
     221    private void AfterDeserialization() {
     222      RegisterParameterEvents();
     223    }
    208224    private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) {
    209225      if (original.dataRowNames != null)
     
    250266      EtaParameter.Hidden = false;
    251267      Problem = new RegressionProblem();
    252     }
    253     #endregion
    254 
    255     [Storable]
    256     private Dictionary<string, List<int>> dataRowNames;
    257     [Storable]
    258     private Dictionary<string, ScatterPlotDataRow> dataRows;
    259     [Storable]
    260     private TSNEStatic<double[]>.TSNEState state;
    261     [Storable]
    262     private int iter;
     268      RegisterParameterEvents();
     269    }
     270    #endregion
    263271
    264272    public override void Prepare() {
     
    285293      }
    286294      for (; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) {
    287         if (iter % UpdateInterval == 0)
    288           Analyze(state);
     295        if (iter % UpdateInterval == 0) Analyze(state);
    289296        TSNEStatic<double[]>.Iterate(state);
    290297      }
    291298      Analyze(state);
    292       dataRowNames = null;
    293       dataRows = null;
    294       state = null;
    295299    }
    296300
     
    306310      Problem.ProblemDataChanged += OnProblemDataChanged;
    307311    }
     312
    308313    protected override void DeregisterProblemEvents() {
    309314      base.DeregisterProblemEvents();
     
    311316    }
    312317
     318    protected override void OnStopped() {
     319      base.OnStopped();
     320      state = null;
     321      dataRowNames = null;
     322      dataRows = null;
     323    }
     324
    313325    private void OnProblemDataChanged(object sender, EventArgs args) {
    314326      if (Problem == null || Problem.ProblemData == null) return;
     327      OnPerplexityChanged(this, null);
     328      Problem.ProblemData.Changed += OnPerplexityChanged;
     329      Problem.ProblemData.Changed += OnColumnsChanged;
     330      Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged;
     331      Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged;
    315332      if (!Parameters.ContainsKey(ClassesNameParameterName)) return;
    316333      ClassesNameParameter.ValidValues.Clear();
    317334      foreach (var input in Problem.ProblemData.InputVariables) ClassesNameParameter.ValidValues.Add(input);
     335    }
     336    private void OnColumnsChanged(object sender, EventArgs e) {
     337      if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(DistanceFunctionParameterName)) return;
     338      DistanceFunctionParameter.ValidValues.OfType<WeightedEuclideanDistance>().Single().Weights = new RealVector(Problem.ProblemData.AllowedInputVariables.Select(x => 1.0).ToArray());
     339    }
     340
     341    private void RegisterParameterEvents() {
     342      PerplexityParameter.Value.ValueChanged -= OnPerplexityChanged;
     343      PerplexityParameter.Value.ValueChanged += OnPerplexityChanged;
     344    }
     345
     346    private void OnPerplexityChanged(object sender, EventArgs e) {
     347      if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(PerplexityParameterName)) return;
     348      PerplexityParameter.Value.ValueChanged -= OnPerplexityChanged;
     349      PerplexityParameter.Value.Value = Math.Max(1, Math.Min((Problem.ProblemData.Dataset.Rows - 1) / 3.0, Perplexity));
     350      PerplexityParameter.Value.ValueChanged += OnPerplexityChanged;
    318351    }
    319352    #endregion
     
    327360      var problemData = Problem.ProblemData;
    328361
     362      if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
     363      if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
     364      if (!results.ContainsKey(ScatterPlotResultName)) results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
     365      if (!results.ContainsKey(DataResultName)) results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
     366      if (!results.ContainsKey(ErrorPlotResultName)) {
     367        var errortable = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent") {
     368          VisualProperties = {
     369            XAxisTitle = "UpdateIntervall",
     370            YAxisTitle = "Error",
     371            YAxisLogScale = true
     372          }
     373        };
     374        errortable.Rows.Add(new DataRow("Errors"));
     375        errortable.Rows["Errors"].VisualProperties.StartIndexZero = true;
     376        results.Add(new Result(ErrorPlotResultName, errortable));
     377      }
     378
    329379      //color datapoints acording to classes variable (be it double or string)
    330       if (problemData.Dataset.VariableNames.Contains(ClassesName)) {
    331         var classificationData = problemData as ClassificationProblemData;
    332         if (classificationData != null && classificationData.TargetVariable.Equals(ClassesName)) {
    333           var classNames = classificationData.ClassValues.Zip(classificationData.ClassNames, (v, n) => new {v, n}).ToDictionary(x => x.v, x => x.n);
    334           var classes = classificationData.Dataset.GetDoubleValues(classificationData.TargetVariable, allIndices).Select(v => classNames[v]).ToArray();
    335           for (var i = 0; i < classes.Length; i++) {
    336             if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
    337             dataRowNames[classes[i]].Add(i);
    338           }
     380      if (!problemData.Dataset.VariableNames.Contains(ClassesName)) {
     381        dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
     382        dataRowNames.Add("Test", problemData.TestIndices.ToList());
     383        return;
     384      }
     385      var classificationData = problemData as ClassificationProblemData;
     386      if (classificationData != null && classificationData.TargetVariable.Equals(ClassesName)) {
     387        var classNames = classificationData.ClassValues.Zip(classificationData.ClassNames, (v, n) => new {v, n}).ToDictionary(x => x.v, x => x.n);
     388        var classes = classificationData.Dataset.GetDoubleValues(classificationData.TargetVariable, allIndices).Select(v => classNames[v]).ToArray();
     389        for (var i = 0; i < classes.Length; i++) {
     390          if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
     391          dataRowNames[classes[i]].Add(i);
    339392        }
    340         else if (((Dataset) problemData.Dataset).VariableHasType<string>(ClassesName)) {
    341           var classes = problemData.Dataset.GetStringValues(ClassesName, allIndices).ToArray();
    342           for (var i = 0; i < classes.Length; i++) {
    343             if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
    344             dataRowNames[classes[i]].Add(i);
    345           }
     393      }
     394      else if (((Dataset) problemData.Dataset).VariableHasType<string>(ClassesName)) {
     395        var classes = problemData.Dataset.GetStringValues(ClassesName, allIndices).ToArray();
     396        for (var i = 0; i < classes.Length; i++) {
     397          if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
     398          dataRowNames[classes[i]].Add(i);
    346399        }
    347         else if (((Dataset) problemData.Dataset).VariableHasType<double>(ClassesName)) {
    348           var clusterdata = new Dataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
    349           const int contours = 8;
    350           Dictionary<int, string> contourMap;
    351           IClusteringModel clusterModel;
    352           double[][] borders;
    353           CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders);
    354           var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
    355           for (var i = 0; i < contours; i++) {
    356             var c = contourorder[i];
    357             var contourname = contourMap[c];
    358             dataRowNames.Add(contourname, new List<int>());
    359             dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
    360             dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
    361             dataRows[contourname].VisualProperties.PointSize = i + 3;
    362           }
    363           var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
    364           for (var i = 0; i < clusterdata.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i);
     400      }
     401      else if (((Dataset) problemData.Dataset).VariableHasType<double>(ClassesName)) {
     402        var clusterdata = new Dataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
     403        const int contours = 8;
     404        Dictionary<int, string> contourMap;
     405        IClusteringModel clusterModel;
     406        double[][] borders;
     407        CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders);
     408        var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
     409        for (var i = 0; i < contours; i++) {
     410          var c = contourorder[i];
     411          var contourname = contourMap[c];
     412          dataRowNames.Add(contourname, new List<int>());
     413          dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
     414          dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
    365415        }
    366         else if (((Dataset) problemData.Dataset).VariableHasType<DateTime>(ClassesName)) {
    367           var clusterdata = new Dataset(problemData.Dataset.DateTimeVariables, problemData.Dataset.DateTimeVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
    368           const int contours = 8;
    369           Dictionary<int, string> contourMap;
    370           IClusteringModel clusterModel;
    371           double[][] borders;
    372           CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders);
    373           var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
    374           for (var i = 0; i < contours; i++) {
    375             var c = contourorder[i];
    376             var contourname = contourMap[c];
    377             dataRowNames.Add(contourname, new List<int>());
    378             dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
    379             dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
    380             dataRows[contourname].VisualProperties.PointSize = i + 3;
    381           }
    382           var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
    383           for (var i = 0; i < clusterdata.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i);
     416        var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
     417        for (var i = 0; i < clusterdata.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i);
     418      }
     419      else if (((Dataset) problemData.Dataset).VariableHasType<DateTime>(ClassesName)) {
     420        var clusterdata = new Dataset(problemData.Dataset.DateTimeVariables, problemData.Dataset.DateTimeVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
     421        const int contours = 8;
     422        Dictionary<int, string> contourMap;
     423        IClusteringModel clusterModel;
     424        double[][] borders;
     425        CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders);
     426        var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
     427        for (var i = 0; i < contours; i++) {
     428          var c = contourorder[i];
     429          var contourname = contourMap[c];
     430          dataRowNames.Add(contourname, new List<int>());
     431          dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
     432          dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
    384433        }
    385         else {
    386           dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
    387           dataRowNames.Add("Test", problemData.TestIndices.ToList());
    388         }
    389 
    390         if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
    391         else ((IntValue) results[IterationResultName].Value).Value = 0;
    392 
    393         if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
    394         else ((DoubleValue) results[ErrorResultName].Value).Value = 0;
    395 
    396         if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent")));
    397         else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent");
    398 
    399         var plot = results[ErrorPlotResultName].Value as DataTable;
    400         if (plot == null) throw new ArgumentException("could not create/access error data table in results collection");
    401 
    402         if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors"));
    403         plot.Rows["errors"].Values.Clear();
    404         plot.Rows["errors"].VisualProperties.StartIndexZero = true;
    405 
    406         results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
    407         results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
     434        var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
     435        for (var i = 0; i < clusterdata.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i);
     436      }
     437      else {
     438        dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
     439        dataRowNames.Add("Test", problemData.TestIndices.ToList());
    408440      }
    409441    }
     
    414446      var plot = results[ErrorPlotResultName].Value as DataTable;
    415447      if (plot == null) throw new ArgumentException("Could not create/access error data table in results collection.");
    416       var errors = plot.Rows["errors"].Values;
     448      var errors = plot.Rows["Errors"].Values;
    417449      var c = tsneState.EvaluateError();
    418450      errors.Add(c);
     
    430462        if (!plot.Rows.ContainsKey(rowName)) {
    431463          plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
    432           plot.Rows[rowName].VisualProperties.PointSize = 6;
     464          plot.Rows[rowName].VisualProperties.PointSize = 8;
    433465        }
    434466        plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
     
    504536    }
    505537
     538    //taken from https://stackoverflow.com/a/17099130
    506539    private static Color HsVtoRgb(double hue, double saturation, double value) {
    507       while (hue > 1f) { hue -= 1f; }
    508       while (hue < 0f) { hue += 1f; }
    509       while (saturation > 1f) { saturation -= 1f; }
    510       while (saturation < 0f) { saturation += 1f; }
    511       while (value > 1f) { value -= 1f; }
    512       while (value < 0f) { value += 1f; }
    513       if (hue > 0.999f) { hue = 0.999f; }
    514       if (hue < 0.001f) { hue = 0.001f; }
    515       if (saturation > 0.999f) { saturation = 0.999f; }
    516       if (saturation < 0.001f) { return Color.FromArgb((int) (value * 255f), (int) (value * 255f), (int) (value * 255f)); }
    517       if (value > 0.999f) { value = 0.999f; }
    518       if (value < 0.001f) { value = 0.001f; }
    519 
    520       var h6 = hue * 6f;
    521       if (h6.IsAlmost(6f)) { h6 = 0f; }
     540      while (hue > 1.0) { hue -= 1.0; }
     541      while (hue < 0.0) { hue += 1.0; }
     542      while (saturation > 1.0) { saturation -= 1.0; }
     543      while (saturation < 0.0) { saturation += 1.0; }
     544      while (value > 1.0) { value -= 1.0; }
     545      while (value < 0.0) { value += 1.0; }
     546      if (hue > 0.999) { hue = 0.999; }
     547      if (hue < 0.001) { hue = 0.001; }
     548      if (saturation > 0.999) { saturation = 0.999; }
     549      if (saturation < 0.001) { return Color.FromArgb((int) (value * 255.0), (int) (value * 255.0), (int) (value * 255.0)); }
     550      if (value > 0.999) { value = 0.999; }
     551      if (value < 0.001) { value = 0.001; }
     552
     553      var h6 = hue * 6.0;
     554      if (h6.IsAlmost(6.0)) { h6 = 0.0; }
    522555      var ihue = (int) h6;
    523       var p = value * (1f - saturation);
    524       var q = value * (1f - saturation * (h6 - ihue));
    525       var t = value * (1f - saturation * (1f - (h6 - ihue)));
     556      var p = value * (1.0 - saturation);
     557      var q = value * (1.0 - saturation * (h6 - ihue));
     558      var t = value * (1.0 - saturation * (1.0 - (h6 - ihue)));
    526559      switch (ihue) {
    527560        case 0:
  • branches/Weighted TSNE/3.4/TSNE/TSNEStatic.cs

    r15451 r15455  
    216216          newData[i, j] = rand.NextDouble() * .0001;
    217217
    218         if (data[0] is IReadOnlyList<double> && !randomInit) {
    219           for (var i = 0; i < noDatapoints; i++)
    220           for (var j = 0; j < newDimensions; j++) {
    221             var row = (IReadOnlyList<double>) data[i];
    222             newData[i, j] = row[j % row.Count];
    223           }
     218        if (!(data[0] is IReadOnlyList<double>) || randomInit) return;
     219        for (var i = 0; i < noDatapoints; i++)
     220        for (var j = 0; j < newDimensions; j++) {
     221          var row = (IReadOnlyList<double>) data[i];
     222          newData[i, j] = row[j % row.Count];
    224223        }
    225224      }
     
    404403        }
    405404      }
    406 
    407405      private static double[][] ComputeDistances(T[] x, IDistance<T> distance) {
    408406        var res = new double[x.Length][];
     
    422420        // return x.Select(m => x.Select(n => distance.Get(m, n)).ToArray()).ToArray();
    423421      }
    424 
    425422      private static double EvaluateErrorExact(double[,] p, double[,] y, int n, int d) {
    426423        // Compute the squared Euclidean distance matrix
     
    450447        return c;
    451448      }
    452 
    453449      private static double EvaluateErrorApproximate(IReadOnlyList<int> rowP, IReadOnlyList<int> colP, IReadOnlyList<double> valP, double[,] y, double theta) {
    454450        // Get estimate of normalization term
     
    592588            ? state.gains[i, j] + .2 // +0.2 nd *0.8 are used in two separate implementations of tSNE -> seems to be correct
    593589            : state.gains[i, j] * .8;
    594 
    595590          if (state.gains[i, j] < .01) state.gains[i, j] = .01;
    596591        }
  • branches/Weighted TSNE/3.4/TSNE/TSNEUtils.cs

    r14414 r15455  
    3535    }
    3636
    37     internal static IList<T> Swap<T>(this IList<T> list, int indexA, int indexB) {
     37    internal static void Swap<T>(this IList<T> list, int indexA, int indexB) {
    3838      var tmp = list[indexA];
    3939      list[indexA] = list[indexB];
    4040      list[indexB] = tmp;
    41       return list;
    4241    }
    4342
    44     internal static int Partition<T>(this IList<T> list, int left, int right, int pivotindex, IComparer<T> comparer) {
     43    private static int Partition<T>(this IList<T> list, int left, int right, int pivotindex, IComparer<T> comparer) {
    4544      var pivotValue = list[pivotindex];
    4645      list.Swap(pivotindex, right);
     
    6766    /// <param name="comparer">comparer for list elemnts </param>
    6867    /// <returns></returns>
    69     internal static T NthElement<T>(this IList<T> list, int left, int right, int n, IComparer<T> comparer) {
     68    internal static void NthElement<T>(this IList<T> list, int left, int right, int n, IComparer<T> comparer) {
    7069      while (true) {
    71         if (left == right) return list[left];
    72         var pivotindex = left + (int)Math.Floor(new System.Random().Next() % (right - (double)left + 1));
     70        if (left == right) return;
     71        var pivotindex = left + (int) Math.Floor(new System.Random().Next() % (right - (double) left + 1));
    7372        pivotindex = list.Partition(left, right, pivotindex, comparer);
    74         if (n == pivotindex) return list[n];
     73        if (n == pivotindex) return;
    7574        if (n < pivotindex) right = pivotindex - 1;
    7675        else left = pivotindex + 1;
Note: See TracChangeset for help on using the changeset viewer.