Add openpilot tests
This commit is contained in:
45
tinygrad_repo/test/external/external_test_jit_on_models.py
vendored
Normal file
45
tinygrad_repo/test/external/external_test_jit_on_models.py
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.helpers import dtypes, CI
|
||||
from tinygrad.ops import Device
|
||||
from test.helpers import derandomize_model
|
||||
|
||||
from examples.llama import Transformer
|
||||
|
||||
def helper_test_jitted_correctness(gen, train, train_jit):
|
||||
nojit = train(*gen()).numpy()
|
||||
for _ in range(5): jit = train_jit(*gen()).numpy()
|
||||
np.testing.assert_allclose(nojit, jit, rtol=1e-3, atol=1e-5)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT")
|
||||
class TestJittedModels(unittest.TestCase):
|
||||
def test_jitted_tiny_llama(self):
|
||||
old_type = Tensor.default_type
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**args_tiny)
|
||||
derandomize_model(model)
|
||||
def test(t): return model(t, 0).realize()
|
||||
|
||||
@TinyJit
|
||||
def test_jit(t): return model(t, 0).realize()
|
||||
helper_test_jitted_correctness(lambda: (Tensor([[1,]]),), test, test_jit)
|
||||
Tensor.default_type = old_type
|
||||
|
||||
@unittest.skipUnless(not CI, "huge for CI")
|
||||
def test_jitted_stable_diffusion(self):
|
||||
from examples.stable_diffusion import UNetModel
|
||||
model = UNetModel()
|
||||
derandomize_model(model)
|
||||
def test(t, t2): return model(t, 801, t2).realize()
|
||||
|
||||
@TinyJit
|
||||
def test_jit(t, t2): return model(t, 801, t2).realize()
|
||||
helper_test_jitted_correctness(lambda: (Tensor.randn(1, 4, 16, 16),Tensor.randn(1, 77, 768)), test, test_jit)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user