Skip to content

Utils

eqxvision.utils.load_torch_weights(model: Module, torch_weights: str = None) -> Module #

Loads weights from a PyTorch serialised file.

Warning
  • This method requires installation of the torch package.

Note

  • This function assumes that Eqxvision's ordering of class attributes mirrors the torchvision.models implementation.
  • This method assumes the eqxvision model is not initialised. Problems arise due to initialised BN modules.
  • The saved checkpoint should only contain model parameters as keys.

Info

A full list of pretrained URLs is provided here.

Arguments:

  • model: An eqx.Module for which the jnp.ndarray leaves are replaced by corresponding PyTorch weights.
  • torch_weights: A string either pointing to PyTorch weights on disk or the download URL.

Returns: The model with weights loaded from the PyTorch checkpoint.

eqxvision.utils.CLASSIFICATION_URLS #

eqxvision.utils.SEGMENTATION_URLS #