Skip to content

Conversation

@danielafrimi
Copy link
Contributor

@danielafrimi danielafrimi commented Dec 8, 2025

The Issues

  1. Unaligned TP Shapes: In TP4, splitting some models resulted in input/weight dimensions (e.g., 464) that were not aligned with the required block size (16) or hardware tile constraints (some kernels), causing mismatched shapes between inputs, weights, and their scales.

  2. Uninitialized Scale Memory: The quantization kernel (scaled_fp4_quant) allocates output scales using torch.empty. Because hardware alignment requires allocating "rounded up" sizes (e.g., to the nearest multiple of 128 blocks), the "padding" area of the tensor contained garbage (NaNs) from uninitialized memory.

  3. Kernel OOB Reads: When we padded the input/weights to fix the alignment issues (e.g., 464 -> 480), the matrix multiplication kernel correctly processed the padded data (zeros). However, it still attempted to read the corresponding scales for those padded blocks. Since those scales resided in the uninitialized/garbage memory region, this propagated NaNs into the final result, causing model crashes.

The Fix

Weight Preprocessing (modelopt.py)

  • Added logic to detect N/K-dimension misalignment.
  • Padded weights with zeros to match the swizzled scale dimensions, ensuring that the weight tensor shape is compatible with the block scales required by the kernel.

Activation Padding (modelopt.py)

  • Added logic in apply() to pad x based on the weights and x_blockscale with 0.0 if the padded input x extends beyond the valid scale blocks. This ensures that the kernel has valid scale values to read for the padded input regions.

Scale Initialization (_custom_ops.py)

  • Changed scale allocation from torch.empty to torch.zeros.
  • Trade-off: torch.zeros is slightly slower than torch.empty, but it is strictly necessary here. It ensures that "gap" bytes (allocated to satisfy hardware alignment requirements but untouched by the quantizer) are initialized to 0.0 instead of potentially containing garbage/NaNs. This prevents crashes when downstream kernels (like MatMul) access these padded regions during vectorized loads.

Example (TP4 Scenario)

  • Original Input: [..., 464] (29 blocks).
  • Required Alignment: 480 (30 blocks).
  • Padding: We pad Input to 480 to match the alignment.
  • Before Fix: The kernel attempts to read Scale[29] for the padded input. Since this memory was uninitialized (via torch.empty or lack of padding), it reads NaN (Garbage), causing the entire output calculation to become NaN.
  • After Fix: The kernel reads Scale[29], which is now explicitly initialized/padded to 0.0. The computation 0 (input) * 0.0 (scale) results in 0, keeping the output valid.

danielafrimi and others added 6 commits December 7, 2025 06:28
Signed-off-by: Daniel Afrimi <dafrimi@nvidia.com>
Signed-off-by: Daniel Afrimi <dafrimi@nvidia.com>
Signed-off-by: Daniel Afrimi <dafrimi@nvidia.com>
Signed-off-by: root <root@gpu-986.slurm-workers-slurm.slurm.svc.cluster.local>
Signed-off-by: root <root@gpu-986.slurm-workers-slurm.slurm.svc.cluster.local>
Signed-off-by: root <root@gpu-986.slurm-workers-slurm.slurm.svc.cluster.local>

Signed-off-by:  <>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +1146 to 1147
if self.backend.startswith("flashinfer-cutlass"):
backend_name = self.backend[len("flashinfer-") :]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge FlashInfer TRTLLM/CUDNN backend crashes in apply()

The apply path now only dispatches when self.backend.startswith("flashinfer-cutlass"), so a configured NVFP4 backend of flashinfer-trtllm or flashinfer-cudnn (both advertised in VLLM_NVFP4_GEMM_BACKEND) will fall through to the cutlass assert and raise immediately. Previously the generic flashinfer- branch covered those options; after this change those backends can no longer run inference.

Useful? React with 👍 / 👎.

Comment on lines +1075 to +1078
if pad_rows > 0:
weight = torch.nn.functional.pad(
weight, (0, 0, 0, pad_rows)
).contiguous()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Cutlass path returns padded outputs for non-128-aligned rows

Weights are now explicitly padded with extra rows to match the swizzled scale shape (e.g., when the output dimension is not a multiple of 128), but the cutlass execution path still derives output_shape from layer.weight.shape and never slices the result back to the logical output size. For cutlass backends with unaligned output widths this returns extra columns (and will also make bias addition fail to broadcast), breaking the layer’s expected output shape.

Useful? React with 👍 / 👎.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces dynamic padding to support Tensor Parallelism for NVFP4 kernels, specifically targeting flashinfer-cutlass. The changes correctly address NaN propagation issues by initializing scale memory with zeros and padding weights and activations to meet hardware alignment requirements. However, I've identified a critical regression that breaks support for other flashinfer backends and a high-severity issue where the cutlass backend may face shape mismatches due to inconsistent padding. My review includes suggestions to resolve these issues.

Comment on lines +1146 to 1195
if self.backend.startswith("flashinfer-cutlass"):
backend_name = self.backend[len("flashinfer-") :]

# Match packed-K bytes between activations and weights using
# pre-calculated padding
pad_k_bytes = getattr(layer, "execution_padding_k_bytes", 0)
output_shape = [x.shape[0], layer.output_size_per_partition]
x_fp4 = torch.nn.functional.pad(x_fp4, (0, pad_k_bytes)).contiguous()

# If we pad x_fp4 so we maybe need to add pad x_block scale as well
original_scale_blocks = x_blockscale.shape[1]
k_elements = (
x_fp4.shape[1] * 2
) # 2 fp4 items are packed in the input dimension
required_scale_blocks = k_elements // layer.quant_config.group_size

