Class Model<TInputData, TPrediction>
- Namespace
- NeuralNetworks.Models
- Assembly
- NeuralNetworks.dll
Represents an abstract neural network model that processes input data and produces predictions. Provides core functionality for forward and backward passes, training, parameter updates, and checkpointing.
public abstract class Model<TInputData, TPrediction> where TInputData : notnull where TPrediction : notnull
Type Parameters
TInputDataThe type of input data provided to the model. Must not be null.
TPredictionThe type of prediction output produced by the model. Must not be null.
- Inheritance
-
Model<TInputData, TPrediction>
- Derived
- Inherited Members
Remarks
This class serves as a base for implementing neural network models with customizable layers and loss functions. It supports training workflows, including batch training and parameter optimization, and provides mechanisms for saving and restoring model checkpoints. Derived classes must implement the method for constructing the layer list. Thread safety is not guaranteed; concurrent access should be managed externally.
Constructors
Model(LayerListBuilder<TInputData, TPrediction>?, Loss<TPrediction>, SeededRandom?, string?)
protected Model(LayerListBuilder<TInputData, TPrediction>? layerListBuilder, Loss<TPrediction> lossFunction, SeededRandom? random = null, string? modelFilePath = null)
Parameters
layerListBuilderLayerListBuilder<TInputData, TPrediction>lossFunctionLoss<TPrediction>randomSeededRandommodelFilePathstring
Properties
LossFunction
public Loss<TPrediction> LossFunction { get; }
Property Value
- Loss<TPrediction>
Random
protected SeededRandom? Random { get; }
Property Value
Methods
Backward(TPrediction)
public void Backward(TPrediction lossGrad)
Parameters
lossGradTPrediction
Clone()
Makes a deep copy of this neural network.
public Model<TInputData, TPrediction> Clone()
Returns
- Model<TInputData, TPrediction>
Describe(int)
public List<string> Describe(int indentation = 0)
Parameters
indentationint
Returns
Forward(TInputData, bool)
public TPrediction Forward(TInputData input, bool inference)
Parameters
inputTInputDatainferencebool
Returns
- TPrediction
GetParamCount()
public int GetParamCount()
Returns
HasCheckpoint()
public bool HasCheckpoint()
Returns
LoadParams(string, TInputData?)
public void LoadParams(string filePath, TInputData? initializationSample = default)
Parameters
filePathstringinitializationSampleTInputData
RestoreCheckpoint()
public void RestoreCheckpoint()
SaveCheckpoint()
public void SaveCheckpoint()
SaveParams(string, string?)
public void SaveParams(string filePath, string? comment = null)
Parameters
TrainBatch(TInputData, TPrediction)
public float TrainBatch(TInputData xBatch, TPrediction yBatch)
Parameters
xBatchTInputDatayBatchTPrediction
Returns
UpdateParams(Optimizer)
public void UpdateParams(Optimizer optimizer)
Parameters
optimizerOptimizer