Swin Transformer V1#
eqxvision.models.SwinTransformer
#
A simple port of torchvision.models.swin_transformer
.
__init__(self, patch_size: List[int], embed_dim: int, depths: List[int], num_heads: List[int], window_size: List[int], mlp_ratio: float = 4.0, dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.1, num_classes: int = 1000, norm_layer: Callable = None, block: eqx.Module = None, downsample_layer: eqx.Module = None, *, key: Optional[jax.random.PRNGKey] = None)
#
Arguments:
patch_size
: Size of the patch to construct from the input image. Defaults to(4, 4)
embed_dim
: The dimension of the resulting embedding of the patchdepths
: Depth of each Swin Transformer layernum_heads
: Number of attention heads in different layerswindow_size
: Window sizemlp_ratio
: Ratio of mlp hidden dim to embedding dim. Defaults to4.0
dropout
: Dropout rate. Defaults to0.0
attention_dropout
: Attention dropout rate. Defaults to0.0
stochastic_depth_prob
: Stochastic depth rate. Defaults to0.1
num_classes
: Number of classes in the classification task. Also controls the final output shape(num_classes,)
norm_layer
: Normalisation applied to the intermediate outputs. Defaults toLayerNorm2d
block
: The SwinTransformer-v1/v2 block to use. Defaults to_SwinTransformerBlock
which is used inv1
downsample_layer
: Downsample layer (patch merging). Defaults to_PatchMerging
which is used inv1
key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
Note
Currently, input image dimensions are required to be divisible by the window_size
for
each level of depth
. For example, input image shape (3, 224, 224)
works for window_size (7, 7)
and (3, 256, 256)
works for (8, 8)
.
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Array
#
Arguments:
x
: The inputJAX
arraykey
: Required parameter. Utilised by few layers such asDropout
orDropPath
eqxvision.models.swin_t(torch_weights: str = None, **kwargs: Any) -> SwinTransformer
#
Constructs a swin_tiny architecture from Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
Arguments:
torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone
eqxvision.models.swin_s(torch_weights: str = None, **kwargs: Any) -> SwinTransformer
#
Constructs a swin_small architecture from Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
Arguments:
torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone
eqxvision.models.swin_b(torch_weights: str = None, **kwargs: Any) -> SwinTransformer
#
Constructs a swin_base architecture from Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
Arguments:
torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone
eqxvision.models.swin_v2_t(torch_weights: str = None, **kwargs: Any) -> SwinTransformer
#
Constructs a swin_v2_tiny architecture from Swin Transformer V2: Scaling Up Capacity and Resolution.
Arguments:
torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone
eqxvision.models.swin_v2_s(torch_weights: str = None, **kwargs: Any) -> SwinTransformer
#
Constructs a swin_v2_small architecture from Swin Transformer V2: Scaling Up Capacity and Resolution.
Arguments:
torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone
eqxvision.models.swin_v2_b(torch_weights: str = None, **kwargs: Any) -> SwinTransformer
#
Constructs a swin_v2_base architecture from Swin Transformer V2: Scaling Up Capacity and Resolution.
Arguments:
torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone