source: branches/2906_Transformations/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisTransformation.cs @ 15885

Last change on this file since 15885 was 15885, checked in by pfleck, 4 years ago

#2906 Updated project references + small refactoring

File size: 13.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30
31namespace HeuristicLab.Problems.DataAnalysis {
32  [Item("Transformation", "A transformation applied to a DataAnalysisProblemData")]
33  [StorableClass]
34  public sealed class DataAnalysisTransformation : ParameterizedNamedItem, IDataAnalysisTransformation {
35    #region Parameter Properties
36    private IFixedValueParameter<StringValue> OriginalVariableParameter {
37      get { return (IFixedValueParameter<StringValue>)Parameters["Original Variable"]; }
38    }
39
40    private IFixedValueParameter<StringValue> TransformedVariableParameter {
41      get { return (IFixedValueParameter<StringValue>)Parameters["Transformed Variable"]; }
42    }
43
44    private ValueParameter<ITransformation> TransformationParameter {
45      get { return (ValueParameter<ITransformation>)Parameters["Transformation"]; }
46    }
47    #endregion
48
49    #region Properties
50    public string OriginalVariable {
51      get { return OriginalVariableParameter.Value.Value; }
52    }
53
54    public string TransformedVariable {
55      get { return TransformedVariableParameter.Value.Value; }
56    }
57
58    public ITransformation Transformation {
59      get { return TransformationParameter.Value; }
60    }
61    #endregion
62
63    #region Constructor, Cloning & Persistence
64    public DataAnalysisTransformation(string originalVariable, string transformedVariable, ITransformation transformation)
65      : base() {
66      Parameters.Add(new FixedValueParameter<StringValue>("Original Variable", new StringValue(originalVariable).AsReadOnly()));
67      Parameters.Add(new FixedValueParameter<StringValue>("Transformed Variable", new StringValue(transformedVariable).AsReadOnly()));
68      Parameters.Add(new ValueParameter<ITransformation>("Transformation", transformation)); // TODO: should be readonly/fixed; alternatively lock in view
69    }
70
71    private DataAnalysisTransformation(DataAnalysisTransformation original, Cloner cloner)
72      : base(original, cloner) { }
73
74    public override IDeepCloneable Clone(Cloner cloner) {
75      return new DataAnalysisTransformation(this, cloner);
76    }
77
78    [StorableConstructor]
79    private DataAnalysisTransformation(bool deserializing)
80      : base(deserializing) { }
81
82    [StorableHook(HookType.AfterDeserialization)]
83    #endregion
84
85    public override string ToString() {
86      return $"{Transformation} ({OriginalVariable} -> {TransformedVariable})";
87    }
88
89    #region Transformation
90
91    #region Variable Extension & Reduction
92    // originals => include extended
93    public static IEnumerable<string> ExtendVariables(IEnumerable<string> variables, IEnumerable<IDataAnalysisTransformation> transformations) {
94      return GetTransitiveVariables(variables, transformations, inverse: false);
95    }
96
97    // extended => originals
98    public static IEnumerable<string> ReduceVariables(IEnumerable<string> variables, IEnumerable<IDataAnalysisTransformation> transformations) {
99      var originalVariables = new HashSet<string>();
100      foreach (var variable in variables) {
101        var originalVariable = GetStrictTransitiveVariables(variable, transformations, inverse: true).Last();
102        originalVariables.Add(originalVariable);
103      }
104
105      return originalVariables;
106    }
107
108    // return all reachable variables
109    public static IEnumerable<string> GetTransitiveVariables(IEnumerable<string> variables, IEnumerable<IDataAnalysisTransformation> transformations, bool inverse = false) {
110      var reachableVariables = new HashSet<string>(variables);
111      if (inverse) transformations = transformations.Reverse();
112      foreach (var transformation in transformations) {
113        var source = inverse ? transformation.TransformedVariable : transformation.OriginalVariable;
114        var target = inverse ? transformation.OriginalVariable : transformation.TransformedVariable;
115        if (reachableVariables.Contains(source))
116          reachableVariables.Add(target);
117      }
118
119      return reachableVariables;
120    }
121
122    // return the (unique) chain of transformations for a given variable
123    public static IEnumerable<string> GetStrictTransitiveVariables(string variable, IEnumerable<IDataAnalysisTransformation> transformations, bool inverse = false) {
124      yield return variable;
125      if (inverse) transformations = transformations.Reverse();
126      foreach (var transformation in transformations) {
127        var source = inverse ? transformation.TransformedVariable : transformation.OriginalVariable;
128        var target = inverse ? transformation.OriginalVariable : transformation.TransformedVariable;
129        if (variable == source) {
130          variable = target;
131          yield return variable;
132        }
133      }
134    }
135    #endregion
136
137    #region Transform Dataset
138    public static IDataset Transform(IDataset dataset, IEnumerable<IDataAnalysisTransformation> transformations) {
139      var modifiableDataset = ((Dataset)dataset).ToModifiable();
140
141      foreach (var transformation in transformations) {
142        var trans = (ITransformation<double>)transformation.Transformation;
143
144        var originalData = modifiableDataset.GetDoubleValues(transformation.OriginalVariable);
145        var transformedData = trans.Apply(originalData).ToList();
146        if (modifiableDataset.VariableNames.Contains(transformation.TransformedVariable))
147          modifiableDataset.ReplaceVariable(transformation.TransformedVariable, transformedData);
148        else
149          modifiableDataset.AddVariable(transformation.TransformedVariable, transformedData);
150      }
151
152      return new Dataset(modifiableDataset);
153    }
154
155    public static IDataset InverseTransform(IDataset dataset, IEnumerable<IDataAnalysisTransformation> transformations, bool removeVirtualVariables = true) {
156      var modifiableDataset = ((Dataset)dataset).ToModifiable();
157
158      var transformationsStack = new Stack<IDataAnalysisTransformation>(transformations);
159      while (transformationsStack.Any()) {
160        var transformation = transformationsStack.Pop();
161        var trans = (ITransformation<double>)transformation.Transformation;
162
163        var prevTransformations = transformations.Except(transformationsStack);
164        bool originalWasChanged = prevTransformations.Any(x => x.TransformedVariable == transformation.OriginalVariable);
165        if (originalWasChanged) {
166          var transformedData = modifiableDataset.GetDoubleValues(transformation.TransformedVariable);
167
168          var originalData = trans.InverseApply(transformedData).ToList();
169          modifiableDataset.ReplaceVariable(transformation.OriginalVariable, originalData);
170        }
171      }
172
173      if (removeVirtualVariables) {
174        var originalVariables = ReduceVariables(dataset.VariableNames, transformations);
175        var virtualVariables = dataset.VariableNames.Except(originalVariables);
176        foreach (var virtualVariable in virtualVariables)
177          modifiableDataset.RemoveVariable(virtualVariable);
178      }
179
180      return new Dataset(modifiableDataset);
181    }
182    #endregion
183
184    #region Transform ProblemData
185    public static IDataAnalysisProblemData ApplyTransformations(IDataAnalysisProblemData problemData) {
186      var newDataset = Transform(problemData.Dataset, problemData.Transformations);
187      var extendedInputs = ExtendVariables(problemData.AllowedInputVariables, problemData.Transformations);
188
189      return CreateNewProblemData(problemData, newDataset, extendedInputs, inverse: false);
190    }
191
192    public static IDataAnalysisProblemData InverseApplyTransformations(IDataAnalysisProblemData problemData) {
193      var newDataset = InverseTransform(problemData.Dataset, problemData.Transformations);
194      var reducedInputs = ReduceVariables(problemData.AllowedInputVariables, problemData.Transformations);
195
196      return CreateNewProblemData(problemData, newDataset, reducedInputs, inverse: true);
197    }
198
199    private static IDataAnalysisProblemData CreateNewProblemData(IDataAnalysisProblemData problemData, IDataset dataset, IEnumerable<string> inputs, bool inverse = false) {
200      IDataAnalysisProblemData newProblemData;
201      if (problemData is IRegressionProblemData regressionProblemData) {
202        var newTargetVariable = GetStrictTransitiveVariables(regressionProblemData.TargetVariable, problemData.Transformations, inverse).Last();
203        if (problemData is ITimeSeriesPrognosisProblemData timeSeriesPrognosisProblemData) {
204          newProblemData = new TimeSeriesPrognosisProblemData(dataset, inputs, newTargetVariable, problemData.Transformations) {
205            TrainingHorizon = timeSeriesPrognosisProblemData.TrainingHorizon,
206            TestHorizon = timeSeriesPrognosisProblemData.TestHorizon,
207          };
208
209        } else
210          newProblemData = new RegressionProblemData(dataset, inputs, newTargetVariable, problemData.Transformations);
211      } else if (problemData is IClassificationProblemData classificationProblemData) {
212        newProblemData = new ClassificationProblemData(dataset, inputs, classificationProblemData.TargetVariable, problemData.Transformations);
213      } else if (problemData is IClusteringProblemData) {
214        newProblemData = new ClusteringProblemData(dataset, inputs, problemData.Transformations);
215      } else throw new NotSupportedException("Type of ProblemData not supported");
216
217      newProblemData.TrainingPartition.Start = problemData.TrainingPartition.Start;
218      newProblemData.TrainingPartition.End = problemData.TrainingPartition.End;
219      newProblemData.TestPartition.Start = problemData.TestPartition.Start;
220      newProblemData.TestPartition.End = problemData.TestPartition.End;
221
222      return newProblemData;
223    }
224    #endregion
225
226    #region Transform Model
227    // problemdata required for type-switch. cannot differ based on model type (e.g. RF model is both regression and classification)
228    public static IDataAnalysisTransformationModel CreateTransformationIntegratedModel(IDataAnalysisModel model, IEnumerable<IDataAnalysisTransformation> transformations, IDataAnalysisProblemData problemData) {
229      if (model is IDataAnalysisTransformationModel)
230        throw new InvalidOperationException("Model already is a transformation model.");
231
232      if (problemData is ITimeSeriesPrognosisProblemData)
233        return new TimeSeriesPrognosisTransformationModel((ITimeSeriesPrognosisModel)model, transformations);
234      if (problemData is IRegressionProblemData)
235        return new RegressionTransformationModel((IRegressionModel)model, transformations);
236      if (problemData is IClassificationProblemData)
237        return new ClassificationTransformationModel((IClassificationModel)model, transformations);
238      if (problemData is IClusteringProblemData)
239        return new ClusteringTransformationModel((IClusteringModel)model, transformations);
240
241      throw new NotSupportedException("Type of the model is not supported;");
242    }
243
244    public static IDataAnalysisModel RestoreTrainedModel(IDataAnalysisModel transformationModel, IEnumerable<IDataAnalysisTransformation> transformations) {
245      if (!(transformationModel is IDataAnalysisTransformationModel model))
246        throw new InvalidOperationException("Cannot restore because model is not a TransformationModel");
247      return model.OriginalModel;
248    }
249    #endregion
250
251    #region Transform Solution
252    public static IDataAnalysisSolution TransformSolution(IDataAnalysisSolution solution) {
253      var transformations = solution.ProblemData.Transformations;
254
255      var model = solution.Model is IDataAnalysisTransformationModel // TODO: what if model is a integrated sym-reg model?
256        ? RestoreTrainedModel(solution.Model, transformations)
257        : CreateTransformationIntegratedModel(solution.Model, transformations, solution.ProblemData);
258
259      var data = solution.Model is IDataAnalysisTransformationModel
260        ? ApplyTransformations(solution.ProblemData) // original -> transformed
261        : InverseApplyTransformations(solution.ProblemData); // transformed -> original
262
263      return CreateSolution(model, data);
264    }
265
266    private static IDataAnalysisSolution CreateSolution(IDataAnalysisModel model, IDataAnalysisProblemData problemData) {
267      if (problemData is ITimeSeriesPrognosisProblemData)
268        return ((ITimeSeriesPrognosisModel)model).CreateTimeSeriesPrognosisSolution((ITimeSeriesPrognosisProblemData)problemData);
269      if (problemData is IRegressionProblemData)
270        return ((IRegressionModel)model).CreateRegressionSolution((IRegressionProblemData)problemData);
271      if (problemData is IClassificationProblemData)
272        return ((IClassificationModel)model).CreateClassificationSolution((IClassificationProblemData)problemData);
273      //if (problemData is IClusteringProblemData)
274      //  return ((IClusteringModel)model).CreateClusteringSolution((IClusteringProblemData)problemData);
275
276      throw new NotSupportedException("Cannot create Solution of the model type.");
277    }
278    #endregion
279
280    #endregion
281  }
282}
Note: See TracBrowser for help on using the repository browser.