Skip to content

vllm.attention.ops.vit_attn_wrappers

This file contains ops for ViT attention to be compatible with torch.compile as there are operations here not supported by torch.compile (for instance, to_list in xformers attn, or .item() in flash attention)

Using these ops and wrapping vision blocks with torch.compile can speed up throughput in vision models by ~5% relative on H100, and improve token latencies by ~7% (see qwen2_5_vl for example usage)

To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)

flash_attn_maxseqlen_wrapper

flash_attn_maxseqlen_wrapper(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    cu_seqlens: Tensor,
    max_seqlen: Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def flash_attn_maxseqlen_wrapper(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_seqlen: torch.Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> torch.Tensor:
    if is_rocm_aiter:
        from aiter import flash_attn_varlen_func
    else:
        if use_upstream_fa:
            from flash_attn import flash_attn_varlen_func
        else:
            from vllm.vllm_flash_attn import flash_attn_varlen_func
    q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
    output = flash_attn_varlen_func(
        q,
        k,
        v,
        cu_seqlens_q=cu_seqlens,
        cu_seqlens_k=cu_seqlens,
        max_seqlen_q=max_seqlen.item(),
        max_seqlen_k=max_seqlen.item(),
        dropout_p=0.0,
        causal=False,
    )
    context_layer = einops.rearrange(
        output, "(b s) h d -> s b (h d)", b=batch_size
    ).contiguous()
    return context_layer

flash_attn_maxseqlen_wrapper_fake

flash_attn_maxseqlen_wrapper_fake(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    cu_seqlens: Tensor,
    max_seqlen: Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def flash_attn_maxseqlen_wrapper_fake(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_seqlen: torch.Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> torch.Tensor:
    b, s, h, d = q.shape
    return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)

vit_flash_attn_wrapper

vit_flash_attn_wrapper(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    cu_seqlens: Tensor,
    max_seqlen: Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def vit_flash_attn_wrapper(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_seqlen: torch.Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> torch.Tensor:
    return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
        q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
    )

vit_xformers_attn_wrapper

vit_xformers_attn_wrapper(
    q: Tensor, k: Tensor, v: Tensor, seqlens: Tensor
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def vit_xformers_attn_wrapper(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
    return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)

xformers_attn_seqlens_wrapper

xformers_attn_seqlens_wrapper(
    q: Tensor, k: Tensor, v: Tensor, seqlens: Tensor
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def xformers_attn_seqlens_wrapper(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalMask

    attn_bias = BlockDiagonalMask.from_seqlens(
        q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
    )
    context_layer = xops.memory_efficient_attention_forward(
        q, k, v, attn_bias=attn_bias, p=0, scale=None
    )
    context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
    return context_layer

xformers_attn_seqlens_wrapper_fake

xformers_attn_seqlens_wrapper_fake(
    q: Tensor, k: Tensor, v: Tensor, seqlens: Tensor
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def xformers_attn_seqlens_wrapper_fake(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
    b, s, h, d = q.shape
    return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)