Changeset 8594 for trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationModel.cs
- Timestamp:
- 09/07/12 11:27:49 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationModel.cs
r8550 r8594 34 34 [StorableClass] 35 35 [Item(Name = "SymbolicDiscriminantFunctionClassificationModel", Description = "Represents a symbolic classification model unsing a discriminant function.")] 36 public class SymbolicDiscriminantFunctionClassificationModel : Symbolic DataAnalysisModel, ISymbolicDiscriminantFunctionClassificationModel {36 public class SymbolicDiscriminantFunctionClassificationModel : SymbolicClassificationModel, ISymbolicDiscriminantFunctionClassificationModel { 37 37 38 38 [Storable] … … 48 48 private set { classValues = value.ToArray(); } 49 49 } 50 51 private IDiscriminantFunctionThresholdCalculator thresholdCalculator; 50 52 [Storable] 51 p rivate double lowerEstimationLimit;52 public double LowerEstimationLimit { get { return lowerEstimationLimit; }}53 [Storable]54 private double upperEstimationLimit;55 public double UpperEstimationLimit { get { return upperEstimationLimit; } } 53 public IDiscriminantFunctionThresholdCalculator ThresholdCalculator { 54 get { return thresholdCalculator; } 55 private set { thresholdCalculator = value; } 56 } 57 56 58 57 59 [StorableConstructor] … … 61 63 classValues = (double[])original.classValues.Clone(); 62 64 thresholds = (double[])original.thresholds.Clone(); 63 lowerEstimationLimit = original.lowerEstimationLimit; 64 upperEstimationLimit = original.upperEstimationLimit; 65 thresholdCalculator = cloner.Clone(original.thresholdCalculator); 65 66 } 66 public SymbolicDiscriminantFunctionClassificationModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, 67 public SymbolicDiscriminantFunctionClassificationModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDiscriminantFunctionThresholdCalculator thresholdCalculator, 67 68 double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue) 68 : base(tree, interpreter ) {69 : base(tree, interpreter, lowerEstimationLimit, upperEstimationLimit) { 69 70 this.thresholds = new double[0]; 70 71 this.classValues = new double[0]; 71 this.lowerEstimationLimit = lowerEstimationLimit; 72 this.upperEstimationLimit = upperEstimationLimit; 72 this.ThresholdCalculator = thresholdCalculator; 73 } 74 75 [StorableHook(HookType.AfterDeserialization)] 76 private void AfterDeserialization() { 77 if (ThresholdCalculator == null) ThresholdCalculator = new AccuracyMaximizationThresholdCalculator(); 73 78 } 74 79 … … 87 92 } 88 93 89 public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) { 90 return Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows).LimitToRange(lowerEstimationLimit, upperEstimationLimit); 94 public override void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) { 95 double[] classValues; 96 double[] thresholds; 97 var targetClassValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 98 var estimatedTrainingValues = GetEstimatedValues(problemData.Dataset, rows); 99 thresholdCalculator.Calculate(problemData, estimatedTrainingValues, targetClassValues, out classValues, out thresholds); 100 SetThresholdsAndClassValues(thresholds, classValues); 91 101 } 92 102 93 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 103 public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) { 104 return Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows).LimitToRange(LowerEstimationLimit, UpperEstimationLimit); 105 } 106 107 public override IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 94 108 if (!Thresholds.Any() && !ClassValues.Any()) throw new ArgumentException("No thresholds and class values were set for the current symbolic classification model."); 95 109 foreach (var x in GetEstimatedValues(dataset, rows)) { … … 104 118 } 105 119 106 public SymbolicDiscriminantFunctionClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 120 121 public override ISymbolicClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 122 return CreateDiscriminantClassificationSolution(problemData); 123 } 124 public SymbolicDiscriminantFunctionClassificationSolution CreateDiscriminantClassificationSolution(IClassificationProblemData problemData) { 107 125 return new SymbolicDiscriminantFunctionClassificationSolution(this, new ClassificationProblemData(problemData)); 108 126 } 109 127 IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) { 110 return Create ClassificationSolution(problemData);128 return CreateDiscriminantClassificationSolution(problemData); 111 129 } 112 130 IDiscriminantFunctionClassificationSolution IDiscriminantFunctionClassificationModel.CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData) { 113 return Create ClassificationSolution(problemData);131 return CreateDiscriminantClassificationSolution(problemData); 114 132 } 115 133 … … 121 139 } 122 140 #endregion 123 124 public void SetAccuracyMaximizingThresholds(IClassificationProblemData problemData) {125 double[] classValues;126 double[] thresholds;127 var targetClassValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);128 var estimatedTrainingValues = GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);129 AccuracyMaximizationThresholdCalculator.CalculateThresholds(problemData, estimatedTrainingValues, targetClassValues, out classValues, out thresholds);130 131 SetThresholdsAndClassValues(thresholds, classValues);132 }133 134 public void SetClassDistributionCutPointThresholds(IClassificationProblemData problemData) {135 double[] classValues;136 double[] thresholds;137 var targetClassValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);138 var estimatedTrainingValues = GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);139 NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(problemData, estimatedTrainingValues, targetClassValues, out classValues, out thresholds);140 141 SetThresholdsAndClassValues(thresholds, classValues);142 }143 144 public static void Scale(SymbolicDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) {145 var dataset = problemData.Dataset;146 var targetVariable = problemData.TargetVariable;147 var rows = problemData.TrainingIndices;148 var estimatedValues = model.Interpreter.GetSymbolicExpressionTreeValues(model.SymbolicExpressionTree, dataset, rows);149 var targetValues = dataset.GetDoubleValues(targetVariable, rows);150 double alpha;151 double beta;152 OnlineCalculatorError errorState;153 OnlineLinearScalingParameterCalculator.Calculate(estimatedValues, targetValues, out alpha, out beta, out errorState);154 if (errorState != OnlineCalculatorError.None) return;155 156 ConstantTreeNode alphaTreeNode = null;157 ConstantTreeNode betaTreeNode = null;158 // check if model has been scaled previously by analyzing the structure of the tree159 var startNode = model.SymbolicExpressionTree.Root.GetSubtree(0);160 if (startNode.GetSubtree(0).Symbol is Addition) {161 var addNode = startNode.GetSubtree(0);162 if (addNode.SubtreeCount == 2 && addNode.GetSubtree(0).Symbol is Multiplication && addNode.GetSubtree(1).Symbol is Constant) {163 alphaTreeNode = addNode.GetSubtree(1) as ConstantTreeNode;164 var mulNode = addNode.GetSubtree(0);165 if (mulNode.SubtreeCount == 2 && mulNode.GetSubtree(1).Symbol is Constant) {166 betaTreeNode = mulNode.GetSubtree(1) as ConstantTreeNode;167 }168 }169 }170 // if tree structure matches the structure necessary for linear scaling then reuse the existing tree nodes171 if (alphaTreeNode != null && betaTreeNode != null) {172 betaTreeNode.Value *= beta;173 alphaTreeNode.Value *= beta;174 alphaTreeNode.Value += alpha;175 } else {176 var mainBranch = startNode.GetSubtree(0);177 startNode.RemoveSubtree(0);178 var scaledMainBranch = MakeSum(MakeProduct(mainBranch, beta), alpha);179 startNode.AddSubtree(scaledMainBranch);180 }181 }182 183 private static ISymbolicExpressionTreeNode MakeSum(ISymbolicExpressionTreeNode treeNode, double alpha) {184 if (alpha.IsAlmost(0.0)) {185 return treeNode;186 } else {187 var addition = new Addition();188 var node = addition.CreateTreeNode();189 var alphaConst = MakeConstant(alpha);190 node.AddSubtree(treeNode);191 node.AddSubtree(alphaConst);192 return node;193 }194 }195 196 private static ISymbolicExpressionTreeNode MakeProduct(ISymbolicExpressionTreeNode treeNode, double beta) {197 if (beta.IsAlmost(1.0)) {198 return treeNode;199 } else {200 var multipliciation = new Multiplication();201 var node = multipliciation.CreateTreeNode();202 var betaConst = MakeConstant(beta);203 node.AddSubtree(treeNode);204 node.AddSubtree(betaConst);205 return node;206 }207 }208 209 private static ISymbolicExpressionTreeNode MakeConstant(double c) {210 var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();211 node.Value = c;212 return node;213 }214 141 } 215 142 }
Note: See TracChangeset
for help on using the changeset viewer.