rannet.utils

Module Contents

Functions

triangular_causal_mask(→ langml.tensor_typing.Tensors)

Generate triangular causal mask

prefix_causal_mask(→ langml.tensor_typing.Tensors)

Generate prefix causal mask

standard_normalize(→ langml.tensor_typing.Tensors)

mean(→ langml.tensor_typing.Tensors)

rannet.utils.triangular_causal_mask(seq_len: int | langml.tensor_typing.Tensors) langml.tensor_typing.Tensors

Generate triangular causal mask :param seq_len: sequence len

Examples

for seq_len = 3, the mask is: array([[1., 0., 0.],

[1., 1., 0.], [1., 1., 1.]], dtype=float32)

rannet.utils.prefix_causal_mask(segment: langml.tensor_typing.Tensors) langml.tensor_typing.Tensors

Generate prefix causal mask :param segment: segment ids

Examples

for segment [[0, 0, 0, 1, 1]], the mask is; array([[[1., 1., 1., 0., 0.],

[1., 1., 1., 0., 0.], [1., 1., 1., 0., 0.], [1., 1., 1., 1., 0.], [1., 1., 1., 1., 1.]]], dtype=float32)

rannet.utils.standard_normalize(x: langml.tensor_typing.Tensors, epsilon: float = 1e-07) langml.tensor_typing.Tensors
rannet.utils.mean(x: langml.tensor_typing.Tensors, mask: langml.tensor_typing.Tensors | None = None, axis: float = -1, keepdims: bool = False) langml.tensor_typing.Tensors