Changeset 6587 for trunk/sources/HeuristicLab.Algorithms.DataAnalysis
- Timestamp:
- 07/22/11 14:57:24 (13 years ago)
- Location:
- trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4
- Files:
-
- 1 edited
- 1 copied
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj
r6583 r6587 107 107 </ItemGroup> 108 108 <ItemGroup> 109 <Compile Include="RegressionWorkbench.cs" /> 109 110 <Compile Include="CrossValidation.cs"> 110 111 <SubType>Code</SubType> -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RegressionWorkbench.cs
r6584 r6587 32 32 using HeuristicLab.Problems.DataAnalysis; 33 33 using HeuristicLab.Problems.DataAnalysis.Symbolic; 34 using HeuristicLab.Parameters; 34 35 35 36 namespace HeuristicLab.Algorithms.DataAnalysis { 36 [Item(" Cross Validation", "Cross-validation wrapper for data analysis algorithms.")]37 [Item("Regression Workbench", "Experiment containing multiple algorithms for regression analysis.")] 37 38 [Creatable("Data Analysis")] 38 39 [StorableClass] 39 public sealed class CrossValidation : ParameterizedNamedItem, IAlgorithm, IStorableContent { 40 public CrossValidation() 40 public sealed class RegressionWorkbench : ParameterizedNamedItem, IOptimizer, IStorableContent { 41 public string Filename { get; set; } 42 43 private const string ProblemDataParameterName = "ProblemData"; 44 45 #region parameter properties 46 public IValueParameter<IRegressionProblemData> ProblemDataParameter { 47 get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; } 48 } 49 #endregion 50 #region properties 51 public IRegressionProblemData ProblemData { 52 get { return ProblemDataParameter.Value; } 53 set { ProblemDataParameter.Value = value; } 54 } 55 #endregion 56 [Storable] 57 private Experiment experiment; 58 59 [StorableConstructor] 60 private RegressionWorkbench(bool deserializing) 61 : base(deserializing) { 62 } 63 private RegressionWorkbench(RegressionWorkbench original, Cloner cloner) 64 : base(original, cloner) { 65 experiment = cloner.Clone(original.experiment); 66 RegisterEventHandlers(); 67 } 68 public RegressionWorkbench() 41 69 : base() { 42 70 name = ItemName; 43 71 description = ItemDescription; 44 72 45 executionState = ExecutionState.Stopped; 46 runs = new RunCollection(); 47 runsCounter = 0; 48 49 algorithm = null; 50 clonedAlgorithms = new ItemCollection<IAlgorithm>(); 51 results = new ResultCollection(); 52 53 folds = new IntValue(2); 54 numberOfWorkers = new IntValue(1); 55 samplesStart = new IntValue(0); 56 samplesEnd = new IntValue(0); 57 storeAlgorithmInEachRun = false; 58 59 RegisterEvents(); 60 if (Algorithm != null) RegisterAlgorithmEvents(); 61 } 62 63 public string Filename { get; set; } 64 65 #region persistence and cloning 66 [StorableConstructor] 67 private CrossValidation(bool deserializing) 68 : base(deserializing) { 69 } 73 Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The regression problem data that should be used for modeling.", new RegressionProblemData())); 74 75 experiment = new Experiment(); 76 77 //var svmExperiments = CreateSvmExperiment(); 78 var rfExperiments = CreateRandomForestExperiments(); 79 80 experiment.Optimizers.Add(new LinearRegression()); 81 experiment.Optimizers.Add(rfExperiments); 82 //experiment.Optimizers.Add(svmExperiments); 83 84 RegisterEventHandlers(); 85 } 86 70 87 [StorableHook(HookType.AfterDeserialization)] 71 88 private void AfterDeserialization() { 72 RegisterEvents(); 73 if (Algorithm != null) RegisterAlgorithmEvents(); 74 } 75 76 private CrossValidation(CrossValidation original, Cloner cloner) 77 : base(original, cloner) { 78 executionState = original.executionState; 79 storeAlgorithmInEachRun = original.storeAlgorithmInEachRun; 80 runs = cloner.Clone(original.runs); 81 runsCounter = original.runsCounter; 82 algorithm = cloner.Clone(original.algorithm); 83 clonedAlgorithms = cloner.Clone(original.clonedAlgorithms); 84 results = cloner.Clone(original.results); 85 86 folds = cloner.Clone(original.folds); 87 numberOfWorkers = cloner.Clone(original.numberOfWorkers); 88 samplesStart = cloner.Clone(original.samplesStart); 89 samplesEnd = cloner.Clone(original.samplesEnd); 90 RegisterEvents(); 91 if (Algorithm != null) RegisterAlgorithmEvents(); 92 } 89 RegisterEventHandlers(); 90 } 91 93 92 public override IDeepCloneable Clone(Cloner cloner) { 94 return new CrossValidation(this, cloner); 95 } 96 97 #endregion 98 99 #region properties 100 [Storable] 101 private IAlgorithm algorithm; 102 public IAlgorithm Algorithm { 103 get { return algorithm; } 104 set { 105 if (ExecutionState != ExecutionState.Prepared && ExecutionState != ExecutionState.Stopped) 106 throw new InvalidOperationException("Changing the algorithm is only allowed if the CrossValidation is stopped or prepared."); 107 if (algorithm != value) { 108 if (value != null && value.Problem != null && !(value.Problem is IDataAnalysisProblem)) 109 throw new ArgumentException("Only algorithms with a DataAnalysisProblem could be used for the cross validation."); 110 if (algorithm != null) DeregisterAlgorithmEvents(); 111 algorithm = value; 112 Parameters.Clear(); 113 114 if (algorithm != null) { 115 algorithm.StoreAlgorithmInEachRun = false; 116 RegisterAlgorithmEvents(); 117 algorithm.Prepare(true); 118 Parameters.AddRange(algorithm.Parameters); 93 return new RegressionWorkbench(this, cloner); 94 } 95 96 private void RegisterEventHandlers() { 97 ProblemDataParameter.ValueChanged += ProblemDataParameterValueChanged; 98 99 experiment.ExceptionOccurred += (sender, e) => OnExceptionOccured(e); 100 experiment.ExecutionStateChanged += (sender, e) => OnExecutionStateChanged(e); 101 experiment.ExecutionTimeChanged += (sender, e) => OnExecutionTimeChanged(e); 102 experiment.Paused += (sender, e) => OnPaused(e); 103 experiment.Prepared += (sender, e) => OnPrepared(e); 104 experiment.Started += (sender, e) => OnStarted(e); 105 experiment.Stopped += (sender, e) => OnStopped(e); 106 } 107 108 private IOptimizer CreateRandomForestExperiments() { 109 var exp = new Experiment(); 110 double[] rs = new double[] { 0.2, 0.3, 0.4, 0.5, 0.6, 0.65 }; 111 foreach (var r in rs) { 112 var cv = new CrossValidation(); 113 var rf = new RandomForestRegression(); 114 rf.R = r; 115 cv.Algorithm = rf; 116 cv.Folds.Value = 5; 117 exp.Optimizers.Add(cv); 118 } 119 return exp; 120 } 121 122 private IOptimizer CreateSvmExperiment() { 123 var exp = new Experiment(); 124 var costs = new double[] { Math.Pow(2, -5), Math.Pow(2, -3), Math.Pow(2, -1), Math.Pow(2, 1), Math.Pow(2, 3), Math.Pow(2, 5), Math.Pow(2, 7), Math.Pow(2, 9), Math.Pow(2, 11), Math.Pow(2, 13), Math.Pow(2, 15) }; 125 var gammas = new double[] { Math.Pow(2, -15), Math.Pow(2, -13), Math.Pow(2, -11), Math.Pow(2, -9), Math.Pow(2, -7), Math.Pow(2, -5), Math.Pow(2, -3), Math.Pow(2, -1), Math.Pow(2, 1), Math.Pow(2, 3) }; 126 var nus = new double[] { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 }; 127 foreach (var gamma in gammas) 128 foreach (var cost in costs) 129 foreach (var nu in nus) { 130 var cv = new CrossValidation(); 131 var svr = new SupportVectorRegression(); 132 svr.Nu.Value = nu; 133 svr.Cost.Value = cost; 134 svr.Gamma.Value = gamma; 135 cv.Algorithm = svr; 136 cv.Folds.Value = 5; 137 exp.Optimizers.Add(cv); 119 138 } 120 OnAlgorithmChanged(); 121 if (algorithm != null) OnProblemChanged(); 122 Prepare(); 123 } 139 return exp; 140 } 141 142 public RunCollection Runs { 143 get { return experiment.Runs; } 144 } 145 146 public void Prepare(bool clearRuns) { 147 experiment.Prepare(clearRuns); 148 } 149 150 public IEnumerable<IOptimizer> NestedOptimizers { 151 get { return experiment.NestedOptimizers; } 152 } 153 154 public ExecutionState ExecutionState { 155 get { return experiment.ExecutionState; } 156 } 157 158 public TimeSpan ExecutionTime { 159 get { return experiment.ExecutionTime; } 160 } 161 162 public void Prepare() { 163 experiment.Prepare(); 164 } 165 166 public void Start() { 167 experiment.Start(); 168 } 169 170 public void Pause() { 171 experiment.Pause(); 172 } 173 174 public void Stop() { 175 experiment.Stop(); 176 } 177 178 public event EventHandler ExecutionStateChanged; 179 private void OnExecutionStateChanged(EventArgs e) { 180 var handler = ExecutionStateChanged; 181 if (handler != null) handler(this, e); 182 } 183 184 public event EventHandler ExecutionTimeChanged; 185 private void OnExecutionTimeChanged(EventArgs e) { 186 var handler = ExecutionTimeChanged; 187 if (handler != null) handler(this, e); 188 } 189 190 public event EventHandler Prepared; 191 private void OnPrepared(EventArgs e) { 192 var handler = Prepared; 193 if (handler != null) handler(this, e); 194 } 195 196 public event EventHandler Started; 197 private void OnStarted(EventArgs e) { 198 var handler = Started; 199 if (handler != null) handler(this, e); 200 } 201 202 public event EventHandler Paused; 203 private void OnPaused(EventArgs e) { 204 var handler = Paused; 205 if (handler != null) handler(this, e); 206 } 207 208 public event EventHandler Stopped; 209 private void OnStopped(EventArgs e) { 210 var handler = Stopped; 211 if (handler != null) handler(this, e); 212 } 213 214 public event EventHandler<EventArgs<Exception>> ExceptionOccurred; 215 private void OnExceptionOccured(EventArgs<Exception> e) { 216 var handler = ExceptionOccurred; 217 if (handler != null) handler(this, e); 218 } 219 220 public void ProblemDataParameterValueChanged(object source, EventArgs e) { 221 foreach (var op in NestedOptimizers.OfType<IDataAnalysisAlgorithm<IRegressionProblem>>()) { 222 op.Problem.ProblemDataParameter.Value = ProblemData; 124 223 } 125 224 } 126 127 128 [Storable]129 private IDataAnalysisProblem problem;130 public IDataAnalysisProblem Problem {131 get {132 if (algorithm == null)133 return null;134 return (IDataAnalysisProblem)algorithm.Problem;135 }136 set {137 if (ExecutionState != ExecutionState.Prepared && ExecutionState != ExecutionState.Stopped)138 throw new InvalidOperationException("Changing the problem is only allowed if the CrossValidation is stopped or prepared.");139 if (algorithm == null) throw new ArgumentNullException("Could not set a problem before an algorithm was set.");140 algorithm.Problem = value;141 problem = value;142 }143 }144 145 IProblem IAlgorithm.Problem {146 get { return Problem; }147 set {148 if (value != null && !ProblemType.IsInstanceOfType(value))149 throw new ArgumentException("Only DataAnalysisProblems could be used for the cross validation.");150 Problem = (IDataAnalysisProblem)value;151 }152 }153 public Type ProblemType {154 get { return typeof(IDataAnalysisProblem); }155 }156 157 [Storable]158 private ItemCollection<IAlgorithm> clonedAlgorithms;159 160 public IEnumerable<IOptimizer> NestedOptimizers {161 get {162 if (Algorithm == null) yield break;163 yield return Algorithm;164 }165 }166 167 [Storable]168 private ResultCollection results;169 public ResultCollection Results {170 get { return results; }171 }172 173 [Storable]174 private IntValue folds;175 public IntValue Folds {176 get { return folds; }177 }178 [Storable]179 private IntValue samplesStart;180 public IntValue SamplesStart {181 get { return samplesStart; }182 }183 [Storable]184 private IntValue samplesEnd;185 public IntValue SamplesEnd {186 get { return samplesEnd; }187 }188 [Storable]189 private IntValue numberOfWorkers;190 public IntValue NumberOfWorkers {191 get { return numberOfWorkers; }192 }193 194 [Storable]195 private bool storeAlgorithmInEachRun;196 public bool StoreAlgorithmInEachRun {197 get { return storeAlgorithmInEachRun; }198 set {199 if (storeAlgorithmInEachRun != value) {200 storeAlgorithmInEachRun = value;201 OnStoreAlgorithmInEachRunChanged();202 }203 }204 }205 206 [Storable]207 private int runsCounter;208 [Storable]209 private RunCollection runs;210 public RunCollection Runs {211 get { return runs; }212 }213 214 [Storable]215 private ExecutionState executionState;216 public ExecutionState ExecutionState {217 get { return executionState; }218 private set {219 if (executionState != value) {220 executionState = value;221 OnExecutionStateChanged();222 OnItemImageChanged();223 }224 }225 }226 public override Image ItemImage {227 get {228 if (ExecutionState == ExecutionState.Prepared) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutablePrepared;229 else if (ExecutionState == ExecutionState.Started) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutableStarted;230 else if (ExecutionState == ExecutionState.Paused) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutablePaused;231 else if (ExecutionState == ExecutionState.Stopped) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutableStopped;232 else return HeuristicLab.Common.Resources.VSImageLibrary.Event;233 }234 }235 236 public TimeSpan ExecutionTime {237 get {238 if (ExecutionState != ExecutionState.Prepared)239 return TimeSpan.FromMilliseconds(clonedAlgorithms.Select(x => x.ExecutionTime.TotalMilliseconds).Sum());240 return TimeSpan.Zero;241 }242 }243 #endregion244 245 public void Prepare() {246 if (ExecutionState == ExecutionState.Started)247 throw new InvalidOperationException(string.Format("Prepare not allowed in execution state \"{0}\".", ExecutionState));248 results.Clear();249 clonedAlgorithms.Clear();250 if (Algorithm != null) {251 Algorithm.Prepare();252 if (Algorithm.ExecutionState == ExecutionState.Prepared) OnPrepared();253 }254 }255 public void Prepare(bool clearRuns) {256 if (clearRuns) runs.Clear();257 Prepare();258 }259 260 private bool startPending;261 public void Start() {262 if ((ExecutionState != ExecutionState.Prepared) && (ExecutionState != ExecutionState.Paused))263 throw new InvalidOperationException(string.Format("Start not allowed in execution state \"{0}\".", ExecutionState));264 265 if (Algorithm != null && !startPending) {266 startPending = true;267 //create cloned algorithms268 if (clonedAlgorithms.Count == 0) {269 int testSamplesCount = (SamplesEnd.Value - SamplesStart.Value) / Folds.Value;270 271 for (int i = 0; i < Folds.Value; i++) {272 IAlgorithm clonedAlgorithm = (IAlgorithm)algorithm.Clone();273 clonedAlgorithm.Name = algorithm.Name + " Fold " + i;274 IDataAnalysisProblem problem = clonedAlgorithm.Problem as IDataAnalysisProblem;275 ISymbolicDataAnalysisProblem symbolicProblem = problem as ISymbolicDataAnalysisProblem;276 277 int testStart = (i * testSamplesCount) + SamplesStart.Value;278 int testEnd = (i + 1) == Folds.Value ? SamplesEnd.Value : (i + 1) * testSamplesCount + SamplesStart.Value;279 280 problem.ProblemData.TestPartition.Start = testStart;281 problem.ProblemData.TestPartition.End = testEnd;282 DataAnalysisProblemData problemData = problem.ProblemData as DataAnalysisProblemData;283 if (problemData != null) {284 problemData.TrainingPartitionParameter.Hidden = false;285 problemData.TestPartitionParameter.Hidden = false;286 }287 288 if (symbolicProblem != null) {289 symbolicProblem.FitnessCalculationPartition.Start = SamplesStart.Value;290 symbolicProblem.FitnessCalculationPartition.End = SamplesEnd.Value;291 }292 293 clonedAlgorithms.Add(clonedAlgorithm);294 }295 }296 297 //start prepared or paused cloned algorithms298 int startedAlgorithms = 0;299 foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {300 if (startedAlgorithms < NumberOfWorkers.Value) {301 if (clonedAlgorithm.ExecutionState == ExecutionState.Prepared ||302 clonedAlgorithm.ExecutionState == ExecutionState.Paused) {303 clonedAlgorithm.Start();304 startedAlgorithms++;305 }306 }307 }308 OnStarted();309 }310 }311 312 private bool pausePending;313 public void Pause() {314 if (ExecutionState != ExecutionState.Started)315 throw new InvalidOperationException(string.Format("Pause not allowed in execution state \"{0}\".", ExecutionState));316 if (!pausePending) {317 pausePending = true;318 if (!startPending) PauseAllClonedAlgorithms();319 }320 }321 private void PauseAllClonedAlgorithms() {322 foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {323 if (clonedAlgorithm.ExecutionState == ExecutionState.Started)324 clonedAlgorithm.Pause();325 }326 }327 328 private bool stopPending;329 public void Stop() {330 if ((ExecutionState != ExecutionState.Started) && (ExecutionState != ExecutionState.Paused))331 throw new InvalidOperationException(string.Format("Stop not allowed in execution state \"{0}\".",332 ExecutionState));333 if (!stopPending) {334 stopPending = true;335 if (!startPending) StopAllClonedAlgorithms();336 }337 }338 private void StopAllClonedAlgorithms() {339 foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {340 if (clonedAlgorithm.ExecutionState == ExecutionState.Started ||341 clonedAlgorithm.ExecutionState == ExecutionState.Paused)342 clonedAlgorithm.Stop();343 }344 }345 346 #region collect parameters and results347 public override void CollectParameterValues(IDictionary<string, IItem> values) {348 values.Add("Algorithm Name", new StringValue(Name));349 values.Add("Algorithm Type", new StringValue(GetType().GetPrettyName()));350 values.Add("Folds", new IntValue(Folds.Value));351 352 if (algorithm != null) {353 values.Add("CrossValidation Algorithm Name", new StringValue(Algorithm.Name));354 values.Add("CrossValidation Algorithm Type", new StringValue(Algorithm.GetType().GetPrettyName()));355 base.CollectParameterValues(values);356 }357 if (Problem != null) {358 values.Add("Problem Name", new StringValue(Problem.Name));359 values.Add("Problem Type", new StringValue(Problem.GetType().GetPrettyName()));360 Problem.CollectParameterValues(values);361 }362 }363 364 public void CollectResultValues(IDictionary<string, IItem> results) {365 var clonedResults = (ResultCollection)this.results.Clone();366 foreach (var result in clonedResults) {367 results.Add(result.Name, result.Value);368 }369 }370 371 private void AggregateResultValues(IDictionary<string, IItem> results) {372 Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>();373 IEnumerable<IRun> runs = clonedAlgorithms.Select(alg => alg.Runs.FirstOrDefault()).Where(run => run != null);374 IEnumerable<KeyValuePair<string, IItem>> resultCollections = runs.Where(x => x != null).SelectMany(x => x.Results).ToList();375 376 foreach (IResult result in ExtractAndAggregateResults<IntValue>(resultCollections))377 results.Add(result.Name, result.Value);378 foreach (IResult result in ExtractAndAggregateResults<DoubleValue>(resultCollections))379 results.Add(result.Name, result.Value);380 foreach (IResult result in ExtractAndAggregateResults<PercentValue>(resultCollections))381 results.Add(result.Name, result.Value);382 foreach (IResult result in ExtractAndAggregateRegressionSolutions(resultCollections)) {383 results.Add(result.Name, result.Value);384 }385 foreach (IResult result in ExtractAndAggregateClassificationSolutions(resultCollections)) {386 results.Add(result.Name, result.Value);387 }388 results.Add("Execution Time", new TimeSpanValue(this.ExecutionTime));389 results.Add("CrossValidation Folds", new RunCollection(runs));390 }391 392 private IEnumerable<IResult> ExtractAndAggregateRegressionSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) {393 Dictionary<string, List<IRegressionSolution>> resultSolutions = new Dictionary<string, List<IRegressionSolution>>();394 foreach (var result in resultCollections) {395 var regressionSolution = result.Value as IRegressionSolution;396 if (regressionSolution != null) {397 if (resultSolutions.ContainsKey(result.Key)) {398 resultSolutions[result.Key].Add(regressionSolution);399 } else {400 resultSolutions.Add(result.Key, new List<IRegressionSolution>() { regressionSolution });401 }402 }403 }404 List<IResult> aggregatedResults = new List<IResult>();405 foreach (KeyValuePair<string, List<IRegressionSolution>> solutions in resultSolutions) {406 // clone manually to correctly clone references between cloned root objects407 Cloner cloner = new Cloner();408 var problemDataClone = (IRegressionProblemData)cloner.Clone(Problem.ProblemData);409 // set partitions of problem data clone correctly410 problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;411 problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;412 // clone models413 var ensembleSolution = new RegressionEnsembleSolution(414 solutions.Value.Select(x => cloner.Clone(x.Model)),415 problemDataClone,416 solutions.Value.Select(x => cloner.Clone(x.ProblemData.TrainingPartition)),417 solutions.Value.Select(x => cloner.Clone(x.ProblemData.TestPartition)));418 419 aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));420 }421 List<IResult> flattenedResults = new List<IResult>();422 CollectResultsRecursively("", aggregatedResults, flattenedResults);423 return flattenedResults;424 }425 426 private IEnumerable<IResult> ExtractAndAggregateClassificationSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) {427 Dictionary<string, List<IClassificationSolution>> resultSolutions = new Dictionary<string, List<IClassificationSolution>>();428 foreach (var result in resultCollections) {429 var classificationSolution = result.Value as IClassificationSolution;430 if (classificationSolution != null) {431 if (resultSolutions.ContainsKey(result.Key)) {432 resultSolutions[result.Key].Add(classificationSolution);433 } else {434 resultSolutions.Add(result.Key, new List<IClassificationSolution>() { classificationSolution });435 }436 }437 }438 var aggregatedResults = new List<IResult>();439 foreach (KeyValuePair<string, List<IClassificationSolution>> solutions in resultSolutions) {440 // clone manually to correctly clone references between cloned root objects441 Cloner cloner = new Cloner();442 var problemDataClone = (IClassificationProblemData)cloner.Clone(Problem.ProblemData);443 // set partitions of problem data clone correctly444 problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;445 problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;446 // clone models447 var ensembleSolution = new ClassificationEnsembleSolution(448 solutions.Value.Select(x => cloner.Clone(x.Model)),449 problemDataClone,450 solutions.Value.Select(x => cloner.Clone(x.ProblemData.TrainingPartition)),451 solutions.Value.Select(x => cloner.Clone(x.ProblemData.TestPartition)));452 453 aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));454 }455 List<IResult> flattenedResults = new List<IResult>();456 CollectResultsRecursively("", aggregatedResults, flattenedResults);457 return flattenedResults;458 }459 460 private void CollectResultsRecursively(string path, IEnumerable<IResult> results, IList<IResult> flattenedResults) {461 foreach (IResult result in results) {462 flattenedResults.Add(new Result(path + result.Name, result.Value));463 ResultCollection childCollection = result.Value as ResultCollection;464 if (childCollection != null) {465 CollectResultsRecursively(path + result.Name + ".", childCollection, flattenedResults);466 }467 }468 }469 470 private static IEnumerable<IResult> ExtractAndAggregateResults<T>(IEnumerable<KeyValuePair<string, IItem>> results)471 where T : class, IItem, new() {472 Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>();473 foreach (var resultValue in results.Where(r => r.Value.GetType() == typeof(T))) {474 if (!resultValues.ContainsKey(resultValue.Key))475 resultValues[resultValue.Key] = new List<double>();476 resultValues[resultValue.Key].Add(ConvertToDouble(resultValue.Value));477 }478 479 DoubleValue doubleValue;480 if (typeof(T) == typeof(PercentValue))481 doubleValue = new PercentValue();482 else if (typeof(T) == typeof(DoubleValue))483 doubleValue = new DoubleValue();484 else if (typeof(T) == typeof(IntValue))485 doubleValue = new DoubleValue();486 else487 throw new NotSupportedException();488 489 List<IResult> aggregatedResults = new List<IResult>();490 foreach (KeyValuePair<string, List<double>> resultValue in resultValues) {491 doubleValue.Value = resultValue.Value.Average();492 aggregatedResults.Add(new Result(resultValue.Key + " (average)", (IItem)doubleValue.Clone()));493 doubleValue.Value = resultValue.Value.StandardDeviation();494 aggregatedResults.Add(new Result(resultValue.Key + " (std.dev.)", (IItem)doubleValue.Clone()));495 }496 return aggregatedResults;497 }498 499 private static double ConvertToDouble(IItem item) {500 if (item is DoubleValue) return ((DoubleValue)item).Value;501 else if (item is IntValue) return ((IntValue)item).Value;502 else throw new NotSupportedException("Could not convert any item type to double");503 }504 #endregion505 506 #region events507 private void RegisterEvents() {508 Folds.ValueChanged += new EventHandler(Folds_ValueChanged);509 SamplesStart.ValueChanged += new EventHandler(SamplesStart_ValueChanged);510 SamplesEnd.ValueChanged += new EventHandler(SamplesEnd_ValueChanged);511 RegisterClonedAlgorithmsEvents();512 }513 private void Folds_ValueChanged(object sender, EventArgs e) {514 if (ExecutionState != ExecutionState.Prepared)515 throw new InvalidOperationException("Can not change number of folds if the execution state is not prepared.");516 }517 private void SamplesStart_ValueChanged(object sender, EventArgs e) {518 if (Problem != null) Problem.ProblemData.TrainingPartition.Start = SamplesStart.Value;519 }520 private void SamplesEnd_ValueChanged(object sender, EventArgs e) {521 if (Problem != null) Problem.ProblemData.TrainingPartition.End = SamplesEnd.Value;522 }523 524 #region template algorithms events525 public event EventHandler AlgorithmChanged;526 private void OnAlgorithmChanged() {527 EventHandler handler = AlgorithmChanged;528 if (handler != null) handler(this, EventArgs.Empty);529 OnProblemChanged();530 if (Problem == null) ExecutionState = ExecutionState.Stopped;531 }532 private void RegisterAlgorithmEvents() {533 algorithm.ProblemChanged += new EventHandler(Algorithm_ProblemChanged);534 algorithm.ExecutionStateChanged += new EventHandler(Algorithm_ExecutionStateChanged);535 }536 private void DeregisterAlgorithmEvents() {537 algorithm.ProblemChanged -= new EventHandler(Algorithm_ProblemChanged);538 algorithm.ExecutionStateChanged -= new EventHandler(Algorithm_ExecutionStateChanged);539 }540 private void Algorithm_ProblemChanged(object sender, EventArgs e) {541 if (algorithm.Problem != null && !(algorithm.Problem is IDataAnalysisProblem)) {542 algorithm.Problem = problem;543 throw new ArgumentException("A cross validation algorithm can only contain DataAnalysisProblems.");544 }545 problem = (IDataAnalysisProblem)algorithm.Problem;546 OnProblemChanged();547 }548 public event EventHandler ProblemChanged;549 private void OnProblemChanged() {550 EventHandler handler = ProblemChanged;551 if (handler != null) handler(this, EventArgs.Empty);552 553 SamplesStart.Value = 0;554 if (Problem != null) {555 Problem.ProblemDataChanged += (object sender, EventArgs e) => OnProblemChanged();556 SamplesEnd.Value = Problem.ProblemData.Dataset.Rows;557 558 DataAnalysisProblemData problemData = Problem.ProblemData as DataAnalysisProblemData;559 if (problemData != null) {560 problemData.TrainingPartitionParameter.Hidden = true;561 problemData.TestPartitionParameter.Hidden = true;562 }563 ISymbolicDataAnalysisProblem symbolicProblem = Problem as ISymbolicDataAnalysisProblem;564 if (symbolicProblem != null) {565 symbolicProblem.FitnessCalculationPartitionParameter.Hidden = true;566 symbolicProblem.FitnessCalculationPartition.Start = SamplesStart.Value;567 symbolicProblem.FitnessCalculationPartition.End = SamplesEnd.Value;568 symbolicProblem.ValidationPartitionParameter.Hidden = true;569 symbolicProblem.ValidationPartition.Start = 0;570 symbolicProblem.ValidationPartition.End = 0;571 }572 } else573 SamplesEnd.Value = 0;574 575 SamplesStart_ValueChanged(this, EventArgs.Empty);576 SamplesEnd_ValueChanged(this, EventArgs.Empty);577 }578 579 private void Algorithm_ExecutionStateChanged(object sender, EventArgs e) {580 switch (Algorithm.ExecutionState) {581 case ExecutionState.Prepared: OnPrepared();582 break;583 case ExecutionState.Started: throw new InvalidOperationException("Algorithm template can not be started.");584 case ExecutionState.Paused: throw new InvalidOperationException("Algorithm template can not be paused.");585 case ExecutionState.Stopped: OnStopped();586 break;587 }588 }589 #endregion590 591 #region clonedAlgorithms events592 private void RegisterClonedAlgorithmsEvents() {593 clonedAlgorithms.ItemsAdded += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded);594 clonedAlgorithms.ItemsRemoved += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved);595 clonedAlgorithms.CollectionReset += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset);596 foreach (IAlgorithm algorithm in clonedAlgorithms)597 RegisterClonedAlgorithmEvents(algorithm);598 }599 private void DeregisterClonedAlgorithmsEvents() {600 clonedAlgorithms.ItemsAdded -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded);601 clonedAlgorithms.ItemsRemoved -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved);602 clonedAlgorithms.CollectionReset -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset);603 foreach (IAlgorithm algorithm in clonedAlgorithms)604 DeregisterClonedAlgorithmEvents(algorithm);605 }606 private void ClonedAlgorithms_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {607 foreach (IAlgorithm algorithm in e.Items)608 RegisterClonedAlgorithmEvents(algorithm);609 }610 private void ClonedAlgorithms_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {611 foreach (IAlgorithm algorithm in e.Items)612 DeregisterClonedAlgorithmEvents(algorithm);613 }614 private void ClonedAlgorithms_CollectionReset(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {615 foreach (IAlgorithm algorithm in e.OldItems)616 DeregisterClonedAlgorithmEvents(algorithm);617 foreach (IAlgorithm algorithm in e.Items)618 RegisterClonedAlgorithmEvents(algorithm);619 }620 private void RegisterClonedAlgorithmEvents(IAlgorithm algorithm) {621 algorithm.ExceptionOccurred += new EventHandler<EventArgs<Exception>>(ClonedAlgorithm_ExceptionOccurred);622 algorithm.ExecutionTimeChanged += new EventHandler(ClonedAlgorithm_ExecutionTimeChanged);623 algorithm.Started += new EventHandler(ClonedAlgorithm_Started);624 algorithm.Paused += new EventHandler(ClonedAlgorithm_Paused);625 algorithm.Stopped += new EventHandler(ClonedAlgorithm_Stopped);626 }627 private void DeregisterClonedAlgorithmEvents(IAlgorithm algorithm) {628 algorithm.ExceptionOccurred -= new EventHandler<EventArgs<Exception>>(ClonedAlgorithm_ExceptionOccurred);629 algorithm.ExecutionTimeChanged -= new EventHandler(ClonedAlgorithm_ExecutionTimeChanged);630 algorithm.Started -= new EventHandler(ClonedAlgorithm_Started);631 algorithm.Paused -= new EventHandler(ClonedAlgorithm_Paused);632 algorithm.Stopped -= new EventHandler(ClonedAlgorithm_Stopped);633 }634 private void ClonedAlgorithm_ExceptionOccurred(object sender, EventArgs<Exception> e) {635 OnExceptionOccurred(e.Value);636 }637 private void ClonedAlgorithm_ExecutionTimeChanged(object sender, EventArgs e) {638 OnExecutionTimeChanged();639 }640 641 private readonly object locker = new object();642 private void ClonedAlgorithm_Started(object sender, EventArgs e) {643 lock (locker) {644 IAlgorithm algorithm = sender as IAlgorithm;645 if (algorithm != null && !results.ContainsKey(algorithm.Name))646 results.Add(new Result(algorithm.Name, "Contains results for the specific fold.", algorithm.Results));647 648 if (startPending) {649 int startedAlgorithms = clonedAlgorithms.Count(alg => alg.ExecutionState == ExecutionState.Started);650 if (startedAlgorithms == NumberOfWorkers.Value ||651 clonedAlgorithms.All(alg => alg.ExecutionState != ExecutionState.Prepared))652 startPending = false;653 654 if (pausePending) PauseAllClonedAlgorithms();655 if (stopPending) StopAllClonedAlgorithms();656 }657 }658 }659 660 private void ClonedAlgorithm_Paused(object sender, EventArgs e) {661 lock (locker) {662 if (pausePending && clonedAlgorithms.All(alg => alg.ExecutionState != ExecutionState.Started))663 OnPaused();664 }665 }666 667 private void ClonedAlgorithm_Stopped(object sender, EventArgs e) {668 lock (locker) {669 if (!stopPending && ExecutionState == ExecutionState.Started) {670 IAlgorithm preparedAlgorithm = clonedAlgorithms.Where(alg => alg.ExecutionState == ExecutionState.Prepared ||671 alg.ExecutionState == ExecutionState.Paused).FirstOrDefault();672 if (preparedAlgorithm != null) preparedAlgorithm.Start();673 }674 if (ExecutionState != ExecutionState.Stopped) {675 if (clonedAlgorithms.All(alg => alg.ExecutionState == ExecutionState.Stopped))676 OnStopped();677 else if (stopPending &&678 clonedAlgorithms.All(679 alg => alg.ExecutionState == ExecutionState.Prepared || alg.ExecutionState == ExecutionState.Stopped))680 OnStopped();681 }682 }683 }684 #endregion685 #endregion686 687 #region event firing688 public event EventHandler ExecutionStateChanged;689 private void OnExecutionStateChanged() {690 EventHandler handler = ExecutionStateChanged;691 if (handler != null) handler(this, EventArgs.Empty);692 }693 public event EventHandler ExecutionTimeChanged;694 private void OnExecutionTimeChanged() {695 EventHandler handler = ExecutionTimeChanged;696 if (handler != null) handler(this, EventArgs.Empty);697 }698 public event EventHandler Prepared;699 private void OnPrepared() {700 ExecutionState = ExecutionState.Prepared;701 EventHandler handler = Prepared;702 if (handler != null) handler(this, EventArgs.Empty);703 OnExecutionTimeChanged();704 }705 public event EventHandler Started;706 private void OnStarted() {707 startPending = false;708 ExecutionState = ExecutionState.Started;709 EventHandler handler = Started;710 if (handler != null) handler(this, EventArgs.Empty);711 }712 public event EventHandler Paused;713 private void OnPaused() {714 pausePending = false;715 ExecutionState = ExecutionState.Paused;716 EventHandler handler = Paused;717 if (handler != null) handler(this, EventArgs.Empty);718 }719 public event EventHandler Stopped;720 private void OnStopped() {721 stopPending = false;722 Dictionary<string, IItem> collectedResults = new Dictionary<string, IItem>();723 AggregateResultValues(collectedResults);724 results.AddRange(collectedResults.Select(x => new Result(x.Key, x.Value)).Cast<IResult>().ToArray());725 runsCounter++;726 runs.Add(new Run(string.Format("{0} Run {1}", Name, runsCounter), this));727 ExecutionState = ExecutionState.Stopped;728 EventHandler handler = Stopped;729 if (handler != null) handler(this, EventArgs.Empty);730 }731 public event EventHandler<EventArgs<Exception>> ExceptionOccurred;732 private void OnExceptionOccurred(Exception exception) {733 EventHandler<EventArgs<Exception>> handler = ExceptionOccurred;734 if (handler != null) handler(this, new EventArgs<Exception>(exception));735 }736 public event EventHandler StoreAlgorithmInEachRunChanged;737 private void OnStoreAlgorithmInEachRunChanged() {738 EventHandler handler = StoreAlgorithmInEachRunChanged;739 if (handler != null) handler(this, EventArgs.Empty);740 }741 #endregion742 225 } 743 226 }
Note: See TracChangeset
for help on using the changeset viewer.