Drop-Path#
eqxvision.layers.DropPath
#
Effectively dropping a sample from the call.
Often used inside a network along side a residual connection.
Equivalent to torchvision.stochastic_depth.
__init__(self, p: float = 0.0, inference: bool = False, mode = 'global')
#
Arguments:
p: The probability to drop a sample entirely during forward passinference: Defaults toFalse. IfTrue, then the input is returned unchanged This may be toggled withequinox.tree_inferencemode: Can be set toglobalorlocal. Ifglobal, the whole input is dropped or retained. Iflocal, then the decision on each input unit is computed independently. Defaults toglobal
Note
For mode = local, an input (channels, dim_0, dim_1, ...) is reshaped and transposed to
(channels, dims).transpose(). For each dim x channels element,
the decision to drop/keep is made independently.
__call__(self, x, *, key: jax.random.PRNGKey) -> Array
#
Arguments:
x: An any-dimensional JAX array to dropkey: Ajax.random.PRNGKeyused to provide randomness for calculating which elements to dropout. (Keyword only argument.)