Class Model<TInputData, TPrediction>
- Namespace
- NeuralNetworks.Models
- Assembly
- NeuralNetworks.dll
Represents an abstract neural network model that processes input data and produces predictions. Provides core functionality for forward and backward passes, training, parameter updates, and checkpointing.
public abstract class Model<TInputData, TPrediction> where TInputData : notnull where TPrediction : notnull
Type Parameters
TInputDataThe type of input data provided to the model. Must not be null.
TPredictionThe type of prediction output produced by the model. Must not be null.
- Inheritance
-
Model<TInputData, TPrediction>
- Derived
- Inherited Members
Remarks
This class serves as a base for implementing neural network models with customizable layers and loss functions. It supports training workflows, including batch training and parameter optimization, and provides mechanisms for saving and restoring model checkpoints. Derived classes must implement the method for constructing the layer list. Thread safety is not guaranteed; concurrent access should be managed externally.
Constructors
Model(LayerListBuilder<TInputData, TPrediction>?, Loss<TPrediction>, SeededRandom?)
protected Model(LayerListBuilder<TInputData, TPrediction>? layerListBuilder, Loss<TPrediction> lossFunction, SeededRandom? random)
Parameters
layerListBuilderLayerListBuilder<TInputData, TPrediction>lossFunctionLoss<TPrediction>randomSeededRandom
Properties
Layers
public IReadOnlyList<Layer> Layers { get; }
Property Value
LossFunction
public Loss<TPrediction> LossFunction { get; }
Property Value
- Loss<TPrediction>
Random
protected SeededRandom? Random { get; }
Property Value
Methods
Backward(TPrediction)
public void Backward(TPrediction lossGrad)
Parameters
lossGradTPrediction
Clone()
Makes a deep copy of this neural network.
public Model<TInputData, TPrediction> Clone()
Returns
- Model<TInputData, TPrediction>
CreateLayerListBuilderInternal()
protected abstract LayerListBuilder<TInputData, TPrediction> CreateLayerListBuilderInternal()
Returns
- LayerListBuilder<TInputData, TPrediction>
Describe(int)
public List<string> Describe(int indentation)
Parameters
indentationint
Returns
Forward(TInputData, bool)
public TPrediction Forward(TInputData input, bool inference)
Parameters
inputTInputDatainferencebool
Returns
- TPrediction
GetParamCount()
public int GetParamCount()
Returns
HasCheckpoint()
public bool HasCheckpoint()
Returns
RestoreCheckpoint()
public void RestoreCheckpoint()
SaveCheckpoint()
public void SaveCheckpoint()
TrainBatch(TInputData, TPrediction)
public float TrainBatch(TInputData xBatch, TPrediction yBatch)
Parameters
xBatchTInputDatayBatchTPrediction
Returns
UpdateParams(Optimizer)
public void UpdateParams(Optimizer optimizer)
Parameters
optimizerOptimizer