Add openpilot tests
This commit is contained in:
57
tinygrad_repo/test/models/test_bert.py
Normal file
57
tinygrad_repo/test/models/test_bert.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import Device
|
||||
import torch
|
||||
|
||||
def get_question_samp(bsz, seq_len, vocab_size, seed):
|
||||
np.random.seed(seed)
|
||||
in_ids= np.random.randint(vocab_size, size=(bsz, seq_len))
|
||||
mask = np.random.choice([True, False], size=(bsz, seq_len))
|
||||
seg_ids = np.random.randint(1, size=(bsz, seq_len))
|
||||
return in_ids, mask, seg_ids
|
||||
|
||||
def set_equal_weights(mdl, torch_mdl):
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
state, torch_state = get_state_dict(mdl), torch_mdl.state_dict()
|
||||
assert len(state) == len(torch_state)
|
||||
for k, v in state.items():
|
||||
assert k in torch_state
|
||||
torch_state[k].copy_(torch.from_numpy(v.numpy()))
|
||||
torch_mdl.eval()
|
||||
|
||||
class TestBert(unittest.TestCase):
|
||||
def test_questions(self):
|
||||
from models.bert import BertForQuestionAnswering
|
||||
from transformers import BertForQuestionAnswering as TorchBertForQuestionAnswering
|
||||
from transformers import BertConfig
|
||||
|
||||
# small
|
||||
config = {
|
||||
'vocab_size':24, 'hidden_size':2, 'num_hidden_layers':2, 'num_attention_heads':2,
|
||||
'intermediate_size':32, 'hidden_dropout_prob':0.1, 'attention_probs_dropout_prob':0.1,
|
||||
'max_position_embeddings':512, 'type_vocab_size':2
|
||||
}
|
||||
|
||||
# Create in tinygrad
|
||||
Tensor.manual_seed(1337)
|
||||
mdl = BertForQuestionAnswering(**config)
|
||||
|
||||
# Create in torch
|
||||
with torch.no_grad():
|
||||
torch_mdl = TorchBertForQuestionAnswering(BertConfig(**config))
|
||||
|
||||
set_equal_weights(mdl, torch_mdl)
|
||||
|
||||
seeds = (1337, 3141)
|
||||
bsz, seq_len = 1, 16
|
||||
for _, seed in enumerate(seeds):
|
||||
in_ids, mask, seg_ids = get_question_samp(bsz, seq_len, config['vocab_size'], seed)
|
||||
out = mdl(Tensor(in_ids), Tensor(mask), Tensor(seg_ids))
|
||||
torch_out = torch_mdl.forward(torch.from_numpy(in_ids).long(), torch.from_numpy(mask), torch.from_numpy(seg_ids).long())[:2]
|
||||
torch_out = torch.cat(torch_out).unsqueeze(2)
|
||||
np.testing.assert_allclose(out.numpy(), torch_out.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user