Skip to content

[RFC][WIP] Code Organization for New Backend Support #965

@pillumina

Description

@pillumina

Status: Draft (will be actively revised)

This RFC proposes a code organization strategy to support multiple hardware backends (e.g., Ascend NPU, etc..) in Liger-Kernel while maintaining clean separation from the existing CUDA implementation. The design enables chip vendors to contribute backend-specific optimizations without modifying core kernel code.

Related: [RFC #954 - Native Ascend NPU Support for Liger Kernel](#954)

Motivation

As noted in RFC #954, Ascend NPU and other hardware backends are seeking native support in Liger-Kernel. The current codebase has some backend-specific code scattered across multiple files (e.g., is_npu_available() checks, conditional imports). This approach has several drawbacks:

  • Code coupling: Backend-specific logic is mixed with core implementations
  • Maintenance burden: Adding or modifying a backend requires changes across many files
  • Regression risk: Backend modifications may inadvertently affect other devices
  • Testing complexity: Difficult to isolate and test individual backends

We need a clean code organization that allows:

  1. Independent development and maintenance per backend
  2. Adding new backends without modifying existing code
  3. Incremental operator coverage (not all operators need backend-specific implementations)

Design Overview

The core principles:

  1. Directory Separation: Each new backend has its own directory under backends/
  2. Incremental Override: Backends only implement operators that need adaptation;
  3. Automatic Fallback: Runtime automatically uses implementations under /ops folders if backend-specific version doesn't exist
  4. Zero Breaking Changes: Existing APIs and behavior remain unchanged

Directory Structure

src/liger_kernel/
├── ops/                        # Base implementations (current CUDA code)
│   ├── __init__.py
│   ├── cross_entropy.py
│   ├── geglu.py
│   ├── rms_norm.py
│   └── ...
├── backends/                   # NEW: Backend-specific implementations
│   ├── __init__.py             # Public API: get_op(), get_current_backend()
│   ├── base.py                 # BackendInfo base class
│   ├── registry.py             # Backend registry
│   └── _npu/                   # Ascend NPU backend
│       ├── __init__.py         # NPUBackendInfo + auto-registration
│       └── ops/                # Only operators that need NPU adaptation
│           ├── __init__.py
│           └── geglu.py        # NPU-specific GEGLU (e.g. UB overflow fix)
└── transformers/               # Unchanged
    └── ...

Key Design Decisions

  1. Incremental Override Strategy
    Backends should only implement operators that require adaptation. For example, Ascend NPU may only need to override geglu (to handle UB overflow) while all other operators work with the base implementation.
# backends/_npu/ops/__init__.py
"""
NPU backend operator overrides.
Only operators listed here use NPU-specific implementations.
All others automatically fall back to base implementations in ops/.
"""
from liger_kernel.backends._npu.ops.geglu import geglu_forward, geglu_backward

__all__ = [
    "geglu_forward",
    "geglu_backward",
    # Other operators (cross_entropy, swiglu, rope, etc.) 
]
  1. Automatic Fallback
    The registry implements a lookup chain:
1. backends/_npu/ops/{op_name}    → Use if exists (backend-specific)
2. ops/{op_name}                   → Fallback to base implementation
  1. Dispatch at autograd.Function Level
    Dispatch happens inside torch.autograd.Function, keeping the public API unchanged:
# ops/geglu.py (modified)
class LigerGELUMulFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        device = a.device.type
        geglu_forward = get_op("geglu_forward", device)  # Dispatch here
        a, b, c = geglu_forward(a, b)
        ctx.save_for_backward(a, b)
        ctx.device_type = device
        return c
  • No changes to user-facing APIs
  • Device detection is automatic
  • Can use caching to minimize dispatch overhead

API Design

BackendInfo Base Class

# backends/base.py

@dataclass
class BackendInfo:
    """Base class for backend configuration."""
    
    vendor_name: str              # e.g.,  "ascend"
    device_name: str              # e.g.,  "npu"
   
    # Performance tuning parameters
    # ...
    
    # Optional extra configuration
    extra_config: Dict[str, Any] = field(default_factory=dict)
    
    def get_multiprocessor_count(self) -> int:
        """Return the number of multiprocessors/compute units."""
        raise NotImplementedError
    
    def calculate_settings(self, n: int) -> Tuple[int, int]:
        """
        Calculate BLOCK_SIZE and num_warps for a given problem size.
        Can be overridden by backends with different heuristics.
        
        Returns:
            Tuple of (BLOCK_SIZE, num_warps)
        """
        ...

BackendRegistry

# backends/registry.py

class BackendRegistry:
    """
    Registry for hardware backends with automatic fallback support.
    """
    
    @classmethod
    def register(cls, backend_info: BackendInfo) -> None:
        """Register a backend configuration."""
        ...
    
    @classmethod
    def get_op(cls, op_name: str, device_name: Optional[str] = None) -> Callable:
        """
        Get operator implementation with fallback.
        
        Lookup order:
        1. backends/_{device}/ops/{op_name} (backend-specific)
        2. ops/{op_name} (the original implementation)
        
        Args:
            op_name: Operator name, e.g., "geglu_forward"
            device_name: Device name, e.g., "npu". Auto-detected if None.
            
        Returns:
            The operator function.
            
        Raises:
            RuntimeError: If operator not found in any location.
        """
        ...
    
    @classmethod
    def get_backend(cls, device_name: str) -> Optional[BackendInfo]:
        """Get backend configuration by device name."""
        ...
    
    @classmethod
    def list_overridden_ops(cls, device_name: str) -> List[str]:
        """List operators overridden by a specific backend."""
        ...

Other Public APIs

# backends/__init__.py

def get_op(op_name: str, device: Optional[str] = None) -> Callable:
    """
    Get operator implementation for the specified device.
    Falls back to base implementation if backend doesn't override.
    """
    ...

def get_current_backend() -> Optional[BackendInfo]:
    """Get the currently active backend configuration."""
    ...

def list_overridden_ops(device_name: str) -> List[str]:
    """List operators that have backend-specific implementations."""
    ...

NPU Backend Example

# backends/_npu/__init__.py

class NPUBackendInfo(BackendInfo):
    """Ascend NPU backend configuration."""
    
    def __init__(self):
        super().__init__(
            vendor_name="ascend",
            device_name="npu",
            max_fused_size=32768,  # NPU may have different limits
            default_num_warps=4,
            extra_config={
                "block_size_sub_default": 1024,  # For UB overflow prevention
            }
        )
    
    def get_multiprocessor_count(self) -> int:
        """Return NPU vector core count."""
        ...
    
    def calculate_settings(self, n: int) -> Tuple[int, int]:
        """NPU-specific parameter calculation."""
        ...

# Auto-register on module import
BackendRegistry.register(NPUBackendInfo())

FAQs

Q: What's the runtime overhead of dispatch?
A: Operator lookups can be cached after the first call. Subsequent calls are simple dictionary lookups (nanosecond-level overhead).

Q: Will this break existing code?
A: No. All public APIs remain unchanged. The dispatch mechanism is internal and transparent to users.
(The current thinking is to leave the existing /ops directory structure unchanged (for example, no consideration has yet been given to creating a separate _cuda folder under backends); the design only targets newly added backends (e.g., NPU). This remains an open issue and can be refined in stages.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions