Changeset 15870 for branches/2906_Transformations/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisTransformationModel.cs
- Timestamp:
- 03/28/18 17:17:31 (6 years ago)
- File:
-
- 1 moved
Legend:
- Unmodified
- Added
- Removed
-
branches/2906_Transformations/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisTransformationModel.cs
r15869 r15870 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; … … 27 28 28 29 namespace HeuristicLab.Problems.DataAnalysis { 29 [Item(" Transformed Regression Model", "A model that was transformed back to match the original variables after the training was performed on transformed variables.")]30 [Item("Data Analysis Transformation Model", "A model that was transformed back to match the original variables after the training was performed on transformed variables.")] 30 31 [StorableClass] 31 public class TransformedRegressionModel : RegressionModel, ITransformedRegressionModel {32 public abstract class DataAnalysisTransformationModel : DataAnalysisModel, IDataAnalysisTransformationModel { 32 33 33 34 [Storable] 34 public I RegressionModel OriginalModel { get; privateset; }35 public IDataAnalysisModel OriginalModel { get; protected set; } 35 36 36 37 [Storable] 37 public ItemList<IDataAnalysisTransformation> Transformations { get; private set; } 38 public ReadOnlyItemList<IDataAnalysisTransformation> InputTransformations { get; protected set; } 39 40 [Storable] 41 public ReadOnlyItemList<IDataAnalysisTransformation> TargetTransformations { get; protected set; } 42 43 // Usually, the TargetVariable is usually only implemented for Regression and Classification. 44 // However, we implement it in the base class for reducing code duplication and to avoid quasi-identical views. 45 [Storable] 46 private string targetVariable; 47 public string TargetVariable { 48 get { return targetVariable; } 49 set { 50 if (string.IsNullOrEmpty(value) || targetVariable == value) return; 51 targetVariable = value; 52 OnTargetVariableChanged(this, EventArgs.Empty); 53 } 54 } 38 55 39 56 public override IEnumerable<string> VariablesUsedForPrediction { 40 get { return OriginalModel.VariablesUsedForPrediction; }57 get { return OriginalModel.VariablesUsedForPrediction; /* TODO: reduce extend-inputs */} 41 58 } 42 59 43 60 #region Constructor, Cloning & Persistence 44 public TransformedRegressionModel(IRegressionModel originalModel, IEnumerable<IDataAnalysisTransformation> transformations) 45 : base(RegressionProblemData.GetOriginalTragetVariable(originalModel.TargetVariable, transformations)) { 46 Name = "Transformed " + originalModel.Name; 61 protected DataAnalysisTransformationModel(IDataAnalysisModel originalModel, IEnumerable<IDataAnalysisTransformation> transformations) 62 : base(originalModel.Name) { 47 63 OriginalModel = originalModel; 48 Transformations = new ItemList<IDataAnalysisTransformation>(transformations); 64 var transitiveInputs = CalculateTransitiveVariables(originalModel.VariablesUsedForPrediction, transformations); 65 InputTransformations = new ItemList<IDataAnalysisTransformation>(transformations.Where(t => transitiveInputs.Contains(t.OriginalVariable))).AsReadOnly(); 66 TargetTransformations = new ReadOnlyItemList<IDataAnalysisTransformation>(); 49 67 } 50 68 51 protected TransformedRegressionModel(TransformedRegressionModel original, Cloner cloner)69 protected DataAnalysisTransformationModel(DataAnalysisTransformationModel original, Cloner cloner) 52 70 : base(original, cloner) { 53 71 OriginalModel = cloner.Clone(original.OriginalModel); 54 Transformations = cloner.Clone(original.Transformations); 55 } 56 57 public override IDeepCloneable Clone(Cloner cloner) { 58 return new TransformedRegressionModel(this, cloner); 72 InputTransformations = cloner.Clone(original.InputTransformations); 73 TargetTransformations = cloner.Clone(original.TargetTransformations); 74 targetVariable = original.targetVariable; 59 75 } 60 76 61 77 [StorableConstructor] 62 protected TransformedRegressionModel(bool deserializing)78 protected DataAnalysisTransformationModel(bool deserializing) 63 79 : base(deserializing) { } 64 80 #endregion 65 81 82 public static ISet<string> CalculateTransitiveVariables(IEnumerable<string> inputVariables, IEnumerable<IDataAnalysisTransformation> transformations) { 83 var transitiveInputs = new HashSet<string>(inputVariables); 84 85 foreach (var transformation in transformations.Reverse()) { 86 if (transitiveInputs.Contains(transformation.TransformedVariable)) { 87 transitiveInputs.Add(transformation.OriginalVariable); 88 } 89 } 90 91 return transitiveInputs; 92 } 93 94 public static IDataset Transform(IDataset dataset, IEnumerable<IDataAnalysisTransformation> transformations) { 95 var modifiableDataset = ((Dataset)dataset).ToModifiable(); 96 97 foreach (var transformation in transformations) { 98 var trans = (ITransformation<double>)transformation.Transformation; 99 100 var originalData = modifiableDataset.GetDoubleValues(transformation.OriginalVariable); 101 if (!trans.Check(originalData, out string errorMessage)) 102 throw new InvalidOperationException($"Cannot estimate Values, Transformation is invalid: {errorMessage}"); 103 104 var transformedData = trans.Apply(originalData).ToList(); 105 if (modifiableDataset.VariableNames.Contains(transformation.TransformedVariable)) 106 modifiableDataset.ReplaceVariable(transformation.TransformedVariable, transformedData); 107 else 108 modifiableDataset.AddVariable(transformation.TransformedVariable, transformedData); 109 } 110 111 return modifiableDataset; 112 } 113 114 115 116 #region Events 117 public event EventHandler TargetVariableChanged; 118 private void OnTargetVariableChanged(object sender, EventArgs args) { 119 var changed = TargetVariableChanged; 120 if (changed != null) 121 changed(sender, args); 122 } 123 #endregion 124 125 126 127 128 129 130 /* 66 131 // dataset in original data range 67 132 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 68 var transformedDataset = Transform Inputs(dataset, Transformations);133 var transformedDataset = Transform(dataset, Transformations); 69 134 70 135 var estimates = OriginalModel.GetEstimatedValues(transformedDataset, rows); 71 136 72 return InverseTransform Estimates(estimates, Transformations, OriginalModel.TargetVariable);137 return InverseTransform(estimates, Transformations, OriginalModel.TargetVariable); 73 138 } 74 139 … … 76 141 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 77 142 // TODO: specialized views for the original solution type are lost (RandomForestSolutionView, ...) 78 return new TransformedRegressionSolution(this, new RegressionProblemData(problemData)); 79 } 80 81 private static IDataset TransformInputs(IDataset dataset, IEnumerable<IDataAnalysisTransformation> transformations) { 82 return DataAnalysisProblemData.Transform(dataset, transformations); 83 } 84 85 private static IEnumerable<double> InverseTransformEstimates(IEnumerable<double> data, IEnumerable<IDataAnalysisTransformation> transformations, string targetVariable) { 86 var estimates = data.ToList(); 87 88 foreach (var transformation in transformations.Reverse()) { 89 if (transformation.TransformedVariable == targetVariable) { 90 var trans = (ITransformation<double>)transformation.Transformation; 91 92 estimates = trans.InverseApply(estimates).ToList(); 93 94 // setup next iteration 95 targetVariable = transformation.OriginalVariable; 96 } 97 } 98 99 return estimates; 100 } 143 return new RegressionSolution(this, new RegressionProblemData(problemData)); 144 } */ 101 145 } 102 146 }
Note: See TracChangeset
for help on using the changeset viewer.