"""Tests for the `residual_rms` kernel. Run `pytest tests/kernels/test_residual_rms.py`. """ from typing import List import pytest import torch from residual_rms._ops import ops from residual_rms.residual_rms import residual_rms @pytest.mark.parametrize("shape", [(2, 3, 4, 5), (2, 3, 4, 5, 6)]) def test_residual_rms(shape: List[int]) -> None: x = torch.randn(shape) out = torch.zeros_like(x) residual_rms(out, x) assert torch.allclose(out, ops.residual_rms(x))