Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/04/18 17:18:02 (6 years ago)
Author:
pfleck
Message:

#2906 Refactoring

  • Moved transformation-specific parts out of existing interfaces.
  • Moved all Transformation logic to DataAnalysisTransformation.
  • Simplified (Inverse)Transformation of Dataset/ProblemData/Model/Solution.
File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2906_Transformations/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisTransformation.cs

    r15879 r15884  
    2020#endregion
    2121
     22using System;
     23using System.Collections.Generic;
     24using System.Linq;
    2225using HeuristicLab.Common;
    2326using HeuristicLab.Core;
     
    6770
    6871    private DataAnalysisTransformation(DataAnalysisTransformation original, Cloner cloner)
    69       : base(original, cloner) {
    70     }
     72      : base(original, cloner) { }
    7173
    7274    public override IDeepCloneable Clone(Cloner cloner) {
     
    8486      return $"{Transformation} ({OriginalVariable} -> {TransformedVariable})";
    8587    }
     88
     89    #region Transformation
     90
     91    #region Variable Extension & Reduction
     92    // originals => include extended
     93    public static IEnumerable<string> ExtendVariables(IEnumerable<string> variables, IEnumerable<IDataAnalysisTransformation> transformations) {
     94      return GetTransitiveVariables(variables, transformations);
     95    }
     96
     97    // extended => originals
     98    public static IEnumerable<string> ReduceVariables(IEnumerable<string> variables, IEnumerable<IDataAnalysisTransformation> transformations) {
     99      var originalVariables = new HashSet<string>();
     100      foreach (var variable in variables)
     101        originalVariables.Add(GetLastTransitiveVariable(variable, transformations, inverse: true));
     102      return originalVariables;
     103    }
     104
     105    public static IEnumerable<string> GetTransitiveVariables(IEnumerable<string> variables, IEnumerable<IDataAnalysisTransformation> transformations, bool inverse = false) {
     106      var reachableVariables = new HashSet<string>(variables);
     107      if (inverse) transformations = transformations.Reverse();
     108      foreach (var transformation in transformations) {
     109        var source = inverse ? transformation.TransformedVariable : transformation.OriginalVariable;
     110        var target = inverse ? transformation.OriginalVariable : transformation.TransformedVariable;
     111        if (reachableVariables.Contains(source))
     112          reachableVariables.Add(target);
     113      }
     114
     115      return reachableVariables;
     116    }
     117
     118    public static string GetLastTransitiveVariable(string variable, IEnumerable<IDataAnalysisTransformation> transformations, bool inverse = false) {
     119      if (inverse) transformations = transformations.Reverse();
     120      foreach (var transformation in transformations) {
     121        var source = inverse ? transformation.TransformedVariable : transformation.OriginalVariable;
     122        var target = inverse ? transformation.OriginalVariable : transformation.TransformedVariable;
     123        if (variable == source)
     124          variable = target;
     125      }
     126
     127      return variable;
     128    }
     129    #endregion
     130
     131    #region Transform Dataset
     132    public static IDataset Transform(IDataset dataset, IEnumerable<IDataAnalysisTransformation> transformations) {
     133      var modifiableDataset = ((Dataset)dataset).ToModifiable();
     134
     135      foreach (var transformation in transformations) {
     136        var trans = (ITransformation<double>)transformation.Transformation;
     137
     138        var originalData = modifiableDataset.GetDoubleValues(transformation.OriginalVariable);
     139        //if (!trans.Check(originalData, out string errorMessage))
     140        //  throw new InvalidOperationException($"Cannot estimate Values, Transformation is invalid: {errorMessage}");
     141        // TODO: check was already called before configure (in preprocessing)
     142        // TODO: newly specified data might not pass the check but it does not matter because the data is not configured with
     143        // e.g. impact calculation -> replacement=most common -> originalMean is zero
     144
     145        var transformedData = trans.Apply(originalData).ToList();
     146        if (modifiableDataset.VariableNames.Contains(transformation.TransformedVariable))
     147          modifiableDataset.ReplaceVariable(transformation.TransformedVariable, transformedData);
     148        else
     149          modifiableDataset.AddVariable(transformation.TransformedVariable, transformedData);
     150      }
     151
     152      return modifiableDataset; // TODO: to regular dataset?
     153    }
     154
     155    public static IDataset InverseTransform(IDataset dataset, IEnumerable<IDataAnalysisTransformation> transformations, bool removeVirtualVariables = true) {
     156      var modifiableDataset = ((Dataset)dataset).ToModifiable();
     157
     158      var transformationsStack = new Stack<IDataAnalysisTransformation>(transformations);
     159      while (transformationsStack.Any()) {
     160        var transformation = transformationsStack.Pop();
     161        var trans = (ITransformation<double>)transformation.Transformation;
     162
     163        var prevTransformations = transformations.Except(transformationsStack);
     164        bool originalWasChanged = prevTransformations.Any(x => x.TransformedVariable == transformation.OriginalVariable);
     165        if (originalWasChanged) {
     166          var transformedData = modifiableDataset.GetDoubleValues(transformation.TransformedVariable);
     167
     168          var originalData = trans.InverseApply(transformedData).ToList();
     169          modifiableDataset.ReplaceVariable(transformation.OriginalVariable, originalData);
     170        }
     171      }
     172
     173      if (removeVirtualVariables) {
     174        var originalVariables = ReduceVariables(dataset.VariableNames, transformations);
     175        var virtualVariables = dataset.VariableNames.Except(originalVariables);
     176        foreach (var virtualVariable in virtualVariables)
     177          modifiableDataset.RemoveVariable(virtualVariable);
     178      }
     179
     180      return modifiableDataset; // TODO: to regular dataset?
     181    }
     182    #endregion
     183
     184    #region Transform ProblemData
     185    public static IDataAnalysisProblemData ApplyTransformations(IDataAnalysisProblemData problemData) {
     186      var newDataset = Transform(problemData.Dataset, problemData.Transformations);
     187      var extendedInputs = ExtendVariables(problemData.AllowedInputVariables, problemData.Transformations);
     188
     189      return CreateNewProblemData(problemData, newDataset, extendedInputs, inverse: false);
     190    }
     191
     192    public static IDataAnalysisProblemData InverseApplyTransformations(IDataAnalysisProblemData problemData) {
     193      var newDataset = InverseTransform(problemData.Dataset, problemData.Transformations);
     194      var reducedInputs = ReduceVariables(problemData.AllowedInputVariables, problemData.Transformations);
     195
     196      return CreateNewProblemData(problemData, newDataset, reducedInputs, inverse: true);
     197    }
     198
     199    private static IDataAnalysisProblemData CreateNewProblemData(IDataAnalysisProblemData problemData, IDataset dataset, IEnumerable<string> inputs, bool inverse = false) {
     200      IDataAnalysisProblemData newProblemData;
     201      if (problemData is IRegressionProblemData regressionProblemData) {
     202        var newTargetVariable = GetLastTransitiveVariable(regressionProblemData.TargetVariable, problemData.Transformations, inverse);
     203        newProblemData = new RegressionProblemData(dataset, inputs, newTargetVariable, problemData.Transformations);
     204      } else if (problemData is IClassificationProblemData classificationProblemData) {
     205        newProblemData = new ClassificationProblemData(dataset, inputs, classificationProblemData.TargetVariable, problemData.Transformations);
     206      } else throw new NotSupportedException("Type of ProblemData not supported");
     207
     208      newProblemData.TrainingPartition.Start = problemData.TrainingPartition.Start;
     209      newProblemData.TrainingPartition.End = problemData.TrainingPartition.End;
     210      newProblemData.TestPartition.Start = problemData.TestPartition.Start;
     211      newProblemData.TestPartition.End = problemData.TestPartition.End;
     212
     213      return newProblemData;
     214    }
     215    #endregion
     216
     217    #region Transform Model
     218    public static IDataAnalysisTransformationModel CreateTransformationIntegratedModel(IDataAnalysisModel model, IEnumerable<IDataAnalysisTransformation> transformations) {
     219      if (model is IDataAnalysisTransformationModel)
     220        throw new InvalidOperationException("Model already is a transformation model.");
     221
     222      switch (model) {
     223        case ITimeSeriesPrognosisModel timeSeriesPrognosisModel:
     224          return new TimeSeriesPrognosisTransformationModel(timeSeriesPrognosisModel, transformations);
     225        case IRegressionModel regressionModel:
     226          return new RegressionTransformationModel(regressionModel, transformations);
     227        case IClassificationModel classificationModel:
     228          return new ClassificationTransformationModel(classificationModel, transformations);
     229        case IClusteringModel clusteringModel:
     230          return new ClusteringTransformationModel(clusteringModel, transformations);
     231        default:
     232          throw new NotSupportedException("Type of the model is not supported;");
     233      }
     234    }
     235
     236    public static IDataAnalysisModel RestoreTrainedModel(IDataAnalysisModel transformationModel, IEnumerable<IDataAnalysisTransformation> transformations) {
     237      if (!(transformationModel is IDataAnalysisTransformationModel model))
     238        throw new InvalidOperationException("Cannot restore because model is not a TransformationModel");
     239      return model.OriginalModel;
     240    }
     241    #endregion
     242
     243    #region Transform Solution
     244    public static IDataAnalysisSolution TransformSolution(IDataAnalysisSolution solution) {
     245      var transformations = solution.ProblemData.Transformations;
     246
     247      var model = solution.Model is IDataAnalysisTransformationModel // TODO: what if model is a integrated sym-reg model?
     248        ? RestoreTrainedModel(solution.Model, transformations)
     249        : CreateTransformationIntegratedModel(solution.Model, transformations);
     250
     251      var data = solution.Model is IDataAnalysisTransformationModel
     252        ? ApplyTransformations(solution.ProblemData) // original -> transformed
     253        : InverseApplyTransformations(solution.ProblemData); // transformed -> original
     254
     255      return CreateSolution(model, data);
     256    }
     257
     258    private static IDataAnalysisSolution CreateSolution(IDataAnalysisModel model, IDataAnalysisProblemData problemData) {
     259      switch (model) {
     260        case ITimeSeriesPrognosisModel timeSeriesPrognosisModel:
     261          return timeSeriesPrognosisModel.CreateTimeSeriesPrognosisSolution((ITimeSeriesPrognosisProblemData)problemData);
     262        case IRegressionModel regressionModel:
     263          return regressionModel.CreateRegressionSolution((IRegressionProblemData)problemData);
     264        case IClassificationModel classificationModel:
     265          return classificationModel.CreateClassificationSolution((IClassificationProblemData)problemData);
     266        default:
     267          throw new NotSupportedException("Cannot create Solution of the model type.");
     268      }
     269    }
     270    #endregion
     271
     272    #endregion
    86273  }
    87274}
Note: See TracChangeset for help on using the changeset viewer.