residual_rms / test /kernels /test_residual_rms.py
drbh
feat: impl residual rms kernel repo
79aac9d
raw
history blame contribute delete
484 Bytes
"""Tests for the `residual_rms` kernel.
Run `pytest tests/kernels/test_residual_rms.py`.
"""
from typing import List
import pytest
import torch
from residual_rms._ops import ops
from residual_rms.residual_rms import residual_rms
@pytest.mark.parametrize("shape", [(2, 3, 4, 5), (2, 3, 4, 5, 6)])
def test_residual_rms(shape: List[int]) -> None:
x = torch.randn(shape)
out = torch.zeros_like(x)
residual_rms(out, x)
assert torch.allclose(out, ops.residual_rms(x))