openpilot v0.9.6 release
date: 2024-01-12T10:13:37 master commit: ba792d576a49a0899b88a753fa1c52956bedf9e6
This commit is contained in:
221
tinygrad_repo/tinygrad/shape/shapetracker.py
Normal file
221
tinygrad_repo/tinygrad/shape/shapetracker.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import functools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, cast
|
||||
from tinygrad.ops import MovementOps
|
||||
from tinygrad.helpers import prod, DEBUG, dedup
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
|
||||
assert len(shape) == len(strides)
|
||||
ret = [(shape[0], strides[0])] if shape else []
|
||||
for i in range(1, len(shape)):
|
||||
if ret[-1][1] == shape[i]*strides[i] or ret[-1][0] == 1:
|
||||
ret[-1] = (ret[-1][0] * shape[i], strides[i])
|
||||
elif shape[i] == 1:
|
||||
continue
|
||||
else:
|
||||
ret.append((shape[i], strides[i]))
|
||||
return tuple(ret)
|
||||
|
||||
def expr_node_mask(view:View, idx, valid=None) -> Node:
|
||||
expr = [valid] if valid is not None else []
|
||||
if view.mask is not None:
|
||||
acc = 1
|
||||
for ns,(x,y) in reversed(list(zip(view.shape, view.mask))):
|
||||
if x != 0 or y != ns:
|
||||
base = ((idx//acc) % ns)
|
||||
expr += [base >= x, base < y]
|
||||
acc *= ns
|
||||
return Variable.ands(expr)
|
||||
|
||||
# generate an expression if you have a single idx variable
|
||||
def expr_node(view:View, idx=None) -> Node:
|
||||
if idx is None: idx = Variable('idx', 0, prod(view.shape)-1)
|
||||
ret: List[Node] = [Variable.num(view.offset) if isinstance(view.offset, int) else view.offset] if view.offset else []
|
||||
acc = 1
|
||||
for d,s in reversed(to_shape_strides(view.shape, view.strides)):
|
||||
ret.append(((idx//acc)%d)*s)
|
||||
acc *= d
|
||||
return Variable.sum(ret)
|
||||
|
||||
# generate an expression if you have a variable or expression for each index
|
||||
def expr_idxs(view:View, idxs) -> Node:
|
||||
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
|
||||
return Variable.sum([Variable.num(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0])
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
if vm2.mask: return None # this isn't supported yet
|
||||
mst = ShapeTracker((vm2, vm1))
|
||||
strides = mst.real_strides()
|
||||
if None in strides: return None
|
||||
return View.create(vm1.shape, cast(Tuple[sint, ...], strides), mst.real_offset(), vm1.mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
|
||||
assert len(idxs) == len(shape), "need an idx for all dimensions"
|
||||
acc = 1
|
||||
ret = []
|
||||
for tidx,d in reversed(list(zip(idxs, shape))):
|
||||
ret.append(tidx * acc)
|
||||
acc *= d
|
||||
return Variable.sum(ret)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
views: Tuple[View, ...]
|
||||
def __post_init__(self): assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views"
|
||||
|
||||
@staticmethod
|
||||
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
|
||||
|
||||
@property
|
||||
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
|
||||
|
||||
# this is the real size (ish)
|
||||
def size(self): return self.views[-1].size()
|
||||
|
||||
def vars(self) -> List[Variable]: return dedup(functools.reduce(operator.add, [v.vars() for v in self.views], []))
|
||||
|
||||
@property
|
||||
def var_vals(self) -> Dict[Variable, int]:
|
||||
ret:Dict[Variable, int] = {}
|
||||
for v in self.vars():
|
||||
var, val = v.unbind()
|
||||
assert var not in ret or ret[var] == val, f"{var} has conflicted values {val} and {ret[var]}"
|
||||
ret[var] = val
|
||||
return ret
|
||||
|
||||
def unbind(self) -> ShapeTracker: return ShapeTracker(tuple(v.unbind() for v in self.views))
|
||||
|
||||
def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
|
||||
to_apply:List[Tuple[MovementOps, Tuple]] = []
|
||||
for v in self.views:
|
||||
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
|
||||
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
|
||||
# first, we apply the offset
|
||||
# then, we make it the correct shape
|
||||
# then, we apply permutations
|
||||
# TODO: don't use as_strided
|
||||
to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset)))
|
||||
# then, we apply pre expand pads
|
||||
if v.mask is not None:
|
||||
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
||||
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
||||
if any(x != (0,0) for x in pre_expand_pads):
|
||||
to_apply.append((MovementOps.PAD, pre_expand_pads))
|
||||
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
|
||||
# then, we do any expands
|
||||
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
|
||||
# lastly, we apply post expand pads
|
||||
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
|
||||
return to_apply
|
||||
|
||||
# these are multiview strides, value is None if it's not a simple strided dimension
|
||||
# TODO: this can be shared code between simplify and merge_views
|
||||
def real_offset(self) -> sint:
|
||||
real_offset, _ = self.expr_node(Variable('zero', 0, 0))
|
||||
return real_offset.b if isinstance(real_offset, NumNode) else real_offset
|
||||
|
||||
# NOTE: if a stride is not always valid, it will be None
|
||||
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
|
||||
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
||||
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
||||
idx, valid = self.expr_idxs(idxs)
|
||||
ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
|
||||
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
||||
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
|
||||
ret[idxs.index(this_dim.a)] = this_dim.b
|
||||
elif isinstance(this_dim, Variable) and this_dim in idxs:
|
||||
ret[idxs.index(this_dim)] = 1
|
||||
idx_vars, valid_vars = idx.vars(), valid.vars()
|
||||
for i,tidx in enumerate(idxs):
|
||||
if tidx in valid_vars and not ignore_valid: ret[i] = None
|
||||
elif tidx not in idx_vars: ret[i] = 0
|
||||
return tuple(ret)
|
||||
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
def _expr_idx(self, idx, valid) -> Tuple[Node, Node]:
|
||||
for v in reversed(self.views[0:-1]):
|
||||
if valid.max == 0: return Variable.num(-1), valid
|
||||
valid = expr_node_mask(v, idx, valid)
|
||||
idx = expr_node(v, idx)
|
||||
return idx, valid
|
||||
|
||||
def simplify(self) -> ShapeTracker:
|
||||
if len(self.views) >= 2:
|
||||
new_view = merge_views(self.views[-2], self.views[-1])
|
||||
if new_view:
|
||||
if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
|
||||
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
|
||||
return self
|
||||
|
||||
def expr_idxs(self, idxs=None):
|
||||
if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
||||
idx = expr_idxs(self.views[-1], tuple(idxs))
|
||||
valid = expr_node_mask(self.views[-1], idxs_to_idx(self.views[-1].shape, tuple(idxs)))
|
||||
return self._expr_idx(idx, valid)
|
||||
|
||||
def expr_node(self, idx='idx'):
|
||||
if idx.__class__ is str: idx = Variable(idx, 0, prod(self.shape)-1)
|
||||
return self._expr_idx(expr_node(self.views[-1], idx), expr_node_mask(self.views[-1], idx))
|
||||
|
||||
def axis_is_masked(self, axis) -> bool:
|
||||
_, valid = self.expr_idxs()
|
||||
return f'idx{axis}' in [v.expr for v in valid.vars()]
|
||||
|
||||
# *** under this line are the movement ops ***
|
||||
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
|
||||
|
||||
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
|
||||
|
||||
def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
|
||||
|
||||
def permute(self, axis: Tuple[int, ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
|
||||
|
||||
def stride(self, mul: Tuple[int, ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
|
||||
|
||||
def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
|
||||
new_view = self.views[-1].reshape(new_shape)
|
||||
if new_view is None:
|
||||
extra_view = View.create(new_shape)
|
||||
# last chance to merge. TODO: move into View
|
||||
if (merged_view := merge_views(self.views[-1], extra_view)) is not None:
|
||||
return ShapeTracker(self.views[0:-1] + (merged_view,))
|
||||
return ShapeTracker(self.views + (extra_view, ))
|
||||
return ShapeTracker(self.views[0:-1] + (new_view,))
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
# TODO: if we remove movementops from lazy.py we can delete this
|
||||
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
||||
# Pre-allocate all groups.
|
||||
axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
|
||||
# Index for new_shape and axis_groups.
|
||||
i: int = 0
|
||||
old_shape_i: int = 0
|
||||
while old_shape_i < len(old_shape):
|
||||
# 1s exist in new_shape only will lead to empty axes group creations.
|
||||
if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
|
||||
if i < len(new_shape) - 1: i += 1
|
||||
else:
|
||||
axis_groups[i].append(old_shape_i)
|
||||
axis_group_size = prod([old_shape[x] for x in axis_groups[i]])
|
||||
# Move to next axes group if total size of all dimensions match.
|
||||
if axis_group_size == new_shape[i]:
|
||||
if i < len(new_shape) - 1: i += 1
|
||||
elif axis_group_size > new_shape[i]: return None
|
||||
old_shape_i += 1
|
||||
return axis_groups
|
||||
352
tinygrad_repo/tinygrad/shape/symbolic.py
Normal file
352
tinygrad_repo/tinygrad/shape/symbolic.py
Normal file
@@ -0,0 +1,352 @@
|
||||
from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
import functools
|
||||
from math import gcd
|
||||
from itertools import product
|
||||
from tinygrad.helpers import partition
|
||||
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator
|
||||
|
||||
# NOTE: Python has different behavior for negative mod and floor div than c
|
||||
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
||||
|
||||
def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
|
||||
|
||||
class Node:
|
||||
b: Union[Node, int]
|
||||
min: int
|
||||
max: int
|
||||
def render(self, ops=None, ctx=None) -> Any:
|
||||
if ops is None: ops = render_python
|
||||
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
def vars(self): return []
|
||||
|
||||
def expand_idx(self) -> VariableOrNum: return next((v for v in self.vars() if v.expr is None), NumNode(0))
|
||||
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
|
||||
# expand increments earlier variables faster than later variables (as specified in the argument)
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def expand(self, idxs:Optional[Tuple[VariableOrNum, ...]]=None) -> List[Node]:
|
||||
if idxs is None: idxs = (self.expand_idx(),)
|
||||
return [self.substitute(dict(zip(idxs, (NumNode(x) for x in rep)))) for rep in Node.iter_idxs(idxs)]
|
||||
@staticmethod
|
||||
def iter_idxs(idxs:Tuple[VariableOrNum, ...]) -> Iterator[Tuple[int,...]]:
|
||||
yield from (x[::-1] for x in product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
|
||||
# substitute Variables with the values in var_vals
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__)
|
||||
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
|
||||
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
@functools.cached_property
|
||||
def hash(self) -> int: return hash(self.key)
|
||||
def __repr__(self): return self.render(ctx="REPR")
|
||||
def __str__(self): return "<"+self.key+">"
|
||||
def __hash__(self): return self.hash
|
||||
def __bool__(self): return not (self.max == self.min == 0)
|
||||
def __eq__(self, other:object) -> bool:
|
||||
if not isinstance(other, Node): return NotImplemented
|
||||
return self.key == other.key
|
||||
def __neg__(self): return self*-1
|
||||
def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
|
||||
def __radd__(self, b:int): return self+b
|
||||
def __sub__(self, b:Union[Node,int]): return self+-b
|
||||
def __rsub__(self, b:int): return -self+b
|
||||
def __le__(self, b:Union[Node,int]): return self < (b+1)
|
||||
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
|
||||
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
|
||||
def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b))
|
||||
def __mul__(self, b:Union[Node, int]):
|
||||
if b == 0: return NumNode(0)
|
||||
if b == 1: return self
|
||||
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
|
||||
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
|
||||
def __rmul__(self, b:int): return self*b
|
||||
|
||||
# *** complex ops ***
|
||||
|
||||
def __rfloordiv__(self, b:int):
|
||||
if self.min > b >= 0: return NumNode(0)
|
||||
if isinstance(self, NumNode): return NumNode(b // self.b)
|
||||
raise RuntimeError(f"not supported: {b} // {self}")
|
||||
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
|
||||
if isinstance(b, Node):
|
||||
if b.__class__ is NumNode: return self // b.b
|
||||
if self == b: return NumNode(1)
|
||||
if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
|
||||
raise RuntimeError(f"not supported: {self} // {b}")
|
||||
assert b != 0
|
||||
if b < 0: return (self//-b)*-1
|
||||
if b == 1: return self
|
||||
|
||||
# the numerator of div is not allowed to be negative
|
||||
if self.min < 0:
|
||||
offset = self.min//b
|
||||
# factor out an "offset" to make the numerator positive. don't allowing factoring again
|
||||
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
|
||||
return create_node(DivNode(self, b))
|
||||
|
||||
def __rmod__(self, b:int):
|
||||
if self.min > b >= 0: return NumNode(b)
|
||||
if isinstance(self, NumNode): return NumNode(b % self.b)
|
||||
raise RuntimeError(f"not supported: {b} % {self}")
|
||||
def __mod__(self, b:Union[Node,int]):
|
||||
if isinstance(b, Node):
|
||||
if b.__class__ is NumNode: return self % b.b
|
||||
if self == b: return NumNode(0)
|
||||
if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node
|
||||
raise RuntimeError(f"not supported: {self} % {b}")
|
||||
assert b > 0
|
||||
if b == 1: return NumNode(0)
|
||||
if self.min >= 0 and self.max < b: return self
|
||||
if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
|
||||
if self.min < 0: return (self - ((self.min//b)*b)) % b
|
||||
return create_node(ModNode(self, b))
|
||||
|
||||
@staticmethod
|
||||
def num(num:int) -> NumNode: return NumNode(num)
|
||||
|
||||
@staticmethod
|
||||
def factorize(nodes:List[Node]) -> List[Node]:
|
||||
mul_groups: Dict[Node, int] = {}
|
||||
for x in nodes:
|
||||
a,b = (x.a,x.b) if isinstance(x, MulNode) else (x,1)
|
||||
mul_groups[a] = mul_groups.get(a, 0) + b
|
||||
return [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
|
||||
|
||||
@staticmethod
|
||||
def sum(nodes:List[Node]) -> Node:
|
||||
nodes = [x for x in nodes if x.max or x.min]
|
||||
if not nodes: return NumNode(0)
|
||||
if len(nodes) == 1: return nodes[0]
|
||||
|
||||
new_nodes: List[Node] = []
|
||||
num_node_sum = 0
|
||||
for node in SumNode(nodes).flat_components:
|
||||
if node.__class__ is NumNode: num_node_sum += node.b
|
||||
else: new_nodes.append(node)
|
||||
|
||||
if len(new_nodes) > 1 and len(set([x.a if isinstance(x, MulNode) else x for x in new_nodes])) < len(new_nodes):
|
||||
new_nodes = Node.factorize(new_nodes)
|
||||
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
|
||||
return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
|
||||
|
||||
@staticmethod
|
||||
def ands(nodes:List[Node]) -> Node:
|
||||
if not nodes: return NumNode(1)
|
||||
if len(nodes) == 1: return nodes[0]
|
||||
if any(not x for x in nodes): return NumNode(0)
|
||||
|
||||
# filter 1s
|
||||
nodes = [x for x in nodes if x.min != x.max]
|
||||
return create_rednode(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
|
||||
|
||||
# 4 basic node types
|
||||
|
||||
class Variable(Node):
|
||||
def __new__(cls, expr:Optional[str], nmin:int, nmax:int):
|
||||
assert nmin >= 0 and nmin <= nmax
|
||||
if nmin == nmax: return NumNode(nmin)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
self.val:Optional[int] = None
|
||||
def bind(self, val):
|
||||
assert self.val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
|
||||
self.val = val
|
||||
return self
|
||||
def unbind(self) -> Tuple[Variable, int]:
|
||||
assert self.val is not None, f"cannot unbind {self}"
|
||||
return Variable(self.expr, self.min, self.max), self.val
|
||||
def vars(self): return [self]
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return var_vals[self] if self in var_vals else self
|
||||
|
||||
class NumNode(Node):
|
||||
def __init__(self, num:int):
|
||||
assert isinstance(num, int), f"{num} is not an int"
|
||||
self.b:int = num
|
||||
self.min, self.max = num, num
|
||||
def bind(self, val):
|
||||
assert self.b == val, f"cannot bind {val} to {self}"
|
||||
return self
|
||||
def __eq__(self, other): return self.b == other
|
||||
def __hash__(self): return self.hash # needed with __eq__ override
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self
|
||||
|
||||
def create_node(ret:Node):
|
||||
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
|
||||
if ret.min == ret.max: return NumNode(ret.min)
|
||||
return ret
|
||||
|
||||
class OpNode(Node):
|
||||
def __init__(self, a:Node, b:Union[Node, int]):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = self.get_bounds()
|
||||
def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else [])
|
||||
@abstractmethod
|
||||
def get_bounds(self) -> Tuple[int, int]: pass
|
||||
|
||||
class LtNode(OpNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
if isinstance(self.b, int):
|
||||
return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
|
||||
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
|
||||
class MulNode(OpNode):
|
||||
def __lt__(self, b: Union[Node, int]):
|
||||
if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1: return Node.__lt__(self, b)
|
||||
sgn = 1 if self.b > 0 else -1
|
||||
return Node.__lt__(self.a*sgn, (b + abs(self.b) - 1)//abs(self.b))
|
||||
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
|
||||
def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
|
||||
if self.b % b == 0: return self.a*(self.b//b)
|
||||
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
a = (self.a * (self.b%b))
|
||||
return Node.__mod__(a, b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
|
||||
class DivNode(OpNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
assert self.a.min >= 0 and isinstance(self.b, int)
|
||||
return self.a.min//self.b, self.a.max//self.b
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) // self.b
|
||||
|
||||
class ModNode(OpNode):
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
if isinstance(b, Node) or isinstance(self.b, Node): return Node.__mod__(self, b)
|
||||
return self.a % b if gcd(self.b, b) == b else Node.__mod__(self, b)
|
||||
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
|
||||
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
assert self.a.min >= 0 and isinstance(self.b, int)
|
||||
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b)
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) % self.b
|
||||
|
||||
class RedNode(Node):
|
||||
def __init__(self, nodes:List[Node]): self.nodes = nodes
|
||||
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
|
||||
|
||||
class SumNode(RedNode):
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
|
||||
fully_divided: List[Node] = []
|
||||
rest: List[Node] = []
|
||||
if isinstance(b, SumNode):
|
||||
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
||||
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
||||
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
|
||||
if isinstance(b, Node):
|
||||
for x in self.flat_components:
|
||||
if x % b == 0: fully_divided.append(x // b)
|
||||
else: rest.append(x)
|
||||
if (sum_fully_divided:=create_rednode(SumNode, fully_divided)) != 0: return sum_fully_divided + create_rednode(SumNode, rest) // b
|
||||
return Node.__floordiv__(self, b, False)
|
||||
if b == 1: return self
|
||||
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
||||
fully_divided, rest = [], []
|
||||
_gcd = b
|
||||
divisor = 1
|
||||
for x in self.flat_components:
|
||||
if x.__class__ in (NumNode, MulNode):
|
||||
if x.b%b == 0: fully_divided.append(x//b)
|
||||
else:
|
||||
rest.append(x)
|
||||
_gcd = gcd(_gcd, x.b)
|
||||
if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
|
||||
else:
|
||||
rest.append(x)
|
||||
_gcd = 1
|
||||
if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd)
|
||||
if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
|
||||
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
if isinstance(b, SumNode):
|
||||
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
||||
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
||||
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
|
||||
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
|
||||
new_nodes: List[Node] = []
|
||||
for x in self.nodes:
|
||||
if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b))
|
||||
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
|
||||
else: new_nodes.append(x)
|
||||
return Node.__mod__(Node.sum(new_nodes), b)
|
||||
|
||||
def __lt__(self, b:Union[Node,int]):
|
||||
lhs: Node = self
|
||||
if isinstance(b, int):
|
||||
new_sum = []
|
||||
for x in self.nodes:
|
||||
# TODO: should we just force the last one to always be the number
|
||||
if isinstance(x, NumNode): b -= x.b
|
||||
else: new_sum.append(x)
|
||||
lhs = Node.sum(new_sum)
|
||||
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
|
||||
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
||||
if muls:
|
||||
# NOTE: gcd in python 3.8 takes exactly 2 args
|
||||
mul_gcd = b
|
||||
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell x.b is int here
|
||||
all_others = Variable.sum(others)
|
||||
if all_others.min >= 0 and all_others.max < mul_gcd:
|
||||
lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
||||
return Node.__lt__(lhs, b)
|
||||
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Variable.sum([node.substitute(var_vals) for node in self.nodes])
|
||||
|
||||
@property
|
||||
def flat_components(self): # recursively expand sumnode components
|
||||
new_nodes = []
|
||||
for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x])
|
||||
return new_nodes
|
||||
|
||||
class AndNode(RedNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=True): return Variable.ands([x//b for x in self.nodes])
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
|
||||
subed = []
|
||||
for node in self.nodes:
|
||||
if not (sub:=node.substitute(var_vals)): return NumNode(0)
|
||||
subed.append(sub)
|
||||
return Variable.ands(subed)
|
||||
|
||||
def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
||||
ret = typ(nodes)
|
||||
if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes]))
|
||||
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
|
||||
return create_node(ret)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
|
||||
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
||||
def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
|
||||
if isinstance(a, (int, float)): return a
|
||||
ret = a.substitute({k:Variable.num(v) for k, v in var_vals.items()})
|
||||
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
|
||||
return ret.b
|
||||
|
||||
# symbolic int
|
||||
sint = Union[Node, int]
|
||||
VariableOrNum = Union[Variable, NumNode]
|
||||
|
||||
render_python: Dict[Type, Callable] = {
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self.val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"),
|
||||
NumNode: lambda self,ops,ctx: f"{self.b}",
|
||||
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
|
||||
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
||||
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
|
||||
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
|
||||
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
||||
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
|
||||
}
|
||||
132
tinygrad_repo/tinygrad/shape/view.py
Normal file
132
tinygrad_repo/tinygrad/shape/view.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
import functools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, cast
|
||||
from tinygrad.helpers import prod, all_int, dedup
|
||||
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, is_sym_int, sint
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
strides = [1] if shape else []
|
||||
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
|
||||
return filter_strides(shape, tuple(strides))
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
shape:Tuple[sint, ...]
|
||||
strides:Tuple[sint, ...]
|
||||
offset:sint
|
||||
mask:Optional[Tuple[Tuple[sint, sint], ...]]
|
||||
contiguous:bool
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
|
||||
strides = filter_strides(shape, strides) if strides else strides_for_shape(shape)
|
||||
contiguous = offset == 0 and mask is None and all(s1 == s2 for s1,s2 in zip(strides, strides_for_shape(shape)))
|
||||
return View(shape, strides, offset, mask, contiguous)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def size(self): return prod([s.max if isinstance(s, Node) else s for s,st in zip(self.shape, self.strides) if st != 0])
|
||||
|
||||
def vars(self) -> List[Variable]:
|
||||
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
|
||||
return dedup(functools.reduce(operator.add, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], []))
|
||||
|
||||
def unbind(self) -> View:
|
||||
unbound_vars:Dict[VariableOrNum,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None}
|
||||
new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
|
||||
new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
|
||||
new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
|
||||
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
|
||||
return View.create(new_shape, new_strides, new_offset, new_mask)
|
||||
|
||||
# MovementOps live here now
|
||||
|
||||
def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
|
||||
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
|
||||
if self.mask:
|
||||
# move the old mask
|
||||
nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.mask, arg)])
|
||||
# merge the masks if we have two
|
||||
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
||||
shape = [y-x for x,y in arg]
|
||||
return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View:
|
||||
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
||||
if any(b or e for b, e in arg):
|
||||
zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
|
||||
mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
|
||||
return self.__unsafe_resize(zvarg, mask=mask)
|
||||
return self
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
|
||||
assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape)
|
||||
return self.__unsafe_resize(arg)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def expand(self, new_shape: Tuple[sint, ...]) -> View:
|
||||
assert len(new_shape) == len(self.shape)
|
||||
assert all(is_sym_int(x) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
|
||||
# NOTE: can the mask ever be (0,0)?
|
||||
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
|
||||
return View.create(new_shape, self.strides, self.offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def permute(self, axis: Tuple[int, ...]) -> View:
|
||||
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
|
||||
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
|
||||
return View.create(tuple([self.shape[a] for a in axis]), tuple([self.strides[a] for a in axis]), self.offset, tuple([self.mask[a] for a in axis]) if self.mask is not None else None)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def stride(self, mul: Tuple[int, ...]) -> View:
|
||||
# except for the negative case, you can build this from the others. invertible in the negative case
|
||||
assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
|
||||
strides = tuple([z*m for z,m in zip(self.strides, mul)])
|
||||
new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
|
||||
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
|
||||
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
|
||||
return View.create(new_shape, strides, self.offset + offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
|
||||
if self.shape == new_shape: return self
|
||||
|
||||
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
|
||||
# check for the same size
|
||||
if all_int(self.shape):
|
||||
if all_int(new_shape):
|
||||
assert prod(self.shape) == prod(new_shape), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
|
||||
else:
|
||||
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||
assert prod(self.shape) == prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
|
||||
|
||||
# after the asserts, it's okay to check contiguous
|
||||
if self.contiguous: return View.create(new_shape)
|
||||
|
||||
# check if this is adding or removing 1s (only)
|
||||
# NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
|
||||
if [x for x in self.shape if x != 1] == [x for x in new_shape if x != 1]:
|
||||
new_strides: List[sint] = [y for x,y in zip(self.shape, self.strides) if x != 1]
|
||||
new_strides_tuple: Tuple[sint, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape])
|
||||
new_mask_tuple: Optional[Tuple[Tuple[sint, sint], ...]] = None
|
||||
if self.mask:
|
||||
for x,y in zip(self.shape, self.mask):
|
||||
if x == 1 and y != (0, 1):
|
||||
new_mask_tuple = ((0,0),) * len(new_shape)
|
||||
break
|
||||
else:
|
||||
new_mask: List[Tuple[sint, sint]] = [y for x,y in zip(self.shape, self.mask) if x != 1]
|
||||
new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape])
|
||||
return View.create(new_shape, new_strides_tuple, self.offset, new_mask_tuple)
|
||||
|
||||
# TODO: bring the merge_views logic here for more caching
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user