Skip to content

Conversation

@AndreasKaratzas
Copy link
Contributor

@AndreasKaratzas AndreasKaratzas commented Dec 8, 2025

This PR addresses several ROCm-specific issues with multi-modal/vision-language models and improves attention backend dispatching for encoder-only self-attention models. It renders green the following test groups on ROCm:

  • Multi-Modal Models Test (Standard)
  • Multi-Modal Models Test (Extended) 1
  • Multi-Modal Models Test (Extended) 2
  • Multi-Modal Models Test (Extended) 3

Key Changes

Attention Backend Selection (vllm/platforms/rocm.py):

  • Added dtype validation (fp16/bf16 only) for Flash Attention backend selection
  • Added automatic FlexAttention fallback for ENCODER_ONLY attention types
  • Implemented VLLM_ATTENTION_BACKEND environment variable override mechanism for model-specific backend selection
  • Improved logging messages with "V1 engine" suffix for clarity

Qwen3-VL Support (vllm/model_executor/models/qwen3_vl.py):

  • Added ROCm-specific attention backend override to use ROCM_AITER_FA on gfx9 architectures
  • Sets VLLM_ATTENTION_BACKEND env var to propagate backend choice to platform layer

Transformers Multimodal Mixin (vllm/model_executor/models/transformers/multimodal.py):

  • Force MATH SDP backend for vision encoder on ROCm to avoid accuracy issues with flash_sdp and mem_efficient_sdp
  • Patches issue #30167

SigLIP2 NaViT (vllm/model_executor/models/siglip2navit.py):

  • Disabled vllm_flash_attn rotary embedding on ROCm

Test Updates:

  • Moved conftest.py up one level to share ROCm SDP workarounds across multimodal tests
  • Use float16 dtype instead of float for CLIP/SigLIP pooling tests on ROCm
  • Added ROCm-specific processor kwargs for Qwen2.5-VL tests
  • Skip MiniMax-VL-01 on ROCm due to GPU OOM and pickle issues with spawn + tp>1
  • Modified CI timeouts for multi-modal test suites for precision

CI Configuration (.buildkite/test-amd.yaml):

  • Extended timeouts for Multi-Modal Models Test suites (Standard: 100min, Extended 2: 120min, Accuracy Eval: 240min)

…s; added automation for Qwen3-VL backend; fixes for groups

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…dels - fixes siglip fp32 on ROCm

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…mers test case

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@AndreasKaratzas
Copy link
Contributor Author

cc @hmellor

@mergify mergify bot added ci/build multi-modality Related to multi-modality (#4194) qwen Related to Qwen models rocm Related to AMD ROCm labels Dec 8, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 329 to 332
# Priority 5: If model is Encoder-only self-attention type
if attn_type is not None and attn_type in (AttentionType.ENCODER_ONLY):
logger.info("Using FlexAttention backend on V1 engine.")
return AttentionBackendEnum.FLEX_ATTENTION.get_path()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Encoder-only fallback condition raises at runtime

The new FlexAttention fallback uses attn_type in (AttentionType.ENCODER_ONLY) without a trailing comma, so the right-hand side is a single AttentionType value rather than an iterable. As soon as attn_type is provided (e.g., for encoder-only models such as CLIP/SigLIP on ROCm), evaluating this condition raises TypeError: argument of type 'AttentionType' is not iterable, preventing backend selection and crashing initialization instead of falling back to FlexAttention.

Useful? React with 👍 / 👎.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment addressed in 9c9f225

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a range of fixes and improvements for multi-modal model support on ROCm. The changes include updates to CI timeouts, model-specific workarounds for attention backends and data types, and enhancements to the attention backend selection logic. While the changes are generally positive and address clear needs, I have identified two significant concerns. First, a critical issue with the use of an environment variable for configuration, which introduces global state and risks race conditions in a concurrent environment. Second, a high-severity maintainability issue arising from code duplication to resolve a circular import. Addressing these points would improve the robustness and long-term maintainability of the codebase.

attn_backend_override = AttentionBackendEnum.ROCM_AITER_FA
# Set also env variable for platform rocm to use ROCM_AITER_FA
# for `selected_backend` in attention backend getter.
os.environ["VLLM_ATTENTION_BACKEND"] = "ROCM_AITER_FA"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using os.environ to control behavior between different parts of the application introduces a global state, which is a potential source of bugs. In a concurrent environment, such as a server handling multiple model loading requests simultaneously, this can lead to race conditions where one request unintentionally affects another. A safer approach would be to pass this configuration through the VllmConfig or a similar mechanism that avoids global state. For instance, you could add an attn_backend_override to ModelConfig which can be set by model-specific code and then read by the attention backend selection logic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no concurrency that involves this part of code. This behavior is ROCm specific. This is critical for this model for now. Nonetheless, we are going to add the support necessary for the rest of the ROCm specific attention backends and make this mini-patch obsolete in the future.

Comment on lines +245 to +258
def getattr_iter(obj: object, names: Iterable[str], default: Any) -> Any:
"""
A helper function that retrieves an attribute from an object which may
have multiple possible names. This is useful when fetching attributes from
``transformers.PretrainedConfig`` instances.
Note:
Duplicated from ``vllm.config.utils`` to avoid circular import
(vllm.transformers_utils.config <-> vllm.config.model).
"""
for name in names:
if hasattr(obj, name):
return getattr(obj, name)
return default
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function getattr_iter is duplicated from vllm.config.utils to avoid a circular import, as noted in the docstring. Code duplication poses a significant maintainability risk, as changes made to one copy may not be propagated to the other, leading to inconsistent behavior and potential bugs. To resolve this, please consider refactoring to break the circular dependency. A common solution is to move the shared utility function to a more fundamental module (e.g., a new file like vllm/utils/config_utils.py) that can be safely imported by both vllm.config.utils and vllm.transformers_utils.config.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a standard function. It serves very specific action and would be highly unlikely to be needed any "maintenance" or modifications at all in the future. Therefore, the duplication only serves the resolution of the circular import bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build multi-modality Related to multi-modality (#4194) qwen Related to Qwen models rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants