-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[ROCm][CI][Bugfix] Multi-Modal Model Support Fixes and Attention Backend Improvements #30270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
84c26b1
d4c47f3
9af82a4
f63c60e
e8dc42f
ac960de
5fe66ac
008e755
206b410
ea2111c
9c9f225
8fb6c74
33f9bff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| import torch | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.attention.backends.abstract import AttentionType | ||
| from vllm.attention.backends.registry import AttentionBackendEnum | ||
| from vllm.logger import init_logger | ||
| from vllm.utils.torch_utils import cuda_device_count_stateless | ||
|
|
@@ -200,7 +201,11 @@ def get_vit_attn_backend( | |
| # TODO: Add support for other VL models in their model class. | ||
| return AttentionBackendEnum.ROCM_AITER_FA | ||
|
|
||
| if on_gfx9() and find_spec("flash_attn") is not None: | ||
| if ( | ||
| on_gfx9() | ||
| and find_spec("flash_attn") is not None | ||
| and (dtype == torch.float16 or dtype == torch.bfloat16) | ||
| ): | ||
| return AttentionBackendEnum.FLASH_ATTN | ||
|
|
||
| return AttentionBackendEnum.TORCH_SDPA | ||
|
|
@@ -241,26 +246,34 @@ def get_attn_backend_cls( | |
| ) | ||
| if selected_backend == AttentionBackendEnum.TRITON_MLA: | ||
| if block_size != 1: | ||
| logger.info_once("Using Triton MLA backend.") | ||
| logger.info_once("Using Triton MLA backend on V1 engine.") | ||
hmellor marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return AttentionBackendEnum.TRITON_MLA.get_path() | ||
| raise ValueError( | ||
| f" The selected backend, {selected_backend.name}," | ||
| f"does not support block size {block_size}." | ||
| ) | ||
| if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA: | ||
| logger.info("Using AITER MLA backend.") | ||
| logger.info("Using AITER MLA backend on V1 engine.") | ||
| return AttentionBackendEnum.ROCM_AITER_MLA.get_path() | ||
| if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA: | ||
| logger.info("Using AITER TRITON MLA backend.") | ||
| logger.info("Using AITER TRITON MLA backend on V1 engine.") | ||
| return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path() | ||
|
|
||
| raise ValueError( | ||
| f" The selected backend, {selected_backend.name}," | ||
| f"is not MLA type while requested for MLA backend." | ||
| ) | ||
|
|
||
| attn_backend_override = os.environ.get("VLLM_ATTENTION_BACKEND") | ||
| if selected_backend is None and attn_backend_override is not None: | ||
| logger.info( | ||
| "Detected VLLM_ATTENTION_BACKEND=%s (set by model architecture).", | ||
| attn_backend_override, | ||
| ) | ||
| selected_backend = AttentionBackendEnum[attn_backend_override] | ||
|
|
||
| if selected_backend == AttentionBackendEnum.FLEX_ATTENTION: | ||
| logger.info("Using FlexAttention backend.") | ||
| logger.info("Using FlexAttention backend on V1 engine.") | ||
| return AttentionBackendEnum.FLEX_ATTENTION.get_path() | ||
|
|
||
| if selected_backend == AttentionBackendEnum.TRITON_ATTN: | ||
|
|
@@ -313,6 +326,11 @@ def get_attn_backend_cls( | |
| logger.info("Using Aiter Flash Attention backend on V1 engine.") | ||
| return AttentionBackendEnum.ROCM_AITER_FA.get_path() | ||
|
|
||
| # Priority 5: If model is Encoder-only self-attention type | ||
| if attn_type is not None and attn_type in (AttentionType.ENCODER_ONLY): | ||
| logger.info("Using FlexAttention backend on V1 engine.") | ||
| return AttentionBackendEnum.FLEX_ATTENTION.get_path() | ||
|
Comment on lines
329
to
332
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The new FlexAttention fallback uses Useful? React with 👍 / 👎.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment addressed in 9c9f225 |
||
|
|
||
| # Default: Triton Unified Attention | ||
| logger.info("Using Triton Attention backend on V1 engine.") | ||
| return AttentionBackendEnum.TRITON_ATTN.get_path() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
os.environto control behavior between different parts of the application introduces a global state, which is a potential source of bugs. In a concurrent environment, such as a server handling multiple model loading requests simultaneously, this can lead to race conditions where one request unintentionally affects another. A safer approach would be to pass this configuration through theVllmConfigor a similar mechanism that avoids global state. For instance, you could add anattn_backend_overridetoModelConfigwhich can be set by model-specific code and then read by the attention backend selection logic.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no concurrency that involves this part of code. This behavior is ROCm specific. This is critical for this model for now. Nonetheless, we are going to add the support necessary for the rest of the ROCm specific attention backends and make this mini-patch obsolete in the future.