Utils
eqxvision.utils.load_torch_weights(model: Module, torch_weights: str = None) -> Module
#
Loads weights from a PyTorch serialised file.
Warning
- This method requires installation of the
torchpackage.
Note
- This function assumes that Eqxvision's ordering of class
attributes mirrors the
torchvision.modelsimplementation. - This method assumes the
eqxvisionmodel is not initialised. Problems arise due to initialisedBNmodules. - The saved checkpoint should only contain model parameters as keys.
Info
A full list of pretrained URLs is provided here.
Arguments:
model: Aneqx.Modulefor which thejnp.ndarrayleaves are replaced by correspondingPyTorchweights.torch_weights: A string either pointing toPyTorchweights on disk or the downloadURL.
Returns:
The model with weights loaded from the PyTorch checkpoint.