Add openpilot tests
This commit is contained in:
66
tinygrad_repo/test/unit/test_disk_cache.py
Normal file
66
tinygrad_repo/test/unit/test_disk_cache.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import unittest
|
||||
import pickle
|
||||
from tinygrad.helpers import diskcache_get, diskcache_put
|
||||
|
||||
def remote_get(table,q,k): q.put(diskcache_get(table, k))
|
||||
def remote_put(table,k,v): diskcache_put(table, k, v)
|
||||
|
||||
class DiskCache(unittest.TestCase):
|
||||
def test_putget(self):
|
||||
table = "test_putget"
|
||||
diskcache_put(table, "hello", "world")
|
||||
self.assertEqual(diskcache_get(table, "hello"), "world")
|
||||
diskcache_put(table, "hello", "world2")
|
||||
self.assertEqual(diskcache_get(table, "hello"), "world2")
|
||||
|
||||
def test_putcomplex(self):
|
||||
table = "test_putcomplex"
|
||||
diskcache_put(table, "k", ("complex", 123, "object"))
|
||||
ret = diskcache_get(table, "k")
|
||||
self.assertEqual(ret, ("complex", 123, "object"))
|
||||
|
||||
def test_getotherprocess(self):
|
||||
table = "test_getotherprocess"
|
||||
from multiprocessing import Process, Queue
|
||||
diskcache_put(table, "k", "getme")
|
||||
q = Queue()
|
||||
p = Process(target=remote_get, args=(table,q,"k"))
|
||||
p.start()
|
||||
p.join()
|
||||
self.assertEqual(q.get(), "getme")
|
||||
|
||||
def test_putotherprocess(self):
|
||||
table = "test_putotherprocess"
|
||||
from multiprocessing import Process
|
||||
p = Process(target=remote_put, args=(table,"k", "remote"))
|
||||
p.start()
|
||||
p.join()
|
||||
self.assertEqual(diskcache_get(table, "k"), "remote")
|
||||
|
||||
def test_no_table(self):
|
||||
self.assertIsNone(diskcache_get("faketable", "k"))
|
||||
|
||||
def test_ret(self):
|
||||
table = "test_ret"
|
||||
self.assertEqual(diskcache_put(table, "key", ("vvs",)), ("vvs",))
|
||||
|
||||
def test_non_str_key(self):
|
||||
table = "test_non_str_key"
|
||||
diskcache_put(table, 4, 5)
|
||||
self.assertEqual(diskcache_get(table, 4), 5)
|
||||
self.assertEqual(diskcache_get(table, "4"), 5)
|
||||
|
||||
def test_dict_key(self):
|
||||
table = "test_dict_key"
|
||||
fancy_key = {"hello": "world", "goodbye": 7, "good": True, "pkl": pickle.dumps("cat")}
|
||||
fancy_key2 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("cat")}
|
||||
fancy_key3 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("dog")}
|
||||
diskcache_put(table, fancy_key, 5)
|
||||
self.assertEqual(diskcache_get(table, fancy_key), 5)
|
||||
diskcache_put(table, fancy_key2, 8)
|
||||
self.assertEqual(diskcache_get(table, fancy_key2), 8)
|
||||
self.assertEqual(diskcache_get(table, fancy_key), 5)
|
||||
self.assertEqual(diskcache_get(table, fancy_key3), None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
150
tinygrad_repo/test/unit/test_disk_tensor.py
Normal file
150
tinygrad_repo/test/unit/test_disk_tensor.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import pathlib
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from tinygrad.helpers import Timing
|
||||
from extra.utils import fetch_as_file, temp
|
||||
|
||||
def compare_weights_both(url):
|
||||
import torch
|
||||
fn = fetch_as_file(url)
|
||||
tg_weights = get_state_dict(torch_load(fn))
|
||||
torch_weights = get_state_dict(torch.load(fn), tensor_type=torch.Tensor)
|
||||
assert list(tg_weights.keys()) == list(torch_weights.keys())
|
||||
for k in tg_weights:
|
||||
np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
|
||||
print(f"compared {len(tg_weights)} weights")
|
||||
|
||||
class TestTorchLoad(unittest.TestCase):
|
||||
# pytorch pkl format
|
||||
def test_load_enet(self): compare_weights_both("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
||||
# pytorch zip format
|
||||
def test_load_enet_alt(self): compare_weights_both("https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth")
|
||||
# pytorch zip format
|
||||
def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth')
|
||||
# TODO: support pytorch tar format with minimal lines
|
||||
#def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
|
||||
|
||||
test_fn = pathlib.Path(__file__).parents[2] / "weights/LLaMA/7B/consolidated.00.pth"
|
||||
#test_size = test_fn.stat().st_size
|
||||
test_size = 1024*1024*1024*2
|
||||
|
||||
# sudo su -c 'sync; echo 1 > /proc/sys/vm/drop_caches' && python3 test/unit/test_disk_tensor.py TestRawDiskBuffer.test_readinto_read_speed
|
||||
@unittest.skipIf(not test_fn.exists(), "download LLaMA weights for read in speed tests")
|
||||
class TestRawDiskBuffer(unittest.TestCase):
|
||||
def test_readinto_read_speed(self):
|
||||
tst = np.empty(test_size, np.uint8)
|
||||
with open(test_fn, "rb") as f:
|
||||
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
|
||||
f.readinto(tst)
|
||||
|
||||
def test_mmap_read_speed(self):
|
||||
db = RawDiskBuffer(test_size, dtype=dtypes.uint8, device=test_fn)
|
||||
tst = np.empty(test_size, np.uint8)
|
||||
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
|
||||
np.copyto(tst, db.toCPU())
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype")
|
||||
class TestSafetensors(unittest.TestCase):
|
||||
def test_real_safetensors(self):
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
torch.manual_seed(1337)
|
||||
tensors = {
|
||||
"weight1": torch.randn((16, 16)),
|
||||
"weight2": torch.arange(0, 17, dtype=torch.uint8),
|
||||
"weight3": torch.arange(0, 17, dtype=torch.int32).reshape(17,1,1),
|
||||
"weight4": torch.arange(0, 2, dtype=torch.uint8),
|
||||
}
|
||||
save_file(tensors, temp("model.safetensors"))
|
||||
|
||||
ret = safe_load(temp("model.safetensors"))
|
||||
for k,v in tensors.items(): np.testing.assert_array_equal(ret[k].numpy(), v.numpy())
|
||||
safe_save(ret, temp("model.safetensors_alt"))
|
||||
with open(temp("model.safetensors"), "rb") as f:
|
||||
with open(temp("model.safetensors_alt"), "rb") as g:
|
||||
assert f.read() == g.read()
|
||||
ret2 = safe_load(temp("model.safetensors_alt"))
|
||||
for k,v in tensors.items(): np.testing.assert_array_equal(ret2[k].numpy(), v.numpy())
|
||||
|
||||
def test_efficientnet_safetensors(self):
|
||||
from models.efficientnet import EfficientNet
|
||||
model = EfficientNet(0)
|
||||
state_dict = get_state_dict(model)
|
||||
safe_save(state_dict, temp("eff0"))
|
||||
state_dict_loaded = safe_load(temp("eff0"))
|
||||
assert sorted(list(state_dict_loaded.keys())) == sorted(list(state_dict.keys()))
|
||||
for k,v in state_dict.items():
|
||||
np.testing.assert_array_equal(v.numpy(), state_dict_loaded[k].numpy())
|
||||
|
||||
# load with the real safetensors
|
||||
from safetensors import safe_open
|
||||
with safe_open(temp("eff0"), framework="pt", device="cpu") as f:
|
||||
assert sorted(list(f.keys())) == sorted(list(state_dict.keys()))
|
||||
for k in f.keys():
|
||||
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
|
||||
|
||||
def test_huggingface_enet_safetensors(self):
|
||||
# test a real file
|
||||
fn = fetch_as_file("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
|
||||
state_dict = safe_load(fn)
|
||||
assert len(state_dict.keys()) == 244
|
||||
assert 'blocks.2.2.se.conv_reduce.weight' in state_dict
|
||||
assert state_dict['blocks.0.0.bn1.num_batches_tracked'].numpy() == 276570
|
||||
assert state_dict['blocks.2.0.bn2.num_batches_tracked'].numpy() == 276570
|
||||
|
||||
def test_metadata(self):
|
||||
metadata = {"hello": "world"}
|
||||
safe_save({}, temp('metadata.safetensors'), metadata)
|
||||
import struct
|
||||
with open(temp('metadata.safetensors'), 'rb') as f:
|
||||
dat = f.read()
|
||||
sz = struct.unpack(">Q", dat[0:8])[0]
|
||||
import json
|
||||
assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world'
|
||||
|
||||
def helper_test_disk_tensor(fn, data, np_fxn, tinygrad_fxn=None):
|
||||
if tinygrad_fxn is None: tinygrad_fxn = np_fxn
|
||||
pathlib.Path(temp(fn)).unlink(missing_ok=True)
|
||||
tinygrad_tensor = Tensor(data, device="CPU").to(f"disk:{temp(fn)}")
|
||||
numpy_arr = np.array(data)
|
||||
tinygrad_fxn(tinygrad_tensor)
|
||||
np_fxn(numpy_arr)
|
||||
np.testing.assert_allclose(tinygrad_tensor.numpy(), numpy_arr)
|
||||
|
||||
class TestDiskTensor(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
pathlib.Path(temp("dt1")).unlink(missing_ok=True)
|
||||
Tensor.empty(100, 100, device=f"disk:{temp('dt1')}")
|
||||
|
||||
def test_write_ones(self):
|
||||
pathlib.Path(temp("dt2")).unlink(missing_ok=True)
|
||||
|
||||
out = Tensor.ones(10, 10, device="CPU")
|
||||
outdisk = out.to(f"disk:{temp('dt2')}")
|
||||
print(outdisk)
|
||||
outdisk.realize()
|
||||
del out, outdisk
|
||||
|
||||
# test file
|
||||
with open(temp("dt2"), "rb") as f:
|
||||
assert f.read() == b"\x00\x00\x80\x3F" * 100
|
||||
|
||||
# test load alt
|
||||
reloaded = Tensor.empty(10, 10, device=f"disk:{temp('dt2')}")
|
||||
out = reloaded.numpy()
|
||||
assert np.all(out == 1.)
|
||||
|
||||
def test_assign_slice(self):
|
||||
def assign(x,s,y): x[s] = y
|
||||
helper_test_disk_tensor("dt3", [0,1,2,3], lambda x: assign(x, slice(0,2), [13, 12]))
|
||||
helper_test_disk_tensor("dt4", [[0,1,2,3],[4,5,6,7]], lambda x: assign(x, slice(0,1), [[13, 12, 11, 10]]))
|
||||
|
||||
def test_reshape(self):
|
||||
helper_test_disk_tensor("dt5", [1,2,3,4,5], lambda x: x.reshape((1,5)))
|
||||
helper_test_disk_tensor("dt6", [1,2,3,4], lambda x: x.reshape((2,2)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
44
tinygrad_repo/test/unit/test_flopcounter.py
Normal file
44
tinygrad_repo/test/unit/test_flopcounter.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.helpers import dtypes
|
||||
|
||||
class TestFlopCounter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.buf0 = LazyOp(BufferOps.MEM, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
|
||||
self.buf1 = LazyOp(BufferOps.MEM, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
|
||||
|
||||
def test_flops_add(self):
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
info = get_lazyop_info(op0)
|
||||
self.assertEqual(info.flops, 4)
|
||||
|
||||
def test_flops_add_twice(self):
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
|
||||
info = get_lazyop_info(op1)
|
||||
self.assertEqual(info.flops, 8)
|
||||
|
||||
def test_flops_add_self(self):
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
op1 = LazyOp(BinaryOps.ADD, (op0,op0,), None)
|
||||
info = get_lazyop_info(op1)
|
||||
self.assertEqual(info.flops, 8)
|
||||
|
||||
def test_flops_add_roundabout_self(self):
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
|
||||
op2 = LazyOp(BinaryOps.ADD, (op0,op1,), None)
|
||||
info = get_lazyop_info(op2)
|
||||
self.assertEqual(info.flops, 12)
|
||||
|
||||
def test_flops_red(self):
|
||||
op0 = LazyOp(BinaryOps.MUL, (self.buf0,self.buf1,), None)
|
||||
op1 = LazyOp(ReduceOps.SUM, (op0,), (1,))
|
||||
op2 = LazyOp(BinaryOps.ADD, (op1, op1,), None)
|
||||
info = get_lazyop_info(op2)
|
||||
self.assertEqual(info.flops, 9)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
142
tinygrad_repo/test/unit/test_helpers.py
Normal file
142
tinygrad_repo/test/unit/test_helpers.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
|
||||
class TestContextVars(unittest.TestCase):
|
||||
# Ensuring that the test does not modify variables outside the tests.
|
||||
ctx = Context()
|
||||
def setUp(self): TestContextVars.ctx.__enter__()
|
||||
def tearDown(self): TestContextVars.ctx.__exit__()
|
||||
|
||||
def test_initial_value_is_set(self):
|
||||
_TMP = ContextVar("_TMP", 5)
|
||||
self.assertEqual(_TMP.value, 5)
|
||||
|
||||
def test_multiple_creation_ignored(self):
|
||||
_TMP2 = ContextVar("_TMP2", 1)
|
||||
_TMP2 = ContextVar("_TMP2", 2)
|
||||
self.assertEqual(_TMP2.value, 1)
|
||||
|
||||
def test_new_var_inside_context(self):
|
||||
# Creating a _new_ variable inside a context should not have any effect on its scope (?)
|
||||
with Context(VARIABLE=1):
|
||||
_TMP3 = ContextVar("_TMP3", 1)
|
||||
_TMP3 = ContextVar("_TMP3", 2)
|
||||
self.assertEqual(_TMP3.value, 1)
|
||||
|
||||
def test_value_accross_modules(self):
|
||||
# Mocking module import by invoking the code but not in our globals().
|
||||
exec('from tinygrad.helpers import ContextVar;C = ContextVar("C", 13)', {}) # pylint:disable=exec-used
|
||||
# It should not matter that the first creation was in another module.
|
||||
C = ContextVar("C", 0)
|
||||
self.assertEqual(C.value, 13)
|
||||
|
||||
def test_assignment_across_modules(self):
|
||||
B = ContextVar("B", 1)
|
||||
# local assignment
|
||||
B.value = 2
|
||||
self.assertEqual(B.value, 2)
|
||||
# Assignment in another module.
|
||||
exec('from tinygrad.helpers import ContextVar;B = ContextVar("B", 0);B.value = 3;', {}) # pylint:disable=exec-used
|
||||
# Assignment in another module should affect this one as well.
|
||||
self.assertEqual(B.value, 3)
|
||||
|
||||
def test_context_assignment(self):
|
||||
with Context(VARIABLE=1):
|
||||
self.assertEqual(VARIABLE.value, 1)
|
||||
self.assertEqual(VARIABLE.value, 0)
|
||||
|
||||
def test_unknown_param_to_context(self):
|
||||
with self.assertRaises(KeyError):
|
||||
with Context(SOMETHING_ELSE=1):
|
||||
pass
|
||||
|
||||
def test_inside_context_assignment(self):
|
||||
with Context(VARIABLE=4):
|
||||
# What you can and cannot do inside a context.
|
||||
# 1. This type of statement has no effect.
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
self.assertTrue(VARIABLE >= 4, "ContextVars inside contextmanager may not set a new value")
|
||||
|
||||
# 2. The call syntax however has a local effect.
|
||||
VARIABLE.value = 13
|
||||
self.assertTrue(VARIABLE.value == 13, "Call syntax however works inside a contextmanager.")
|
||||
|
||||
# Related to 2. above. Note that VARIABLE is back to 0 again as expected.
|
||||
self.assertEqual(VARIABLE.value, 0)
|
||||
|
||||
def test_new_var_inside_context_other_module(self):
|
||||
with Context(VARIABLE=1):
|
||||
_NEW2 = ContextVar("_NEW2", 0)
|
||||
_NEW2 = ContextVar("_NEW2", 1)
|
||||
self.assertEqual(_NEW2.value, 0)
|
||||
|
||||
code = """\
|
||||
from tinygrad.helpers import Context, ContextVar
|
||||
with Context(VARIABLE=1):
|
||||
_NEW3 = ContextVar("_NEW3", 0)"""
|
||||
exec(code, {}) # pylint:disable=exec-used
|
||||
# While _NEW3 was created in an outside scope it should still work the same as above.
|
||||
_NEW3 = ContextVar("_NEW3", 1)
|
||||
self.assertEqual(_NEW3.value, 0)
|
||||
|
||||
def test_nested_context(self):
|
||||
with Context(VARIABLE=1):
|
||||
with Context(VARIABLE=2):
|
||||
with Context(VARIABLE=3):
|
||||
self.assertEqual(VARIABLE.value, 3)
|
||||
self.assertEqual(VARIABLE.value, 2)
|
||||
self.assertEqual(VARIABLE.value, 1)
|
||||
self.assertEqual(VARIABLE.value, 0)
|
||||
|
||||
def test_decorator(self):
|
||||
@Context(VARIABLE=1, DEBUG=4)
|
||||
def test():
|
||||
self.assertEqual(VARIABLE.value, 1)
|
||||
|
||||
self.assertEqual(VARIABLE.value, 0)
|
||||
test()
|
||||
self.assertEqual(VARIABLE.value, 0)
|
||||
|
||||
def test_context_exit_reverts_updated_values(self):
|
||||
D = ContextVar("D", 1)
|
||||
D.value = 2
|
||||
with Context(D=3):
|
||||
...
|
||||
assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value."
|
||||
|
||||
class TestMergeDicts(unittest.TestCase):
|
||||
def test_merge_dicts(self):
|
||||
a = {"a": 1, "b": 2}
|
||||
b = {"a": 1, "c": 3}
|
||||
c = {}
|
||||
d = {"a": 2, "b": 2}
|
||||
assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3}
|
||||
assert merge_dicts([a, c]) == a
|
||||
assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3}
|
||||
with self.assertRaises(AssertionError):
|
||||
merge_dicts([a, d])
|
||||
|
||||
class TestDtypes(unittest.TestCase):
|
||||
def test_dtypes_fields(self):
|
||||
fields = dtypes.fields()
|
||||
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
|
||||
self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None))
|
||||
|
||||
class TestStripParens(unittest.TestCase):
|
||||
def test_simple(self): self.assertEqual("1+2", strip_parens("(1+2)"))
|
||||
def test_nested(self): self.assertEqual("1+(2+3)", strip_parens("(1+(2+3))"))
|
||||
def test_casted_no_strip(self): self.assertEqual("(int)(1+2)", strip_parens("(int)(1+2)"))
|
||||
|
||||
class TestProd(unittest.TestCase):
|
||||
def test_empty(self): self.assertEqual(1, prod(tuple()))
|
||||
def test_ints(self): self.assertEqual(30, prod((2, 3, 5)))
|
||||
def test_variable(self): self.assertEqual("(a*12)", prod((Variable("a", 1, 5), 3, 4)).render())
|
||||
def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render())
|
||||
def test_num_nodes(self): self.assertEqual(NumNode(6), prod((NumNode(2), NumNode(3))))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
663
tinygrad_repo/test/unit/test_shapetracker.py
Normal file
663
tinygrad_repo/test/unit/test_shapetracker.py
Normal file
@@ -0,0 +1,663 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from itertools import product
|
||||
|
||||
def shapetracker_getitem(st, val):
|
||||
locals = {"idx": val, "valid": 1}
|
||||
idx, valid = st.expr_node()
|
||||
exec(f"valid={valid.render()};idx={idx.render()}", None, locals)
|
||||
return locals["idx"] if locals["valid"] else -1
|
||||
|
||||
class CheckingShapeTracker:
|
||||
def __init__(self, shape):
|
||||
self.st = ShapeTracker.from_shape(shape)
|
||||
self.t = np.arange(prod(shape), dtype=np.int32).reshape(shape)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.t.shape
|
||||
|
||||
def simplify(self):
|
||||
self.st = self.st.simplify()
|
||||
return self
|
||||
|
||||
def reshape(self, new_shape):
|
||||
self.st = self.st.reshape(new_shape)
|
||||
self.t = self.t.reshape(new_shape)
|
||||
return self
|
||||
|
||||
def permute(self, axis):
|
||||
self.st = self.st.permute(axis)
|
||||
self.t = np.transpose(self.t, axis)
|
||||
return self
|
||||
|
||||
def expand(self, new_shape):
|
||||
self.st = self.st.expand(new_shape)
|
||||
self.t = np.broadcast_to(self.t, new_shape)
|
||||
return self
|
||||
|
||||
def flip(self, axis):
|
||||
self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
|
||||
self.t = np.flip(self.t, axis)
|
||||
return self
|
||||
|
||||
def shrink(self, arg):
|
||||
self.st = self.st.shrink(arg)
|
||||
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
|
||||
return self
|
||||
|
||||
def pad(self, arg):
|
||||
self.st = self.st.pad(arg)
|
||||
self.t = np.pad(self.t, arg, constant_values=-1)
|
||||
return self
|
||||
|
||||
def stride(self, arg):
|
||||
self.st = self.st.stride(arg)
|
||||
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
|
||||
return self
|
||||
|
||||
def __getitem__(self, val):
|
||||
return self.t.flatten()[val]
|
||||
|
||||
@property
|
||||
def views(self): return self.st.views
|
||||
|
||||
@property
|
||||
def contiguous(self): return self.st.contiguous
|
||||
|
||||
def assert_same(self):
|
||||
x = [shapetracker_getitem(self.st, i) for i in range(prod(self.st.shape))]
|
||||
y = [self[i] for i in range(prod(self.shape))]
|
||||
idx, valid = self.st.expr_node()
|
||||
if DEBUG >= 1: print(x, y, self.st.shape, self.shape, idx.render(), valid.render(), self.st)
|
||||
assert self.st.shape == self.shape
|
||||
assert x == y, f"mismatch shapetracker:{x} real:{y}"
|
||||
|
||||
class TestRealIssues(unittest.TestCase):
|
||||
def test_reshape_doesnt_multiview(self):
|
||||
self.st = ShapeTracker((View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None),))
|
||||
self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2))
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
class TestRealDoesntSimplify(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
st = self.st.real_strides()
|
||||
print(st)
|
||||
self.st = self.st.simplify()
|
||||
assert len(self.st.views) != 1
|
||||
assert None in st
|
||||
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((
|
||||
View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
|
||||
View.create((8, 6, 11), (66, 11, 1), 0, None)))
|
||||
assert self.st.real_strides() == (33, None, 1)
|
||||
|
||||
def test_2(self):
|
||||
self.st = ShapeTracker((
|
||||
View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
|
||||
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)))
|
||||
assert self.st.real_strides() == (None, 18, -3, -1)
|
||||
|
||||
class TestRealStrides(unittest.TestCase):
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((
|
||||
View.create((2048,), (1,), 0, ((0, 512),)),
|
||||
View.create((16, 32, 4), (128, 4, 1), 0, None)))
|
||||
st = self.st.real_strides()
|
||||
print(self.st, st)
|
||||
assert st == (None, 4, 1)
|
||||
|
||||
class TestRealSimplifies(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
st = self.st.real_strides()
|
||||
self.st = self.st.simplify()
|
||||
assert len(self.st.views) == 1
|
||||
print(self.st.views[-1].strides, st)
|
||||
assert self.st.views[-1].strides == st
|
||||
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((
|
||||
View.create((1, 3, 2, 11, 4, 28), (0, 308, 0, 28, 0, 1), 0, None),
|
||||
View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)))
|
||||
|
||||
def test_2(self):
|
||||
self.st = ShapeTracker((
|
||||
View.create((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
|
||||
View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)))
|
||||
|
||||
class TestIndexExpressions2d(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
|
||||
offsets = [0, 1, 15, 28, 10000]
|
||||
self.sts = [ShapeTracker((View.create(base_shape, offset=offset),)) for base_shape in shapes for offset in offsets]
|
||||
self.offset = [Variable.num(offset) for base_shape in shapes for offset in offsets]
|
||||
self.shapes = [shape for shape in shapes for offset in offsets]
|
||||
self.node_exprs = []
|
||||
self.idxs_exprs = []
|
||||
|
||||
def tearDown(self):
|
||||
for st, offset, shape, node_expr, idxs_expr in zip(self.sts, self.offset, self.shapes, self.node_exprs, self.idxs_exprs):
|
||||
numel = prod(shape)
|
||||
assert node_expr(self.default_idx(st.shape)) == st.expr_node()[0]
|
||||
assert node_expr(self.default_idx(st.shape)) == st.expr_node(None)[0]
|
||||
assert node_expr(self.default_idx(st.shape)) == st.expr_node('idx')[0]
|
||||
self.check_bounds(node_expr(self.default_idx(st.shape)), offset, numel)
|
||||
for idx in [(0, numel-1), (7, 203), (2, 5), (0, 0), (numel, numel), (0, numel), (0, numel+1), (numel+100, numel+100)]:
|
||||
idx = Variable("idx", idx[0], idx[1])
|
||||
assert node_expr(idx) == st.expr_node(idx)[0]
|
||||
self.check_bounds(node_expr(idx), offset, numel)
|
||||
|
||||
assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs()[0]
|
||||
assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs(None)[0]
|
||||
self.check_bounds(idxs_expr(self.default_idxs(st.shape)), offset, numel)
|
||||
idx0s = [(0,0), (0, min(1, st.shape[0]-1)), (0, st.shape[0]-1), (min(3, st.shape[0]-1), min(6, st.shape[0]-1)), (st.shape[0]-1, st.shape[0]-1)]
|
||||
idx1s = [(0,0), (0, min(1, st.shape[1]-1)), (0, st.shape[1]-1), (min(3, st.shape[1]-1), min(6, st.shape[1]-1)), (st.shape[1]-1, st.shape[1]-1)]
|
||||
idx2s = [(0,0), (0, min(1, st.shape[2]-1)), (0, st.shape[2]-1), (min(3, st.shape[2]-1), min(6, st.shape[2]-1)), (st.shape[2]-1, st.shape[2]-1)] if len(st.shape) == 3 else [None for _ in idx0s]
|
||||
for idx0, idx1, idx2 in product(idx0s, idx1s, idx2s):
|
||||
idxs = [Variable(f"idx{i}", idx[0], idx[1]) for i, idx in enumerate((idx0, idx1, idx2)) if idx is not None]
|
||||
assert idxs_expr(idxs) == st.expr_idxs(idxs)[0]
|
||||
self.check_bounds(idxs_expr(idxs), offset, numel)
|
||||
|
||||
def default_idx(self, shape):
|
||||
return Variable("idx", 0, prod(shape)-1)
|
||||
|
||||
def default_idxs(self, shape):
|
||||
return [Variable(f"idx{i}", 0, d-1) for i,d in enumerate(shape)]
|
||||
|
||||
def check_bounds(self, expr, offset, numel):
|
||||
assert expr.min >= offset
|
||||
assert expr.max <= offset + numel - 1
|
||||
|
||||
def test_noop(self):
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[1] + offset)
|
||||
|
||||
def test_permute(self):
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st = st.permute((1, 0))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + idxs[1]*base_shape[1] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
def test_reshape(self):
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st = st.reshape((base_shape[0], 1, base_shape[1]))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
def test_reshape_expand(self):
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st = st.reshape((base_shape[0], 1, base_shape[1]))
|
||||
st = st.expand((base_shape[0], base_shape[1], base_shape[1]))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
def test_permute_reshape_1(self): # This tests multiple views
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st = st.permute((1, 0))
|
||||
st = st.reshape((base_shape[0]//5, 1, base_shape[1]*5))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[0]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[0]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
def test_permute_reshape_2(self):
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st = st.permute((1, 0))
|
||||
st = st.reshape((1, base_shape[0]//5, base_shape[1]*5))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[1]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[1]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.st = CheckingShapeTracker((1, 10))
|
||||
|
||||
def tearDown(self):
|
||||
self.st.assert_same()
|
||||
|
||||
# multiview simplify
|
||||
def test_expand_contract_simple(self):
|
||||
self.st = self.st.expand((10, 10))
|
||||
self.st = self.st.reshape((100,))
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
self.st = self.st.reshape((10, 10))
|
||||
print(self.st.views)
|
||||
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 1)
|
||||
|
||||
# multiview simplify
|
||||
def test_expand_contract_different_shape(self):
|
||||
self.st.expand((10, 10))
|
||||
self.st.reshape((100,))
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
self.st.reshape((2, 5, 2, 5))
|
||||
print(self.st.views)
|
||||
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 1)
|
||||
|
||||
# multiview simplify
|
||||
def test_expand_contract_still_complex(self):
|
||||
self.st.expand((10, 10))
|
||||
self.st.reshape((100,))
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
self.st.reshape((5, 20))
|
||||
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
|
||||
# Tensor.zeros(2, 4).permute(1,0).reshape(2, 4)
|
||||
# (d1*4 + d0%4), d1=x//4, d0=x%4 = ((x//4)*4) + (x%4)%4
|
||||
|
||||
class TestComplexShapeTracker(unittest.TestCase):
|
||||
def test_add_1s(self):
|
||||
self.st = CheckingShapeTracker((4, 4))
|
||||
self.st.permute((1,0))
|
||||
self.st.reshape((1,4,1,4,1))
|
||||
assert not self.st.contiguous
|
||||
self.st.permute((0,3,2,1,4))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_permute_1s_simple(self):
|
||||
self.st = CheckingShapeTracker((1, 16, 9,9))
|
||||
self.st.permute((1,0,2,3))
|
||||
assert self.st.contiguous
|
||||
self.st = CheckingShapeTracker((2, 16, 9,9))
|
||||
self.st.permute((1,0,2,3))
|
||||
assert not self.st.contiguous
|
||||
|
||||
def test_remove_1s_simple(self):
|
||||
self.st = CheckingShapeTracker((1, 16, 1, 1))
|
||||
self.st.reshape((16,))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_remove_1s(self):
|
||||
self.st = CheckingShapeTracker((1, 4, 1, 4, 1))
|
||||
self.st.permute((0,3,2,1,4))
|
||||
self.st.reshape((4,4))
|
||||
assert not self.st.contiguous
|
||||
self.st.permute((1,0))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_permute_reshape(self):
|
||||
self.st = CheckingShapeTracker((4, 4))
|
||||
self.st.permute((1,0))
|
||||
self.st.reshape((2, 2, 2, 2))
|
||||
# TODO: should also be tested by test_super_complex
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
def test_factorize_split(self):
|
||||
self.st = CheckingShapeTracker((4, 4))
|
||||
self.st.permute((1,0))
|
||||
self.st.reshape((2, 2, 2, 2))
|
||||
self.st.permute((2,3,0,1))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_factorize_combine(self):
|
||||
self.st = CheckingShapeTracker((4, 4, 4))
|
||||
self.st.permute((2, 0, 1))
|
||||
self.st.reshape((4, 16))
|
||||
self.st.permute((1, 0))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_factorize_combine_add_ones(self):
|
||||
self.st = CheckingShapeTracker((4, 4, 4))
|
||||
self.st.permute((2, 0, 1))
|
||||
self.st.reshape((4, 16, 1, 1))
|
||||
self.st.permute((1, 0, 2, 3))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_fancy_factorize(self):
|
||||
self.st = CheckingShapeTracker((32, 3, 3, 1))
|
||||
self.st.reshape((8, 4, 3, 3))
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
def test_super_complex_2_fail(self):
|
||||
self.st = CheckingShapeTracker((4, 4, 4))
|
||||
self.st.permute((2, 0, 1))
|
||||
self.st.reshape((16, 4))
|
||||
assert len(self.st.views) != 1
|
||||
|
||||
def test_work(self):
|
||||
self.st = CheckingShapeTracker((64, 1024, 4))
|
||||
self.st.reshape((1, 64, 128, 32))
|
||||
self.st.permute((0, 3, 1, 2))
|
||||
self.st.reshape((1, 32, 1, 64, 128))
|
||||
self.st.permute((0, 3, 4, 1, 2))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_work2(self):
|
||||
self.st = CheckingShapeTracker((64, 1024, 4))
|
||||
self.st.reshape((1, 64, 128, 32))
|
||||
self.st.permute((0, 3, 1, 2))
|
||||
self.st.reshape((1, 1, 32, 64, 128))
|
||||
self.st.permute((0, 3, 4, 1, 2))
|
||||
self.st.reshape((64, 1024, 4))
|
||||
print(self.st.views)
|
||||
assert self.st.contiguous
|
||||
|
||||
class TestSingleShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.st = CheckingShapeTracker((7,4))
|
||||
|
||||
def tearDown(self):
|
||||
self.st.assert_same()
|
||||
|
||||
def test_reshape(self):
|
||||
self.st.reshape((7,1,4))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_permute(self):
|
||||
self.st.permute((1,0))
|
||||
assert not self.st.contiguous
|
||||
|
||||
def test_shrink(self):
|
||||
self.st.shrink(((1,2), (0,4)))
|
||||
assert not self.st.contiguous
|
||||
|
||||
def test_double_permute(self):
|
||||
self.st.permute((1,0))
|
||||
self.st.permute((1,0))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_reshape_permute(self):
|
||||
self.st.reshape((7,1,4))
|
||||
self.st.permute((0,1,2))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_reshape_permute_yes(self):
|
||||
self.st.reshape((7,1,4))
|
||||
self.st.permute((0,2,1))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_reshape_permute_no(self):
|
||||
self.st.reshape((4,7))
|
||||
self.st.permute((1,0))
|
||||
assert not self.st.contiguous
|
||||
|
||||
class TestShapeTrackerFuzzFailures(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.st = CheckingShapeTracker((3,3,3))
|
||||
def tearDown(self):
|
||||
self.st.assert_same()
|
||||
@unittest.skip("simplify doesn't work in this case")
|
||||
def test_case_1(self):
|
||||
self.st.shrink(((1, 2), (1, 3), (1, 3)))
|
||||
self.st.reshape((1, 4))
|
||||
self.st.shrink(((0, 1), (1, 3)))
|
||||
print(self.st.st)
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.st)
|
||||
def test_case_2(self):
|
||||
self.st.stride( (1, 1, -2) )
|
||||
self.st.reshape( (3, 6) )
|
||||
self.st.shrink( ((1, 2), (1, 5)) )
|
||||
self.st.stride( (1, -1) )
|
||||
def test_case_3(self):
|
||||
self.st.shrink( ((0, 2), (0, 2), (0, 1)) )
|
||||
self.st.permute( (1, 0, 2) )
|
||||
self.st.reshape( (4,) )
|
||||
self.st.shrink( ((0, 3),) )
|
||||
self.st.stride( (-1,) )
|
||||
def test_case_4(self):
|
||||
self.st.reshape( (3, 3, 3, 1) )
|
||||
self.st.pad( ((0, 0), (0, 0), (0, 0), (1, 1)) )
|
||||
self.st.shrink( ((0, 2), (1, 2), (0, 2), (0, 1)) )
|
||||
self.st.expand( (2, 1, 2, 3) )
|
||||
|
||||
class TestMaskedShapeTracker(unittest.TestCase):
|
||||
def test_pad_1x1(self):
|
||||
self.st = CheckingShapeTracker((1,1))
|
||||
self.st.pad(((1,1), (1,1)))
|
||||
self.st.assert_same()
|
||||
|
||||
def test_pad_2x2(self):
|
||||
self.st = CheckingShapeTracker((2,2))
|
||||
self.st.pad(((1,1), (1,1)))
|
||||
self.st.assert_same()
|
||||
|
||||
class TestShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.st = CheckingShapeTracker((7,4))
|
||||
self.apply = lambda fxn: [fxn(x) for x in [self.st]]
|
||||
|
||||
def tearDown(self):
|
||||
self.st.assert_same()
|
||||
|
||||
def test_noop(self):
|
||||
pass
|
||||
|
||||
def test_simple_split(self):
|
||||
self.test_permute()
|
||||
self.apply(lambda x: x.reshape((prod(self.st.shape), )))
|
||||
|
||||
def test_simple_pad(self):
|
||||
self.st.pad(((1,1), (1,1)))
|
||||
|
||||
def test_pad_shrink(self):
|
||||
self.st.pad(((1,1), (1,1)))
|
||||
self.st.shrink(((0,4), (0,4)))
|
||||
|
||||
def test_pad_one_sided(self):
|
||||
self.st.pad(((0,1), (0,0)))
|
||||
|
||||
def test_pad_reshape(self):
|
||||
self.st.pad(((0,1), (0,0)))
|
||||
self.st.reshape((8*4,))
|
||||
|
||||
def test_pad_pad(self):
|
||||
self.st.pad(((1,1), (1,1)))
|
||||
self.st.pad(((1,1), (1,1)))
|
||||
|
||||
def test_pad_permute(self):
|
||||
self.st.pad(((1,1), (2,2)))
|
||||
self.st.permute((1,0))
|
||||
|
||||
def test_pad_expand(self):
|
||||
self.st.reshape((7,4,1))
|
||||
self.st.pad(((1,1), (1,1), (0,0)))
|
||||
self.st.expand((9,6,4))
|
||||
|
||||
def test_pad_expand_alt(self):
|
||||
self.st.pad(((1,1), (1,1)))
|
||||
self.st.reshape((9,6,1))
|
||||
self.st.expand((9,6,4))
|
||||
|
||||
def test_pad_stride(self):
|
||||
self.st.pad(((1,4), (1,3)))
|
||||
self.st.stride((2,2))
|
||||
|
||||
def test_pad_stride_neg(self):
|
||||
self.st.pad(((1,2), (1,0)))
|
||||
self.st.stride((-1,-1))
|
||||
|
||||
def test_pad_stride_both(self):
|
||||
self.st.pad(((1,2), (1,0)))
|
||||
self.st.stride((-2,-2))
|
||||
|
||||
def test_shrink_pad(self):
|
||||
self.st.shrink(((0,4), (0,4)))
|
||||
self.st.pad(((1,1), (1,1)))
|
||||
|
||||
def test_reshape(self):
|
||||
new_shape = self.st.shape[::-1]
|
||||
self.apply(lambda x: x.reshape(new_shape))
|
||||
|
||||
def test_permute(self):
|
||||
if len(self.st.shape) == 2: self.apply(lambda x: x.permute((1,0)))
|
||||
elif len(self.st.shape) == 3: self.apply(lambda x: x.permute((2,0,1)))
|
||||
|
||||
def test_reshape_with_1(self):
|
||||
new_shape = (self.st.shape[0], 1, self.st.shape[1])
|
||||
self.apply(lambda x: x.reshape(new_shape))
|
||||
|
||||
def test_expand(self):
|
||||
self.test_reshape_with_1()
|
||||
new_shape = list(self.st.shape)
|
||||
new_shape[1] = 2
|
||||
self.apply(lambda x: x.expand(tuple(new_shape)))
|
||||
|
||||
def test_flip_0(self):
|
||||
self.apply(lambda x: x.flip((0,)))
|
||||
|
||||
def test_flip_1(self):
|
||||
self.apply(lambda x: x.flip((1,)))
|
||||
|
||||
def test_flip_01(self):
|
||||
self.apply(lambda x: x.flip((0,1)))
|
||||
|
||||
def test_slice_0(self):
|
||||
self.apply(lambda x: x.shrink(((1, x.shape[0]), (0, x.shape[1]))))
|
||||
|
||||
def test_slice_1(self):
|
||||
self.apply(lambda x: x.shrink(((0, x.shape[0]), (1, x.shape[1]))))
|
||||
|
||||
def test_slice_1c1(self):
|
||||
self.apply(lambda x: x.shrink(((0, 1), (0, 1))))
|
||||
|
||||
def test_slice_1c2(self):
|
||||
self.apply(lambda x: x.shrink(((1, 2), (1, 2))))
|
||||
|
||||
def test_double_permute(self):
|
||||
self.apply(lambda x: x.permute((1, 0)))
|
||||
self.apply(lambda x: x.permute((1, 0)))
|
||||
|
||||
def test_slice_permute(self):
|
||||
self.apply(lambda x: x.shrink(((0, 2), (2, 4))))
|
||||
self.apply(lambda x: x.permute((1, 0)))
|
||||
|
||||
def test_slice_expand(self):
|
||||
self.apply(lambda x: x.shrink(((0, 2), (3, 4))))
|
||||
self.apply(lambda x: x.expand((2, 10)))
|
||||
|
||||
def test_double_stride(self):
|
||||
self.apply(lambda x: x.stride((1, 2)))
|
||||
self.apply(lambda x: x.stride((2, 1)))
|
||||
|
||||
def test_stride(self): self.apply(lambda x: x.stride((2,1)))
|
||||
def test_stride_int(self): self.apply(lambda x: x.stride((1,2)))
|
||||
def test_stride_2(self): self.apply(lambda x: x.stride((2,2)))
|
||||
def test_stride_n(self): self.apply(lambda x: x.stride((-2,1)))
|
||||
def test_stride_int_n(self): self.apply(lambda x: x.stride((-1,2)))
|
||||
def test_stride_2_n(self): self.apply(lambda x: x.stride((-2,-2)))
|
||||
|
||||
def test_reshape_then_permute(self):
|
||||
self.test_reshape()
|
||||
self.test_permute()
|
||||
|
||||
def test_reshape_then_expand(self):
|
||||
self.test_reshape()
|
||||
self.test_expand()
|
||||
|
||||
def test_permute_then_reshape(self):
|
||||
self.test_permute()
|
||||
self.test_reshape()
|
||||
|
||||
def test_expand_then_reshape(self):
|
||||
self.test_expand()
|
||||
self.test_reshape()
|
||||
|
||||
def test_combo(self):
|
||||
self.test_permute()
|
||||
self.test_reshape()
|
||||
self.test_slice_1()
|
||||
self.test_expand()
|
||||
self.test_permute()
|
||||
|
||||
class TestGetContraction(unittest.TestCase):
|
||||
def test_contraction(self):
|
||||
r = get_contraction((1,2,3,4), (2,3,4))
|
||||
self.assertEqual(r, [[0, 1], [2], [3]])
|
||||
|
||||
r = get_contraction((2,1,3,4), (2,3,4))
|
||||
self.assertEqual(r, [[0], [1, 2], [3]])
|
||||
|
||||
r = get_contraction((1,2,3,1,4), (1,2,3,4))
|
||||
self.assertEqual(r, [[0], [1], [2], [3, 4]])
|
||||
|
||||
r = get_contraction((1,2,3,1,4,1,1), (2,3,4))
|
||||
self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]])
|
||||
|
||||
r = get_contraction((1,2,3,4), (1,2,3*4))
|
||||
self.assertEqual(r, [[0], [1], [2, 3]])
|
||||
|
||||
r = get_contraction((1,2,3,4), (2,1,3,4))
|
||||
self.assertEqual(r, [[0, 1], [], [2], [3]])
|
||||
|
||||
r = get_contraction((1,2,3,4), (1,1,2*3*4,1))
|
||||
self.assertEqual(r, [[0], [], [1,2,3], []])
|
||||
|
||||
r = get_contraction((2,1,3,4), (1,2,3,4))
|
||||
self.assertEqual(r, [[], [0], [1, 2], [3]])
|
||||
|
||||
r = get_contraction((1,2,3,4), (2*3*4,1,1,1))
|
||||
self.assertEqual(r, [[0, 1, 2, 3], [], [], []])
|
||||
|
||||
r = get_contraction((4,4,4,4), (16,1,16))
|
||||
self.assertEqual(r, [[0, 1], [], [2, 3]])
|
||||
|
||||
r = get_contraction((1,2,3,4,1,1,1), (2,3,4))
|
||||
self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]])
|
||||
|
||||
r = get_contraction((1,2,3,4), (1,2,3,4,1))
|
||||
self.assertEqual(r, [[0], [1], [2], [3], []])
|
||||
|
||||
r = get_contraction((14,1,384,14,1,1,1,1), (1,14,384,14))
|
||||
self.assertEqual(r, [[], [0], [1,2], [3,4,5,6,7]])
|
||||
|
||||
r = get_contraction((14,1,384,1,14,1,1,1,1), (1,14,384,14))
|
||||
self.assertEqual(r, [[], [0], [1,2], [3,4,5,6,7,8]])
|
||||
|
||||
r = get_contraction((512, 512), (1, 1, 512, 1, 1, 1, 1, 512))
|
||||
self.assertEqual(r, [[], [], [0], [], [], [], [], [1]])
|
||||
|
||||
r = get_contraction((1,2,3,4), (1,2,6,2))
|
||||
self.assertEqual(r, None)
|
||||
|
||||
def test_contraction_ones(self):
|
||||
r = get_contraction((1,), (1,1,1))
|
||||
self.assertEqual(r, [[0], [], []])
|
||||
|
||||
r = get_contraction((1,1), (1,1,1))
|
||||
self.assertEqual(r, [[0], [1], []])
|
||||
|
||||
r = get_contraction((1,1,1,1), (1,))
|
||||
self.assertEqual(r, [[0,1,2,3]])
|
||||
|
||||
r = get_contraction((1,1,1,1), (1,1))
|
||||
self.assertEqual(r, [[0], [1,2,3]])
|
||||
|
||||
r = get_contraction((1,1,1,1), (1,1,1))
|
||||
self.assertEqual(r, [[0], [1], [2,3]])
|
||||
|
||||
r = get_contraction((1,1,1,1), (1,1,1,1))
|
||||
self.assertEqual(r, [[0], [1], [2], [3]])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
39
tinygrad_repo/test/unit/test_shm_tensor.py
Normal file
39
tinygrad_repo/test/unit/test_shm_tensor.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import unittest
|
||||
import multiprocessing.shared_memory as shared_memory
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.runtime.ops_shm import RawShmBuffer
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
import numpy as np
|
||||
|
||||
class TestRawShmBuffer(unittest.TestCase):
|
||||
def test_e2e(self):
|
||||
t = Tensor.randn(2, 2, 2).realize()
|
||||
|
||||
# copy to shm
|
||||
shm_name = (s := shared_memory.SharedMemory(create=True, size=t.nbytes())).name
|
||||
s.close()
|
||||
t_shm = t.to(f"shm:{shm_name}").realize()
|
||||
|
||||
# copy from shm
|
||||
t2 = t_shm.to(Device.DEFAULT).realize()
|
||||
|
||||
assert np.allclose(t.numpy(), t2.numpy())
|
||||
s.unlink()
|
||||
|
||||
@unittest.skipIf(CI, "CI doesn't like big shared memory")
|
||||
def test_e2e_big(self):
|
||||
t = Tensor.randn(2048, 2048, 8).realize()
|
||||
|
||||
# copy to shm
|
||||
shm_name = (s := shared_memory.SharedMemory(create=True, size=t.nbytes())).name
|
||||
s.close()
|
||||
t_shm = t.to(f"shm:{shm_name}").realize()
|
||||
|
||||
# copy from shm
|
||||
t2 = t_shm.to(Device.DEFAULT).realize()
|
||||
|
||||
assert np.allclose(t.numpy(), t2.numpy())
|
||||
s.unlink()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
448
tinygrad_repo/test/unit/test_symbolic.py
Normal file
448
tinygrad_repo/test/unit/test_symbolic.py
Normal file
@@ -0,0 +1,448 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.shape.symbolic import Node, MulNode, SumNode, Variable, NumNode, LtNode, sym_render, sym_infer, create_rednode
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s):
|
||||
self.assertEqual(v.render(), s)
|
||||
self.assertEqual(v.min, n)
|
||||
self.assertEqual(v.max, m)
|
||||
|
||||
def test_ge(self):
|
||||
self.helper_test_variable(Variable("a", 3, 8)>=77, 0, 0, "0")
|
||||
self.helper_test_variable(Variable("a", 3, 8)>=9, 0, 0, "0")
|
||||
self.helper_test_variable(Variable("a", 3, 8)>=8, 0, 1, "((a*-1)<-7)")
|
||||
self.helper_test_variable(Variable("a", 3, 8)>=4, 0, 1, "((a*-1)<-3)")
|
||||
self.helper_test_variable(Variable("a", 3, 8)>=3, 1, 1, "1")
|
||||
self.helper_test_variable(Variable("a", 3, 8)>=2, 1, 1, "1")
|
||||
|
||||
def test_lt(self):
|
||||
self.helper_test_variable(Variable("a", 3, 8)<77, 1, 1, "1")
|
||||
self.helper_test_variable(Variable("a", 3, 8)<9, 1, 1, "1")
|
||||
self.helper_test_variable(Variable("a", 3, 8)<8, 0, 1, "(a<8)")
|
||||
self.helper_test_variable(Variable("a", 3, 8)<4, 0, 1, "(a<4)")
|
||||
self.helper_test_variable(Variable("a", 3, 8)<3, 0, 0, "0")
|
||||
self.helper_test_variable(Variable("a", 3, 8)<2, 0, 0, "0")
|
||||
|
||||
def test_ge_divides(self):
|
||||
expr = (Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512
|
||||
self.helper_test_variable(expr, 0, 1, "(idx<128)")
|
||||
|
||||
def test_ge_divides_and(self):
|
||||
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
|
||||
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
|
||||
self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))")
|
||||
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
|
||||
(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512])
|
||||
self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and ((idx1//4)<32))")
|
||||
|
||||
def test_lt_factors(self):
|
||||
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512])
|
||||
self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)")
|
||||
|
||||
def test_div_becomes_num(self):
|
||||
assert isinstance(Variable("a", 2, 3)//2, NumNode)
|
||||
|
||||
def test_var_becomes_num(self):
|
||||
assert isinstance(Variable("a", 2, 2), NumNode)
|
||||
|
||||
def test_equality(self):
|
||||
idx1 = Variable("idx1", 0, 3)
|
||||
idx2 = Variable("idx2", 0, 3)
|
||||
assert idx1 == idx1
|
||||
assert idx1 != idx2
|
||||
assert idx1*4 == idx1*4
|
||||
assert idx1*4 != idx1*3
|
||||
assert idx1*4 != idx1+4
|
||||
assert idx1*4 != idx2*4
|
||||
assert idx1+idx2 == idx1+idx2
|
||||
assert idx1+idx2 == idx2+idx1
|
||||
assert idx1+idx2 != idx2
|
||||
|
||||
def test_factorize(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a*2+a*3, 0, 8*5, "(a*5)")
|
||||
|
||||
def test_factorize_no_mul(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)")
|
||||
|
||||
def test_neg(self):
|
||||
self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)")
|
||||
|
||||
def test_add_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(1+a)")
|
||||
|
||||
def test_add_num_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)+Variable.num(1), 1, 9, "(1+a)")
|
||||
|
||||
def test_sub_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(-1+a)")
|
||||
|
||||
def test_sub_num_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)-Variable.num(1), -1, 7, "(-1+a)")
|
||||
|
||||
def test_mul_0(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)*0, 0, 0, "0")
|
||||
|
||||
def test_mul_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)*1, 0, 8, "a")
|
||||
|
||||
def test_mul_neg_1(self):
|
||||
self.helper_test_variable((Variable("a", 0, 2)*-1)//3, -1, 0, "((((a*-1)+3)//3)+-1)")
|
||||
|
||||
def test_mul_2(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)*2, 0, 16, "(a*2)")
|
||||
|
||||
def test_div_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)//1, 0, 8, "a")
|
||||
|
||||
def test_mod_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)%1, 0, 0, "0")
|
||||
|
||||
def test_add_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8) * 2 + 12, 12, 16+12, "((a*2)+12)")
|
||||
|
||||
def test_div_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)")
|
||||
|
||||
def test_div_neg_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 0, 7) // -2, -3, 0, "((a//2)*-1)")
|
||||
|
||||
def test_sum_div_min_max(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
|
||||
|
||||
def test_sum_div_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
|
||||
|
||||
def test_sum_div_some_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
|
||||
|
||||
def test_sum_div_some_partial_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
||||
self.helper_test_variable(Variable.sum([Variable.num(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
|
||||
|
||||
def test_sum_div_no_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
|
||||
|
||||
def test_mod_factor(self):
|
||||
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
|
||||
|
||||
def test_mod_to_sub(self):
|
||||
# This is mod reduction
|
||||
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, (Variable("a",1,2)-1).render())
|
||||
|
||||
def test_sum_div_const(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable.num(3)]) // 4, 0, 7, "a")
|
||||
|
||||
def test_sum_div_const_big(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable.num(3)]) // 16, 0, 1, "(a//4)")
|
||||
|
||||
def test_sum_lt_fold(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1, "(((a*4)+b)<16)")
|
||||
|
||||
def test_mod_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a")
|
||||
|
||||
def test_mod_mod(self):
|
||||
self.helper_test_variable((Variable("a", 0, 31)%12)%4, 0, 3, "(a%4)")
|
||||
self.helper_test_variable(((4*Variable("a", 0, 31)) % 12) % 4, 0, 0, "0")
|
||||
self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)")
|
||||
|
||||
def test_mul_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
|
||||
|
||||
def test_mul_lt(self):
|
||||
self.helper_test_variable((Variable("a", 0, 5)*4)<13, 0, 1, "(a<4)")
|
||||
self.helper_test_variable((Variable("a", 0, 5)*4)<16, 0, 1, "(a<4)")
|
||||
self.helper_test_variable((Variable("a", 0, 5)*4)>11, 0, 1, "((a*-1)<-2)")
|
||||
self.helper_test_variable((Variable("a", 0, 5)*4)>12, 0, 1, "((a*-1)<-3)")
|
||||
|
||||
def test_div_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
|
||||
|
||||
def test_distribute_mul(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
|
||||
|
||||
def test_mod_mul_sum(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)")
|
||||
|
||||
def test_sum_0(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)]), 0, 7, "a")
|
||||
|
||||
def test_mod_remove(self):
|
||||
self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
|
||||
|
||||
def test_big_mod(self):
|
||||
# NOTE: we no longer support negative variables
|
||||
#self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)")
|
||||
#self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)")
|
||||
#self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)")
|
||||
self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
|
||||
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
|
||||
|
||||
def test_gt_remove(self):
|
||||
self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "0")
|
||||
|
||||
def test_lt_remove(self):
|
||||
self.helper_test_variable(Variable("a", 0, 6) < -3, 0, 0, "0")
|
||||
self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)")
|
||||
self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "1")
|
||||
|
||||
def test_lt_sum_remove(self):
|
||||
self.helper_test_variable((Variable("a", 0, 6) + 2) < 3, 0, 1, "(a<1)")
|
||||
|
||||
def test_and_fold(self):
|
||||
self.helper_test_variable(Variable.ands([Variable.num(0), Variable("a", 0, 1)]), 0, 0, "0")
|
||||
|
||||
def test_and_remove(self):
|
||||
self.helper_test_variable(Variable.ands([Variable.num(1), Variable("a", 0, 1)]), 0, 1, "a")
|
||||
|
||||
def test_mod_factor_negative(self):
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
|
||||
|
||||
def test_sum_combine_num(self):
|
||||
self.helper_test_variable(Variable.sum([Variable.num(29), Variable("a", 0, 10), Variable.num(-23)]), 6, 16, "(6+a)")
|
||||
|
||||
def test_sum_num_hoisted_and_factors_cancel_out(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
|
||||
|
||||
def test_div_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
|
||||
|
||||
def test_mul_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
|
||||
|
||||
def test_mul_div_factor_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)")
|
||||
|
||||
def test_mul_div_factor_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
|
||||
|
||||
def test_div_remove(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
|
||||
|
||||
def test_div_numerator_negative(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")
|
||||
|
||||
def test_div_into_mod(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
# TODO: why are the negative tests broken? (even if we did support negative variables)
|
||||
#MIN, MAX = -10, 10
|
||||
MIN, MAX = 0, 10
|
||||
# one number
|
||||
for i in range(MIN, MAX):
|
||||
v = f(Variable.num(i))
|
||||
#print(i, f(i), v.min, v.max)
|
||||
self.assertEqual(v.min, v.max)
|
||||
self.assertEqual(v.min, f(i))
|
||||
for kmin in range(MIN, MAX):
|
||||
for kmax in range(MIN, MAX):
|
||||
if kmin > kmax: continue
|
||||
v = f(Variable("tmp", kmin, kmax))
|
||||
values = [f(rv) for rv in range(kmin, kmax+1)]
|
||||
# the min and max may not be exact
|
||||
self.assertLessEqual(v.min, min(values))
|
||||
self.assertGreaterEqual(v.max, max(values))
|
||||
|
||||
def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4))
|
||||
def test_div_4(self): self.helper_test_numeric(lambda x: (x//4))
|
||||
def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: (x+1)//2)
|
||||
def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: (x+1)%2)
|
||||
def test_times_2(self): self.helper_test_numeric(lambda x: x*2)
|
||||
def test_times_2_plus_3(self): self.helper_test_numeric(lambda x: x*2 + 3)
|
||||
def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)%4)
|
||||
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
|
||||
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
|
||||
|
||||
class TestSymbolicVars(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
z = NumNode(0)
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
assert z.vars() == z.vars() == []
|
||||
assert a.vars() == a.vars() == [a]
|
||||
m = MulNode(a, 3)
|
||||
assert m.vars() == [a]
|
||||
s = SumNode([a, b, c])
|
||||
assert s.vars() == [a, b, c]
|
||||
|
||||
def test_compound(self):
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
assert (a + b * c).vars() == [a, b, c]
|
||||
assert (a % 3 + b // 5).vars() == [a, b]
|
||||
assert (a + b + c - a).vars() == [b, c]
|
||||
|
||||
class TestSymbolicMinMax(unittest.TestCase):
|
||||
def test_min_max_known(self):
|
||||
a = Variable("a", 1, 8)
|
||||
assert max(1, a) == max(a, 1) == a
|
||||
assert min(1, a) == min(a, 1) == 1
|
||||
|
||||
class TestSymRender(unittest.TestCase):
|
||||
def test_sym_render(self):
|
||||
a = Variable("a", 1, 8)
|
||||
b = Variable("b", 1, 10)
|
||||
assert sym_render(a) == "a"
|
||||
assert sym_render(1) == "1"
|
||||
assert sym_render(a+1) == "(1+a)"
|
||||
assert sym_render(a*b) == "(a*b)"
|
||||
|
||||
class TestSymInfer(unittest.TestCase):
|
||||
def test_sym_infer(self):
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
var_vals = {a: 2, b: 3, c: 4}
|
||||
assert sym_infer(5, var_vals) == 5
|
||||
assert sym_infer(a, var_vals) == 2
|
||||
assert sym_infer(b, var_vals) == 3
|
||||
assert sym_infer(a+b, var_vals) == 5
|
||||
assert sym_infer(a-b, var_vals) == -1
|
||||
assert sym_infer(a+b+c, var_vals) == 9
|
||||
assert sym_infer(a*b, var_vals) == 6
|
||||
assert sym_infer(a*b+c, var_vals) == 10
|
||||
|
||||
class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
def test_node_divmod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
idx0 = Variable("idx0", 0, i*3-1)
|
||||
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(0) % (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) % (Variable("i", 1, 10)*128) == 127
|
||||
assert 127 // (Variable("i", 1, 10)*128) == 0
|
||||
assert 127 % (Variable("i", 1, 10)*128) == 127
|
||||
assert NumNode(128) // (Variable("i", 1, 10)*128 + 128) == 0
|
||||
assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
|
||||
assert 128 // (Variable("i", 1, 10)*128 + 128) == 0
|
||||
assert 128 % (Variable("i", 1, 10)*128 + 128) == 128
|
||||
assert 0 // (Variable("i", 1, 10)*128) == 0
|
||||
assert 0 % (Variable("i", 1, 10)*128) == 0
|
||||
assert idx0 // (i*3) == 0
|
||||
assert idx0 % (i*3) == idx0
|
||||
assert i // i == 1
|
||||
assert i % i == 0
|
||||
assert 128 // NumNode(4) == 32
|
||||
assert 128 % NumNode(4) == 0
|
||||
assert NumNode(128) // NumNode(4) == 32
|
||||
assert NumNode(128) % NumNode(4) == 0
|
||||
|
||||
def test_mulnode_divmod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
idx0 = Variable("idx0", 0, 31)
|
||||
assert (idx0*(i*4+4)) // (i+1) == (idx0*4)
|
||||
assert (idx0*(i*4+4)) % (i+1) == 0
|
||||
assert (idx0*i) % i == 0
|
||||
|
||||
def test_sumnode_divmod_sumnode(self):
|
||||
i = Variable("i", 1, 10)
|
||||
idx0 = Variable("idx0", 0, 7)
|
||||
idx1 = Variable("idx1", 0, 3)
|
||||
idx2 = Variable("idx2", 0, i)
|
||||
assert (idx0*(i*4+4)+idx1*(i+1)+idx2) // (i+1) == idx0*4+idx1
|
||||
assert (idx0*(i*4+4)+idx1*(i+1)+idx2) % (i+1) == idx2
|
||||
assert (i+1) // (i*128+128) == 0
|
||||
assert (i+1) % (i*128+128) == (i+1)
|
||||
assert (i+1+idx2) // (i+1) == 1
|
||||
assert (i+1+idx2) % (i+1) == idx2
|
||||
assert (idx0*(i*4+4)+i+1+idx2) // (i+1) == idx0*4+1
|
||||
assert (idx0*(i*4+4)+i+1+idx2) % (i+1) == idx2
|
||||
assert (i*128+128)*2 // (i*128+128) == 2
|
||||
assert (i*128+128)*2 % (i*128+128) == 0
|
||||
|
||||
def test_sumnode_divmod_sumnode_complex(self):
|
||||
i = Variable("i", 1, 1024)
|
||||
gidx0 = Variable("gidx0", 0, i)
|
||||
lidx1 = Variable("lidx1", 0, 7)
|
||||
ridx2 = Variable("ridx1", 0, 31)
|
||||
assert ((i*128+128)*2 + gidx0*128 + lidx1*(i*512+512) + ridx2*4) // (i*128+128) == 2 + lidx1*4
|
||||
assert ((i*128+128)*2 + gidx0*128 + lidx1*(i*512+512) + ridx2*4) % (i*128+128) == gidx0*128 + ridx2*4
|
||||
assert ((gidx0*128+i*128+ridx2*4+129)) // (i*128+128) == 1
|
||||
assert ((gidx0*128+i*128+ridx2*4+129)) % (i*128+128) == gidx0*128 + ridx2*4 + 1
|
||||
assert (ridx2*(i*4+4)+1+i+gidx0) // (i*128+128) == 0
|
||||
assert (ridx2*(i*4+4)+1+i+gidx0) % (i*128+128) == (ridx2*(i*4+4)+1+i+gidx0)
|
||||
|
||||
def test_node_lt_node(self):
|
||||
a = Variable("a", 1, 5)
|
||||
b = Variable("b", 6, 9)
|
||||
c = Variable("c", 1, 10)
|
||||
d = Variable("d", 5, 10)
|
||||
# if the value is always the same, it folds to num
|
||||
assert (a < b) == 1
|
||||
assert (b < a) == 0
|
||||
assert (d < a) == 0
|
||||
# if it remains as a LtNode, bool is always true and (min, max) == (0, 1)
|
||||
assert isinstance((a < c), LtNode) and (a < c).min == 0 and (a < c).max == 1
|
||||
assert a < c
|
||||
assert isinstance((a > c), LtNode) and (a > c).min == 0 and (a > c).max == 1
|
||||
# same when comparing with a constant
|
||||
assert a < 3 and (a < 3).min == 0 and (a < 3).max == 1
|
||||
assert a > 3 and (a > 3).min == 0 and (a > 3).max == 1
|
||||
|
||||
def test_num_node_mul_node(self):
|
||||
a = Variable("a", 1, 5)
|
||||
b = NumNode(2) * a
|
||||
assert b == a * 2
|
||||
assert isinstance(b, MulNode)
|
||||
b = NumNode(1) * a
|
||||
assert b == a
|
||||
assert isinstance(b, Variable)
|
||||
b = NumNode(0) * a
|
||||
assert b == 0
|
||||
assert isinstance(b, NumNode)
|
||||
|
||||
def test_num_node_expand(self):
|
||||
a = NumNode(42)
|
||||
assert a.expand() == [a]
|
||||
|
||||
def test_variable_expand(self):
|
||||
a = Variable("a", 5, 7)
|
||||
assert a.expand() == [a]
|
||||
|
||||
def test_variable_expand_expr_none(self):
|
||||
a = Variable(None, 5, 7)
|
||||
assert a.expand() == [NumNode(5), NumNode(6), NumNode(7)]
|
||||
|
||||
def test_mul_node_expand(self):
|
||||
a = Variable(None, 5, 7)
|
||||
m = MulNode(a, 3)
|
||||
assert m.expand() == [NumNode(15), NumNode(18), NumNode(21)]
|
||||
|
||||
b = Variable("b", 1, 3)
|
||||
n = MulNode(b, 3)
|
||||
assert n.expand() == [Variable("b", 1, 3)*3]
|
||||
|
||||
def test_sum_node_expand(self):
|
||||
a = Variable(None, 1, 3)
|
||||
b = Variable("b", 5, 7)
|
||||
|
||||
s1 = create_rednode(SumNode, [a, b])
|
||||
assert s1.expand() == [Variable.sum([NumNode(i),b]) for i in range(1,4)]
|
||||
|
||||
def test_multi_expand(self):
|
||||
a = Variable("a", 1, 3)
|
||||
b = Variable("b", 14, 17)
|
||||
s1 = create_rednode(SumNode, [a, b])
|
||||
# expand increments earlier variables faster than later variables (as specified in the argument)
|
||||
# this behavior was just copied from before, no idea why this should be true
|
||||
assert s1.expand((a, b)) == [NumNode(x + y) for x in range(b.min, b.max + 1) for y in range(a.min, a.max + 1)]
|
||||
|
||||
def test_substitute(self):
|
||||
a = Variable(None, 1, 3)
|
||||
b = a + 1
|
||||
c = b.substitute({a: NumNode(1)})
|
||||
assert c == NumNode(2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user