|
import torch |
|
import unittest |
|
from Andromeda.model import Andromeda |
|
|
|
|
|
class TestAndromeda(unittest.TestCase): |
|
def setUp(self): |
|
self.model = Andromeda() |
|
|
|
def test_initialization(self): |
|
self.assertIsNotNone(self.model.andromeda, "Transformer is not initialized.") |
|
self.assertIsNotNone(self.model.decoder, "AutoregressiveWrapper is not initialized.") |
|
|
|
def test_forward_pass(self): |
|
input_tokens = torch.randint(0, 50432, (1, 8192)) |
|
output = self.model(input_tokens) |
|
self.assertIsInstance(output, torch.Tensor, "Output is not a PyTorch tensor.") |
|
self.assertEqual(output.shape[0], input_tokens.shape[0], "Output batch size does not match input.") |
|
|
|
def test_error_handling(self): |
|
with self.assertRaises(Exception): |
|
self.model.forward(None) |
|
|
|
def test_model_parameters(self): |
|
self.assertEqual(self.model.Andromeda.num_tokens, 50432, "Number of tokens is not correctly set.") |
|
self.assertEqual(self.model.Andromeda.max_seq_len, 8192, "Max sequence length is not correctly set.") |
|
|
|
def test_model_output(self): |
|
input_tokens = torch.randint(0, 50432, (1, 8192)) |
|
output1 = self.model(input_tokens) |
|
output2 = self.model(input_tokens) |
|
self.assertTrue(torch.allclose(output1, output2), "Model does not produce consistent output.") |
|
|
|
|
|
class TestAndromedaExtended(unittest.TestCase): |
|
def setUp(self): |
|
self.model = Andromeda() |
|
|
|
def test_input_size(self): |
|
for seq_len in [512, 1024, 2048, 4096]: |
|
input_tokens = torch.randint(0, 50432, (1, seq_len)) |
|
output = self.model(input_tokens) |
|
self.assertEqual(output.shape[1], seq_len, f"Output sequence length does not match input for seq_len={seq_len}.") |
|
|
|
def test_batch_size(self): |
|
for batch_size in [2, 4, 8, 16]: |
|
input_tokens = torch.randint(0, 50432, (batch_size, 8192)) |
|
output = self.model(input_tokens) |
|
self.assertEqual(output.shape[0], batch_size, f"Output batch size does not match input for batch_size={batch_size}.") |
|
|
|
def test_token_range(self): |
|
for token in [0, 50431]: |
|
input_tokens = torch.full((1, 8192), fill_value=token) |
|
output = self.model(input_tokens) |
|
self.assertIsInstance(output, torch.Tensor, f"Output is not a PyTorch tensor for token={token}.") |
|
|
|
def test_model_depth(self): |
|
for depth in [16, 32, 64]: |
|
model = Andromeda(depth=depth) |
|
self.assertEqual(model.Andromeda.attn_layers.depth, depth, f"Model depth is not correctly set for depth={depth}.") |
|
|
|
def test_model_dim(self): |
|
for dim in [1280, 2560, 5120]: |
|
model = Andromeda(dim=dim) |
|
self.assertEqual(model.Andromeda.attn_layers.dim, dim, f"Model dimension is not correctly set for dim={dim}.") |
|
|
|
def test_model_heads(self): |
|
for heads in [12, 24, 48]: |
|
model = Andromeda(heads=heads) |
|
self.assertEqual(model.Andromeda.attn_layers.heads, heads, f"Number of heads is not correctly set for heads={heads}.") |
|
|
|
def test_model_dim_head(self): |
|
for dim_head in [64, 128, 256]: |
|
model = Andromeda(dim_head=dim_head) |
|
self.assertEqual(model.Andromeda.attn_layers.dim_head, dim_head, f"Head dimension is not correctly set for dim_head={dim_head}.") |
|
|
|
def test_model_alibi_num_heads(self): |
|
for alibi_num_heads in [6, 12, 24]: |
|
model = Andromeda(alibi_num_heads=alibi_num_heads) |
|
self.assertEqual(model.Andromeda.attn_layers.alibi_num_heads, alibi_num_heads, f"Number of alibi heads is not correctly set for alibi_num_heads={alibi_num_heads}.") |
|
|
|
def test_model_shift_tokens(self): |
|
for shift_tokens in [0, 1, 2]: |
|
model = Andromeda(shift_tokens=shift_tokens) |
|
self.assertEqual(model.Andromeda.attn_layers.shift_tokens, shift_tokens, f"Number of shift tokens is not correctly set for shift_tokens={shift_tokens}.") |
|
|
|
def test_model_use_abs_pos_emb(self): |
|
for use_abs_pos_emb in [True, False]: |
|
model = Andromeda(use_abs_pos_emb=use_abs_pos_emb) |
|
self.assertEqual(model.Andromeda.use_abs_pos_emb, use_abs_pos_emb, f"Use absolute position embedding flag is not correctly set for use_abs_pos_emb={use_abs_pos_emb}.") |
|
|
|
def test_model_alibi_pos_bias(self): |
|
for alibi_pos_bias in [True, False]: |
|
model = Andromeda(alibi_pos_bias=alibi_pos_bias) |
|
self.assertEqual(model.Andromeda.attn_layers.alibi_pos_bias, alibi_pos_bias, f"Alibi position bias flag is not correctly set for alibi_pos_bias={alibi_pos_bias}.") |
|
|
|
def test_model_rotary_xpos(self): |
|
for rotary_xpos in [True, False]: |
|
model = Andromeda(rotary_xpos=rotary_xpos) |
|
self.assertEqual(model.Andromeda.attn_layers.rotary_xpos, rotary_xpos, f"Rotary position flag is not correctly set for rotary_xpos={rotary_xpos}.") |
|
|
|
def test_model_attn_flash(self): |
|
for attn_flash in [True, False]: |
|
model = Andromeda(attn_flash=attn_flash) |
|
self.assertEqual(model.Andromeda.attn_layers.attn_flash, attn_flash, f"Attention flash flag is not correctly set for attn_flash={attn_flash}") |
|
|
|
if __name__ == '__main__': |
|
unittest.main() |