Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/07/12 11:27:49 (12 years ago)
Author:
mkommend
Message:

#1940: Added support in symbolic classification for different methods to create the classification ModelCreator.

  • Added ModelCreators
  • Refactored SymbolicClassificationModel and SymbolicDiscriminantFunctionClassificationModel
  • Added ModelCreatorParameter to Analyzers and Evaluators if needed
  • Corrected wiring in symbolic classification problems (single- and multiobjective
  • Adapted simplifier
File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationModel.cs

    r8550 r8594  
    3434  [StorableClass]
    3535  [Item(Name = "SymbolicDiscriminantFunctionClassificationModel", Description = "Represents a symbolic classification model unsing a discriminant function.")]
    36   public class SymbolicDiscriminantFunctionClassificationModel : SymbolicDataAnalysisModel, ISymbolicDiscriminantFunctionClassificationModel {
     36  public class SymbolicDiscriminantFunctionClassificationModel : SymbolicClassificationModel, ISymbolicDiscriminantFunctionClassificationModel {
    3737
    3838    [Storable]
     
    4848      private set { classValues = value.ToArray(); }
    4949    }
     50
     51    private IDiscriminantFunctionThresholdCalculator thresholdCalculator;
    5052    [Storable]
    51     private double lowerEstimationLimit;
    52     public double LowerEstimationLimit { get { return lowerEstimationLimit; } }
    53     [Storable]
    54     private double upperEstimationLimit;
    55     public double UpperEstimationLimit { get { return upperEstimationLimit; } }
     53    public IDiscriminantFunctionThresholdCalculator ThresholdCalculator {
     54      get { return thresholdCalculator; }
     55      private set { thresholdCalculator = value; }
     56    }
     57
    5658
    5759    [StorableConstructor]
     
    6163      classValues = (double[])original.classValues.Clone();
    6264      thresholds = (double[])original.thresholds.Clone();
    63       lowerEstimationLimit = original.lowerEstimationLimit;
    64       upperEstimationLimit = original.upperEstimationLimit;
     65      thresholdCalculator = cloner.Clone(original.thresholdCalculator);
    6566    }
    66     public SymbolicDiscriminantFunctionClassificationModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     67    public SymbolicDiscriminantFunctionClassificationModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDiscriminantFunctionThresholdCalculator thresholdCalculator,
    6768      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue)
    68       : base(tree, interpreter) {
     69      : base(tree, interpreter, lowerEstimationLimit, upperEstimationLimit) {
    6970      this.thresholds = new double[0];
    7071      this.classValues = new double[0];
    71       this.lowerEstimationLimit = lowerEstimationLimit;
    72       this.upperEstimationLimit = upperEstimationLimit;
     72      this.ThresholdCalculator = thresholdCalculator;
     73    }
     74
     75    [StorableHook(HookType.AfterDeserialization)]
     76    private void AfterDeserialization() {
     77      if (ThresholdCalculator == null) ThresholdCalculator = new AccuracyMaximizationThresholdCalculator();
    7378    }
    7479
     
    8792    }
    8893
    89     public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
    90       return Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows).LimitToRange(lowerEstimationLimit, upperEstimationLimit);
     94    public override void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) {
     95      double[] classValues;
     96      double[] thresholds;
     97      var targetClassValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
     98      var estimatedTrainingValues = GetEstimatedValues(problemData.Dataset, rows);
     99      thresholdCalculator.Calculate(problemData, estimatedTrainingValues, targetClassValues, out classValues, out thresholds);
     100      SetThresholdsAndClassValues(thresholds, classValues);
    91101    }
    92102
    93     public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
     103    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
     104      return Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows).LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
     105    }
     106
     107    public override IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    94108      if (!Thresholds.Any() && !ClassValues.Any()) throw new ArgumentException("No thresholds and class values were set for the current symbolic classification model.");
    95109      foreach (var x in GetEstimatedValues(dataset, rows)) {
     
    104118    }
    105119
    106     public SymbolicDiscriminantFunctionClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
     120
     121    public override ISymbolicClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
     122      return CreateDiscriminantClassificationSolution(problemData);
     123    }
     124    public SymbolicDiscriminantFunctionClassificationSolution CreateDiscriminantClassificationSolution(IClassificationProblemData problemData) {
    107125      return new SymbolicDiscriminantFunctionClassificationSolution(this, new ClassificationProblemData(problemData));
    108126    }
    109127    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
    110       return CreateClassificationSolution(problemData);
     128      return CreateDiscriminantClassificationSolution(problemData);
    111129    }
    112130    IDiscriminantFunctionClassificationSolution IDiscriminantFunctionClassificationModel.CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData) {
    113       return CreateClassificationSolution(problemData);
     131      return CreateDiscriminantClassificationSolution(problemData);
    114132    }
    115133
     
    121139    }
    122140    #endregion
    123 
    124     public void SetAccuracyMaximizingThresholds(IClassificationProblemData problemData) {
    125       double[] classValues;
    126       double[] thresholds;
    127       var targetClassValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
    128       var estimatedTrainingValues = GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
    129       AccuracyMaximizationThresholdCalculator.CalculateThresholds(problemData, estimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    130 
    131       SetThresholdsAndClassValues(thresholds, classValues);
    132     }
    133 
    134     public void SetClassDistributionCutPointThresholds(IClassificationProblemData problemData) {
    135       double[] classValues;
    136       double[] thresholds;
    137       var targetClassValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
    138       var estimatedTrainingValues = GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
    139       NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(problemData, estimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    140 
    141       SetThresholdsAndClassValues(thresholds, classValues);
    142     }
    143 
    144     public static void Scale(SymbolicDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) {
    145       var dataset = problemData.Dataset;
    146       var targetVariable = problemData.TargetVariable;
    147       var rows = problemData.TrainingIndices;
    148       var estimatedValues = model.Interpreter.GetSymbolicExpressionTreeValues(model.SymbolicExpressionTree, dataset, rows);
    149       var targetValues = dataset.GetDoubleValues(targetVariable, rows);
    150       double alpha;
    151       double beta;
    152       OnlineCalculatorError errorState;
    153       OnlineLinearScalingParameterCalculator.Calculate(estimatedValues, targetValues, out alpha, out beta, out errorState);
    154       if (errorState != OnlineCalculatorError.None) return;
    155 
    156       ConstantTreeNode alphaTreeNode = null;
    157       ConstantTreeNode betaTreeNode = null;
    158       // check if model has been scaled previously by analyzing the structure of the tree
    159       var startNode = model.SymbolicExpressionTree.Root.GetSubtree(0);
    160       if (startNode.GetSubtree(0).Symbol is Addition) {
    161         var addNode = startNode.GetSubtree(0);
    162         if (addNode.SubtreeCount == 2 && addNode.GetSubtree(0).Symbol is Multiplication && addNode.GetSubtree(1).Symbol is Constant) {
    163           alphaTreeNode = addNode.GetSubtree(1) as ConstantTreeNode;
    164           var mulNode = addNode.GetSubtree(0);
    165           if (mulNode.SubtreeCount == 2 && mulNode.GetSubtree(1).Symbol is Constant) {
    166             betaTreeNode = mulNode.GetSubtree(1) as ConstantTreeNode;
    167           }
    168         }
    169       }
    170       // if tree structure matches the structure necessary for linear scaling then reuse the existing tree nodes
    171       if (alphaTreeNode != null && betaTreeNode != null) {
    172         betaTreeNode.Value *= beta;
    173         alphaTreeNode.Value *= beta;
    174         alphaTreeNode.Value += alpha;
    175       } else {
    176         var mainBranch = startNode.GetSubtree(0);
    177         startNode.RemoveSubtree(0);
    178         var scaledMainBranch = MakeSum(MakeProduct(mainBranch, beta), alpha);
    179         startNode.AddSubtree(scaledMainBranch);
    180       }
    181     }
    182 
    183     private static ISymbolicExpressionTreeNode MakeSum(ISymbolicExpressionTreeNode treeNode, double alpha) {
    184       if (alpha.IsAlmost(0.0)) {
    185         return treeNode;
    186       } else {
    187         var addition = new Addition();
    188         var node = addition.CreateTreeNode();
    189         var alphaConst = MakeConstant(alpha);
    190         node.AddSubtree(treeNode);
    191         node.AddSubtree(alphaConst);
    192         return node;
    193       }
    194     }
    195 
    196     private static ISymbolicExpressionTreeNode MakeProduct(ISymbolicExpressionTreeNode treeNode, double beta) {
    197       if (beta.IsAlmost(1.0)) {
    198         return treeNode;
    199       } else {
    200         var multipliciation = new Multiplication();
    201         var node = multipliciation.CreateTreeNode();
    202         var betaConst = MakeConstant(beta);
    203         node.AddSubtree(treeNode);
    204         node.AddSubtree(betaConst);
    205         return node;
    206       }
    207     }
    208 
    209     private static ISymbolicExpressionTreeNode MakeConstant(double c) {
    210       var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
    211       node.Value = c;
    212       return node;
    213     }
    214141  }
    215142}
Note: See TracChangeset for help on using the changeset viewer.