# Align to 4 to be safe with int32 packing assumptions in kernels
if required_scale_blocks % 4 != 0:
required_scale_blocks += 4 - (required_scale_blocks % 4)

if original_scale_blocks < required_scale_blocks:
pad_scales = required_scale_blocks - original_scale_blocks
x_blockscale = torch.nn.functional.pad(
x_blockscale, (0, pad_scales), value=0.0
).contiguous()

mm_args = (
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
)
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)

# Slice output to remove padding if weight was padded in N dimension
if out.shape[1] != output_shape[1]:
out = out[:, : output_shape[1]].contiguous()
else:
assert self.backend == "cutlass"
mm_args = (
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
)
out = cutlass_scaled_fp4_mm(*mm_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change introduces a regression for flashinfer- backends other than flashinfer-cutlass (e.g., flashinfer-trtllm). The original if self.backend.startswith("flashinfer-") check was replaced with a more specific startswith("flashinfer-cutlass"), and the else block now asserts self.backend == "cutlass". This will cause an assertion failure for other flashinfer- backends. The previous structure should be restored, with the new padding logic nested inside a check for flashinfer-cutlass to maintain support for other backends.

        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]

            _x_fp4, _x_blockscale = x_fp4, x_blockscale
            _output_shape = output_shape

            if self.backend == "flashinfer-cutlass":
                # Match packed-K bytes between activations and weights using
                # pre-calculated padding
                pad_k_bytes = getattr(layer, "execution_padding_k_bytes", 0)
                _output_shape = [x.shape[0], layer.output_size_per_partition]
                _x_fp4 = torch.nn.functional.pad(x_fp4, (0, pad_k_bytes)).contiguous()

                # If we pad x_fp4 so we maybe need to add pad x_block scale as well
                original_scale_blocks = x_blockscale.shape[1]
                k_elements = (
                    _x_fp4.shape[1] * 2
                )  # 2 fp4 items are packed in the input dimension
                required_scale_blocks = k_elements // layer.quant_config.group_size

                # Align to 4 to be safe with int32 packing assumptions in kernels
                if required_scale_blocks % 4 != 0:
                    required_scale_blocks += 4 - (required_scale_blocks % 4)

                if original_scale_blocks < required_scale_blocks:
                    pad_scales = required_scale_blocks - original_scale_blocks
                    _x_blockscale = torch.nn.functional.pad(
                        x_blockscale, (0, pad_scales), value=0.0
                    ).contiguous()

            mm_args = (
                _x_fp4,
                layer.weight,
                _x_blockscale,
                layer.weight_scale,
                layer.alpha,
                output_dtype,
            )
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)

            if self.backend == "flashinfer-cutlass":
                # Slice output to remove padding if weight was padded in N dimension
                if out.shape[1] != _output_shape[1]:
                    out = out[:, : _output_shape[1]].contiguous()
        else:
            assert self.backend == "cutlass"
            mm_args = (
                x_fp4,
                layer.weight,
                x_blockscale,
                layer.weight_scale,
                layer.alpha,
                output_dtype,
            )
            out = cutlass_scaled_fp4_mm(*mm_args)

Comment on lines +1058 to +1110
# Swizzle block scales and then pad the packed NVFP4 weights so that
# both N (rows) and K (columns) satisfy the alignment constraints
# implied by the FlashInfer / Cutlass kernels.
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
weight = layer.weight.data

weight_current_rows = weight.shape[0]
weight_scale_rows_padded = swizzled_weight_scale.shape[0]

# Pad weight to match swizzled scale dimensions
if weight_scale_rows_padded != weight_current_rows:
pad_rows = weight_scale_rows_padded - weight_current_rows
assert pad_rows >= 0, (
f"Weight scale rows ({weight_scale_rows_padded}) < "
f"weight rows ({weight_current_rows})."
)
if pad_rows > 0:
weight = torch.nn.functional.pad(
weight, (0, 0, 0, pad_rows)
).contiguous()

# Calculate the number of k blocks padded to satisfy alignment
# constraints for the weight
layer.execution_padding_k_bytes = 0
group_size = self.quant_config.group_size
num_k_blocks_padded = swizzled_weight_scale.shape[1]
k_bytes_padded = (num_k_blocks_padded * group_size) // 2
k_bytes_orig = weight.shape[1]

logger.info(
"[FP4 Weight Prep] K-dim alignment: original_bytes=%d, "
"padded_bytes=%d, group_size=%d",
k_bytes_orig,
k_bytes_padded,
group_size,
)

if k_bytes_padded != k_bytes_orig:
pad_bytes = k_bytes_padded - k_bytes_orig
assert pad_bytes > 0
logger.info(
"[FP4 Weight Prep] Padding K-dim: %d -> %d (pad_bytes=%d)",
weight.shape[1],
weight.shape[1] + pad_bytes,
pad_bytes,
)
weight = torch.nn.functional.pad(
weight, (0, pad_bytes, 0, 0)
).contiguous()
layer.execution_padding_k_bytes = pad_bytes

layer.weight = Parameter(weight, requires_grad=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic block pads weights for both cutlass and flashinfer-cutlass backends, which is correct. However, the corresponding activation padding in the apply method is only performed for flashinfer-cutlass. This inconsistency will likely lead to a shape mismatch error during matrix multiplication for the cutlass backend, as it will receive padded weights but unpadded activations. To ensure correctness, the activation padding logic should be applied for both backends if the weight padding is intended for both.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant