Table of Contents

Class Trainer<TInputData, TPrediction>

Namespace
NeuralNetworks.Trainers
Assembly
NeuralNetworks.dll

Represents a trainer for a neural network.

public class Trainer<TInputData, TPrediction> where TInputData : notnull where TPrediction : notnull

Type Parameters

TInputData
TPrediction
Inheritance
Trainer<TInputData, TPrediction>
Inherited Members

Constructors

Trainer(Model<TInputData, TPrediction>, Optimizer, ConsoleOutputMode, SeededRandom?, ILogger?)

Represents a trainer for a neural network.

public Trainer(Model<TInputData, TPrediction> model, Optimizer optimizer, ConsoleOutputMode consoleOutputMode = ConsoleOutputMode.OnlyOnEval, SeededRandom? random = null, ILogger? logger = null)

Parameters

model Model<TInputData, TPrediction>
optimizer Optimizer
consoleOutputMode ConsoleOutputMode
random SeededRandom
logger ILogger

Properties

Memo

Gets or sets the memo associated with the trainer.

public string? Memo { get; set; }

Property Value

string

Methods

Describe(int)

public List<string> Describe(int indentation)

Parameters

indentation int

Returns

List<string>

Fit(DataSource<TInputData, TPrediction>, EvalFunction<TInputData, TPrediction>?, int, int, int, int, bool, bool, bool, bool, bool, bool)

Fits the neural network to the provided data source.

public void Fit(DataSource<TInputData, TPrediction> dataSource, EvalFunction<TInputData, TPrediction>? evalFunction = null, int epochs = 100, int evalEveryEpochs = 10, int logEveryEpochs = 1, int batchSize = 32, bool earlyStop = false, bool restart = true, bool displayDescriptionOnStart = true, bool operationBackendTimingEnabled = false, bool saveParamsOnBestLoss = false, bool showTrainEval = false)

Parameters

dataSource DataSource<TInputData, TPrediction>

The data source.

evalFunction EvalFunction<TInputData, TPrediction>

The evaluation function.

epochs int

The number of epochs.

evalEveryEpochs int

The number of epochs between evaluations.

logEveryEpochs int

The number of epochs between logging.

batchSize int

The batch size.

earlyStop bool

A flag indicating whether to enable early stopping.

restart bool

A flag indicating whether to restart the training or continue from the last state.

displayDescriptionOnStart bool

A flag indicating whether to display the fit+model description on start.

operationBackendTimingEnabled bool

A flag indicating whether to enable operation backend timing. If true, the backend operations timing report will be displayed after training.

saveParamsOnBestLoss bool
showTrainEval bool

GenerateBatches(TInputData, TPrediction, int)

Generates batches of input and output matrices.

protected virtual IEnumerable<(TInputData xBatch, TPrediction yBatch)> GenerateBatches(TInputData x, TPrediction y, int batchSize = 32)

Parameters

x TInputData

The input matrix.

y TPrediction

The output matrix.

batchSize int

The batch size.

Returns

IEnumerable<(TInputData xBatch, TPrediction yBatch)>

An enumerable of batches.

GetRowCount(TInputData)

protected virtual int GetRowCount(TInputData x)

Parameters

x TInputData

Returns

int

PermuteData(TInputData, TPrediction, Random)

protected virtual void PermuteData(TInputData x, TPrediction y, Random random)

Parameters

x TInputData
y TPrediction
random Random