-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
Support TP which is not divded for NVFP4 kernels (flashinfer-cutlass) by adding dynamic padding #30260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Daniel Afrimi <dafrimi@nvidia.com>
Signed-off-by: root <root@gpu-986.slurm-workers-slurm.slurm.svc.cluster.local>
Signed-off-by: root <root@gpu-986.slurm-workers-slurm.slurm.svc.cluster.local>
Signed-off-by: root <root@gpu-986.slurm-workers-slurm.slurm.svc.cluster.local> Signed-off-by: <>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if self.backend.startswith("flashinfer-cutlass"): | ||
| backend_name = self.backend[len("flashinfer-") :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FlashInfer TRTLLM/CUDNN backend crashes in apply()
The apply path now only dispatches when self.backend.startswith("flashinfer-cutlass"), so a configured NVFP4 backend of flashinfer-trtllm or flashinfer-cudnn (both advertised in VLLM_NVFP4_GEMM_BACKEND) will fall through to the cutlass assert and raise immediately. Previously the generic flashinfer- branch covered those options; after this change those backends can no longer run inference.
Useful? React with 👍 / 👎.
| if pad_rows > 0: | ||
| weight = torch.nn.functional.pad( | ||
| weight, (0, 0, 0, pad_rows) | ||
| ).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cutlass path returns padded outputs for non-128-aligned rows
Weights are now explicitly padded with extra rows to match the swizzled scale shape (e.g., when the output dimension is not a multiple of 128), but the cutlass execution path still derives output_shape from layer.weight.shape and never slices the result back to the logical output size. For cutlass backends with unaligned output widths this returns extra columns (and will also make bias addition fail to broadcast), breaking the layer’s expected output shape.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces dynamic padding to support Tensor Parallelism for NVFP4 kernels, specifically targeting flashinfer-cutlass. The changes correctly address NaN propagation issues by initializing scale memory with zeros and padding weights and activations to meet hardware alignment requirements. However, I've identified a critical regression that breaks support for other flashinfer backends and a high-severity issue where the cutlass backend may face shape mismatches due to inconsistent padding. My review includes suggestions to resolve these issues.
| if self.backend.startswith("flashinfer-cutlass"): | ||
| backend_name = self.backend[len("flashinfer-") :] | ||
|
|
||
| # Match packed-K bytes between activations and weights using | ||
| # pre-calculated padding | ||
| pad_k_bytes = getattr(layer, "execution_padding_k_bytes", 0) | ||
| output_shape = [x.shape[0], layer.output_size_per_partition] | ||
| x_fp4 = torch.nn.functional.pad(x_fp4, (0, pad_k_bytes)).contiguous() | ||
|
|
||
| # If we pad x_fp4 so we maybe need to add pad x_block scale as well | ||
| original_scale_blocks = x_blockscale.shape[1] | ||
| k_elements = ( | ||
| x_fp4.shape[1] * 2 | ||
| ) # 2 fp4 items are packed in the input dimension | ||
| required_scale_blocks = k_elements // layer.quant_config.group_size | ||
|
|
||
| # Align to 4 to be safe with int32 packing assumptions in kernels | ||
| if required_scale_blocks % 4 != 0: | ||
| required_scale_blocks += 4 - (required_scale_blocks % 4) | ||
|
|
||
| if original_scale_blocks < required_scale_blocks: | ||
| pad_scales = required_scale_blocks - original_scale_blocks | ||
| x_blockscale = torch.nn.functional.pad( | ||
| x_blockscale, (0, pad_scales), value=0.0 | ||
| ).contiguous() | ||
|
|
||
| mm_args = ( | ||
| x_fp4, | ||
| layer.weight, | ||
| x_blockscale, | ||
| layer.weight_scale, | ||
| layer.alpha, | ||
| output_dtype, | ||
| ) | ||
| out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) | ||
|
|
||
| # Slice output to remove padding if weight was padded in N dimension | ||
| if out.shape[1] != output_shape[1]: | ||
| out = out[:, : output_shape[1]].contiguous() | ||
| else: | ||
| assert self.backend == "cutlass" | ||
| mm_args = ( | ||
| x_fp4, | ||
| layer.weight, | ||
| x_blockscale, | ||
| layer.weight_scale, | ||
| layer.alpha, | ||
| output_dtype, | ||
| ) | ||
| out = cutlass_scaled_fp4_mm(*mm_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change introduces a regression for flashinfer- backends other than flashinfer-cutlass (e.g., flashinfer-trtllm). The original if self.backend.startswith("flashinfer-") check was replaced with a more specific startswith("flashinfer-cutlass"), and the else block now asserts self.backend == "cutlass". This will cause an assertion failure for other flashinfer- backends. The previous structure should be restored, with the new padding logic nested inside a check for flashinfer-cutlass to maintain support for other backends.
if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :]
_x_fp4, _x_blockscale = x_fp4, x_blockscale
_output_shape = output_shape
if self.backend == "flashinfer-cutlass":
# Match packed-K bytes between activations and weights using
# pre-calculated padding
pad_k_bytes = getattr(layer, "execution_padding_k_bytes", 0)
_output_shape = [x.shape[0], layer.output_size_per_partition]
_x_fp4 = torch.nn.functional.pad(x_fp4, (0, pad_k_bytes)).contiguous()
# If we pad x_fp4 so we maybe need to add pad x_block scale as well
original_scale_blocks = x_blockscale.shape[1]
k_elements = (
_x_fp4.shape[1] * 2
) # 2 fp4 items are packed in the input dimension
required_scale_blocks = k_elements // layer.quant_config.group_size
# Align to 4 to be safe with int32 packing assumptions in kernels
if required_scale_blocks % 4 != 0:
required_scale_blocks += 4 - (required_scale_blocks % 4)
if original_scale_blocks < required_scale_blocks:
pad_scales = required_scale_blocks - original_scale_blocks
_x_blockscale = torch.nn.functional.pad(
x_blockscale, (0, pad_scales), value=0.0
).contiguous()
mm_args = (
_x_fp4,
layer.weight,
_x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
)
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
if self.backend == "flashinfer-cutlass":
# Slice output to remove padding if weight was padded in N dimension
if out.shape[1] != _output_shape[1]:
out = out[:, : _output_shape[1]].contiguous()
else:
assert self.backend == "cutlass"
mm_args = (
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
)
out = cutlass_scaled_fp4_mm(*mm_args)| # Swizzle block scales and then pad the packed NVFP4 weights so that | ||
| # both N (rows) and K (columns) satisfy the alignment constraints | ||
| # implied by the FlashInfer / Cutlass kernels. | ||
| swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) | ||
| layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) | ||
| layer.weight = Parameter(layer.weight.data, requires_grad=False) | ||
| weight = layer.weight.data | ||
|
|
||
| weight_current_rows = weight.shape[0] | ||
| weight_scale_rows_padded = swizzled_weight_scale.shape[0] | ||
|
|
||
| # Pad weight to match swizzled scale dimensions | ||
| if weight_scale_rows_padded != weight_current_rows: | ||
| pad_rows = weight_scale_rows_padded - weight_current_rows | ||
| assert pad_rows >= 0, ( | ||
| f"Weight scale rows ({weight_scale_rows_padded}) < " | ||
| f"weight rows ({weight_current_rows})." | ||
| ) | ||
| if pad_rows > 0: | ||
| weight = torch.nn.functional.pad( | ||
| weight, (0, 0, 0, pad_rows) | ||
| ).contiguous() | ||
|
|
||
| # Calculate the number of k blocks padded to satisfy alignment | ||
| # constraints for the weight | ||
| layer.execution_padding_k_bytes = 0 | ||
| group_size = self.quant_config.group_size | ||
| num_k_blocks_padded = swizzled_weight_scale.shape[1] | ||
| k_bytes_padded = (num_k_blocks_padded * group_size) // 2 | ||
| k_bytes_orig = weight.shape[1] | ||
|
|
||
| logger.info( | ||
| "[FP4 Weight Prep] K-dim alignment: original_bytes=%d, " | ||
| "padded_bytes=%d, group_size=%d", | ||
| k_bytes_orig, | ||
| k_bytes_padded, | ||
| group_size, | ||
| ) | ||
|
|
||
| if k_bytes_padded != k_bytes_orig: | ||
| pad_bytes = k_bytes_padded - k_bytes_orig | ||
| assert pad_bytes > 0 | ||
| logger.info( | ||
| "[FP4 Weight Prep] Padding K-dim: %d -> %d (pad_bytes=%d)", | ||
| weight.shape[1], | ||
| weight.shape[1] + pad_bytes, | ||
| pad_bytes, | ||
| ) | ||
| weight = torch.nn.functional.pad( | ||
| weight, (0, pad_bytes, 0, 0) | ||
| ).contiguous() | ||
| layer.execution_padding_k_bytes = pad_bytes | ||
|
|
||
| layer.weight = Parameter(weight, requires_grad=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic block pads weights for both cutlass and flashinfer-cutlass backends, which is correct. However, the corresponding activation padding in the apply method is only performed for flashinfer-cutlass. This inconsistency will likely lead to a shape mismatch error during matrix multiplication for the cutlass backend, as it will receive padded weights but unpadded activations. To ensure correctness, the activation padding logic should be applied for both backends if the weight padding is intended for both.
The Issues
Unaligned TP Shapes: In TP4, splitting some models resulted in input/weight dimensions (e.g., 464) that were not aligned with the required block size (16) or hardware tile constraints (some kernels), causing mismatched shapes between inputs, weights, and their scales.
Uninitialized Scale Memory: The quantization kernel (
scaled_fp4_quant) allocates output scales usingtorch.empty. Because hardware alignment requires allocating "rounded up" sizes (e.g., to the nearest multiple of 128 blocks), the "padding" area of the tensor contained garbage (NaNs) from uninitialized memory.Kernel OOB Reads: When we padded the input/weights to fix the alignment issues (e.g., 464 -> 480), the matrix multiplication kernel correctly processed the padded data (zeros). However, it still attempted to read the corresponding scales for those padded blocks. Since those scales resided in the uninitialized/garbage memory region, this propagated
NaNs into the final result, causing model crashes.The Fix
Weight Preprocessing (
modelopt.py)Activation Padding (
modelopt.py)apply()to padxbased on the weights andx_blockscalewith0.0if the padded inputxextends beyond the valid scale blocks. This ensures that the kernel has valid scale values to read for the padded input regions.Scale Initialization (
_custom_ops.py)torch.emptytotorch.zeros.torch.zerosis slightly slower thantorch.empty, but it is strictly necessary here. It ensures that "gap" bytes (allocated to satisfy hardware alignment requirements but untouched by the quantizer) are initialized to0.0instead of potentially containing garbage/NaNs. This prevents crashes when downstream kernels (like MatMul) access these padded regions during vectorized loads.Example (TP4 Scenario)
[..., 464](29 blocks).Scale[29]for the padded input. Since this memory was uninitialized (viatorch.emptyor lack of padding), it readsNaN(Garbage), causing the entire output calculation to becomeNaN.Scale[29], which is now explicitly initialized/padded to0.0. The computation0 (input) * 0.0 (scale)results in0, keeping the output valid.