Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/14/10 18:46:12 (14 years ago)
Author:
mkommend
Message:

updated classification branch (ticket #939)

Location:
branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3
Files:
1 added
1 deleted
6 edited
2 moved

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/ClassificationProblemData.cs

    r4366 r4391  
    3434  [StorableClass]
    3535  public class ClassificationProblemData : DataAnalysisProblemData {
     36    #region default data
     37    private static string[] defaultInputs = new string[] { "sample", "clump thickness", "cell size", "cell shape", "marginal adhesion", "epithelial cell size", "bare nuclei", "chromatin", "nucleoli", "mitoses", "class" };
     38    private static double[,] defaultData = new double[,]{
     39     {1000025,5,1,1,1,2,1,3,1,1,2      },
     40     {1002945,5,4,4,5,7,10,3,2,1,2     },
     41     {1015425,3,1,1,1,2,2,3,1,1,2      },
     42     {1016277,6,8,8,1,3,4,3,7,1,2      },
     43     {1017023,4,1,1,3,2,1,3,1,1,2      },
     44     {1017122,8,10,10,8,7,10,9,7,1,4   },
     45     {1018099,1,1,1,1,2,10,3,1,1,2     },
     46     {1018561,2,1,2,1,2,1,3,1,1,2      },
     47     {1033078,2,1,1,1,2,1,1,1,5,2      },
     48     {1033078,4,2,1,1,2,1,2,1,1,2      },
     49     {1035283,1,1,1,1,1,1,3,1,1,2      },
     50     {1036172,2,1,1,1,2,1,2,1,1,2      },
     51     {1041801,5,3,3,3,2,3,4,4,1,4      },
     52     {1043999,1,1,1,1,2,3,3,1,1,2      },
     53     {1044572,8,7,5,10,7,9,5,5,4,4     },
     54     {1047630,7,4,6,4,6,1,4,3,1,4      },
     55     {1048672,4,1,1,1,2,1,2,1,1,2      },
     56     {1049815,4,1,1,1,2,1,3,1,1,2      },
     57     {1050670,10,7,7,6,4,10,4,1,2,4    },
     58     {1050718,6,1,1,1,2,1,3,1,1,2      },
     59     {1054590,7,3,2,10,5,10,5,4,4,4    },
     60     {1054593,10,5,5,3,6,7,7,10,1,4    },
     61     {1056784,3,1,1,1,2,1,2,1,1,2      },
     62     {1057013,8,4,5,1,2,2,7,3,1,4      },
     63     {1059552,1,1,1,1,2,1,3,1,1,2      },
     64     {1065726,5,2,3,4,2,7,3,6,1,4      },
     65     {1066373,3,2,1,1,1,1,2,1,1,2      },
     66     {1066979,5,1,1,1,2,1,2,1,1,2      },
     67     {1067444,2,1,1,1,2,1,2,1,1,2      },
     68     {1070935,1,1,3,1,2,1,1,1,1,2      },
     69     {1070935,3,1,1,1,1,1,2,1,1,2      },
     70     {1071760,2,1,1,1,2,1,3,1,1,2      },
     71     {1072179,10,7,7,3,8,5,7,4,3,4     },
     72     {1074610,2,1,1,2,2,1,3,1,1,2      },
     73     {1075123,3,1,2,1,2,1,2,1,1,2      },
     74     {1079304,2,1,1,1,2,1,2,1,1,2      },
     75     {1080185,10,10,10,8,6,1,8,9,1,4   },
     76     {1081791,6,2,1,1,1,1,7,1,1,2      },
     77     {1084584,5,4,4,9,2,10,5,6,1,4     },
     78     {1091262,2,5,3,3,6,7,7,5,1,4      },
     79     {1096800,6,6,6,9,6,4,7,8,1,2      },
     80     {1099510,10,4,3,1,3,3,6,5,2,4     },
     81     {1100524,6,10,10,2,8,10,7,3,3,4   },
     82     {1102573,5,6,5,6,10,1,3,1,1,4     },
     83     {1103608,10,10,10,4,8,1,8,10,1,4  },
     84     {1103722,1,1,1,1,2,1,2,1,2,2      },
     85     {1105257,3,7,7,4,4,9,4,8,1,4      },
     86     {1105524,1,1,1,1,2,1,2,1,1,2      },
     87     {1106095,4,1,1,3,2,1,3,1,1,2      },
     88     {1106829,7,8,7,2,4,8,3,8,2,4      },
     89     {1108370,9,5,8,1,2,3,2,1,5,4      },
     90     {1108449,5,3,3,4,2,4,3,4,1,4      },
     91     {1110102,10,3,6,2,3,5,4,10,2,4    },
     92     {1110503,5,5,5,8,10,8,7,3,7,4     },
     93     {1110524,10,5,5,6,8,8,7,1,1,4     },
     94     {1111249,10,6,6,3,4,5,3,6,1,4     },
     95     {1112209,8,10,10,1,3,6,3,9,1,4    },
     96     {1113038,8,2,4,1,5,1,5,4,4,4      },
     97     {1113483,5,2,3,1,6,10,5,1,1,4     },
     98     {1113906,9,5,5,2,2,2,5,1,1,4      },
     99     {1115282,5,3,5,5,3,3,4,10,1,4     },
     100     {1115293,1,1,1,1,2,2,2,1,1,2      },
     101     {1116116,9,10,10,1,10,8,3,3,1,4   },
     102     {1116132,6,3,4,1,5,2,3,9,1,4      },
     103     {1116192,1,1,1,1,2,1,2,1,1,2      },
     104     {1116998,10,4,2,1,3,2,4,3,10,4    },
     105     {1117152,4,1,1,1,2,1,3,1,1,2      },
     106     {1118039,5,3,4,1,8,10,4,9,1,4     },
     107     {1120559,8,3,8,3,4,9,8,9,8,4      },
     108     {1121732,1,1,1,1,2,1,3,2,1,2      },
     109     {1121919,5,1,3,1,2,1,2,1,1,2      },
     110     {1123061,6,10,2,8,10,2,7,8,10,4   },
     111     {1124651,1,3,3,2,2,1,7,2,1,2      },
     112     {1125035,9,4,5,10,6,10,4,8,1,4    },
     113     {1126417,10,6,4,1,3,4,3,2,3,4     },
     114     {1131294,1,1,2,1,2,2,4,2,1,2      },
     115     {1132347,1,1,4,1,2,1,2,1,1,2      },
     116     {1133041,5,3,1,2,2,1,2,1,1,2      },
     117     {1133136,3,1,1,1,2,3,3,1,1,2      },
     118     {1136142,2,1,1,1,3,1,2,1,1,2      },
     119     {1137156,2,2,2,1,1,1,7,1,1,2      },
     120     {1143978,4,1,1,2,2,1,2,1,1,2      },
     121     {1143978,5,2,1,1,2,1,3,1,1,2      },
     122     {1147044,3,1,1,1,2,2,7,1,1,2      },
     123     {1147699,3,5,7,8,8,9,7,10,7,4     },
     124     {1147748,5,10,6,1,10,4,4,10,10,4  },
     125     {1148278,3,3,6,4,5,8,4,4,1,4      },
     126     {1148873,3,6,6,6,5,10,6,8,3,4     },
     127     {1152331,4,1,1,1,2,1,3,1,1,2      },
     128     {1155546,2,1,1,2,3,1,2,1,1,2      },
     129     {1156272,1,1,1,1,2,1,3,1,1,2      },
     130     {1156948,3,1,1,2,2,1,1,1,1,2      },
     131     {1157734,4,1,1,1,2,1,3,1,1,2      },
     132     {1158247,1,1,1,1,2,1,2,1,1,2      },
     133     {1160476,2,1,1,1,2,1,3,1,1,2      },
     134     {1164066,1,1,1,1,2,1,3,1,1,2      },
     135     {1165297,2,1,1,2,2,1,1,1,1,2      },
     136     {1165790,5,1,1,1,2,1,3,1,1,2      },
     137     {1165926,9,6,9,2,10,6,2,9,10,4    },
     138     {1166630,7,5,6,10,5,10,7,9,4,4    },
     139     {1166654,10,3,5,1,10,5,3,10,2,4   },
     140     {1167439,2,3,4,4,2,5,2,5,1,4      },
     141     {1167471,4,1,2,1,2,1,3,1,1,2      },
     142     {1168359,8,2,3,1,6,3,7,1,1,4      },
     143     {1168736,10,10,10,10,10,1,8,8,8,4 },
     144     {1169049,7,3,4,4,3,3,3,2,7,4      },
     145     {1170419,10,10,10,8,2,10,4,1,1,4  },
     146     {1170420,1,6,8,10,8,10,5,7,1,4    },
     147     {1171710,1,1,1,1,2,1,2,3,1,2      },
     148     {1171710,6,5,4,4,3,9,7,8,3,4      },
     149     {1171795,1,3,1,2,2,2,5,3,2,2      },
     150     {1171845,8,6,4,3,5,9,3,1,1,4      },
     151     {1172152,10,3,3,10,2,10,7,3,3,4   },
     152     {1173216,10,10,10,3,10,8,8,1,1,4  },
     153     {1173235,3,3,2,1,2,3,3,1,1,2      },
     154     {1173347,1,1,1,1,2,5,1,1,1,2      },
     155     {1173347,8,3,3,1,2,2,3,2,1,2      },
     156     {1173509,4,5,5,10,4,10,7,5,8,4    },
     157     {1173514,1,1,1,1,4,3,1,1,1,2      },
     158     {1173681,3,2,1,1,2,2,3,1,1,2      },
     159     {1174057,1,1,2,2,2,1,3,1,1,2      },
     160     {1174057,4,2,1,1,2,2,3,1,1,2      },
     161     {1174131,10,10,10,2,10,10,5,3,3,4 },
     162     {1174428,5,3,5,1,8,10,5,3,1,4     },
     163     {1175937,5,4,6,7,9,7,8,10,1,4     },
     164     {1176406,1,1,1,1,2,1,2,1,1,2      },
     165     {1176881,7,5,3,7,4,10,7,5,5,4        }
     166};
     167    #endregion
     168
    36169    private const int MaximumClasses = 20;
    37170    private const string ClassNamesParameterName = "ClassNames";
     
    64197
    65198    public ClassificationProblemData()
    66       : base() {
     199      : base(new Dataset(defaultInputs, defaultData), defaultInputs, defaultInputs[defaultInputs.Length - 1], 0, 60, 60, 120) {
    67200      Parameters.Add(new ValueParameter<StringArray>(ClassNamesParameterName, "An array of the names for all class values."));
    68201      Parameters.Add(new ValueParameter<DoubleMatrix>(MisclassificationMatrixParameterName, "A matrix that describles the penalties for misclassifaction between the single classes."));
    69202      sortedClassValues = new List<double>();
    70203
     204      InputVariables.SetItemCheckedState(InputVariables[InputVariables.Count - 1], false);
    71205      RegisterParameterEvents();
    72206      UpdateClassValues();
     
    173307
    174308    private void UpdateMisclassifciationMatrixHeaders() {
    175       MisclassificationMatrix.RowNames = ClassNames;
    176       MisclassificationMatrix.ColumnNames = ClassNames;
     309      MisclassificationMatrix.RowNames = ClassNames.Select(name => "Estimated " + name);
     310      MisclassificationMatrix.ColumnNames = ClassNames.Select(name => "Actual " + name);
    177311    }
    178312
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/HeuristicLab.Problems.DataAnalysis.Classification-3.3.csproj

    r4388 r4391  
    156156    <Compile Include="Interfaces\ISymbolicClassificationAnalyzer.cs" />
    157157    <Compile Include="Symbolic\Analyzer\ValidationBestSymbolicClassificationSolutionAnalyzer.cs" />
     158    <Compile Include="Symbolic\Evaluators\SymbolicClassificationPearsonRSquaredEvaluator.cs" />
    158159    <Compile Include="Symbolic\SymbolicClassificationSolution.cs" />
    159     <Compile Include="Symbolic\SingleObjectiveSymbolicClassificationEvaluator.cs" />
    160     <Compile Include="Symbolic\SymbolicClassificationMeanSquaredErrorEvaluator.cs" />
     160    <Compile Include="Symbolic\Evaluators\SymbolicClassificationMeanSquaredErrorEvaluator.cs" />
    161161    <Compile Include="Symbolic\SymbolicClassificationProblem.cs" />
    162162  </ItemGroup>
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Interfaces/ISymbolicClassificationAnalyzer.cs

    r4366 r4391  
    2020#endregion
    2121
    22 using HeuristicLab.Core;
    23 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    24 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Interfaces;
    25 using HeuristicLab.Optimization;
    26 using HeuristicLab.Parameters;
    27 
    28 // This interface is exactly the same as ISymbolicRegressionAnalyzer
    29 // consider creating a base interface for both analyzer
    30 
     22using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers;
    3123namespace HeuristicLab.Problems.DataAnalysis.Classification {
    32   public interface ISymbolicClassificationAnalyzer : ISymbolicExpressionTreeAnalyzer {
    33     ILookupParameter<ResultCollection> ResultsParameter { get; }
     24  public interface ISymbolicClassificationAnalyzer : ISymbolicRegressionAnalyzer {
    3425  }
    3526}
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Interfaces/ISymbolicClassificationEvaluator.cs

    r4366 r4391  
    2020#endregion
    2121
    22 using System.Collections.Generic;
    23 using HeuristicLab.Core;
    24 using HeuristicLab.Data;
    25 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    26 using HeuristicLab.Optimization;
    27 using HeuristicLab.Problems.DataAnalysis.Symbolic;
     22using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic;
    2823namespace HeuristicLab.Problems.DataAnalysis.Classification {
    29   public interface ISymbolicClassificationEvaluator : ISingleObjectiveEvaluator {
    30     ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter { get; }
    31     ILookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter { get; }
    32     ILookupParameter<ClassificationProblemData> RegressionProblemDataParameter { get; }
    33     IValueLookupParameter<IntValue> SamplesStartParameter { get; }
    34     IValueLookupParameter<IntValue> SamplesEndParameter { get; }
    35     IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter { get; }
    36     IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter { get; }
    37 
    38 
    39     double Evaluate(ISymbolicExpressionTreeInterpreter interpreter, SymbolicExpressionTree tree,
    40           double lowerEstimationLimit, double upperEstimationLimit,
    41           Dataset dataset, string targetVariable, IEnumerable<double> sortedClassValues, IEnumerable<int> rows);
     24  public interface ISymbolicClassificationEvaluator : ISymbolicRegressionEvaluator {
     25    ClassificationProblemData ClassificationProblemData { get; }
    4226  }
    4327}
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Symbolic/Analyzer/ValidationBestSymbolicClassificationSolutionAnalyzer.cs

    r4366 r4391  
    2121
    2222using System.Collections.Generic;
     23using System.Linq;
     24using HeuristicLab.Analysis;
    2325using HeuristicLab.Core;
    2426using HeuristicLab.Data;
     
    2931using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3032using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic;
     33using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers;
    3134using HeuristicLab.Problems.DataAnalysis.Symbolic;
    3235
     
    5053
    5154    private const string ResultsParameterName = "Results";
    52     private const string BestValidationQualityParameterName = "BestValidationQuality";
    53     private const string BestValidationSolutionParameterName = "BestValidationSolution";
     55    private const string BestValidationQualityParameterName = "Best validation quality";
     56    private const string BestValidationSolutionParameterName = "Best validation solution";
     57    private const string BestSolutionGenerationParameterName = "Best solution generation";
     58    private const string BestSolutionInputvariableCountParameterName = "Variables used by best solution";
     59    private const string VariableFrequenciesParameterName = "VariableFrequencies";
     60    private const string VariableImpactsParameterName = "Variable Impacts";
    5461
    5562    #region parameter properties
     
    9198      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
    9299    }
     100    public ILookupParameter<DataTable> VariableFrequenciesParameter {
     101      get { return (ILookupParameter<DataTable>)Parameters[VariableFrequenciesParameterName]; }
     102    }
    93103
    94104    public ILookupParameter<ResultCollection> ResultsParameter {
     
    100110    public ILookupParameter<SymbolicClassificationSolution> BestValidationSolutionParameter {
    101111      get { return (ILookupParameter<SymbolicClassificationSolution>)Parameters[BestValidationSolutionParameterName]; }
     112    }
     113    public ILookupParameter<IntValue> BestSolutionGenerationParameter {
     114      get { return (ILookupParameter<IntValue>)Parameters[BestSolutionGenerationParameterName]; }
     115    }
     116    public ILookupParameter<DoubleMatrix> VariableImpactsParameter {
     117      get { return (ILookupParameter<DoubleMatrix>)Parameters[VariableImpactsParameterName]; }
     118    }
     119    public ILookupParameter<IntValue> BestSolutionInputvariableCountParameter {
     120      get { return (ILookupParameter<IntValue>)Parameters[BestSolutionInputvariableCountParameterName]; }
    102121    }
    103122    #endregion
     
    140159      get { return LowerEstimationLimitParameter.ActualValue; }
    141160    }
     161    public DataTable VariableFrequencies {
     162      get { return VariableFrequenciesParameter.ActualValue; }
     163    }
    142164
    143165    public ResultCollection Results {
     
    151173      get { return BestValidationSolutionParameter.ActualValue; }
    152174      protected set { BestValidationSolutionParameter.ActualValue = value; }
     175    }
     176    public IntValue BestSolutionGeneration {
     177      get { return BestSolutionGenerationParameter.ActualValue; }
     178      protected set { BestSolutionGenerationParameter.ActualValue = value; }
     179    }
     180    public IntValue BestSolutionInputvariableCount {
     181      get { return BestSolutionInputvariableCountParameter.ActualValue; }
     182      protected set { BestSolutionInputvariableCountParameter.ActualValue = value; }
     183    }
     184    public DoubleMatrix VariableImpacts {
     185      get { return VariableImpactsParameter.ActualValue; }
     186      protected set { VariableImpactsParameter.ActualValue = value; }
    153187    }
    154188    #endregion
     
    169203      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit that was set for the evaluation of the symbolic expression trees."));
    170204      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit that was set for the evaluation of the symbolic expression trees."));
     205      Parameters.Add(new LookupParameter<DataTable>(VariableFrequenciesParameterName, "The variable frequencies table to use for the calculation of variable impacts"));
    171206
    172207      Parameters.Add(new ValueLookupParameter<ResultCollection>(ResultsParameterName, "The results collection where the analysis values should be stored."));
    173208      Parameters.Add(new LookupParameter<DoubleValue>(BestValidationQualityParameterName, "The validation quality of the best solution in the current run."));
    174209      Parameters.Add(new LookupParameter<SymbolicClassificationSolution>(BestValidationSolutionParameterName, "The best solution on the validation data found in the current run."));
     210      Parameters.Add(new LookupParameter<IntValue>(BestSolutionGenerationParameterName, "The generation in which the best solution was found."));
     211      Parameters.Add(new LookupParameter<DoubleMatrix>(VariableImpactsParameterName, "The impacts of the input variables calculated during the run."));
     212      Parameters.Add(new LookupParameter<IntValue>(BestSolutionInputvariableCountParameterName, "The number of input variables used by the best solution."));
     213
    175214    }
    176215
     
    199238        double quality = Evaluator.Evaluate(SymbolicExpressionTreeInterpreter, tree,
    200239          lowerEstimationLimit, upperEstimationLimit, ClassificationProblemData.Dataset,
    201           targetVariable, ClassificationProblemData.SortedClassValues, rows);
     240          targetVariable, rows);
    202241
    203242        if ((Maximization.Value && quality > bestQuality) ||
     
    214253        (!Maximization.Value && bestQuality < BestValidationQuality.Value);
    215254      if (newBest) {
     255        double alpha, beta;
     256        int trainingStart = ClassificationProblemData.TrainingSamplesStart.Value;
     257        int trainingEnd = ClassificationProblemData.TrainingSamplesEnd.Value;
     258        IEnumerable<int> trainingRows = Enumerable.Range(trainingStart, trainingEnd - trainingStart);
     259        SymbolicRegressionScaledMeanSquaredErrorEvaluator.Calculate(SymbolicExpressionTreeInterpreter, bestTree,
     260          lowerEstimationLimit, upperEstimationLimit,
     261          ClassificationProblemData.Dataset, targetVariable,
     262          trainingRows, out beta, out alpha);
     263
     264        // scale tree for solution
     265        var scaledTree = SymbolicRegressionSolutionLinearScaler.Scale(bestTree, alpha, beta);
    216266        var model = new SymbolicRegressionModel((ISymbolicExpressionTreeInterpreter)SymbolicExpressionTreeInterpreter.Clone(),
    217           bestTree);
     267          scaledTree);
    218268
    219269        if (BestValidationSolution == null) {
     
    222272          BestValidationSolution.Description = "Best solution on validation partition found over the whole run.";
    223273          BestValidationQuality = new DoubleValue(bestQuality);
     274          BestSolutionGeneration = (IntValue)Generations.Clone();
     275          BestSolutionInputvariableCount = new IntValue(BestValidationSolution.Model.InputVariables.Count());
     276
    224277          Results.Add(new Result(BestValidationSolutionParameterName, BestValidationSolution));
    225278          Results.Add(new Result(BestValidationQualityParameterName, BestValidationQuality));
     279          Results.Add(new Result(BestSolutionGenerationParameterName, BestSolutionGeneration));
     280
     281          Results.Add(new Result(BestSolutionInputvariableCountParameterName, BestSolutionInputvariableCount));
     282
     283          if (VariableFrequencies != null) {
     284            VariableImpacts = CalculateVariableImpacts(VariableFrequencies);
     285            Results.Add(new Result(VariableImpactsParameterName, VariableImpacts));
     286          }
    226287
    227288        } else {
    228289          BestValidationSolution.Model = model;
    229290          BestValidationQuality.Value = bestQuality;
     291          BestSolutionGeneration.Value = Generations.Value;
     292          BestSolutionInputvariableCount.Value = BestValidationSolution.Model.InputVariables.Count();
     293
     294          if (VariableFrequencies != null) {
     295            VariableImpacts = CalculateVariableImpacts(VariableFrequencies);
     296            Results[VariableImpactsParameterName].Value = VariableImpacts;
     297          }
    230298        }
    231299      }
    232300      return base.Apply();
    233301    }
     302
     303    private static DoubleMatrix CalculateVariableImpacts(DataTable variableFrequencies) {
     304      if (variableFrequencies != null) {
     305        var impacts = new DoubleMatrix(variableFrequencies.Rows.Count, 1, new string[] { "Impact" }, variableFrequencies.Rows.Select(x => x.Name));
     306        impacts.SortableView = true;
     307        int rowIndex = 0;
     308        foreach (var dataRow in variableFrequencies.Rows) {
     309          string variableName = dataRow.Name;
     310          impacts[rowIndex++, 0] = dataRow.Values.Average();
     311        }
     312        return impacts;
     313      } else return new DoubleMatrix(1, 1);
     314    }
    234315  }
    235316}
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Symbolic/Evaluators/SymbolicClassificationMeanSquaredErrorEvaluator.cs

    r4366 r4391  
    2020#endregion
    2121
    22 using System;
    23 using System.Collections.Generic;
    24 using System.Linq;
    25 using HeuristicLab.Common;
    2622using HeuristicLab.Core;
    27 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2823using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    29 using HeuristicLab.Problems.DataAnalysis.Evaluators;
    30 using HeuristicLab.Problems.DataAnalysis.Symbolic;
     24using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic;
    3125
    3226namespace HeuristicLab.Problems.DataAnalysis.Classification {
    3327  [Item("SymbolicClassificationMeanSquaredErrorEvaluator", "Calculates the mean squared error of a symbolic classification solution.")]
    3428  [StorableClass]
    35   public class SymbolicClassifacitionMeanSquaredErrorEvaluator : SingleObjectiveSymbolicClassificationEvaluator {
     29  public class SymbolicClassifacitionMeanSquaredErrorEvaluator : SymbolicRegressionMeanSquaredErrorEvaluator, ISymbolicClassificationEvaluator {
     30    public ClassificationProblemData ClassificationProblemData {
     31      get { return (ClassificationProblemData)RegressionProblemData; }
     32    }
    3633
    3734    public SymbolicClassifacitionMeanSquaredErrorEvaluator()
    3835      : base() {
    3936    }
    40 
    41     public override double Evaluate(ISymbolicExpressionTreeInterpreter interpreter, SymbolicExpressionTree solution, double lowerEstimationLimit, double upperEstimationLimit, Dataset dataset, string targetVariable, IEnumerable<double> sortedClassValues, IEnumerable<int> rows) {
    42       double mse = Calculate(interpreter, solution, lowerEstimationLimit, upperEstimationLimit, dataset, targetVariable, sortedClassValues, rows);
    43       return mse;
    44     }
    45 
    46     public static double Calculate(ISymbolicExpressionTreeInterpreter interpreter, SymbolicExpressionTree solution, double lowerEstimationLimit, double upperEstimationLimit, Dataset dataset, string targetVariable, IEnumerable<double> sortedClassValues, IEnumerable<int> rows) {
    47       IEnumerable<double> estimatedValues = interpreter.GetSymbolicExpressionTreeValues(solution, dataset, rows);
    48       IEnumerable<double> originalValues = dataset.GetEnumeratedVariableValues(targetVariable, rows);
    49       IEnumerator<double> originalEnumerator = originalValues.GetEnumerator();
    50       IEnumerator<double> estimatedEnumerator = estimatedValues.GetEnumerator();
    51       OnlineMeanSquaredErrorEvaluator mseEvaluator = new OnlineMeanSquaredErrorEvaluator();
    52 
    53 
    54       double firstClassValue = sortedClassValues.First();
    55       double lastClassValue = sortedClassValues.Last();
    56       while (originalEnumerator.MoveNext() && estimatedEnumerator.MoveNext()) {
    57         double estimated = estimatedEnumerator.Current;
    58         double original = originalEnumerator.Current;
    59         if (double.IsNaN(estimated))
    60           estimated = upperEstimationLimit;
    61         else if (estimated < original && original.IsAlmost(firstClassValue))
    62           estimated = original;
    63         else if (estimated > original && original.IsAlmost(lastClassValue))
    64           estimated = original;
    65         else
    66           estimated = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, estimated));
    67         mseEvaluator.Add(original, estimated);
    68       }
    69 
    70       if (estimatedEnumerator.MoveNext() || originalEnumerator.MoveNext()) {
    71         throw new ArgumentException("Number of elements in original and estimated enumeration doesn't match.");
    72       } else {
    73         return mseEvaluator.MeanSquaredError;
    74       }
    75     }
    7637  }
    7738}
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Symbolic/SymbolicClassificationProblem.cs

    r4366 r4391  
    3333using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3434using HeuristicLab.PluginInfrastructure;
     35using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers;
    3536using HeuristicLab.Problems.DataAnalysis.Symbolic;
    3637using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
     
    228229      Operators.AddRange(ApplicationManager.Manager.GetInstances<ISymbolicExpressionTreeOperator>().OfType<IOperator>());
    229230      Operators.Add(new MinAverageMaxSymbolicExpressionTreeSizeAnalyzer());
     231      Operators.Add(new SymbolicRegressionVariableFrequencyAnalyzer());
    230232      Operators.Add(new ValidationBestSymbolicClassificationSolutionAnalyzer());
    231233    }
     
    294296
    295297    private void ParameterizeAnalyzers() {
    296       foreach (ISymbolicExpressionTreeAnalyzer analyzer in Operators.OfType<ISymbolicExpressionTreeAnalyzer>()) {
     298      foreach (ISymbolicRegressionAnalyzer analyzer in Operators.OfType<ISymbolicRegressionAnalyzer>()) {
    297299        analyzer.SymbolicExpressionTreeParameter.ActualName = SolutionCreator.SymbolicExpressionTreeParameter.ActualName;
    298300        var bestValidationSolutionAnalyzer = analyzer as ValidationBestSymbolicClassificationSolutionAnalyzer;
     
    306308          bestValidationSolutionAnalyzer.ValidationSamplesEndParameter.Value = ValidationSamplesEnd;
    307309        }
     310        var varFreqAnalyzer = analyzer as SymbolicRegressionVariableFrequencyAnalyzer;
     311        if (varFreqAnalyzer != null) {
     312          varFreqAnalyzer.ProblemDataParameter.ActualName = ClassificationProblemDataParameter.Name;
     313        }
    308314      }
    309315    }
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Symbolic/SymbolicClassificationSolution.cs

    r4366 r4391  
    3131namespace HeuristicLab.Problems.DataAnalysis.Classification {
    3232  /// <summary>
    33   /// Represents a solution for a symbolic regression problem which can be visualized in the GUI.
     33  /// Represents a solution for a symbolic classification problem which can be visualized in the GUI.
    3434  /// </summary>
    3535  [Item("SymbolicClassificationSolution", "Represents a solution for a symbolic classification problem which can be visualized in the GUI.")]
    3636  [StorableClass]
    37   public sealed class SymbolicClassificationSolution : DataAnalysisSolution, IClassificationSolution {
     37  public class SymbolicClassificationSolution : DataAnalysisSolution, IClassificationSolution {
    3838    private SymbolicClassificationSolution() : base() { }
    3939    public SymbolicClassificationSolution(ClassificationProblemData problemData, SymbolicRegressionModel model, double lowerEstimationLimit, double upperEstimationLimit)
     
    4646    }
    4747
     48    public new SymbolicRegressionModel Model {
     49      get { return (SymbolicRegressionModel)base.Model; }
     50      set { base.Model = value; }
     51    }
     52
    4853    public new ClassificationProblemData ProblemData {
    4954      get { return (ClassificationProblemData)base.ProblemData; }
    5055      set { base.ProblemData = value; }
    51     }
    52 
    53     public new SymbolicRegressionModel Model {
    54       get { return (SymbolicRegressionModel)base.Model; }
    55       set { base.Model = value; }
    5656    }
    5757
     
    6565    }
    6666
     67    protected List<double> estimatedValues;
     68    public override IEnumerable<double> EstimatedValues {
     69      get {
     70        if (estimatedValues == null) RecalculateEstimatedValues();
     71        return estimatedValues.AsEnumerable();
     72      }
     73    }
     74
     75    public override IEnumerable<double> EstimatedTrainingValues {
     76      get {
     77        if (estimatedValues == null) RecalculateEstimatedValues();
     78        int start = ProblemData.TrainingSamplesStart.Value;
     79        int n = ProblemData.TrainingSamplesEnd.Value - start;
     80        return estimatedValues.Skip(start).Take(n).ToList();
     81      }
     82    }
     83
     84    public override IEnumerable<double> EstimatedTestValues {
     85      get {
     86        if (estimatedValues == null) RecalculateEstimatedValues();
     87        int start = ProblemData.TestSamplesStart.Value;
     88        int n = ProblemData.TestSamplesEnd.Value - start;
     89        return estimatedValues.Skip(start).Take(n).ToList();
     90      }
     91    }
     92
    6793    private void RecalculateClassIntermediates() {
    68       int slices = 1000;
     94      int slices = 100;
     95
     96      List<int> classInstances = (from classValue in ProblemData.Dataset.GetVariableValues(ProblemData.TargetVariable.Value)
     97                                  group classValue by classValue into grouping
     98                                  select grouping.Count()).ToList();
    6999
    70100      List<KeyValuePair<double, double>> estimatedTargetValues =
    71         (from row in Enumerable.Range(ProblemData.TrainingSamplesStart.Value, ProblemData.TrainingSamplesEnd.Value - ProblemData.TrainingSamplesStart.Value)
    72          select new KeyValuePair<double, double>(
    73            estimatedValues[row],
    74            ProblemData.Dataset[ProblemData.TargetVariable.Value, row])).ToList();
    75 
    76       List<double> originalClasses = ProblemData.Dataset.GetVariableValues(ProblemData.TargetVariable.Value).Distinct().OrderBy(x => x).ToList();
    77       int numberOfClasses = originalClasses.Count;
    78 
    79       double[] thresholds = new double[numberOfClasses + 1];
     101         (from row in Enumerable.Range(ProblemData.TrainingSamplesStart.Value, ProblemData.TrainingSamplesEnd.Value - ProblemData.TrainingSamplesStart.Value)
     102          select new KeyValuePair<double, double>(
     103            estimatedValues[row],
     104            ProblemData.Dataset[ProblemData.TargetVariable.Value, row])).ToList();
     105
     106      List<double> originalClasses = ProblemData.SortedClassValues.ToList();
     107      double[] thresholds = new double[ProblemData.NumberOfClasses + 1];
    80108      thresholds[0] = double.NegativeInfinity;
    81109      thresholds[thresholds.Length - 1] = double.PositiveInfinity;
    82 
    83110
    84111      for (int i = 1; i < thresholds.Length - 1; i++) {
     
    88115
    89116        double bestThreshold = double.NaN;
    90         double bestQuality = double.NegativeInfinity;
     117        double bestClassificationScore = double.PositiveInfinity;
    91118
    92119        while (actualThreshold < originalClasses[i]) {
    93           int truePosivites = 0;
    94           int falsePosivites = 0;
    95           int trueNegatives = 0;
    96           int falseNegatives = 0;
     120          double classificationScore = 0.0;
    97121
    98122          foreach (KeyValuePair<double, double> estimatedTarget in estimatedTargetValues) {
     
    100124            if (estimatedTarget.Value.IsAlmost(originalClasses[i - 1])) {
    101125              if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
    102                 truePosivites++;
     126                //true positive
     127                classificationScore += ProblemData.MisclassificationMatrix[i - 1, i - 1] / classInstances[i - 1];
    103128              else
    104                 falseNegatives++;
     129                //false negative
     130                classificationScore += ProblemData.MisclassificationMatrix[i, i - 1] / classInstances[i - 1];
    105131            }
    106132              //all negatives
    107133            else {
    108134              if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
    109                 falsePosivites++;
     135                classificationScore += ProblemData.MisclassificationMatrix[i - 1, i] / classInstances[i];
    110136              else
    111                 trueNegatives++;
     137                //true negative, consider only upper class
     138                classificationScore += ProblemData.MisclassificationMatrix[i, i] / classInstances[i];
    112139            }
    113140          }
    114 
    115           //mkommend 30.08.2010
    116           //matthews correlation coefficient taken from http://en.wikipedia.org/wiki/Matthews_correlation_coefficient
    117           //MCC = [(TP * FP) - (FP * FN)] / sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))
    118           double dividend = truePosivites * falsePosivites - falsePosivites * falseNegatives;
    119           double divisor = Math.Sqrt((truePosivites + falsePosivites) * (truePosivites + falsePosivites) *
    120             (trueNegatives + falsePosivites) * (trueNegatives + falseNegatives));
    121           if (divisor == 0)
    122             divisor = 1;
    123 
    124           double mcc = dividend / divisor;
    125 
    126           if (bestQuality < mcc) {
    127             bestQuality = mcc;
     141          if (classificationScore < bestClassificationScore) {
     142            bestClassificationScore = classificationScore;
    128143            bestThreshold = actualThreshold;
    129144          }
     
    149164        actualThresholds = new List<double>(value);
    150165        OnThresholdsChanged();
    151       }
    152     }
    153 
    154     private List<double> estimatedValues;
    155     public override IEnumerable<double> EstimatedValues {
    156       get {
    157         if (estimatedValues == null) RecalculateEstimatedValues();
    158         return estimatedValues.AsEnumerable();
    159166      }
    160167    }
     
    172179    }
    173180
    174     public override IEnumerable<double> EstimatedTrainingValues {
    175       get {
    176         if (estimatedValues == null) RecalculateEstimatedValues();
    177         int start = ProblemData.TrainingSamplesStart.Value;
    178         int n = ProblemData.TrainingSamplesEnd.Value - start;
    179         return estimatedValues.Skip(start).Take(n).ToList();
    180       }
    181     }
    182181    public IEnumerable<double> EstimatedTrainingClassValues {
    183182      get {
     
    188187    }
    189188
    190     public override IEnumerable<double> EstimatedTestValues {
    191       get {
    192         if (estimatedValues == null) RecalculateEstimatedValues();
    193         int start = ProblemData.TestSamplesStart.Value;
    194         int n = ProblemData.TestSamplesEnd.Value - start;
    195         return estimatedValues.Skip(start).Take(n).ToList();
    196       }
    197     }
    198189    public IEnumerable<double> EstimatedTestClassValues {
    199190      get {
Note: See TracChangeset for help on using the changeset viewer.