Table of Contents

Class Trainer<TInputData, TPrediction>

Namespace
NeuralNetworks.Trainers
Assembly
NeuralNetworks.dll

Represents a trainer for a neural network.

public abstract class Trainer<TInputData, TPrediction> where TInputData : notnull where TPrediction : notnull

Type Parameters

TInputData
TPrediction
Inheritance
Trainer<TInputData, TPrediction>
Derived
Inherited Members

Constructors

Trainer(Model<TInputData, TPrediction>, Optimizer, ConsoleOutputMode, SeededRandom?, ILogger<Trainer<TInputData, TPrediction>>?, bool)

Represents a trainer for a neural network.

protected Trainer(Model<TInputData, TPrediction> model, Optimizer optimizer, ConsoleOutputMode consoleOutputMode = ConsoleOutputMode.OnlyOnEval, SeededRandom? random = null, ILogger<Trainer<TInputData, TPrediction>>? logger = null, bool operationBackendTimingEnabled = false)

Parameters

model Model<TInputData, TPrediction>
optimizer Optimizer
consoleOutputMode ConsoleOutputMode
random SeededRandom
logger ILogger<Trainer<TInputData, TPrediction>>
operationBackendTimingEnabled bool

Properties

Memo

Gets or sets the memo associated with the trainer.

public string? Memo { get; set; }

Property Value

string

Methods

Describe(int)

public List<string> Describe(int indentation)

Parameters

indentation int

Returns

List<string>

Fit(DataSource<TInputData, TPrediction>, EvalFunction<TInputData, TPrediction>?, int, int, int, int, bool, bool, bool)

Fits the neural network to the provided data source.

public void Fit(DataSource<TInputData, TPrediction> dataSource, EvalFunction<TInputData, TPrediction>? evalFunction = null, int epochs = 100, int evalEveryEpochs = 10, int logEveryEpochs = 1, int batchSize = 32, bool earlyStop = false, bool restart = true, bool displayDescriptionOnStart = true)

Parameters

dataSource DataSource<TInputData, TPrediction>

The data source.

evalFunction EvalFunction<TInputData, TPrediction>

The evaluation function.

epochs int

The number of epochs.

evalEveryEpochs int

The number of epochs between evaluations.

logEveryEpochs int
batchSize int

The batch size.

earlyStop bool
restart bool

A flag indicating whether to restart the training.

displayDescriptionOnStart bool

GenerateBatches(TInputData, TPrediction, int)

Generates batches of input and output matrices.

protected abstract IEnumerable<(TInputData xBatch, TPrediction yBatch)> GenerateBatches(TInputData x, TPrediction y, int batchSize = 32)

Parameters

x TInputData

The input matrix.

y TPrediction

The output matrix.

batchSize int

The batch size.

Returns

IEnumerable<(TInputData xBatch, TPrediction yBatch)>

An enumerable of batches.

GetRows(TInputData)

protected abstract float GetRows(TInputData x)

Parameters

x TInputData

Returns

float

PermuteData(TInputData, TPrediction, Random)

protected abstract void PermuteData(TInputData x, TPrediction y, Random random)

Parameters

x TInputData
y TPrediction
random Random