"""Tests for the `residual_rms` kernel. | |
Run `pytest tests/kernels/test_residual_rms.py`. | |
""" | |
from typing import List | |
import pytest | |
import torch | |
from residual_rms._ops import ops | |
from residual_rms.residual_rms import residual_rms | |
def test_residual_rms(shape: List[int]) -> None: | |
x = torch.randn(shape) | |
out = torch.zeros_like(x) | |
residual_rms(out, x) | |
assert torch.allclose(out, ops.residual_rms(x)) |