Class Trainer2D
- Namespace
- NeuralNetworks.Trainers
- Assembly
- NeuralNetworks.dll
public class Trainer2D : Trainer<float[,], float[,]>
- Inheritance
-
Trainer2D
- Inherited Members
Constructors
Trainer2D(Model<float[,], float[,]>, Optimizer, SeededRandom?, ILogger<Trainer2D>?, bool)
public Trainer2D(Model<float[,], float[,]> model, Optimizer optimizer, SeededRandom? random, ILogger<Trainer2D>? logger = null, bool operationBackendTimingEnabled = false)
Parameters
modelModel<float[,], float[,]>optimizerOptimizerrandomSeededRandomloggerILogger<Trainer2D>operationBackendTimingEnabledbool
Methods
GenerateBatches(float[,], float[,], int)
Generates batches of input and output matrices.
protected override IEnumerable<(float[,] xBatch, float[,] yBatch)> GenerateBatches(float[,] x, float[,] y, int batchSize)
Parameters
Returns
- IEnumerable<(float[,] xBatch, float[,] yBatch)>
An enumerable of batches.
GetRows(float[,])
protected override float GetRows(float[,] x)
Parameters
xfloat[,]
Returns
PermuteData(float[,], float[,], Random)
protected override void PermuteData(float[,] x, float[,] y, Random random)