openpilot v0.9.6 release

date: 2024-01-12T10:13:37
master commit: ba792d576a49a0899b88a753fa1c52956bedf9e6
This commit is contained in:
FrogAi
2024-01-12 22:39:28 -07:00
commit 08e9fb1edc
1881 changed files with 653708 additions and 0 deletions

View File

@@ -0,0 +1,221 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import functools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, cast
from tinygrad.ops import MovementOps
from tinygrad.helpers import prod, DEBUG, dedup
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
from tinygrad.shape.view import View
@functools.lru_cache(maxsize=None)
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
assert len(shape) == len(strides)
ret = [(shape[0], strides[0])] if shape else []
for i in range(1, len(shape)):
if ret[-1][1] == shape[i]*strides[i] or ret[-1][0] == 1:
ret[-1] = (ret[-1][0] * shape[i], strides[i])
elif shape[i] == 1:
continue
else:
ret.append((shape[i], strides[i]))
return tuple(ret)
def expr_node_mask(view:View, idx, valid=None) -> Node:
expr = [valid] if valid is not None else []
if view.mask is not None:
acc = 1
for ns,(x,y) in reversed(list(zip(view.shape, view.mask))):
if x != 0 or y != ns:
base = ((idx//acc) % ns)
expr += [base >= x, base < y]
acc *= ns
return Variable.ands(expr)
# generate an expression if you have a single idx variable
def expr_node(view:View, idx=None) -> Node:
if idx is None: idx = Variable('idx', 0, prod(view.shape)-1)
ret: List[Node] = [Variable.num(view.offset) if isinstance(view.offset, int) else view.offset] if view.offset else []
acc = 1
for d,s in reversed(to_shape_strides(view.shape, view.strides)):
ret.append(((idx//acc)%d)*s)
acc *= d
return Variable.sum(ret)
# generate an expression if you have a variable or expression for each index
def expr_idxs(view:View, idxs) -> Node:
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
return Variable.sum([Variable.num(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0])
@functools.lru_cache(maxsize=None)
def merge_views(vm2:View, vm1:View) -> Optional[View]:
if vm2.mask: return None # this isn't supported yet
mst = ShapeTracker((vm2, vm1))
strides = mst.real_strides()
if None in strides: return None
return View.create(vm1.shape, cast(Tuple[sint, ...], strides), mst.real_offset(), vm1.mask)
@functools.lru_cache(maxsize=None)
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
assert len(idxs) == len(shape), "need an idx for all dimensions"
acc = 1
ret = []
for tidx,d in reversed(list(zip(idxs, shape))):
ret.append(tidx * acc)
acc *= d
return Variable.sum(ret)
@dataclass(frozen=True)
class ShapeTracker:
views: Tuple[View, ...]
def __post_init__(self): assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views"
@staticmethod
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
@property
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
@property
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
# this is the real size (ish)
def size(self): return self.views[-1].size()
def vars(self) -> List[Variable]: return dedup(functools.reduce(operator.add, [v.vars() for v in self.views], []))
@property
def var_vals(self) -> Dict[Variable, int]:
ret:Dict[Variable, int] = {}
for v in self.vars():
var, val = v.unbind()
assert var not in ret or ret[var] == val, f"{var} has conflicted values {val} and {ret[var]}"
ret[var] = val
return ret
def unbind(self) -> ShapeTracker: return ShapeTracker(tuple(v.unbind() for v in self.views))
def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
to_apply:List[Tuple[MovementOps, Tuple]] = []
for v in self.views:
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
# first, we apply the offset
# then, we make it the correct shape
# then, we apply permutations
# TODO: don't use as_strided
to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset)))
# then, we apply pre expand pads
if v.mask is not None:
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
if any(x != (0,0) for x in pre_expand_pads):
to_apply.append((MovementOps.PAD, pre_expand_pads))
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
# then, we do any expands
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
# lastly, we apply post expand pads
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
return to_apply
# these are multiview strides, value is None if it's not a simple strided dimension
# TODO: this can be shared code between simplify and merge_views
def real_offset(self) -> sint:
real_offset, _ = self.expr_node(Variable('zero', 0, 0))
return real_offset.b if isinstance(real_offset, NumNode) else real_offset
# NOTE: if a stride is not always valid, it will be None
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
idx, valid = self.expr_idxs(idxs)
ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
ret[idxs.index(this_dim.a)] = this_dim.b
elif isinstance(this_dim, Variable) and this_dim in idxs:
ret[idxs.index(this_dim)] = 1
idx_vars, valid_vars = idx.vars(), valid.vars()
for i,tidx in enumerate(idxs):
if tidx in valid_vars and not ignore_valid: ret[i] = None
elif tidx not in idx_vars: ret[i] = 0
return tuple(ret)
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def _expr_idx(self, idx, valid) -> Tuple[Node, Node]:
for v in reversed(self.views[0:-1]):
if valid.max == 0: return Variable.num(-1), valid
valid = expr_node_mask(v, idx, valid)
idx = expr_node(v, idx)
return idx, valid
def simplify(self) -> ShapeTracker:
if len(self.views) >= 2:
new_view = merge_views(self.views[-2], self.views[-1])
if new_view:
if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
return self
def expr_idxs(self, idxs=None):
if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
idx = expr_idxs(self.views[-1], tuple(idxs))
valid = expr_node_mask(self.views[-1], idxs_to_idx(self.views[-1].shape, tuple(idxs)))
return self._expr_idx(idx, valid)
def expr_node(self, idx='idx'):
if idx.__class__ is str: idx = Variable(idx, 0, prod(self.shape)-1)
return self._expr_idx(expr_node(self.views[-1], idx), expr_node_mask(self.views[-1], idx))
def axis_is_masked(self, axis) -> bool:
_, valid = self.expr_idxs()
return f'idx{axis}' in [v.expr for v in valid.vars()]
# *** under this line are the movement ops ***
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
def permute(self, axis: Tuple[int, ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
def stride(self, mul: Tuple[int, ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
new_view = self.views[-1].reshape(new_shape)
if new_view is None:
extra_view = View.create(new_shape)
# last chance to merge. TODO: move into View
if (merged_view := merge_views(self.views[-1], extra_view)) is not None:
return ShapeTracker(self.views[0:-1] + (merged_view,))
return ShapeTracker(self.views + (extra_view, ))
return ShapeTracker(self.views[0:-1] + (new_view,))
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
# TODO: if we remove movementops from lazy.py we can delete this
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
# Pre-allocate all groups.
axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
# Index for new_shape and axis_groups.
i: int = 0
old_shape_i: int = 0
while old_shape_i < len(old_shape):
# 1s exist in new_shape only will lead to empty axes group creations.
if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
if i < len(new_shape) - 1: i += 1
else:
axis_groups[i].append(old_shape_i)
axis_group_size = prod([old_shape[x] for x in axis_groups[i]])
# Move to next axes group if total size of all dimensions match.
if axis_group_size == new_shape[i]:
if i < len(new_shape) - 1: i += 1
elif axis_group_size > new_shape[i]: return None
old_shape_i += 1
return axis_groups

View File

@@ -0,0 +1,352 @@
from __future__ import annotations
from abc import abstractmethod
import functools
from math import gcd
from itertools import product
from tinygrad.helpers import partition
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
class Node:
b: Union[Node, int]
min: int
max: int
def render(self, ops=None, ctx=None) -> Any:
if ops is None: ops = render_python
assert self.__class__ in (Variable, NumNode) or self.min != self.max
return ops[type(self)](self, ops, ctx)
def vars(self): return []
def expand_idx(self) -> VariableOrNum: return next((v for v in self.vars() if v.expr is None), NumNode(0))
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
# expand increments earlier variables faster than later variables (as specified in the argument)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def expand(self, idxs:Optional[Tuple[VariableOrNum, ...]]=None) -> List[Node]:
if idxs is None: idxs = (self.expand_idx(),)
return [self.substitute(dict(zip(idxs, (NumNode(x) for x in rep)))) for rep in Node.iter_idxs(idxs)]
@staticmethod
def iter_idxs(idxs:Tuple[VariableOrNum, ...]) -> Iterator[Tuple[int,...]]:
yield from (x[::-1] for x in product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
# substitute Variables with the values in var_vals
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__)
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
@functools.cached_property
def hash(self) -> int: return hash(self.key)
def __repr__(self): return self.render(ctx="REPR")
def __str__(self): return "<"+self.key+">"
def __hash__(self): return self.hash
def __bool__(self): return not (self.max == self.min == 0)
def __eq__(self, other:object) -> bool:
if not isinstance(other, Node): return NotImplemented
return self.key == other.key
def __neg__(self): return self*-1
def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
def __radd__(self, b:int): return self+b
def __sub__(self, b:Union[Node,int]): return self+-b
def __rsub__(self, b:int): return -self+b
def __le__(self, b:Union[Node,int]): return self < (b+1)
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b))
def __mul__(self, b:Union[Node, int]):
if b == 0: return NumNode(0)
if b == 1: return self
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
def __rmul__(self, b:int): return self*b
# *** complex ops ***
def __rfloordiv__(self, b:int):
if self.min > b >= 0: return NumNode(0)
if isinstance(self, NumNode): return NumNode(b // self.b)
raise RuntimeError(f"not supported: {b} // {self}")
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
if isinstance(b, Node):
if b.__class__ is NumNode: return self // b.b
if self == b: return NumNode(1)
if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
raise RuntimeError(f"not supported: {self} // {b}")
assert b != 0
if b < 0: return (self//-b)*-1
if b == 1: return self
# the numerator of div is not allowed to be negative
if self.min < 0:
offset = self.min//b
# factor out an "offset" to make the numerator positive. don't allowing factoring again
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
return create_node(DivNode(self, b))
def __rmod__(self, b:int):
if self.min > b >= 0: return NumNode(b)
if isinstance(self, NumNode): return NumNode(b % self.b)
raise RuntimeError(f"not supported: {b} % {self}")
def __mod__(self, b:Union[Node,int]):
if isinstance(b, Node):
if b.__class__ is NumNode: return self % b.b
if self == b: return NumNode(0)
if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node
raise RuntimeError(f"not supported: {self} % {b}")
assert b > 0
if b == 1: return NumNode(0)
if self.min >= 0 and self.max < b: return self
if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
if self.min < 0: return (self - ((self.min//b)*b)) % b
return create_node(ModNode(self, b))
@staticmethod
def num(num:int) -> NumNode: return NumNode(num)
@staticmethod
def factorize(nodes:List[Node]) -> List[Node]:
mul_groups: Dict[Node, int] = {}
for x in nodes:
a,b = (x.a,x.b) if isinstance(x, MulNode) else (x,1)
mul_groups[a] = mul_groups.get(a, 0) + b
return [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
@staticmethod
def sum(nodes:List[Node]) -> Node:
nodes = [x for x in nodes if x.max or x.min]
if not nodes: return NumNode(0)
if len(nodes) == 1: return nodes[0]
new_nodes: List[Node] = []
num_node_sum = 0
for node in SumNode(nodes).flat_components:
if node.__class__ is NumNode: num_node_sum += node.b
else: new_nodes.append(node)
if len(new_nodes) > 1 and len(set([x.a if isinstance(x, MulNode) else x for x in new_nodes])) < len(new_nodes):
new_nodes = Node.factorize(new_nodes)
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
@staticmethod
def ands(nodes:List[Node]) -> Node:
if not nodes: return NumNode(1)
if len(nodes) == 1: return nodes[0]
if any(not x for x in nodes): return NumNode(0)
# filter 1s
nodes = [x for x in nodes if x.min != x.max]
return create_rednode(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
# 4 basic node types
class Variable(Node):
def __new__(cls, expr:Optional[str], nmin:int, nmax:int):
assert nmin >= 0 and nmin <= nmax
if nmin == nmax: return NumNode(nmin)
return super().__new__(cls)
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
self.val:Optional[int] = None
def bind(self, val):
assert self.val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
self.val = val
return self
def unbind(self) -> Tuple[Variable, int]:
assert self.val is not None, f"cannot unbind {self}"
return Variable(self.expr, self.min, self.max), self.val
def vars(self): return [self]
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return var_vals[self] if self in var_vals else self
class NumNode(Node):
def __init__(self, num:int):
assert isinstance(num, int), f"{num} is not an int"
self.b:int = num
self.min, self.max = num, num
def bind(self, val):
assert self.b == val, f"cannot bind {val} to {self}"
return self
def __eq__(self, other): return self.b == other
def __hash__(self): return self.hash # needed with __eq__ override
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self
def create_node(ret:Node):
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
if ret.min == ret.max: return NumNode(ret.min)
return ret
class OpNode(Node):
def __init__(self, a:Node, b:Union[Node, int]):
self.a, self.b = a, b
self.min, self.max = self.get_bounds()
def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else [])
@abstractmethod
def get_bounds(self) -> Tuple[int, int]: pass
class LtNode(OpNode):
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
def get_bounds(self) -> Tuple[int, int]:
if isinstance(self.b, int):
return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
class MulNode(OpNode):
def __lt__(self, b: Union[Node, int]):
if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1: return Node.__lt__(self, b)
sgn = 1 if self.b > 0 else -1
return Node.__lt__(self.a*sgn, (b + abs(self.b) - 1)//abs(self.b))
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
if self.b % b == 0: return self.a*(self.b//b)
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
return Node.__floordiv__(self, b, factoring_allowed)
def __mod__(self, b: Union[Node, int]):
a = (self.a * (self.b%b))
return Node.__mod__(a, b)
def get_bounds(self) -> Tuple[int, int]:
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
class DivNode(OpNode):
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0 and isinstance(self.b, int)
return self.a.min//self.b, self.a.max//self.b
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) // self.b
class ModNode(OpNode):
def __mod__(self, b: Union[Node, int]):
if isinstance(b, Node) or isinstance(self.b, Node): return Node.__mod__(self, b)
return self.a % b if gcd(self.b, b) == b else Node.__mod__(self, b)
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
return Node.__floordiv__(self, b, factoring_allowed)
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0 and isinstance(self.b, int)
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b)
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) % self.b
class RedNode(Node):
def __init__(self, nodes:List[Node]): self.nodes = nodes
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
class SumNode(RedNode):
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
fully_divided: List[Node] = []
rest: List[Node] = []
if isinstance(b, SumNode):
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
if isinstance(b, Node):
for x in self.flat_components:
if x % b == 0: fully_divided.append(x // b)
else: rest.append(x)
if (sum_fully_divided:=create_rednode(SumNode, fully_divided)) != 0: return sum_fully_divided + create_rednode(SumNode, rest) // b
return Node.__floordiv__(self, b, False)
if b == 1: return self
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
fully_divided, rest = [], []
_gcd = b
divisor = 1
for x in self.flat_components:
if x.__class__ in (NumNode, MulNode):
if x.b%b == 0: fully_divided.append(x//b)
else:
rest.append(x)
_gcd = gcd(_gcd, x.b)
if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
else:
rest.append(x)
_gcd = 1
if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd)
if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __mod__(self, b: Union[Node, int]):
if isinstance(b, SumNode):
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
new_nodes: List[Node] = []
for x in self.nodes:
if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b))
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
else: new_nodes.append(x)
return Node.__mod__(Node.sum(new_nodes), b)
def __lt__(self, b:Union[Node,int]):
lhs: Node = self
if isinstance(b, int):
new_sum = []
for x in self.nodes:
# TODO: should we just force the last one to always be the number
if isinstance(x, NumNode): b -= x.b
else: new_sum.append(x)
lhs = Node.sum(new_sum)
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
if muls:
# NOTE: gcd in python 3.8 takes exactly 2 args
mul_gcd = b
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell x.b is int here
all_others = Variable.sum(others)
if all_others.min >= 0 and all_others.max < mul_gcd:
lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
return Node.__lt__(lhs, b)
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Variable.sum([node.substitute(var_vals) for node in self.nodes])
@property
def flat_components(self): # recursively expand sumnode components
new_nodes = []
for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x])
return new_nodes
class AndNode(RedNode):
def __floordiv__(self, b: Union[Node, int], _=True): return Variable.ands([x//b for x in self.nodes])
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
subed = []
for node in self.nodes:
if not (sub:=node.substitute(var_vals)): return NumNode(0)
subed.append(sub)
return Variable.ands(subed)
def create_rednode(typ:Type[RedNode], nodes:List[Node]):
ret = typ(nodes)
if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes]))
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
return create_node(ret)
@functools.lru_cache(maxsize=None)
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
if isinstance(a, (int, float)): return a
ret = a.substitute({k:Variable.num(v) for k, v in var_vals.items()})
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
return ret.b
# symbolic int
sint = Union[Node, int]
VariableOrNum = Union[Variable, NumNode]
render_python: Dict[Type, Callable] = {
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self.val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"),
NumNode: lambda self,ops,ctx: f"{self.b}",
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
}

View File

@@ -0,0 +1,132 @@
from __future__ import annotations
import functools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, cast
from tinygrad.helpers import prod, all_int, dedup
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, is_sym_int, sint
@functools.lru_cache(maxsize=None)
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
@functools.lru_cache(maxsize=None)
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
strides = [1] if shape else []
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
return filter_strides(shape, tuple(strides))
@dataclass(frozen=True)
class View:
shape:Tuple[sint, ...]
strides:Tuple[sint, ...]
offset:sint
mask:Optional[Tuple[Tuple[sint, sint], ...]]
contiguous:bool
@staticmethod
@functools.lru_cache(maxsize=None)
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
strides = filter_strides(shape, strides) if strides else strides_for_shape(shape)
contiguous = offset == 0 and mask is None and all(s1 == s2 for s1,s2 in zip(strides, strides_for_shape(shape)))
return View(shape, strides, offset, mask, contiguous)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def size(self): return prod([s.max if isinstance(s, Node) else s for s,st in zip(self.shape, self.strides) if st != 0])
def vars(self) -> List[Variable]:
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
return dedup(functools.reduce(operator.add, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], []))
def unbind(self) -> View:
unbound_vars:Dict[VariableOrNum,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None}
new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
return View.create(new_shape, new_strides, new_offset, new_mask)
# MovementOps live here now
def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
if self.mask:
# move the old mask
nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.mask, arg)])
# merge the masks if we have two
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
shape = [y-x for x,y in arg]
return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View:
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
if any(b or e for b, e in arg):
zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
return self.__unsafe_resize(zvarg, mask=mask)
return self
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape)
return self.__unsafe_resize(arg)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def expand(self, new_shape: Tuple[sint, ...]) -> View:
assert len(new_shape) == len(self.shape)
assert all(is_sym_int(x) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
# NOTE: can the mask ever be (0,0)?
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
return View.create(new_shape, self.strides, self.offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def permute(self, axis: Tuple[int, ...]) -> View:
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
return View.create(tuple([self.shape[a] for a in axis]), tuple([self.strides[a] for a in axis]), self.offset, tuple([self.mask[a] for a in axis]) if self.mask is not None else None)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def stride(self, mul: Tuple[int, ...]) -> View:
# except for the negative case, you can build this from the others. invertible in the negative case
assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
strides = tuple([z*m for z,m in zip(self.strides, mul)])
new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
return View.create(new_shape, strides, self.offset + offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
if self.shape == new_shape: return self
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
# check for the same size
if all_int(self.shape):
if all_int(new_shape):
assert prod(self.shape) == prod(new_shape), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
else:
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
assert prod(self.shape) == prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
# after the asserts, it's okay to check contiguous
if self.contiguous: return View.create(new_shape)
# check if this is adding or removing 1s (only)
# NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
if [x for x in self.shape if x != 1] == [x for x in new_shape if x != 1]:
new_strides: List[sint] = [y for x,y in zip(self.shape, self.strides) if x != 1]
new_strides_tuple: Tuple[sint, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape])
new_mask_tuple: Optional[Tuple[Tuple[sint, sint], ...]] = None
if self.mask:
for x,y in zip(self.shape, self.mask):
if x == 1 and y != (0, 1):
new_mask_tuple = ((0,0),) * len(new_shape)
break
else:
new_mask: List[Tuple[sint, sint]] = [y for x,y in zip(self.shape, self.mask) if x != 1]
new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape])
return View.create(new_shape, new_strides_tuple, self.offset, new_mask_tuple)
# TODO: bring the merge_views logic here for more caching
return None