Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/15/17 10:29:40 (7 years ago)
Author:
gkronber
Message:

#2699,#2700
merged r14862, r14863, r14911, r14936, r15156, r15157, r15158, r15164, r15169, r15207:15209, r15225, r15227, r15234, r15248 from trunk to stable

Location:
stable
Files:
3 edited
1 copied

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs

    r14863 r15249  
    3232using HeuristicLab.Parameters;
    3333using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     34using HeuristicLab.PluginInfrastructure;
    3435using HeuristicLab.Problems.DataAnalysis;
    3536using HeuristicLab.Random;
     
    4142  /// </summary>
    4243  [Item("tSNE", "t-distributed stochastic neighbourhood embedding projects the data in a low " +
    43                 "dimensional space to allow visual cluster identification.")]
     44                "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")]
    4445  [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 100)]
    4546  [StorableClass]
     
    5758
    5859    #region parameter names
    59     private const string DistanceParameterName = "DistanceFunction";
     60    private const string DistanceFunctionParameterName = "DistanceFunction";
    6061    private const string PerplexityParameterName = "Perplexity";
    6162    private const string ThetaParameterName = "Theta";
     
    6970    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    7071    private const string SeedParameterName = "Seed";
    71     private const string ClassesParameterName = "ClassNames";
     72    private const string ClassesNameParameterName = "ClassesName";
    7273    private const string NormalizationParameterName = "Normalization";
    7374    private const string UpdateIntervalParameterName = "UpdateInterval";
     
    8687      get { return Parameters[PerplexityParameterName] as IFixedValueParameter<DoubleValue>; }
    8788    }
    88     public IFixedValueParameter<DoubleValue> ThetaParameter {
    89       get { return Parameters[ThetaParameterName] as IFixedValueParameter<DoubleValue>; }
     89    public IFixedValueParameter<PercentValue> ThetaParameter {
     90      get { return Parameters[ThetaParameterName] as IFixedValueParameter<PercentValue>; }
    9091    }
    9192    public IFixedValueParameter<IntValue> NewDimensionsParameter {
    9293      get { return Parameters[NewDimensionsParameterName] as IFixedValueParameter<IntValue>; }
    9394    }
    94     public IValueParameter<IDistance<double[]>> DistanceParameter {
    95       get { return Parameters[DistanceParameterName] as IValueParameter<IDistance<double[]>>; }
     95    public IConstrainedValueParameter<IDistance<double[]>> DistanceFunctionParameter {
     96      get { return Parameters[DistanceFunctionParameterName] as IConstrainedValueParameter<IDistance<double[]>>; }
    9697    }
    9798    public IFixedValueParameter<IntValue> MaxIterationsParameter {
     
    119120      get { return Parameters[SeedParameterName] as IFixedValueParameter<IntValue>; }
    120121    }
    121     public IFixedValueParameter<StringValue> ClassesParameter {
    122       get { return Parameters[ClassesParameterName] as IFixedValueParameter<StringValue>; }
     122    public IConstrainedValueParameter<StringValue> ClassesNameParameter {
     123      get { return Parameters[ClassesNameParameterName] as IConstrainedValueParameter<StringValue>; }
    123124    }
    124125    public IFixedValueParameter<BoolValue> NormalizationParameter {
     
    131132
    132133    #region  Properties
    133     public IDistance<double[]> Distance {
    134       get { return DistanceParameter.Value; }
     134    public IDistance<double[]> DistanceFunction {
     135      get { return DistanceFunctionParameter.Value; }
    135136    }
    136137    public double Perplexity {
     
    178179      set { SeedParameter.Value.Value = value; }
    179180    }
    180     public string Classes {
    181       get { return ClassesParameter.Value.Value; }
    182       set { ClassesParameter.Value.Value = value; }
     181    public string ClassesName {
     182      get { return ClassesNameParameter.Value != null ? ClassesNameParameter.Value.Value : null; }
     183      set { ClassesNameParameter.Value.Value = value; }
    183184    }
    184185    public bool Normalization {
     
    198199
    199200    private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) {
    200       if(original.dataRowNames!=null)
    201       this.dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames);
     201      if (original.dataRowNames != null)
     202        this.dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames);
    202203      if (original.dataRows != null)
    203204        this.dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value));
     
    208209    public override IDeepCloneable Clone(Cloner cloner) { return new TSNEAlgorithm(this, cloner); }
    209210    public TSNEAlgorithm() {
    210       Problem = new RegressionProblem();
    211       Parameters.Add(new ValueParameter<IDistance<double[]>>(DistanceParameterName, "The distance function used to differentiate similar from non-similar points", new EuclideanDistance()));
     211      var distances = new ItemSet<IDistance<double[]>>(ApplicationManager.Manager.GetInstances<IDistance<double[]>>());
     212      Parameters.Add(new ConstrainedValueParameter<IDistance<double[]>>(DistanceFunctionParameterName, "The distance function used to differentiate similar from non-similar points", distances, distances.OfType<EuclideanDistance>().FirstOrDefault()));
    212213      Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity-parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower", new DoubleValue(25)));
    213       Parameters.Add(new FixedValueParameter<DoubleValue>(ThetaParameterName, "Value describing how much appoximated " +
     214      Parameters.Add(new FixedValueParameter<PercentValue>(ThetaParameterName, "Value describing how much appoximated " +
    214215                                                                              "gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. " +
    215216                                                                              "Appropriate values for theta are between 0.1 and 0.7 (default = 0.5). CAUTION: exact calculation of " +
    216217                                                                              "forces requires building a non-sparse N*N matrix where N is the number of data points. This may " +
    217218                                                                              "exceed memory limitations. The function is designed to run on large (N > 5000) data sets. It may give" +
    218                                                                               " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new DoubleValue(0)));
     219                                                                              " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new PercentValue(0)));
    219220      Parameters.Add(new FixedValueParameter<IntValue>(NewDimensionsParameterName, "Dimensionality of projected space (usually 2 for easy visual analysis)", new IntValue(2)));
    220221      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent.", new IntValue(1000)));
     
    226227      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "If the seed should be random.", new BoolValue(true)));
    227228      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The seed used if it should not be random.", new IntValue(0)));
    228       Parameters.Add(new FixedValueParameter<StringValue>(ClassesParameterName, "name of the column specifying the class lables of each data point. If the label column can not be found training/test is used as labels.", new StringValue("none")));
     229      Parameters.Add(new OptionalConstrainedValueParameter<StringValue>(ClassesNameParameterName, "Name of the column specifying the class lables of each data point. If this is not set training/test is used as labels."));
    229230      Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored.", new BoolValue(true)));
    230       Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "", new IntValue(50)));
     231      Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "The interval after which the results will be updated.", new IntValue(50)));
    231232      Parameters[UpdateIntervalParameterName].Hidden = true;
    232233
     
    236237      StopLyingIterationParameter.Hidden = true;
    237238      EtaParameter.Hidden = false;
     239      Problem = new RegressionProblem();
    238240    }
    239241    #endregion
     
    269271        if (Normalization) data = NormalizeData(data);
    270272
    271         state = TSNEStatic<double[]>.CreateState(data, Distance, random, NewDimensions, Perplexity, Theta,
     273        state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta,
    272274          StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta);
    273275
     
    283285    }
    284286
     287    #region Events
     288    protected override void OnProblemChanged() {
     289      base.OnProblemChanged();
     290      if (Problem == null) return;
     291      OnProblemDataChanged(this, null);
     292    }
     293
     294    protected override void RegisterProblemEvents() {
     295      base.RegisterProblemEvents();
     296      Problem.ProblemDataChanged += OnProblemDataChanged;
     297    }
     298    protected override void DeregisterProblemEvents() {
     299      base.DeregisterProblemEvents();
     300      Problem.ProblemDataChanged -= OnProblemDataChanged;
     301    }
     302
     303    private void OnProblemDataChanged(object sender, EventArgs args) {
     304      if (Problem == null || Problem.ProblemData == null) return;
     305      if (!Parameters.ContainsKey(ClassesNameParameterName)) return;
     306      ClassesNameParameter.ValidValues.Clear();
     307      foreach (var input in Problem.ProblemData.InputVariables) ClassesNameParameter.ValidValues.Add(input);
     308    }
     309
     310    #endregion
     311
     312    #region Helpers
    285313    private void SetUpResults(IReadOnlyCollection<double[]> data) {
    286314      if (Results == null) return;
     
    291319
    292320      //color datapoints acording to classes variable (be it double or string)
    293       if (problemData.Dataset.VariableNames.Contains(Classes)) {
    294         if ((problemData.Dataset as Dataset).VariableHasType<string>(Classes)) {
    295           var classes = problemData.Dataset.GetStringValues(Classes).ToArray();
     321      if (problemData.Dataset.VariableNames.Contains(ClassesName)) {
     322        if ((problemData.Dataset as Dataset).VariableHasType<string>(ClassesName)) {
     323          var classes = problemData.Dataset.GetStringValues(ClassesName).ToArray();
    296324          for (var i = 0; i < classes.Length; i++) {
    297325            if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
    298326            dataRowNames[classes[i]].Add(i);
    299327          }
    300         } else if ((problemData.Dataset as Dataset).VariableHasType<double>(Classes)) {
    301           var classValues = problemData.Dataset.GetDoubleValues(Classes).ToArray();
     328        } else if ((problemData.Dataset as Dataset).VariableHasType<double>(ClassesName)) {
     329          var classValues = problemData.Dataset.GetDoubleValues(ClassesName).ToArray();
    302330          var max = classValues.Max() + 0.1;
    303331          var min = classValues.Min() - 0.1;
     
    377405      for (var i = 0; i < data.GetLength(0); i++) {
    378406        for (var j = 0; j < data.GetLength(1); j++) {
    379           res[i, j] = (data[i, j] - (max[j] + min[j]) / 2) / (max[j] - min[j]);
     407          var d = max[j] - min[j];
     408          var s = data[i, j] - (max[j] + min[j]) / 2;  //shift data
     409          if (d.IsAlmost(0)) res[i, j] = data[i, j];   //no scaling possible
     410          else res[i, j] = s / d;  //scale data
    380411        }
    381412      }
     
    395426      for (var i = 0; i < data.Count; i++) {
    396427        nData[i] = new double[n];
    397         for (var j = 0; j < n; j++) nData[i][j] = (data[i][j] - mean[j]) / max[j];
     428        for (var j = 0; j < n; j++) nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j];
    398429      }
    399430      return nData;
     
    416447      return "[" + (min + i * size) + ";" + (min + (i + 1) * size) + ")";
    417448    }
     449    #endregion
    418450  }
    419451}
Note: See TracChangeset for help on using the changeset viewer.