Changeset 18239 for branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression
- Timestamp:
- 03/22/22 13:28:56 (3 years ago)
- Location:
- branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4.csproj
r17930 r18239 47 47 <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet> 48 48 <Prefer32Bit>false</Prefer32Bit> 49 <LangVersion> 7</LangVersion>49 <LangVersion>latest</LangVersion> 50 50 </PropertyGroup> 51 51 <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' "> … … 58 58 <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet> 59 59 <Prefer32Bit>false</Prefer32Bit> 60 <LangVersion> 7</LangVersion>60 <LangVersion>latest</LangVersion> 61 61 </PropertyGroup> 62 62 <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|x64' "> … … 69 69 <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet> 70 70 <Prefer32Bit>false</Prefer32Bit> 71 <LangVersion> 7</LangVersion>71 <LangVersion>latest</LangVersion> 72 72 </PropertyGroup> 73 73 <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|x64' "> … … 80 80 <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet> 81 81 <Prefer32Bit>false</Prefer32Bit> 82 <LangVersion> 7</LangVersion>82 <LangVersion>latest</LangVersion> 83 83 </PropertyGroup> 84 84 <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|x86' "> … … 91 91 <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet> 92 92 <Prefer32Bit>false</Prefer32Bit> 93 <LangVersion> 7</LangVersion>93 <LangVersion>latest</LangVersion> 94 94 </PropertyGroup> 95 95 <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|x86' "> … … 102 102 <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet> 103 103 <Prefer32Bit>false</Prefer32Bit> 104 <LangVersion> 7</LangVersion>104 <LangVersion>latest</LangVersion> 105 105 </PropertyGroup> 106 106 <ItemGroup> … … 109 109 <HintPath>..\..\bin\ALGLIB-3.7.0.dll</HintPath> 110 110 <Private>False</Private> 111 </Reference>112 <Reference Include="DiffSharp.Merged, Version=0.8.4.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=AMD64">113 <SpecificVersion>False</SpecificVersion>114 <HintPath>..\..\bin\DiffSharp.Merged.dll</HintPath>115 111 </Reference> 116 112 <Reference Include="MathNet.Numerics"> … … 133 129 <Reference Include="System.Data" /> 134 130 <Reference Include="System.Xml" /> 135 <Reference Include="Tensor Flow.NET.Merged, Version=0.15.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">131 <Reference Include="Tensorflow.Binding, Version=0.70.1.0, Culture=neutral, PublicKeyToken=cc7b13ffcd2ddd51, processorArchitecture=AMD64"> 136 132 <SpecificVersion>False</SpecificVersion> 137 <HintPath>..\..\bin\TensorFlow.NET.Merged.dll</HintPath> 133 <HintPath>..\..\bin\Tensorflow.Binding.dll</HintPath> 134 </Reference> 135 <Reference Include="Tensorflow.Keras, Version=0.7.0.0, Culture=neutral, PublicKeyToken=cc7b13ffcd2ddd51, processorArchitecture=AMD64"> 136 <SpecificVersion>False</SpecificVersion> 137 <HintPath>..\..\bin\Tensorflow.Keras.dll</HintPath> 138 138 </Reference> 139 139 </ItemGroup> -
branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/Plugin.cs.frame
r17786 r18239 43 43 [PluginDependency("HeuristicLab.MathNet.Numerics", "4.9.0")] 44 44 [PluginDependency("HeuristicLab.TensorFlowNet", "0.15.0")] 45 [PluginDependency("HeuristicLab.DiffSharp", "0.7.7")]45 //[PluginDependency("HeuristicLab.DiffSharp", "0.7.7")] 46 46 public class HeuristicLabProblemsDataAnalysisSymbolicRegressionPlugin : PluginBase { 47 47 } -
branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/NonlinearLeastSquaresVectorConstantOptimizationEvaluator.cs
r17930 r18239 19 19 */ 20 20 #endregion 21 22 #if INCLUDE_DIFFSHARP 21 23 22 24 using System; … … 194 196 } 195 197 } 198 199 #endif -
branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs
r17721 r18239 25 25 26 26 using System; 27 using System.Collections;28 27 using System.Collections.Generic; 29 28 #if LOG_CONSOLE … … 42 41 using HeuristicLab.Parameters; 43 42 using HEAL.Attic; 44 using NumSharp;45 43 using Tensorflow; 44 using Tensorflow.NumPy; 46 45 using static Tensorflow.Binding; 46 using static Tensorflow.KerasApi; 47 47 using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>; 48 48 … … 54 54 private const string LearningRateName = "LearningRate"; 55 55 56 //private static readonly TF_DataType DataType = tf.float64; 56 57 private static readonly TF_DataType DataType = tf.float32; 57 58 … … 105 106 CancellationToken cancellationToken = default(CancellationToken), IProgress<double> progress = null) { 106 107 107 int numRows = rows.Count(); 108 var variableLengths = problemData.AllowedInputVariables.ToDictionary( 109 var => var, 110 var => { 111 if (problemData.Dataset.VariableHasType<double>(var)) return 1; 112 if (problemData.Dataset.VariableHasType<DoubleVector>(var)) return problemData.Dataset.GetDoubleVectorValue(var, 0).Count; 113 throw new NotSupportedException($"Type of variable {var} is not supported."); 114 }); 115 116 bool success = TreeToTensorConverter.TryConvert(tree, 117 numRows, variableLengths, 108 const bool eager = true; 109 110 bool prepared = TreeToTensorConverter.TryPrepareTree( 111 tree, 112 problemData, rows.ToList(), 118 113 updateVariableWeights, applyLinearScaling, 119 out Tensor prediction,120 out Dictionary< Tensor, string> parameters, out List<Tensor> variables/*, out double[] initialConstants*/);121 122 if (! success)114 eager, 115 out Dictionary<string, Tensor> inputFeatures, out Tensor target, 116 out Dictionary<ISymbolicExpressionTreeNode, ResourceVariable[]> variables); 117 if (!prepared) 123 118 return (ISymbolicExpressionTree)tree.Clone(); 124 119 125 var target = tf.placeholder(DataType, new TensorShape(numRows), name: problemData.TargetVariable); 126 // MSE 127 var cost = tf.reduce_mean(tf.square(target - prediction)); 128 129 var optimizer = tf.train.AdamOptimizer((float)learningRate); 130 //var optimizer = tf.train.GradientDescentOptimizer((float)learningRate); 131 var optimizationOperation = optimizer.minimize(cost); 132 133 #if EXPORT_GRAPH 134 //https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging 135 tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false, 136 clear_devices: true, clear_extraneous_savers: false, strip_default_attrs: true); 137 #endif 138 139 // features as feed items 140 var variablesFeed = new Hashtable(); 141 foreach (var kvp in parameters) { 142 var variable = kvp.Key; 143 var variableName = kvp.Value; 144 if (problemData.Dataset.VariableHasType<double>(variableName)) { 145 var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray(); 146 variablesFeed.Add(variable, np.array(data).reshape(numRows, 1)); 147 } else if (problemData.Dataset.VariableHasType<DoubleVector>(variableName)) { 148 var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).Select(x => x.Select(y => (float)y).ToArray()).ToArray(); 149 variablesFeed.Add(variable, np.array(data)); 150 } else 151 throw new NotSupportedException($"Type of the variable is not supported: {variableName}"); 120 var optimizer = keras.optimizers.Adam((float)learningRate); 121 122 for (int i = 0; i < maxIterations; i++) { 123 if (cancellationToken.IsCancellationRequested) break; 124 125 using var tape = tf.GradientTape(); 126 127 bool success = TreeToTensorConverter.TryEvaluate( 128 tree, 129 inputFeatures, variables, 130 updateVariableWeights, applyLinearScaling, 131 eager, 132 out Tensor prediction); 133 if (!success) 134 return (ISymbolicExpressionTree)tree.Clone(); 135 136 var loss = tf.reduce_mean(tf.square(target - prediction)); 137 138 progress?.Report(loss.ToArray<float>()[0]); 139 140 var variablesList = variables.Values.SelectMany(x => x).ToList(); 141 var gradients = tape.gradient(loss, variablesList); 142 143 optimizer.apply_gradients(zip(gradients, variablesList)); 152 144 } 153 var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray(); 154 variablesFeed.Add(target, np.array(targetData)); 155 156 157 List<NDArray> constants; 158 using (var session = tf.Session()) { 159 160 #if LOG_FILE 161 var directoryName = $"C:\\temp\\TFboard\\logdir\\manual_{DateTime.Now.ToString("yyyyMMddHHmmss")}_{maxIterations}_{learningRate.ToString(CultureInfo.InvariantCulture)}"; 162 Directory.CreateDirectory(directoryName); 163 var costsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Costs.csv"))); 164 var weightsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Weights.csv"))); 165 var gradientsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Gradients.csv"))); 166 #endif 167 168 #if LOG_CONSOLE || LOG_FILE 169 var gradients = optimizer.compute_gradients(cost); 170 #endif 171 172 session.run(tf.global_variables_initializer()); 173 174 progress?.Report(session.run(cost, variablesFeed)[0].GetValue<float>(0)); 175 176 177 #if LOG_CONSOLE 178 Trace.WriteLine("Costs:"); 179 Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}"); 180 181 Trace.WriteLine("Weights:"); 182 foreach (var v in variables) { 183 Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}"); 184 } 185 186 Trace.WriteLine("Gradients:"); 187 foreach (var t in gradients) { 188 Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}"); 189 } 190 #endif 191 192 #if LOG_FILE 193 costsWriter.WriteLine("MSE"); 194 costsWriter.WriteLine(session.run(cost, variablesFeed)[0].GetValue<float>(0).ToString(CultureInfo.InvariantCulture)); 195 196 weightsWriter.WriteLine(string.Join(";", variables.Select(v => v.name))); 197 weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture)))); 198 199 gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => t.Item2.name))); 200 gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture)))); 201 #endif 202 203 for (int i = 0; i < maxIterations; i++) { 204 if (cancellationToken.IsCancellationRequested) 145 146 var cloner = new Cloner(); 147 var newTree = cloner.Clone(tree); 148 var newConstants = variables.ToDictionary( 149 kvp => (ISymbolicExpressionTreeNode)cloner.GetClone(kvp.Key), 150 kvp => kvp.Value.Select(x => (double)(x.numpy().ToArray<float>()[0])).ToArray() 151 ); 152 UpdateConstants(newTree, newConstants); 153 154 155 return newTree; 156 157 158 159 160 161 // //int numRows = rows.Count(); 162 163 164 165 166 167 168 // var variableLengths = problemData.AllowedInputVariables.ToDictionary( 169 // var => var, 170 // var => { 171 // if (problemData.Dataset.VariableHasType<double>(var)) return 1; 172 // if (problemData.Dataset.VariableHasType<DoubleVector>(var)) return problemData.Dataset.GetDoubleVectorValue(var, 0).Count; 173 // throw new NotSupportedException($"Type of variable {var} is not supported."); 174 // }); 175 176 // var variablesDict = problemData.AllowedInputVariables.ToDictionary( 177 // var => var, 178 // var => { 179 // if (problemData.Dataset.VariableHasType<double>(var)) { 180 // var data = problemData.Dataset.GetDoubleValues(var, rows).Select(x => (float)x).ToArray(); 181 // return tf.convert_to_tensor(np.array(data).reshape(new Shape(numRows, 1)), DataType); 182 // } else if (problemData.Dataset.VariableHasType<DoubleVector>(var)) { 183 // var data = problemData.Dataset.GetDoubleVectorValues(var, rows).SelectMany(x => x.Select(y => (float)y)).ToArray(); 184 // return tf.convert_to_tensor(np.array(data).reshape(new Shape(numRows, -1)), DataType); 185 // } else throw new NotSupportedException($"Type of the variable is not supported: {var}"); 186 // } 187 // ); 188 189 // using var tape = tf.GradientTape(persistent: true); 190 191 // bool success = TreeToTensorConverter.TryEvaluateEager(tree, 192 // numRows, variablesDict, 193 // updateVariableWeights, applyLinearScaling, 194 // out Tensor prediction, 195 // out Dictionary<Tensor, string> parameters, out List<ResourceVariable> variables); 196 197 // //bool success = TreeToTensorConverter.TryConvert(tree, 198 // // numRows, variableLengths, 199 // // updateVariableWeights, applyLinearScaling, 200 // // out Tensor prediction, 201 // // out Dictionary<Tensor, string> parameters, out List<Tensor> variables); 202 203 // if (!success) 204 // return (ISymbolicExpressionTree)tree.Clone(); 205 206 // //var target = tf.placeholder(DataType, new Shape(numRows), name: problemData.TargetVariable); 207 // var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray(); 208 // var target = tf.convert_to_tensor(np.array(targetData).reshape(new Shape(numRows)), DataType); 209 // // MSE 210 // var cost = tf.reduce_sum(tf.square(prediction - target)); 211 212 // tape.watch(cost); 213 214 // //var optimizer = tf.train.AdamOptimizer((float)learningRate); 215 // //var optimizer = tf.train.AdamOptimizer(tf.constant(learningRate, DataType)); 216 // //var optimizer = tf.train.GradientDescentOptimizer((float)learningRate); 217 // //var optimizer = tf.train.GradientDescentOptimizer(tf.constant(learningRate, DataType)); 218 // //var optimizer = tf.train.GradientDescentOptimizer((float)learningRate); 219 // //var optimizer = tf.train.AdamOptimizer((float)learningRate); 220 // //var optimizationOperation = optimizer.minimize(cost); 221 // var optimizer = keras.optimizers.Adam((float)learningRate); 222 223 // #if EXPORT_GRAPH 224 // //https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging 225 // tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false, 226 // clear_devices: true, clear_extraneous_savers: false, strip_default_attrs: true); 227 //#endif 228 229 // //// features as feed items 230 // //var variablesFeed = new Hashtable(); 231 // //foreach (var kvp in parameters) { 232 // // var variable = kvp.Key; 233 // // var variableName = kvp.Value; 234 // // if (problemData.Dataset.VariableHasType<double>(variableName)) { 235 // // var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray(); 236 // // variablesFeed.Add(variable, np.array(data).reshape(new Shape(numRows, 1))); 237 // // } else if (problemData.Dataset.VariableHasType<DoubleVector>(variableName)) { 238 // // var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).SelectMany(x => x.Select(y => (float)y)).ToArray(); 239 // // variablesFeed.Add(variable, np.array(data).reshape(new Shape(numRows, -1))); 240 // // } else 241 // // throw new NotSupportedException($"Type of the variable is not supported: {variableName}"); 242 // //} 243 // //var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray(); 244 // //variablesFeed.Add(target, np.array(targetData)); 245 246 247 // List<NDArray> constants; 248 // //using (var session = tf.Session()) { 249 250 //#if LOG_FILE 251 // var directoryName = $"C:\\temp\\TFboard\\logdir\\manual_{DateTime.Now.ToString("yyyyMMddHHmmss")}_{maxIterations}_{learningRate.ToString(CultureInfo.InvariantCulture)}"; 252 // Directory.CreateDirectory(directoryName); 253 // var costsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Costs.csv"))); 254 // var weightsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Weights.csv"))); 255 // var gradientsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Gradients.csv"))); 256 //#endif 257 258 // //session.run(tf.global_variables_initializer()); 259 260 //#if LOG_CONSOLE || LOG_FILE 261 // var gradients = optimizer.compute_gradients(cost); 262 //#endif 263 264 // //var vars = variables.Select(v => session.run(v, variablesFeed)[0].ToArray<float>()[0]).ToList(); 265 // //var gradient = optimizer.compute_gradients(cost) 266 // // .Where(g => g.Item1 != null) 267 // // //.Select(g => session.run(g.Item1, variablesFeed)[0].GetValue<float>(0)). 268 // // .Select(g => session.run(g.Item1, variablesFeed)[0].ToArray<float>()[0]) 269 // // .ToList(); 270 271 // //var gradientPrediction = optimizer.compute_gradients(prediction) 272 // // .Where(g => g.Item1 != null) 273 // // .Select(g => session.run(g.Item1, variablesFeed)[0].ToArray<float>()[0]) 274 // // .ToList(); 275 276 277 // //progress?.Report(session.run(cost, variablesFeed)[0].ToArray<float>()[0]); 278 // progress?.Report(cost.ToArray<float>()[0]); 279 280 281 282 283 284 //#if LOG_CONSOLE 285 // Trace.WriteLine("Costs:"); 286 // Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}"); 287 288 // Trace.WriteLine("Weights:"); 289 // foreach (var v in variables) { 290 // Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}"); 291 // } 292 293 // Trace.WriteLine("Gradients:"); 294 // foreach (var t in gradients) { 295 // Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}"); 296 // } 297 //#endif 298 299 //#if LOG_FILE 300 // costsWriter.WriteLine("MSE"); 301 // costsWriter.WriteLine(session.run(cost, variablesFeed)[0].ToArray<float>()[0].ToString(CultureInfo.InvariantCulture)); 302 303 // weightsWriter.WriteLine(string.Join(";", variables.Select(v => v.name))); 304 // weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).ToArray<float>()[0].ToString(CultureInfo.InvariantCulture)))); 305 306 // gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => t.Item2.Name))); 307 // gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].ToArray<float>()[0].ToString(CultureInfo.InvariantCulture)))); 308 //#endif 309 310 // for (int i = 0; i < maxIterations; i++) { 311 // if (cancellationToken.IsCancellationRequested) 312 // break; 313 314 315 // var gradients = tape.gradient(cost, variables); 316 // //optimizer.apply_gradients(gradients.Zip(variables, Tuple.Create<Tensor, IVariableV1>).ToArray()); 317 // optimizer.apply_gradients(zip(gradients, variables)); 318 319 320 // //session.run(optimizationOperation, variablesFeed); 321 322 // progress?.Report(cost.ToArray<float>()[0]); 323 // //progress?.Report(session.run(cost, variablesFeed)[0].ToArray<float>()[0]); 324 325 //#if LOG_CONSOLE 326 // Trace.WriteLine("Costs:"); 327 // Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}"); 328 329 // Trace.WriteLine("Weights:"); 330 // foreach (var v in variables) { 331 // Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}"); 332 // } 333 334 // Trace.WriteLine("Gradients:"); 335 // foreach (var t in gradients) { 336 // Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}"); 337 // } 338 //#endif 339 340 //#if LOG_FILE 341 // costsWriter.WriteLine(session.run(cost, variablesFeed)[0].ToArray<float>()[0].ToString(CultureInfo.InvariantCulture)); 342 // weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).ToArray<float>()[0].ToString(CultureInfo.InvariantCulture)))); 343 // gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].ToArray<float>()[0].ToString(CultureInfo.InvariantCulture)))); 344 //#endif 345 // } 346 347 //#if LOG_FILE 348 // costsWriter.Close(); 349 // weightsWriter.Close(); 350 // gradientsWriter.Close(); 351 //#endif 352 // //constants = variables.Select(v => session.run(v)).ToList(); 353 // constants = variables.Select(v => v.numpy()).ToList(); 354 // //} 355 356 // if (applyLinearScaling) 357 // constants = constants.Skip(2).ToList(); 358 // var newTree = (ISymbolicExpressionTree)tree.Clone(); 359 // UpdateConstants(newTree, constants, updateVariableWeights); 360 361 // return newTree; 362 } 363 364 private static void UpdateConstants(ISymbolicExpressionTree tree, Dictionary<ISymbolicExpressionTreeNode, double[]> constants) { 365 foreach (var kvp in constants) { 366 var node = kvp.Key; 367 var value = kvp.Value; 368 369 switch (node) { 370 case ConstantTreeNode constantTreeNode: 371 constantTreeNode.Value = value[0]; 205 372 break; 206 207 session.run(optimizationOperation, variablesFeed); 208 209 progress?.Report(session.run(cost, variablesFeed)[0].GetValue<float>(0)); 210 211 #if LOG_CONSOLE 212 Trace.WriteLine("Costs:"); 213 Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}"); 214 215 Trace.WriteLine("Weights:"); 216 foreach (var v in variables) { 217 Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}"); 373 case VariableTreeNodeBase variableTreeNodeBase: 374 variableTreeNodeBase.Weight = value[0]; 375 break; 376 case FactorVariableTreeNode factorVarTreeNode: { 377 for (int i = 0; i < factorVarTreeNode.Weights.Length; i++) { 378 factorVarTreeNode.Weights[i] = value[i]; 379 } 380 break; 218 381 } 219 220 Trace.WriteLine("Gradients:");221 foreach (var t in gradients) {222 Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");223 }224 #endif225 226 #if LOG_FILE227 costsWriter.WriteLine(session.run(cost, variablesFeed)[0].GetValue<float>(0).ToString(CultureInfo.InvariantCulture));228 weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));229 gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));230 #endif231 }232 233 #if LOG_FILE234 costsWriter.Close();235 weightsWriter.Close();236 gradientsWriter.Close();237 #endif238 constants = variables.Select(v => session.run(v)).ToList();239 }240 241 if (applyLinearScaling)242 constants = constants.Skip(2).ToList();243 var newTree = (ISymbolicExpressionTree)tree.Clone();244 UpdateConstants(newTree, constants, updateVariableWeights);245 246 return newTree;247 }248 249 private static void UpdateConstants(ISymbolicExpressionTree tree, IList<NDArray> constants, bool updateVariableWeights) {250 int i = 0;251 foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {252 if (node is ConstantTreeNode constantTreeNode)253 constantTreeNode.Value = constants[i++].GetValue<float>(0, 0);254 else if (node is VariableTreeNodeBase variableTreeNodeBase && updateVariableWeights)255 variableTreeNodeBase.Weight = constants[i++].GetValue<float>(0, 0);256 else if (node is FactorVariableTreeNode factorVarTreeNode && updateVariableWeights) {257 for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)258 factorVarTreeNode.Weights[j] = constants[i++].GetValue<float>(0, 0);259 382 } 260 383 } 261 384 } 262 385 386 //private static void UpdateConstants(ISymbolicExpressionTree tree, IList<NDArray> constants, bool updateVariableWeights) { 387 // int i = 0; 388 // foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) { 389 // if (node is ConstantTreeNode constantTreeNode) { 390 // constantTreeNode.Value = constants[i++].ToArray<float>()[0]; 391 // } else if (node is VariableTreeNodeBase variableTreeNodeBase && updateVariableWeights) { 392 // variableTreeNodeBase.Weight = constants[i++].ToArray<float>()[0]; 393 // } else if (node is FactorVariableTreeNode factorVarTreeNode && updateVariableWeights) { 394 // for (int j = 0; j < factorVarTreeNode.Weights.Length; j++) 395 // factorVarTreeNode.Weights[j] = constants[i++].ToArray<float>()[0]; 396 // } 397 // } 398 //} 399 263 400 public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) { 264 401 return TreeToTensorConverter.IsCompatible(tree);
Note: See TracChangeset
for help on using the changeset viewer.