Changeset 18213 for branches/3138_Shape_Constraints_Transformations/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression
- Timestamp:
- 02/08/22 13:06:49 (2 years ago)
- Location:
- branches/3138_Shape_Constraints_Transformations/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/3138_Shape_Constraints_Transformations/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4.csproj
r18181 r18213 263 263 <Private>False</Private> 264 264 </ProjectReference> 265 <ProjectReference Include="..\..\HeuristicLab.Random\3.3\HeuristicLab.Random-3.3.csproj"> 266 <Project>{F4539FB6-4708-40C9-BE64-0A1390AEA197}</Project> 267 <Name>HeuristicLab.Random-3.3</Name> 268 </ProjectReference> 265 269 </ItemGroup> 266 270 <ItemGroup> -
branches/3138_Shape_Constraints_Transformations/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/ShapeConstraintsAnalyzer.cs
r17958 r18213 20 20 #endregion 21 21 22 using System.Collections; 23 using System.Collections.Generic; 22 24 using System.Linq; 23 25 using HEAL.Attic; … … 28 30 using HeuristicLab.Optimization; 29 31 using HeuristicLab.Parameters; 32 using HeuristicLab.Random; 30 33 31 34 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression { … … 36 39 private const string ConstraintViolationsParameterName = "ConstraintViolations"; 37 40 private const string InfeasibleSolutionsParameterName = "InfeasibleSolutions"; 41 private const string AverageConstraintViolationsParameterName = "AverageConstraintViolations"; 42 private const string SymbolicDataAnalysisTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter"; 38 43 39 44 #region parameter properties … … 48 53 (IResultParameter<DataTable>)Parameters[InfeasibleSolutionsParameterName]; 49 54 55 public IResultParameter<DataTable> AverageConstraintViolationsParameter => 56 (IResultParameter<DataTable>)Parameters[AverageConstraintViolationsParameterName]; 57 58 public ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter> SymbolicDataAnalysisTreeInterpreterParameter => 59 (ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>)Parameters[SymbolicDataAnalysisTreeInterpreterParameterName]; 60 50 61 #endregion 51 62 52 #region properties63 #region properties 53 64 public IRegressionProblemData RegressionProblemData => RegressionProblemDataParameter.ActualValue; 54 65 public DataTable ConstraintViolations => ConstraintViolationsParameter.ActualValue; 55 66 public DataTable InfeasibleSolutions => InfeasibleSolutionsParameter.ActualValue; 67 public DataTable AverageConstraintViolations => AverageConstraintViolationsParameter.ActualValue; 56 68 #endregion 57 69 … … 75 87 Parameters.Add(new ResultParameter<DataTable>(InfeasibleSolutionsParameterName, 76 88 "The number of infeasible solutions.")); 89 Parameters.Add(new ResultParameter<DataTable>(AverageConstraintViolationsParameterName, 90 "The average violations of each constraint.")); 91 Parameters.Add(new LookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>(SymbolicDataAnalysisTreeInterpreterParameterName, 92 "The interpreter that should be used to calculate the output values of the symbolic data analysis tree.") { Hidden = true }); 93 77 94 78 95 … … 90 107 } 91 108 }; 109 110 AverageConstraintViolationsParameter.DefaultValue = new DataTable(SymbolicDataAnalysisTreeInterpreterParameterName) { 111 VisualProperties = { 112 XAxisTitle = "Generations", 113 YAxisTitle = "Average Constraint Violations" 114 } 115 }; 92 116 } 93 117 94 118 95 119 [StorableHook(HookType.AfterDeserialization)] 96 private void AfterDeserialization() { } 120 private void AfterDeserialization() { 121 if (!Parameters.ContainsKey(SymbolicDataAnalysisTreeInterpreterParameterName)) 122 Parameters.Add(new LookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>(SymbolicDataAnalysisTreeInterpreterParameterName, 123 "The interpreter that should be used to calculate the output values of the symbolic data analysis tree.") { Hidden = true }); 124 } 97 125 98 126 public override IOperation Apply() { 99 127 var problemData = (IShapeConstrainedRegressionProblemData)RegressionProblemData; 100 128 var trees = SymbolicExpressionTree.ToArray(); 101 129 102 130 var results = ResultCollection; 103 var constraints = problemData.ShapeConstraints.EnabledConstraints; 131 var modelConstraints = problemData.ShapeConstraints.EnabledConstraints; 132 var extendedConstraints = problemData.CheckedExtendedConstraints; 104 133 var variableRanges = problemData.VariableRanges; 105 134 var constraintViolationsTable = ConstraintViolations; 135 var averageConstraintViolations = AverageConstraintViolations; 136 var interpreter = SymbolicDataAnalysisTreeInterpreterParameter.ActualValue; 106 137 var estimator = new IntervalArithBoundsEstimator(); 138 139 if (!constraintViolationsTable.Rows.Any()) { 140 foreach (var constraint in modelConstraints) { 141 constraintViolationsTable.Rows.Add(new DataRow(constraint.ToString())); 142 averageConstraintViolations.Rows.Add(new DataRow(constraint.ToString())); 143 } 107 144 108 if (!constraintViolationsTable.Rows.Any()) 109 foreach (var constraint in constraints) 110 constraintViolationsTable.Rows.Add(new DataRow(constraint.ToString())); 145 foreach (var extendedConstraint in extendedConstraints.SelectMany(x => x.ShapeConstraints.EnabledConstraints)) { 146 constraintViolationsTable.Rows.Add(new DataRow(extendedConstraint.ToString())); 147 averageConstraintViolations.Rows.Add(new DataRow(extendedConstraint.ToString())); 148 } 149 } 111 150 112 foreach (var constraint in constraints) { 113 var numViolations = trees.Count(tree => IntervalUtil.GetConstraintViolation(constraint, estimator, variableRanges, tree) > 0.0); 114 constraintViolationsTable.Rows[constraint.ToString()].Values.Add(numViolations); 151 var violationsPerTree = new Dictionary<ISymbolicExpressionTree, int>(); 152 var violationsPerConstraint = new Dictionary<string, IList<double>>(); 153 154 foreach(var tree in trees) { 155 var violations = NMSESingleObjectiveConstraintsEvaluator.CalculateShapeConstraintsViolations(problemData, tree, interpreter, estimator, new MersenneTwister()); 156 foreach(var violation in violations) { 157 var constraint = violation.Item1; 158 var error = violation.Item2; 159 if (!violationsPerConstraint.ContainsKey(constraint.ToString())) 160 violationsPerConstraint.Add(constraint.ToString(), new List<double>()); 161 violationsPerConstraint[constraint.ToString()].Add(error); 162 } 163 violationsPerTree.Add(tree, violations.Count(x => x.Item2 > 0)); 164 } 165 166 foreach (var constraint in modelConstraints) { 167 var errors = violationsPerConstraint[constraint.ToString()]; 168 constraintViolationsTable.Rows[constraint.ToString()].Values.Add(errors.Count(x => x > 0)); 169 averageConstraintViolations.Rows[constraint.ToString()].Values.Add(errors.Sum() / errors.Count()); 170 } 171 172 foreach (var extendedConstraint in extendedConstraints.SelectMany(x => x.ShapeConstraints.EnabledConstraints)) { 173 var errors = violationsPerConstraint[extendedConstraint.ToString()]; 174 constraintViolationsTable.Rows[extendedConstraint.ToString()].Values.Add(errors.Count(x => x > 0)); 175 averageConstraintViolations.Rows[extendedConstraint.ToString()].Values.Add(errors.Sum() / errors.Count()); 115 176 } 116 177 … … 121 182 infeasibleSolutionsDataTable.Rows[InfeasibleSolutionsParameterName] 122 183 .Values 123 .Add(trees.Count(t => IntervalUtil.GetConstraintViolations(constraints, estimator, variableRanges, t).Any(x => x > 0.0)));184 .Add(trees.Count(t => violationsPerTree[t] > 0)); 124 185 125 186 return base.Apply(); -
branches/3138_Shape_Constraints_Transformations/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/NMSESingleObjectiveConstraintsEvaluator.cs
r18181 r18213 21 21 22 22 using System; 23 using System.Collections; 23 24 using System.Collections.Generic; 24 25 using System.Linq; … … 29 30 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 30 31 using HeuristicLab.Parameters; 32 using HeuristicLab.Random; 31 33 32 34 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression { … … 41 43 private const string BoundsEstimatorParameterName = "BoundsEstimator"; 42 44 private const string PenaltyFactorParameterName = "PenaltyFactor"; 43 private const string ExtendedConstraintsParameterName = "ExtendedConstraints";44 45 45 46 … … 58 59 (IFixedValueParameter<DoubleValue>)Parameters[PenaltyFactorParameterName]; 59 60 60 public IFixedValueParameter<IItemList<ExtendedConstraint>> ExtendedConstraintsParameter =>61 (IFixedValueParameter<IItemList<ExtendedConstraint>>)Parameters[ExtendedConstraintsParameterName];62 61 63 62 … … 86 85 set => PenaltyFactorParameter.Value.Value = value; 87 86 } 88 89 public IEnumerable<ExtendedConstraint> ExtendedConstraints {90 get => ExtendedConstraintsParameter.Value;91 }92 93 94 87 95 88 public override bool Maximization => false; // NMSE is minimized … … 116 109 Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyFactorParameterName, 117 110 "Punishment factor for constraint violations for soft constraint handling (fitness = NMSE + penaltyFactor * avg(violations)) (default: 1.0)", new DoubleValue(1.0))); 118 Parameters.Add(new FixedValueParameter<ItemList<ExtendedConstraint>>(ExtendedConstraintsParameterName, "", new ItemList<ExtendedConstraint>()));119 111 } 120 112 … … 135 127 var estimationLimits = EstimationLimitsParameter.ActualValue; 136 128 var applyLinearScaling = ApplyLinearScalingParameter.ActualValue.Value; 129 var random = RandomParameter.ActualValue; 137 130 138 131 if (OptimizeParameters) { … … 175 168 176 169 var quality = Calculate(interpreter, tree, estimationLimits.Lower, estimationLimits.Upper, problemData, rows, 177 BoundsEstimator, UseSoftConstraints, PenalityFactor, ExtendedConstraints);170 BoundsEstimator, random, UseSoftConstraints, PenalityFactor); 178 171 QualityParameter.ActualValue = new DoubleValue(quality); 179 172 … … 186 179 double lowerEstimationLimit, double upperEstimationLimit, 187 180 IRegressionProblemData problemData, IEnumerable<int> rows, 188 IBoundsEstimator estimator, 189 bool useSoftConstraints = false, double penaltyFactor = 1.0, 190 IEnumerable<ExtendedConstraint> extendedConstraints = null) { 191 192 var estimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, rows); 181 IBoundsEstimator estimator, IRandom random, 182 bool useSoftConstraints = false, double penaltyFactor = 1.0) { 183 184 var trainingEstimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, rows); 193 185 var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 194 var constraints = Enumerable.Empty<ShapeConstraint>(); 186 187 var trainingBoundedEstimatedValues = trainingEstimatedValues.LimitToRange(lowerEstimationLimit, upperEstimationLimit); 188 var nmse = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(targetValues, trainingBoundedEstimatedValues, 189 out var errorState); 190 191 if (errorState != OnlineCalculatorError.None) 192 return double.MaxValue; 193 194 var violations = Enumerable.Empty<double>(); 195 195 if (problemData is ShapeConstrainedRegressionProblemData scProbData) { 196 constraints = scProbData.ShapeConstraints.EnabledConstraints; 197 } 198 var intervalCollection = problemData.VariableRanges; 199 200 var boundedEstimatedValues = estimatedValues.LimitToRange(lowerEstimationLimit, upperEstimationLimit); 201 var nmse = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(targetValues, boundedEstimatedValues, 202 out var errorState); 203 204 if (errorState != OnlineCalculatorError.None) { 205 return 1.0; 206 } 207 208 var constraintViolations = IntervalUtil.GetConstraintViolations(constraints, estimator, intervalCollection, tree); 209 210 if (constraintViolations.Any(x => double.IsNaN(x) || double.IsInfinity(x))) { 211 return 1.0; 212 } 196 violations = CalculateShapeConstraintsViolations(scProbData, tree, interpreter, estimator, random).Select(x => x.Item2); 197 } 198 199 if (violations.Any(x => double.IsNaN(x) || double.IsInfinity(x))) 200 return double.MaxValue; 213 201 214 202 if (useSoftConstraints) { … … 216 204 throw new ArgumentException("The parameter has to be >= 0.0.", nameof(penaltyFactor)); 217 205 218 var weightedViolationsAvg = constraints 219 .Zip(constraintViolations, (c, v) => c.Weight * v) 220 .Average(); 221 222 return Math.Min(nmse, 1.0) + penaltyFactor * weightedViolationsAvg; 223 } else if (constraintViolations.Any(x => x > 0.0)) { 224 return 1.0; 225 } 226 227 return nmse; 206 return nmse + penaltyFactor * violations.Average(); 207 } 208 return violations.Any(x => x > 0.0) ? 1.0 : nmse; 209 } 210 211 public static IEnumerable<Tuple<ShapeConstraint, double>> CalculateShapeConstraintsViolations( 212 IShapeConstrainedRegressionProblemData problemData, ISymbolicExpressionTree tree, 213 ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IBoundsEstimator estimator, 214 IRandom random) { 215 IList<Tuple<ShapeConstraint, double>> violations = new List<Tuple<ShapeConstraint, double>>(); 216 217 var baseConstraints = problemData.ShapeConstraints.EnabledConstraints; 218 var intervalCollection = problemData.VariableRanges; 219 var extendedShapeConstraints = problemData.CheckedExtendedConstraints; 220 var allEstimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, problemData.AllIndices); 221 222 foreach (var constraint in baseConstraints) 223 violations.Add(Tuple.Create(constraint, IntervalUtil.GetConstraintViolation(constraint, estimator, intervalCollection, tree) * constraint.Weight)); 224 225 IDictionary<string, IList> dict = new Dictionary<string, IList>(); 226 foreach (var varName in problemData.Dataset.VariableNames) { 227 if (varName != problemData.TargetVariable) 228 dict.Add(varName, problemData.Dataset.GetDoubleValues(varName).ToList()); 229 else dict.Add(varName, allEstimatedValues.ToList()); 230 } 231 var tmpDataset = new Dataset(dict.Keys, dict.Values); 232 233 foreach (var extendedConstraint in extendedShapeConstraints) { 234 var enabledConstraints = extendedConstraint.ShapeConstraints.EnabledConstraints; 235 if (enabledConstraints.Any()) { 236 var extendedConstraintExprValues = interpreter.GetSymbolicExpressionTreeValues(extendedConstraint.Tree, tmpDataset, problemData.AllIndices); 237 var extendedConstraintExprInterval = new Interval(extendedConstraintExprValues.Min(), extendedConstraintExprValues.Max()); 238 239 foreach (var constraint in enabledConstraints) { 240 if (constraint.Regions.Count > 0) { 241 // adapt dataset 242 foreach (var kvp in constraint.Regions.GetReadonlyDictionary()) { 243 var lb = double.IsNegativeInfinity(kvp.Value.LowerBound) ? double.MinValue : kvp.Value.LowerBound; 244 var ub = double.IsPositiveInfinity(kvp.Value.UpperBound) ? double.MaxValue : kvp.Value.UpperBound; 245 246 var vals = Enumerable.Range(0, dict[kvp.Key].Count - 2) 247 .Select(x => UniformDistributedRandom.NextDouble(random, lb, ub)) 248 .ToList(); 249 vals.Add(lb); 250 vals.Add(ub); 251 vals.Sort(); 252 dict[kvp.Key] = vals; 253 } 254 // calc again with new regions 255 tmpDataset = new Dataset(dict.Keys, dict.Values); 256 // calc target again 257 allEstimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, tmpDataset, problemData.AllIndices); 258 dict[problemData.TargetVariable] = allEstimatedValues.ToList(); 259 tmpDataset = new Dataset(dict.Keys, dict.Values); 260 extendedConstraintExprValues = interpreter.GetSymbolicExpressionTreeValues(extendedConstraint.Tree, tmpDataset, problemData.AllIndices); 261 extendedConstraintExprInterval = new Interval(extendedConstraintExprValues.Min(), extendedConstraintExprValues.Max()); 262 } 263 violations.Add(Tuple.Create(constraint, IntervalUtil.GetIntervalError(constraint.Interval, extendedConstraintExprInterval, constraint.Threshold) * constraint.Weight)); 264 } 265 } 266 } 267 return violations; 228 268 } 229 269 … … 234 274 EstimationLimitsParameter.ExecutionContext = context; 235 275 ApplyLinearScalingParameter.ExecutionContext = context; 276 RandomParameter.ExecutionContext = context; 236 277 237 278 var nmse = Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, 238 279 EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, 239 problemData, rows, BoundsEstimator, UseSoftConstraints, PenalityFactor, ExtendedConstraints);280 problemData, rows, BoundsEstimator, RandomParameter.Value, UseSoftConstraints, PenalityFactor); 240 281 241 282 SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null; 242 283 EstimationLimitsParameter.ExecutionContext = null; 243 284 ApplyLinearScalingParameter.ExecutionContext = null; 285 RandomParameter.ExecutionContext = null; 244 286 245 287 return nmse;
Note: See TracChangeset
for help on using the changeset viewer.