Class Trainer<TInputData, TPrediction>
- Namespace
- NeuralNetworks.Trainers
- Assembly
- NeuralNetworks.dll
Represents a trainer for a neural network.
public abstract class Trainer<TInputData, TPrediction> where TInputData : notnull where TPrediction : notnull
Type Parameters
TInputDataTPrediction
- Inheritance
-
Trainer<TInputData, TPrediction>
- Derived
- Inherited Members
Constructors
Trainer(Model<TInputData, TPrediction>, Optimizer, ConsoleOutputMode, SeededRandom?, ILogger<Trainer<TInputData, TPrediction>>?, bool)
Represents a trainer for a neural network.
protected Trainer(Model<TInputData, TPrediction> model, Optimizer optimizer, ConsoleOutputMode consoleOutputMode = ConsoleOutputMode.OnlyOnEval, SeededRandom? random = null, ILogger<Trainer<TInputData, TPrediction>>? logger = null, bool operationBackendTimingEnabled = false)
Parameters
modelModel<TInputData, TPrediction>optimizerOptimizerconsoleOutputModeConsoleOutputModerandomSeededRandomloggerILogger<Trainer<TInputData, TPrediction>>operationBackendTimingEnabledbool
Properties
Memo
Gets or sets the memo associated with the trainer.
public string? Memo { get; set; }
Property Value
Methods
Describe(int)
public List<string> Describe(int indentation)
Parameters
indentationint
Returns
Fit(DataSource<TInputData, TPrediction>, EvalFunction<TInputData, TPrediction>?, int, int, int, int, bool, bool, bool)
Fits the neural network to the provided data source.
public void Fit(DataSource<TInputData, TPrediction> dataSource, EvalFunction<TInputData, TPrediction>? evalFunction = null, int epochs = 100, int evalEveryEpochs = 10, int logEveryEpochs = 1, int batchSize = 32, bool earlyStop = false, bool restart = true, bool displayDescriptionOnStart = true)
Parameters
dataSourceDataSource<TInputData, TPrediction>The data source.
evalFunctionEvalFunction<TInputData, TPrediction>The evaluation function.
epochsintThe number of epochs.
evalEveryEpochsintThe number of epochs between evaluations.
logEveryEpochsintbatchSizeintThe batch size.
earlyStopboolrestartboolA flag indicating whether to restart the training.
displayDescriptionOnStartbool
GenerateBatches(TInputData, TPrediction, int)
Generates batches of input and output matrices.
protected abstract IEnumerable<(TInputData xBatch, TPrediction yBatch)> GenerateBatches(TInputData x, TPrediction y, int batchSize = 32)
Parameters
xTInputDataThe input matrix.
yTPredictionThe output matrix.
batchSizeintThe batch size.
Returns
- IEnumerable<(TInputData xBatch, TPrediction yBatch)>
An enumerable of batches.
GetRows(TInputData)
protected abstract float GetRows(TInputData x)
Parameters
xTInputData
Returns
PermuteData(TInputData, TPrediction, Random)
protected abstract void PermuteData(TInputData x, TPrediction y, Random random)
Parameters
xTInputDatayTPredictionrandomRandom