Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2457_ExpertSystem/HeuristicLab.Algorithms.MemPR/3.3/Util/CkMeans1D.cs @ 16310

Last change on this file since 16310 was 14420, checked in by abeham, 8 years ago

#2708: added binary version of mempr with new concepts of scope in basic alg

File size: 3.9 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading.Tasks;
6using HeuristicLab.Common;
7
8namespace HeuristicLab.Algorithms.MemPR.Util {
9  /// <summary>
10  /// Implements the Ckmeans.1d.dp method. It is described in the paper:
11  /// Haizhou Wang and Mingzhou Song. 2011.
12  /// Ckmeans.1d.dp: Optimal k-means Clustering in One Dimension by Dynamic Programming
13  /// The R Journal Vol. 3/2, pp. 29-33.
14  /// available at https://journal.r-project.org/archive/2011-2/RJournal_2011-2_Wang+Song.pdf
15  /// </summary>
16  public class CkMeans1D {
17    /// <summary>
18    /// Clusters the 1-dimensional data given in <paramref name="estimations"/>.
19    /// </summary>
20    /// <param name="estimations">The 1-dimensional data that should be clustered.</param>
21    /// <param name="k">The maximum number of clusters.</param>
22    /// <param name="clusterValues">A vector of the same length as estimations that assigns to each point a cluster id.</param>
23    /// <returns>A sorted list of cluster centroids and corresponding cluster ids.</returns>
24    public static SortedList<double, int> Cluster(double[] estimations, int k, out int[] clusterValues) {
25      int nPoints = estimations.Length;
26      var distinct = estimations.Distinct().OrderBy(x => x).ToArray();
27      var max = distinct.Max();
28      if (distinct.Length <= k) {
29        var dict = distinct.Select((v, i) => new { Index = i, Value = v }).ToDictionary(x => x.Value, y => y.Index);
30        for (int i = distinct.Length; i < k; i++)
31          dict.Add(max + i - distinct.Length + 1, i);
32
33        clusterValues = new int[nPoints];
34        for (int i = 0; i < nPoints; i++)
35          if (!dict.ContainsKey(estimations[i])) clusterValues[i] = 0;
36          else clusterValues[i] = dict[estimations[i]];
37
38        return new SortedList<double, int>(dict);
39      }
40
41      var n = distinct.Length;
42      var D = new double[n, k];
43      var B = new int[n, k];
44
45      for (int m = 0; m < k; m++) {
46        for (int j = m; j <= n - k + m; j++) {
47          if (m == 0)
48            D[j, m] = SumOfSquaredDistances(distinct, 0, j + 1);
49          else {
50            var minD = double.MaxValue;
51            var minI = 0;
52            for (int i = 1; i <= j; i++) {
53              var d = D[i - 1, m - 1] + SumOfSquaredDistances(distinct, i, j + 1);
54              if (d < minD) {
55                minD = d;
56                minI = i;
57              }
58            }
59            D[j, m] = minD;
60            B[j, m] = minI;
61          }
62        }
63      }
64
65      var centers = new SortedList<double, int>();
66      var upper = B[n - 1, k - 1];
67      var c = Mean(distinct, upper, n);
68      centers.Add(c, k - 1);
69      for (int i = k - 2; i >= 0; i--) {
70        var lower = B[upper - 1, i];
71        var c2 = Mean(distinct, lower, upper);
72        centers.Add(c2, i);
73        upper = lower;
74      }
75
76      clusterValues = new int[nPoints];
77      for (int i = 0; i < estimations.Length; i++) {
78        clusterValues[i] = centers.MinItems(x => Math.Abs(estimations[i] - x.Key)).First().Value;
79      }
80
81      return centers;
82    }
83
84    private static double SumOfSquaredDistances(double[] x, int start, int end) {
85      if (start == end) throw new InvalidOperationException();
86      if (start + 1 == end) return 0.0;
87      double mean = 0.0;
88      for (int i = start; i < end; i++) {
89        mean += x[i];
90      }
91      mean /= (end - start);
92      var sum = 0.0;
93      for (int i = start; i < end; i++) {
94        sum += (x[i] - mean) * (x[i] - mean);
95      }
96      return sum;
97    }
98
99    private static double Mean(double[] x, int start, int end) {
100      if (start == end) throw new InvalidOperationException();
101      double mean = 0.0;
102      for (int i = start; i < end; i++) {
103        mean += x[i];
104      }
105      mean /= (end - start);
106      return mean;
107    }
108  }
109}
Note: See TracBrowser for help on using the repository browser.