Class Trainer<TInputData, TPrediction>
- Namespace
- NeuralNetworks.Trainers
- Assembly
- NeuralNetworks.dll
Represents a trainer for a neural network.
public class Trainer<TInputData, TPrediction> where TInputData : notnull where TPrediction : notnull
Type Parameters
TInputDataTPrediction
- Inheritance
-
Trainer<TInputData, TPrediction>
- Inherited Members
Constructors
Trainer(Model<TInputData, TPrediction>, Optimizer, ConsoleOutputMode, SeededRandom?, ILogger?)
Represents a trainer for a neural network.
public Trainer(Model<TInputData, TPrediction> model, Optimizer optimizer, ConsoleOutputMode consoleOutputMode = ConsoleOutputMode.OnlyOnEval, SeededRandom? random = null, ILogger? logger = null)
Parameters
modelModel<TInputData, TPrediction>optimizerOptimizerconsoleOutputModeConsoleOutputModerandomSeededRandomloggerILogger
Properties
Memo
Gets or sets the memo associated with the trainer.
public string? Memo { get; set; }
Property Value
Methods
Describe(int)
public List<string> Describe(int indentation)
Parameters
indentationint
Returns
Fit(DataSource<TInputData, TPrediction>, EvalFunction<TInputData, TPrediction>?, int, int, int, int, bool, bool, bool, bool, bool, bool)
Fits the neural network to the provided data source.
public void Fit(DataSource<TInputData, TPrediction> dataSource, EvalFunction<TInputData, TPrediction>? evalFunction = null, int epochs = 100, int evalEveryEpochs = 10, int logEveryEpochs = 1, int batchSize = 32, bool earlyStop = false, bool restart = true, bool displayDescriptionOnStart = true, bool operationBackendTimingEnabled = false, bool saveParamsOnBestLoss = false, bool showTrainEval = false)
Parameters
dataSourceDataSource<TInputData, TPrediction>The data source.
evalFunctionEvalFunction<TInputData, TPrediction>The evaluation function.
epochsintThe number of epochs.
evalEveryEpochsintThe number of epochs between evaluations.
logEveryEpochsintThe number of epochs between logging.
batchSizeintThe batch size.
earlyStopboolA flag indicating whether to enable early stopping.
restartboolA flag indicating whether to restart the training or continue from the last state.
displayDescriptionOnStartboolA flag indicating whether to display the fit+model description on start.
operationBackendTimingEnabledboolA flag indicating whether to enable operation backend timing. If true, the backend operations timing report will be displayed after training.
saveParamsOnBestLossboolshowTrainEvalbool
GenerateBatches(TInputData, TPrediction, int)
Generates batches of input and output matrices.
protected virtual IEnumerable<(TInputData xBatch, TPrediction yBatch)> GenerateBatches(TInputData x, TPrediction y, int batchSize = 32)
Parameters
xTInputDataThe input matrix.
yTPredictionThe output matrix.
batchSizeintThe batch size.
Returns
- IEnumerable<(TInputData xBatch, TPrediction yBatch)>
An enumerable of batches.
GetRowCount(TInputData)
protected virtual int GetRowCount(TInputData x)
Parameters
xTInputData
Returns
PermuteData(TInputData, TPrediction, Random)
protected virtual void PermuteData(TInputData x, TPrediction y, Random random)
Parameters
xTInputDatayTPredictionrandomRandom