Utils
eqxvision.utils.load_torch_weights(model: Module, torch_weights: str = None) -> Module
#
Loads weights from a PyTorch serialised file.
Warning
- This method requires installation of the
torch
package.
Note
- This function assumes that Eqxvision's ordering of class
attributes mirrors the
torchvision.models
implementation. - This method assumes the
eqxvision
model is not initialised. Problems arise due to initialisedBN
modules. - The saved checkpoint should only contain model parameters as keys.
Info
A full list of pretrained URLs is provided here.
Arguments:
model
: Aneqx.Module
for which thejnp.ndarray
leaves are replaced by correspondingPyTorch
weights.torch_weights
: A string either pointing toPyTorch
weights on disk or the downloadURL
.
Returns:
The model with weights loaded from the PyTorch
checkpoint.