DenseNet#
eqxvision.models.DenseNet
#
A simple port of torchvision.models.densenet
__init__(self, growth_rate: int = 32, block_config: Tuple[int, int, int, int] = (6, 12, 24, 16), num_init_features: int = 64, bn_size: int = 4, drop_rate: float = 0, num_classes: int = 1000, *, key: Optional[jax.random.PRNGKey] = None)
#
Arguments:
growth_rate
: Number of filters to add in each layer (k
in paper)block_config
: Number of layers in each pooling blocknum_init_features
- The number of filters to learn in the first convolution layerbn_size
: Multiplicative factor for number of bottle neck layers (i.e. bn_size * k features in the bottleneck layer)drop_rate
: Dropout rate after each dense layernum_classes
: Number of classes in the classification task. Also controls the final output shape(num_classes,)
. Defaults to1000
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Array
#
Arguments:
x
: The input. Should be a JAX array with3
channelskey
: Required parameter. Utilised by few layers such asDropout
orDropPath
eqxvision.models.densenet121(torch_weights: str = None, **kwargs: Any) -> DenseNet
#
Densenet-121 model from Densely Connected Convolutional Networks. The required minimum input size of the model is 29x29.
Arguments:
torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone
eqxvision.models.densenet161(torch_weights: str = None, **kwargs: Any) -> DenseNet
#
Densenet-161 model from Densely Connected Convolutional Networks. The required minimum input size of the model is 29x29.
Arguments:
torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone
eqxvision.models.densenet169(torch_weights: str = None, **kwargs: Any) -> DenseNet
#
Densenet-169 model from Densely Connected Convolutional Networks. The required minimum input size of the model is 29x29.
Arguments:
torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone
eqxvision.models.densenet201(torch_weights: str = None, **kwargs: Any) -> DenseNet
#
Densenet-201 model from Densely Connected Convolutional Networks. The required minimum input size of the model is 29x29.
Arguments:
torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone