openpilot v0.9.6 release

date: 2024-02-21T23:02:42
master commit: 0b4d08fab8e35a264bc7383e878538f8083c33e5
This commit is contained in:
FrogAi
2024-02-27 16:34:45 -07:00
commit 2901597132
1940 changed files with 647891 additions and 0 deletions

209
tinygrad_repo/extra/onnx.py Normal file
View File

@@ -0,0 +1,209 @@
from __future__ import annotations
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
import importlib
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
from typing import List,Dict
from onnx.onnx_pb import AttributeProto, ModelProto, TensorProto, TypeProto
try:
from onnx.helper import tensor_dtype_to_np_dtype
except ImportError:
# for onnx < 1.13
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
tensor_dtype_to_np_dtype = lambda x: TENSOR_TYPE_TO_NP_TYPE[x]
# global numpy cache for parameters
numpy_cache = {}
def safe_numpy(t) -> np.ndarray:
if not isinstance(t, Tensor): return t
global numpy_cache
if t not in numpy_cache:
if DEBUG >= 3: print("numpy cache miss", t)
tmp = t.numpy()
numpy_cache[t] = tmp if len(tmp.shape) else tmp.reshape(1)
assert len(numpy_cache[t].shape) > 0
return numpy_cache[t]
onnx_ops = importlib.import_module('extra.onnx_ops')
ONNXLIMIT = getenv("ONNXLIMIT", -1)
def get_run_onnx(onnx_model: ModelProto):
def type_parse(type_proto: TypeProto):
ret = []
while True:
attr = type_proto.WhichOneof('value')
if attr == 'tensor_type':
if "dim_value" not in getattr(type_proto, attr).shape.dim.__dir__(): return () # variable type, unable to determine shape
elif not ret:
return tuple([x.dim_value for x in getattr(type_proto, attr).shape.dim])
else:
ret.extend([(x.dim_value,) for x in getattr(type_proto, attr).shape.dim])
return tuple(ret)
elif attr == 'sequence_type':
type_proto = getattr(type_proto, attr).elem_type
ret.append(1)
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
def buffer_parse(inp: TensorProto) -> Tensor:
if inp.data_type in (1,10,6,7):
# TODO: this is shared with below
if len(inp.float_data) > 0:
ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
elif len(inp.int64_data) > 0:
ret = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
elif len(inp.int32_data) > 0:
ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False)
else:
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
else:
raise Exception(f"bad data type {inp.name} {inp.dims} {inp.data_type}")
return ret
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
# TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list
if a.type == AttributeProto.FLOAT: return float(a.f)
elif a.type == AttributeProto.INT: return int(a.i)
elif a.type == AttributeProto.STRING: return a.s.decode("utf-8")
elif a.type == AttributeProto.TENSOR: return buffer_parse(a.t) # TENSOR
elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats)
elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints)
elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings)
elif a.type == AttributeProto.GRAPH: raise Exception(f"graph not implemented: {a.g}")
else: raise Exception(f"can't parse {a.type} {a}")
def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): return {x.name:attribute_parse(x) for x in a}
tensors: Dict[str, Tensor] = {}
# get weights and biases
for inp in onnx_model.graph.initializer:
if len(inp.raw_data) > 0:
tensors[inp.name] = buffer_parse(inp)
elif len(inp.float_data) > 0:
tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
elif len(inp.int64_data) > 0:
tensors[inp.name] = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
elif len(inp.raw_data) == 0:
tensors[inp.name] = Tensor(np.array([], dtype=np.float32), requires_grad=False)
else:
print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
print(inp)
raise Exception("no data")
# preparse the attributes
attribute_dict = {}
domain = ""
for num,n in enumerate(onnx_model.graph.node):
attribute_dict[num] = attribute_to_dict(n.attribute)
if n.domain: domain = n.domain
onnx_model_version = onnx_model.opset_import[0].version
def run_onnx(inputs={}, debug=0):
debug = getenv("DEBUGONNX") or debug
input_tensors: Dict[str,Tensor] = {}
intermediate_tensors: Dict[str,Tensor] = {}
output_tensor_names = [x.name for x in onnx_model.graph.output]
# get inputs
for inp in onnx_model.graph.input:
if inp.name in tensors: continue
shape = type_parse(inp.type)
if inp.name in inputs:
if isinstance(inputs[inp.name], Tensor):
input_tensors[inp.name] = inputs[inp.name]
elif isinstance(inputs[inp.name], list):
input_tensors[inp.name] = [Tensor(i, requires_grad=False) for i in inputs[inp.name]]
elif domain == "ai.onnx.preview.training": # not sure if in real use the domain is "ai.onnx.preview.training"
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=True) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops
else:
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=False)
if shape: # if only input_tensor is not variable type
input_shape = input_tensors[inp.name].shape if isinstance(input_tensors[inp.name], Tensor) else (1, *[i.shape for i in input_tensors[inp.name]])
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
else:
raise Exception(f"no data for {inp.name} with shape {shape}")
def fetch_tensor(x: str):
if x in tensors: return tensors[x]
if x in intermediate_tensors: return intermediate_tensors[x]
if x != str(): return input_tensors[x]
return None
for num,n in enumerate(onnx_model.graph.node):
inp: List[Tensor] = []
if debug >= 3: print("inputs:")
for x in n.input:
t = fetch_tensor(x)
if debug >= 3: print(f"\t{x} - {t}")
inp.append(t)
opt: Dict = attribute_dict[num]
if debug >= 1: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
# some ops live here because they require some local variables
if n.op_type == "Split": # have to use n.output for cases when num_outputs is absent
axis = opt.get("axis", 0)
split = None if len(inp) == 1 else [int(x) for x in safe_numpy(inp[1])]
if split is None:
split = [inp[0].shape[axis] // len(n.output)] * len(n.output)
for i in range(inp[0].shape[axis] % len(n.output)):
split[i] += 1
i, ret = 0, []
arg = [(0,x) for x in inp[0].shape]
for s in split:
arg[axis] = (i,i+s)
ret.append(inp[0].shrink(arg=tuple(arg)))
i = i+s
ret = tuple(ret)
elif n.op_type == "Slice": # need to check onnx_model_version
if onnx_model_version < 10:
axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim
else:
starts, ends = inp[1:3]
axes = safe_numpy(Tensor.arange(inp[0].ndim, dtype=dtypes.int32) if len(inp) <= 3 else inp[3]).tolist()
steps = safe_numpy(inp[4]) if len(inp) > 4 else [1]*inp[0].ndim
starts, ends = safe_numpy(starts.ceil().cast(dtypes.int32)).tolist(), safe_numpy(ends.ceil().cast(dtypes.int32)).tolist()
arg = [(0,x,1) for x in inp[0].shape]
for i, axis in enumerate(axes):
axis = int(axis) + inp[0].ndim if axis < 0 else int(axis)
starts[i], ends[i] = starts[i] + inp[0].shape[axis] if starts[i] < 0 else starts[i], ends[i] + inp[0].shape[axis] if ends[i] < 0 else ends[i]
starts[i], ends[i] = max(0, min(starts[i], inp[0].shape[axis])), max(0, min(ends[i], inp[0].shape[axis]))
if starts[i] > ends[i] and steps[i] >= 0: steps[i] = -steps[i]
arg[axis] = (starts[i], ends[i], steps[i])
new_shape = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in arg)
if any(s==e for s,e in new_shape): ret = inp[0].shrink(new_shape)
else: ret = inp[0].__getitem__(tuple([slice(s,e,st) for s,e,st in arg]))
elif n.op_type == "Gradient": # need to call backward on intermediate_tensors
assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"
y = opt["y"]
intermediate_tensors[y].backward()
ret = tuple([t.grad for t in inp])
elif hasattr(onnx_ops, n.op_type):
fxn = getattr(onnx_ops, n.op_type)
if isinstance(fxn, dict):
for k in sorted(fxn.keys()):
if k <= onnx_model_version:
real_fxn = fxn[k]
else:
real_fxn = fxn
ret = real_fxn(*inp, **opt)
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise Exception(f"op_type {n.op_type} not supported")
if not isinstance(ret, tuple): ret = (ret, )
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
if debug >= 2: print([x.shape if isinstance(x, Tensor) else None for x in ret])
if debug >= 2: print("outputs:")
for i in range(len(n.output)):
if debug >= 2: print(f"\t{n.output[i]} - {ret[i]}")
intermediate_tensors[n.output[i]] = ret[i]
if num == ONNXLIMIT:
output_tensor_names = n.output
break
return {outp:intermediate_tensors[outp] for outp in output_tensor_names}
return run_onnx

View File

@@ -0,0 +1,720 @@
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, dtypes, ImageDType
from extra.onnx import safe_numpy
from onnx.helper import tensor_dtype_to_np_dtype
from onnx.onnx_pb import TensorProto
import os
import numpy as np
import functools
from typing import Union, Tuple, Optional, List, Any
import math
# **************** Free Ops ****************
def Identity(input: Tensor): return input
def Neg(input: Tensor): return -input
def Add(input: Tensor, other: Tensor, broadcast=None): return input + other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input + other).cast(input.dtype)
def Sub(input: Union[Tensor, Any], other: Tensor): return input - other # some test has input as int
def Mul(input: Tensor, other: Tensor): return (input * other) if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input * other).cast(input.dtype)
# in openpilot, due to SHUFFLE_PAD_OPS issues, we are spending an extra kernel
def Div(input: Tensor, other: Tensor): return input / other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else input.div(other).floor()
def Pow(input: Tensor, other: Tensor): return (input.float() ** other.float()).cast(input.dtype)
def Reciprocal(input: Tensor): return input.reciprocal()
def Sqrt(input: Tensor): return input.sqrt()
def Sign(input: Tensor): return input.sign()
def Abs(input: Tensor): return input.abs()
def Exp(input: Tensor): return input.exp()
def Log(input: Tensor): return input.log()
def Mish(input: Tensor): return input.mish()
def Sin(x: Tensor): return x.sin()
def Cos(x: Tensor): return x.cos()
def Tan(x: Tensor): return x.tan()
def Relu(input: Tensor): return input.relu()
def Sigmoid(input: Tensor): return input.sigmoid()
def Tanh(input: Tensor): return input.tanh()
def MatMul(input: Tensor, other: Tensor): return input.matmul(other)
def Floor(x:Tensor): return x.floor()
def Ceil(x:Tensor): return x.ceil()
def Less(x:Tensor,y:Tensor): return (x<y).cast(dtypes.bool)
def LessOrEqual(x:Tensor,y:Tensor): return (x<=y).cast(dtypes.bool)
def Greater(x:Tensor,y:Tensor): return (x>y).cast(dtypes.bool)
def GreaterOrEqual(x:Tensor,y:Tensor): return (x>=y).cast(dtypes.bool)
def Equal(x:Tensor,y:Tensor): return (x==y).cast(dtypes.bool)
def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0): return functools.reduce(Tensor.__add__, data_0)
def Mean(*data_0): return functools.reduce(Tensor.__add__, data_0) / len(data_0)
def Where(condition:Tensor,X:Tensor,Y:Tensor): return condition.where(X, Y).cast(X.dtype)
def Cast(input: Tensor, to): return input.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to)))
# **************** Simple Ops ****************
def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None):
if value: return value
elif value_float: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
elif value_floats: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
elif value_int: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
elif value_ints: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
elif value_string or value_strings: raise NotImplementedError(f'value_string or value_strings not implemented for Constant op')
def Softsign(input: Tensor): return input / (1+input.abs())
def Cosh(x): return (math.e ** x + math.e ** -x) / 2
def Sinh(x): return (math.e ** x - math.e ** -x) / 2
def Tanh(x): return x.tanh()
def HardSigmoid(input: Tensor, alpha=0.2, beta=0.5): return (alpha*input + beta).clip(0, 1)
def HardSwish(input: Tensor): return input * HardSigmoid(input, 1/6, 0.5)
def Celu(X: Tensor, alpha=1.0): return X.relu() - (-alpha*(X/alpha).exp()+1).relu()
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
def Softplus(X: Tensor): return X.softplus()
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
return X.clip(0, float("inf")) + X.clip(float("-inf"), 0) * slope
def LeakyRelu(X: Tensor, alpha=0.01): return X.leakyrelu(alpha)
def ThresholdedRelu(X: Tensor, alpha=1.0): return (X-alpha).relu() + (X-alpha).relu().sign() * alpha
def Softmax_1(input: Tensor, axis=1): return input.softmax(axis)
def Softmax_13(input: Tensor, axis=-1): return input.softmax(axis)
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
def LogSoftmax(input: Tensor, axis=-1): return input.log_softmax(axis)
def Clip(input: Tensor, min=None, max=None): return input.clip(float('-inf') if min is None else min, float('inf') if max is None else max)
# NOTE ReduceProd would require a new llop
def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,)) else ([] if noop_with_empty_axes else None)
def ReduceMax(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMin(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMean(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSumSquare(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL1(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.abs().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).sqrt()
def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.exp().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
def OptionalHasElement(x: Tensor=None): return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool)
def OptionalGetElement(x: Tensor=None): return x if x is not None else Tensor([], dtype=dtypes.float32)
def Tile(input: Tensor, repeats): return input.repeat([int(x) for x in safe_numpy(repeats)])
def Range(start: Tensor, limit, delta): return Tensor.arange(start=int(safe_numpy(start)), stop=int(safe_numpy(limit)), step=int(safe_numpy(delta))).cast(dtype=start.dtype)
def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64) # TODO: really?
def Size(data: Tensor): return prod(data if isinstance(data, list) else data.shape)
def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)
def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))])
def Shrink(input: Tensor, bias=0.0, lambd=0.5): return (input < -lambd)*(input+bias) + (input > lambd)*(input-bias)
def And(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.zeros(*x.shape)).cast(dtypes.bool)
def Or(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.ones(*x.shape)).cast(dtypes.bool)
def Xor(x:Tensor, y:Tensor): return Where((x==y), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
def Not(x:Tensor): return Where((x==1), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
def Asin(x): return Atan(x / Tensor.sqrt(1 - x * x))
def Asinh(x): return Tensor.log(x + Tensor.sqrt(x * x + 1))
def Acosh(x): return Tensor.log(x + Tensor.sqrt(x * x - 1))
def Atanh(x): return 0.5 * Tensor.log((1 + x)/(1 - x))
def Acos(x: Tensor):
negate = (x < 0)
x = x.abs()
ret = ((((-0.0187293 * x) + 0.0742610)*x - 0.2121144) * x + 1.5707288) * Tensor.sqrt(1.0 - x)
ret = ret - 2 * negate * ret
return negate * 3.14159265358979 + ret
def Atan(y: Tensor):
x = Tensor.ones(y.shape)
t3 = x
t1 = y.abs()
t0 = (t3 > t1).where(t3, t1)
t1 = (t3 < t1).where(t3, t1)
t3 = t1 / t0
t4 = t3 * t3
t0 = ((((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4 - 0.332994597) * t4 + 0.999995630
t3 = t0 * t3
t3 = (y.abs() > x.abs()).where(1.570796327 - t3, t3)
return (y < 0).where(-t3, t3)
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
k = int(k.numpy().item()) if k != 0 else 0 # onnx passes k as a tensor int64 with one element, default is 0
return x.triu(k) if upper else x.tril(k)
def Squeeze(input: Tensor, axes):
if isinstance(axes, Tensor): axes = safe_numpy(axes)
axes = [int(x) if x >= 0 else int(x+input.ndim) for x in axes]
return input.reshape([s for i,s in enumerate(input.shape) if i not in axes])
def Unsqueeze(data: Tensor, axes):
axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)]
new_shape = [1] * (len(data.shape) + len(axes))
ptr = iter(data.shape)
for i in range(len(new_shape)):
if i not in axes:
new_shape[i] = next(ptr)
return data.reshape(new_shape)
def Binarizer(input, threshold=0.0): return input > threshold
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
axis = axis + x.ndim if axis < 0 else axis
m = x == (x.max(axis=axis, keepdim=keepdims) if keepdims else x.max(axis=axis, keepdim=keepdims).unsqueeze(axis))
c = Tensor.arange(x.shape[axis]).reshape(*[1]*(axis), x.shape[axis], *[1]*(x.ndim - axis-1)) * m
return c.max(axis=axis,keepdim=keepdims).cast(dtypes.int64)
def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
def Elu(input: Tensor, alpha=1.0): return input.elu(alpha=alpha)
def Concat(*inputs: List[Tensor], axis): return inputs[0].cat(*inputs[1:], dim=axis)
def Transpose(input: Tensor, perm=None): return input.permute(order=list(range(len(input.shape))[::-1]) if perm is None else perm)
# NOTE: since we only have one type, this is valid!
def CastLike(input, target_type):
assert isinstance(target_type, Tensor), "can only CastLike Tensor"
return input
def ConstantOfShape(input, value:Tensor=None):
if value is None: value=Tensor([0.0])
shape = [int(x) for x in safe_numpy(input)]
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
# TODO: abstract out the broadcast logic in tensor
def Expand(input: Tensor, shape):
x_shape, y_shape = input.shape, [int(x) for x in safe_numpy(shape)]
# copied from _broadcasted
x_shape, y_shape = [([1]*(max(len(x_shape), len(y_shape))-len(t_shape)) + list(t_shape)) for t_shape in [x_shape, y_shape]]
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x_shape, y_shape))
return input.reshape(x_shape).expand(shape_ret)
# **************** Complex Ops ****************
def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1]))
return ret
# works with Tensors.ndim != 4
def _batchnorm(self:Tensor, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor):
shape = [1, -1] + [1] * (self.ndim-2)
x = (self - mean.reshape(shape=shape))
if weight: x = x * weight.reshape(shape=shape)
ret = x.mul(invstd.reshape(shape=shape) if len(invstd.shape) == 1 else invstd)
return (ret + bias.reshape(shape=shape)) if bias else ret
# TODO: this is copied from tinygrad/nn/__init__.py
# spatial is from opset 7 and has since been removed
def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0, spatial=1, is_test=0):
if training_mode:
x_detached = X.detach()
current_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
current_var = (y*y).mean(axis=(0,2,3))
current_invstd = current_var.add(epsilon).pow(-0.5)
running_mean = input_mean * momentum + current_mean * (1 - momentum)
running_var = input_var * momentum + current_var * (1 - momentum)
return _batchnorm(X, scale, B, current_mean, current_invstd), running_mean, running_var
else:
invstd = (input_var + epsilon)**-0.5
return _batchnorm(X, scale, B, input_mean, invstd)
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
axis = tuple(range(2, len(x.shape)))
mean = x.mean(axis=axis, keepdim=True)
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).pow(-0.5)
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
assert stash_type == 1, "only float32 is supported"
axis = tuple(i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape)))
mean = x.mean(axis=axis, keepdim=True)
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).sqrt().reciprocal()
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
# onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
# numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...)
def _format_padding(onnx_pads, ndims=None, axes=None):
if ndims and len(onnx_pads)//2 != ndims: onnx_pads = onnx_pads * ndims # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2
if ndims is None: ndims = len(onnx_pads) // 2
if axes is None: axes = list(range(ndims))
num_axes = len(axes)
np_pads = [(0,0)] * ndims
for i in range(num_axes):
np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes])
return np_pads
def _padding(X: Tensor, pads=None, auto_pad="NOTSET", axes=None, constant_value=0., strides=None, kernel_shape=None, dilations=None):
if auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
if pads is None: return X
pads = _format_padding(pads, ndims=len(X.shape), axes=axes)
return X.pad(tuple(pads), value=constant_value)
def _auto_pad(X, auto_pad, strides, kernel_shape, dilations):
strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape)
dilations = [1]*len(kernel_shape) if dilations == 1 else dilations
pad_shape = [(math.ceil(sh/st)-1)*st+((ks-1)*di+1)-sh for sh, st, ks, di in zip(X.shape[-len(strides):], strides, kernel_shape, dilations)]
if auto_pad == "SAME_UPPER": return [pad_shape[0]//2, pad_shape[1]//2, pad_shape[0]-pad_shape[0]//2, pad_shape[1]-pad_shape[1]//2]
elif auto_pad == "SAME_LOWER": return [pad_shape[0]-pad_shape[0]//2, pad_shape[1]-pad_shape[1]//2, pad_shape[0]//2, pad_shape[1]//2]
else: raise NotImplementedError(f"auto_pad={auto_pad} not implemented, yet")
def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=None, axes: Tensor=None, mode="constant", value: float=0.):
constant_value = value if constant_value is None else float(safe_numpy(constant_value)[0])
seq_pads = list(pads) if isinstance(pads, tuple) else safe_numpy(pads)
seq_pads = [math.ceil(i) for i in seq_pads]
seq_axes = safe_numpy(axes).astype(np.int32).tolist() if axes is not None else None
base_shape = x.shape
pads = _format_padding(seq_pads, ndims=len(x.shape), axes=seq_axes)
if mode == "wrap":
repeat_args = [math.ceil(dim[0]/sh) + math.ceil(dim[1]/sh) + 1 for dim, sh in zip(pads, base_shape)]
new_shape = [s*r for s,r in zip(base_shape, repeat_args)]
shrink_args = [(sh-dim[0]%sh if dim[0]%sh != 0 else 0, nsh-(sh-dim[1]%sh) if dim[1]%sh != 0 else nsh) for dim, sh, nsh in zip(pads, base_shape, new_shape)]
return x.repeat(tuple(repeat_args)).shrink(tuple(shrink_args))
elif mode == "reflect":
for i,s in enumerate(x.shape):
if pads[i] == (0,0): continue
elif pads[i][0] and not pads[i][1]:
x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (s-pads[i][0]-1, s_-1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (0,s) for i_ in range(x.ndim)])) + \
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
elif not pads[i][0] and pads[i][1]:
x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (1, pads[i][1]+1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (s,0) for i_ in range(x.ndim)])) + \
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
else:
x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (s-pads[i][0]-1, s_-1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (0,s+pads[i][1]) for i_ in range(x.ndim)])) + \
x.flip(i).shrink(tuple([(0,s_) if i_ != i else (1, pads[i][1]+1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
return x
elif mode == "edge":
for i,s in enumerate(x.shape):
if pads[i] == (0,0): continue
elif pads[i][0] and not pads[i][1]:
x = x.shrink(tuple([(0,s_) if i_ != i else (0,1) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (0,s) for i_ in range(x.ndim)])) + \
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
elif not pads[i][0] and pads[i][1]:
x = x.shrink(tuple([(0,s_) if i_ != i else (s_-1, s_) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
else:
x = x.shrink(tuple([(0,s_) if i_ != i else (0,1) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (0,s+pads[i][1]) for i_ in range(x.ndim)])) + \
x.shrink(tuple([(0,s_) if i_ != i else (s_-1, s_) for i_,s_ in enumerate(x.shape)])).expand([pads[i][1] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
return x
elif mode == "constant":
return _padding(x, seq_pads, axes=seq_axes, constant_value=constant_value)
def AveragePool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_pad=0, dilations=1, pads=None, strides=1):
if dilations != 1: raise NotImplementedError(f"dilations != 1 not supported, dilations:{dilations}")
pixel_axes = tuple(range(len(X.shape)))[-2:]
if ceil_mode: auto_pad = "SAME_UPPER"
padding_included = _padding(X, pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations).avg_pool2d(kernel_shape, stride=strides)
if count_include_pad:
return padding_included
else:
div = _padding(Tensor.ones(*X.shape), pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations).avg_pool2d(kernel_shape, stride=strides)
return padding_included / div
def MaxPool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1):
if ceil_mode: auto_pad = "SAME_UPPER"
ret = _padding(X, pads, auto_pad, constant_value=-np.inf, axes=tuple(range(len(X.shape)))[-len(kernel_shape):], strides=strides, kernel_shape=kernel_shape, dilations=dilations)
ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations)
ret_len, X_len = ret.numel(), X.numel()
indices = ((ret.flatten().unsqueeze(1).expand(ret_len, X_len) == X.flatten().reshape(1, X_len).expand(ret_len, X_len)) * Tensor.arange(X_len).reshape(1, X_len).expand(ret_len, X_len)).sum(1).reshape(ret.shape).cast(dtypes.int64)
if storage_order: indices = indices.transpose(indices.ndim-2, indices.ndim-1)
return ret, indices
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: Tensor=None, kernel_shape=None, pads=None, strides=None):
out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)]
outlength = prod(out_sh)
xI = xI.flatten().unsqueeze(1).expand(prod(xT.shape), outlength)
arange = Tensor.arange(outlength, requires_grad=False).reshape(1, outlength).expand(xI.shape)
xT = xT.flatten().unsqueeze(1).expand(prod(xT.shape), outlength)
ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh)
if outshape is not None:
outshape = safe_numpy(outshape).tolist()
if outshape != ret.shape:
diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]]
pad_args = [diff[0]//2, diff[1]//2, diff[0]-diff[0]//2, diff[1]-diff[1]//2]
ret = ret.pad2d((pad_args[1], pad_args[3], pad_args[0], pad_args[2]))
return ret
def Conv(X: Tensor, W: Tensor, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1):
if auto_pad != "NOTSET": padding = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
else: padding = [p for ps in zip(pads[:len(pads)//2][::-1], pads[len(pads)//2:][::-1]) for p in ps] if pads is not None else 0 # reorder padding
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=padding)
def ConvTranspose(X: Tensor, W: Tensor, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, output_shape=None, output_padding=0, strides=1):
if not kernel_shape: kernel_shape = W.shape
if pads is None and auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
elif pads is None and auto_pad == "NOTSET": pads = [0,0] * (X.ndim - 2)
strides_ = [1]*(W.ndim-1) + [strides] if isinstance(strides, int) else [1]*(W.ndim-len(strides)) + list(strides)
dilations_ = [1]*(W.ndim-1) + [dilations] if isinstance(dilations, int) else [1]*(W.ndim-len(dilations)) + list(dilations)
if output_shape and not output_padding:
out_sh = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides_, X.shape, kernel_shape, dilations_))]
output_padding = [os - rs for os, rs in zip(output_shape, out_sh[-len(output_shape):])]
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads if pads is not None else 0, output_padding=output_padding)
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None):
if isinstance(ratio, Tensor) and not ratio.shape: ratio = safe_numpy(ratio) # ratio and tensor is passed in as Tensor with shape: ()
if isinstance(training_mode, Tensor) and not training_mode.shape: training_mode = safe_numpy(training_mode)
if not training_mode: return data, Tensor.ones(*data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
rng = np.random.RandomState(seed)
ratio = ratio.lazydata.realize().toCPU()[0] if isinstance(ratio, Tensor) else ratio
mask = Tensor((rng.random(data.shape) >= ratio), requires_grad=False, device=data.device)
return data * mask * (1/(1.0 - ratio)), mask
def LRN(input: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0):
bs, c, iy, ix = input.shape
return input / input.mul(input).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta)
def MeanVarianceNormalization(input: Tensor, axis=(0, 2, 3)):
data_mean = input.mean(axis=axis, keepdim=True)
std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt()
return (input - data_mean) / (std + 1e-9)
def NegativeLogLikelihoodLoss(input: Tensor, target: Tensor, weight=None, ignore_index=None, reduction="mean"):
target = target.cast(dtypes.float32)
N, C, i_shape = input.shape[0], input.shape[1], input.shape
t_shape = target.shape
if len(input.shape) != 3:
input = input.reshape((N, C, -1))
target = target.reshape((N, -1))
if weight is not None:
mask = target.unsqueeze(-1) == Tensor.arange(C).repeat((N, 1, 1))
weight = (mask * weight).sum(axis=-1)
if ignore_index is not None:
cond = target == ignore_index
weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1)
mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(input.shape) -2))
loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight)
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
elif reduction == "sum": return loss.sum()
return loss.reshape(t_shape) if len(i_shape) != 3 else loss
def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean"):
N, C, *s_dimensions = scores.shape
if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels)
mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(1, C, *[1]*len(s_dimensions))
y = scores.log_softmax(axis=1)
if weights is not None: weights = weights.__getitem__(tuple([labels, *[slice(None)]*(weights.ndim-1)]))
loss = (mask * -y).sum(1) if weights is None else (mask * -y).sum(1) * weights
if reduction == "mean": loss = loss.sum() / (loss == 0).where(0, 1).sum() if weights is None else loss.sum() / weights.sum()
elif reduction == "sum": loss = loss.sum()
return loss, y
def ArrayFeatureExtractor(input: Tensor, indices: Tensor): return input.__getitem__(tuple([slice(None) if i != (input.ndim-1) else indices for i in range(input.ndim)]))
def Gather(input: Tensor, indices: Tensor, axis=0):
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
input_sh = list(input.shape)
ret_shape = input_sh[:axis] + list(indices.shape) + input_sh[axis+1:]
if indices.ndim > 1: indices = indices.flatten()
indices = [int(safe_numpy(indices))] if indices.shape == () else [input_sh[axis]+int(x) if x<0 else int(x) for x in safe_numpy(indices)]
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(input_sh)] for i in indices]
return input.shrink(arg=tuple(args[0])).cat(*[input.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
else: # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
return input.__getitem__(tuple([slice(None) if i != axis else indices for i in range(input.ndim)]))
def GatherElements(input: Tensor, indices: Tensor, axis):
indices = indices.sign().contiguous().__neg__().contiguous().relu() * input.shape[axis] + indices
return input.gather(indices, axis)
def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0)
assert n <= 1, f"n:{n} shouldn't be larger than 1"
b = x.cast(dtypes.int32).contiguous().cast(x.dtype)
b = (b >= 0).where(b+n, b-n)
if equidistant_case == "round_down":
return (x > b).where(b+1-n, b-n)
elif equidistant_case == "round_up":
return (x >= b).where(b+1-n, b-n)
elif equidistant_case == "round_to_even":
x_ceil_fraction = x.ceil()/2
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
x = (_and(x == b, cond_ceil_even)).where(x+1-n, x)
x = (x > b).where(b+1-n, b-n)
return x
def Round(X:Tensor): return _round(X, 0.5, "round_to_even")
def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel', cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch', mode='nearest', nearest_mode='round_prefer_floor'):
def _nearest_gather(X: Tensor, x_out, y_out): return X[:,:,y_out,:][:,:,:,x_out]
def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len):
if nearest_mode == "round_prefer_floor": ret = _round(x_resized, 0.5, "round_down")
elif nearest_mode == "round_prefer_ceil": ret = _round(x_resized, 0.5, "round_up")
elif nearest_mode == "floor": ret = x_resized.floor()
elif nearest_mode == "ceil": ret = x_resized.ceil()
return ret.clip(0, x_len-1)
def _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi=None):
if coordinate_transformation_mode == "half_pixel":
x_out = (x_out + 0.5)/Tensor(scales_lol[-1]) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
y_out = (y_out + 0.5)/Tensor(scales_lol[-2]) - 0.5
elif coordinate_transformation_mode == "align_corners":
x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1)
y_out = y_out * (X.shape[-2] - 1) / (output_shape[-2] - 1)
elif coordinate_transformation_mode == "asymmetric":
x_out = x_out/scales_lol[-1]
y_out = y_out/scales_lol[-2]
elif coordinate_transformation_mode == "half_pixel_symmetric":
x_out = X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1]) + (x_out + 0.5) / scales_lol[-1] - 0.5
y_out = X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2]) + (y_out + 0.5) / scales_lol[-2] - 0.5
elif coordinate_transformation_mode == "pytorch_half_pixel":
x_out = (x_out + 0.5)/scales_lol[-1] - 0.5 if output_shape[-1] > 1 else Tensor([0])
y_out = (y_out + 0.5)/scales_lol[-2] - 0.5 if output_shape[-2] > 1 else Tensor([0])
elif coordinate_transformation_mode == "tf_crop_and_resize":
x_out = roi[-1][0] * (X.shape[-1] - 1) + x_out * ((roi[-1][1] - roi[-1][0]) * (X.shape[-1] - 1) / (output_shape[-1] - 1)) if output_shape[-1] > 1 else Tensor([0.5 * (roi[-1][0] + roi[-1][1]) * (X.shape[-1] - 1)])
y_out = roi[-2][0] * (X.shape[-2] - 1) + y_out * ((roi[-2][1] - roi[-2][0]) * (X.shape[-2] - 1) / (output_shape[-2] - 1)) if output_shape[-2] > 1 else Tensor([0.5 * (roi[-2][0] + roi[-2][1]) * (X.shape[-2] - 1)])
return x_out.clip(0, X.shape[-1]-1), y_out.clip(0, X.shape[-2]-1)
if roi is not None:
roi = safe_numpy(roi)
roi = [(st,ed) for st, ed in zip(roi[:len(roi)//2], roi[len(roi)//2:])]
roi_ = [(1,1)] * 4
if axes is not None:
for a,r in zip(axes, roi):
roi_[a] = r
roi = roi_
if scales is not None:
scales = safe_numpy(scales).tolist()
if axes is not None:
scales_ = [1]*X.ndim
for a,s in zip(axes, scales):
scales_[a] = s
scales = scales_
elif sizes is not None:
sizes = [int(i) for i in safe_numpy(sizes)]
scales = []
if axes is not None:
sizes_ = [1]*X.ndim
for a,s in zip(axes, sizes):
sizes_[a] = s
scales.append(s/X.shape[a])
sizes = sizes_
else: scales = [si/xs for xs, si in zip(X.shape, sizes)]
if keep_aspect_ratio_policy == "not_larger":
scale = min(scales)
sizes = _round(Tensor(list(X.shape[-2:]))*scale, 0.5, "round_up")
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
elif keep_aspect_ratio_policy == "not_smaller":
scale = max(scales)
sizes = _round(Tensor(list(X.shape[-2:]))*scale, 0.5, "round_up")
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
scales_lol = [os/xs for xs, os in zip(X.shape, output_shape)]
x_out = Tensor.arange(output_shape[-1])
y_out = Tensor.arange(output_shape[-2])
if mode == "nearest":
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi)
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])
y_out = _nearest_mode(y_out, nearest_mode, X.shape[-1])
return _nearest_gather(X, x_out, y_out)
elif mode == "linear":
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape_, scales, roi)
ret = []
for y in safe_numpy(y_out):
for x in safe_numpy(x_out):
x_floor, y_floor = int(x), int(y)
y_shrink = (0, X.shape[2]) if X.shape[2] == 1 else (y_floor, y_floor+2) if y != y_floor else (y_floor, y_floor+1)
x_shrink = (x_floor, x_floor+2) if x != x_floor else (x_floor, x_floor+1)
shrink_args = ((0, X.shape[0]), (0, X.shape[1]), y_shrink, x_shrink)
corners = safe_numpy(X.shrink(shrink_args))
x1, x2, y1, y2 = x_floor, x_floor+1, y_floor, y_floor+1
if x == x_floor and y == y_floor: # TODO https://en.wikipedia.org/wiki/Bilinear_interpolation#Weighted_mean maybe do weighted mean?
ret.append(corners[0,0,0,0])
elif x == x_floor:
ret.append((corners[0,0,0,0] * (y2 - y) + corners[0,0,1,0] * (y - y1)) / (y2 - y1))
elif y == y_floor:
ret.append((corners[0,0,0,0] * (x2 - x) + corners[0,0,0,1] * (x - x1)) / (x2 - x1))
else:
ret.append((corners[0,0,0,0] * (x2 - x) * (y2 - y) + corners[0,0,0,1] * (x - x1) * (y2 - y) + corners[0,0,1,0] * (x2 - x) * (y - y1) + corners[0,0,1,1] * (x - x1) * (y - y1)) / ((x2 - x1) * (y2 - y1)))
return Tensor(ret).reshape(output_shape)
elif mode == "cubic":
raise Exception("cubic interpolation is not implemented")
def CenterCropPad(input: Tensor, shape: Tensor, axes=None):
if not axes: axes = list(range(input.ndim))
shrink_arg = [(0,i) for i in input.shape]
pad_arg = [(0,0) for _ in range(input.ndim)]
shape = safe_numpy(shape).tolist()
for s, x in zip(shape, axes):
if s < input.shape[x]: shrink_arg[x] = (input.shape[x]//2 - s//2, input.shape[x]//2 + s//2) if s%2 == 0 else (input.shape[x]//2 - s//2 - 1, input.shape[x]//2 + s//2)
elif s > input.shape[x]: pad_arg[x] = ((s - input.shape[x])//2, (s - input.shape[x])//2) if (s - input.shape[x])% 2 == 0 else ((s - input.shape[x])//2, (s - input.shape[x])//2 + 1)
return input.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
depth = int(safe_numpy(depth).item())
indices, rank = (indices < 0).where(indices+depth, indices), len(indices.shape)
if axis < 0: axis += rank + 1
ls, rs = indices.shape[0:axis], indices.shape[axis: rank]
cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs))
return cond.where(values[1], values[0]).cast(values.dtype)
def Erf(x: Tensor):
sign = x.sign()
x = x.abs()
t = 1.0 / (1.0 + 0.3275911 * x)
term1 = 0.254829592 * t
term2 = -0.284496736 * t ** 2
term3 = 1.421413741 * t ** 3
term4 = -1.453152027 * t ** 4
term5 = 1.061405429 * t ** 5
y = (term1 + term2 + term3 + term4 + term5)
return sign * (1.0 - y * Tensor.exp(-x * x))
def Compress(inp: Tensor, condition: Tensor, axis=None):
if axis is None:
inp = inp.flatten()
axis = 0
axis = axis + inp.ndim if axis < 0 else axis
con_np = safe_numpy(condition)
con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor
return inp.__getitem__(tuple([slice(None) if i != axis else con for i in range(inp.ndim)]))
type_map = {TensorProto.DOUBLE: dtypes.double, TensorProto.FLOAT: dtypes.float32}
def EyeLike(x: Tensor, dtype=None, k=0):
if dtype is None: dtype = x.dtype
else: dtype = type_map[dtype]
shape = x.shape
dim = min(x.shape)
if shape[0] == shape[1]: return Tensor.eye(dim=dim, dtype=dtype)
else:
diff = (shape[0]-dim, shape[1]-dim)
padarg = tuple([(d, d) if d == 0 else (k, d-k) for d in diff])
return Tensor.eye(dim=dim, dtype=dtype).pad(padarg)
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode)
# Needs work
def IsInf(x,detect_negative=1,detect_positive=1):
ret = (x == float("inf"))*detect_positive + (x == float("-inf"))*detect_negative + Tensor.zeros(*x.shape)
return ret.cast(dtypes.bool)
# Needs work
def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point=0, axis=1):
axis = axis + x.ndim if axis < 0 else axis
x_sc = x_scale.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim))
x_zer = x_zero_point.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim)) if isinstance(x_zero_point, Tensor) else x_zero_point
return (x - x_zer) * x_sc
# Needs work
def IsNaN(x):
return (x < float("-inf")).cast(dtypes.bool)
# **************** com.microsoft Ops ****************
def SkipLayerNormalization(input:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon=None):
if epsilon is None: epsilon=1e-12
x = input + skip + bias
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
x = x + bias
return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x ** 3).tanh())
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
assert (segment_ids is None) is (segment_embedding is None)
assert (mask is None) is (mask_index_type is None)
assert mask is None, "functionality not supported yet" # TODO
input_shape = input_ids.shape
bsz, seq_length = input_shape[0], input_shape[1]
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None)
def embedding(x:Tensor, vocab_size, weight:Tensor)->Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor
vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).reshape(1, 1, vocab_size).expand(*x.shape, vocab_size)
return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight
# bert embedding layer
if epsilon is None: epsilon = 1e-12
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
return out, None, embedding_sum
def Attention(input:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None, relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary=None, mask_filter_value=None, num_heads=None, past_present_share_buffer=None, qkv_hidden_sizes=None, scale=None, unidirectional=None):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
assert num_heads is not None # required
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, "functionality not supported yet" # TODO strange params
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
if unidirectional: # gpt-style
assert hidden_size == v_hidden_size
xqkv = input.linear(weights, bias)
xq, xk, xv = [xqkv.slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
else: # bert-style
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
xq, xk, xv = [input.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
if past is not None:
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
def attn(query, key, value, attn_mask):
query_length, key_length = query.shape[-2], key.shape[-2]
cdim = max(query_length, key_length) + 1
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
# This is where Tensor.scaled_dot_product_attention differs:
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False).cast(dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length].cast(dtypes.bool)
return (Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask).softmax(-1) @ value
bsz, _, seq_len, _ = xq.shape
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
return out, present
# **************** ai.onnx.preview.training Ops ****************
# TODO not entirely sure these optimizers are correct
def Adagrad(R, T, *inputs, decay_factor=0.0, epsilon=0.0, norm_coefficient=0.0):
groups = len(inputs) // 3
grouped_inputs = [inputs[i::groups] for i in range(groups)]
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
r = R / (1 + T * decay_factor)
ret = []
for input in grouped_inputs:
X, G, H = input
X.grad = norm_coefficient * X + G
X.grad.requires_grad, H.requires_grad = False, False # TODO manually turning off requires_grad, see TODO under (domain == "ai.onnx.preview.training") in onnx.py
H.assign(H.detach() + X.grad * X.grad).realize()
H_adaptive = H.sqrt() + epsilon
X.assign(X.detach() - r * X.grad / H_adaptive)
ret.extend([X, H])
ret = ret[::2] + ret[1::2]
return tuple(ret)
def Momentum(R, T, *inputs, alpha, beta, mode, norm_coefficient):
groups = len(inputs) // 3
grouped_inputs = [inputs[i::groups] for i in range(groups)]
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
beta_adjusted = beta if T > 0 else 1
ret = []
for input in grouped_inputs:
X, G, V = input
X.grad = (norm_coefficient * X + G).realize()
X.grad.requires_grad, V.requires_grad = False, False
V.assign(alpha * V + beta_adjusted * X.grad).realize()
if mode == "standard": X.assign(X.detach() - R * V).realize()
elif mode == "nesterov": X.assign(X.detach() - R * (X.grad + alpha + V)).realize()
ret.extend([X, V])
ret = ret[::2] + ret[1::2]
return tuple(ret)
# copied from tinygrad/nn/optim.py: LAMB with some edits
def Adam(R, T, *inputs, alpha=0.9, beta=0.999, epsilon=0.0, norm_coefficient=0.0, norm_coefficient_post=0.0):
groups = len(inputs) // 4
grouped_inputs = [inputs[i::groups] for i in range(groups)]
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
ret = []
for input in grouped_inputs:
X, G, V, H = input
X.grad = (norm_coefficient * X + G).realize()
V.requires_grad, H.requires_grad, X.grad.requires_grad = False, False, False
V.assign(alpha * V + (1.0 - alpha) * X.grad).realize()
H.assign(beta * H + (1.0 - beta) * (X.grad * X.grad)).realize()
up = (V / (1.0 - alpha**T)) / ((H / (1.0 - beta**T)).sqrt() + epsilon) if T > 0 else V / (H.sqrt() + epsilon)
X.assign(X.detach() - R * up).realize()
X = (1 - norm_coefficient_post) * X
ret.extend([X, V, H])
ret = ret[::3] + ret[1::3] + ret[2::3]
return tuple(ret)

View File

@@ -0,0 +1,285 @@
# this can be constructed from a cl_cache or loaded from a thneed file
import time
import struct
import json
import traceback
import numpy as np
from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu
from tinygrad.helpers import DEBUG, getenv
from collections import defaultdict
import pyopencl as cl
from tinygrad.runtime.ops_gpu import CL, OSX_TIMING_RATIO
DEBUGCL = getenv("DEBUGCL", 0)
FLOAT16 = getenv("FLOAT16", 0)
class Thneed:
def __init__(self, cl_cache=[], inputs={}):
self.cl_cache, self.inputs = cl_cache[:], inputs
self.gobj = 0
# build graph
# NOTE: if CLCACHE=1, this is wrong!
nodes = defaultdict(lambda: {'in_edges': [], 'out_edges': []})
for _, args in self.cl_cache:
# output is always the first parameter
for a in args[3:]:
nodes[a]['out_edges'].append(args[2])
nodes[args[2]]['in_edges'].append(a)
# get buffers to save
self.buffers_to_save = set()
self.outputs = []
for n in nodes.keys():
if len(nodes[n]['in_edges']) == 0:
self.buffers_to_save.add(n)
if len(nodes[n]['out_edges']) == 0:
self.outputs.append(n)
fake_inputs = []
for k,n in self.inputs.items():
if n in self.buffers_to_save:
self.buffers_to_save.remove(n)
else:
print(f"WARNING: {k} was not a used input, removing it")
fake_inputs.append(k)
for k in fake_inputs:
del self.inputs[k]
def load(self, input_fn):
float32 = not FLOAT16
mf = cl.mem_flags
image_fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT if float32 else cl.channel_type.HALF_FLOAT)
image_fmt_32 = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT)
with open(input_fn, "rb") as f:
json_len = struct.unpack("I", f.read(4))[0]
jdat = json.loads(f.read(json_len).decode('latin_1'))
weights = f.read()
# load in the buffers
bufs = {'\x00\x00\x00\x00\x00\x00\x00\x00': None}
bufs_loaded = {}
ptr = 0
for o in jdat['objects']:
#print(o)
if o['needs_load']:
nptr = ptr + o['size']
o['data'] = weights[ptr:nptr]
ptr = nptr
if o['arg_type'] == "image2d_t" or o['arg_type'] == "image1d_t":
tfmt = image_fmt_32 if 'float32' in o and o['float32'] else image_fmt
if o['arg_type'] == "image2d_t":
if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]:
# hack: use a image1d since we can back that with a buffer
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
else:
# buffer isn't supported in image2d, copy buffer into image
if 'buffer_id' in o and bufs_loaded[o['buffer_id']]:
arr = np.zeros(bufs[o['buffer_id']].size // 2, dtype=np.float16)
cl.enqueue_copy(CL.cl_queue[0], arr, bufs[o['buffer_id']])
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr)
elif o['needs_load']:
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data'])
else:
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
if o['arg_type'] == "image1d_t":
assert not o['needs_load']
assert not bufs_loaded[o['buffer_id']]
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
else:
if 'data' in o:
buf = cl.Buffer(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
else:
# zero out buffers
buf = cl.Buffer(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
bufs[o['id']] = buf
bufs_loaded[o['id']] = 'data' in o
# if it's loaded, it's saved
if 'data' in o:
self.buffers_to_save.add(buf)
# load binaries
prgs = {}
for o in jdat['binaries']:
nptr = ptr + o['length']
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr])
ptr = nptr
# populate the cl_cache
for i,k in enumerate(jdat['kernels']):
kernel = prgs[k['name']]
aaa = []
for j,(a,sz) in enumerate(zip(k['args'], k['args_size'])):
if len(a) == 0:
aa = cl.LocalMemory(sz)
elif len(a) == 4:
a = a.encode('latin_1')
aa = np.uint32(struct.unpack("I", a)[0])
elif len(a) == 2:
a = a.encode('latin_1')
aa = np.uint16(struct.unpack("H", a)[0])
elif len(a) == 8:
#print(i,j,struct.unpack("Q", a.encode('latin_1'))[0])
aa = bufs[a]
aaa.append(aa)
self.cl_cache.append((kernel, [k['global_work_size'], k['local_work_size'], *aaa]))
if DEBUG >= 1: print(f"thneed: total bufs loaded: {len(bufs.keys())}")
# load inputs
for k in jdat['inputs']:
self.inputs[k['name']] = bufs[k['buffer_id']]
# load outputs
for k in jdat['outputs']:
self.outputs.append(bufs[k['buffer_id']])
def save(self, output_fn):
# this is the struct that will be saved
jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []}
# build the pieces of this struct
weights = []
binaries = []
saved_objs = set()
saved_binaries = set()
for prg, args in self.cl_cache:
# get binaries for saving
if prg.name not in saved_binaries:
binary = prg.clprograms[0].get_info(cl.program_info.BINARIES)
assert len(binary) == 1
jdat['binaries'].append({"name":prg.name, "length":len(binary[0])})
binaries.append(binary[0])
saved_binaries.add(prg.name)
# get the args from the kernel, some need the data saved
targs, args_size = [], []
argdtypes = prg.argdtypes if prg.argdtypes is not None else [None]*(len(args)-2)
for a,d in zip(args[2:], argdtypes):
if d == np.int16:
targs.append(struct.pack("H", a).decode("latin_1"))
args_size.append(2)
elif d == np.int32:
targs.append(struct.pack("I", a).decode("latin_1"))
args_size.append(4)
elif isinstance(a, cl.LocalMemory):
targs.append("")
args_size.append(a.size)
elif d is None:
if getattr(a, "global_id", None) is None:
setattr(a, "global_id", self.gobj)
self.gobj += 1
ptr = struct.pack("Q", a.global_id).decode("latin_1")
if ptr not in saved_objs:
if isinstance(a, cl.Buffer):
needs_load = a in self.buffers_to_save
jdat['objects'].append({
"id": ptr, "arg_type": "float*", "needs_load": needs_load, "size": a.size,
})
if needs_load:
data = np.empty(a.size//4, dtype=np.float32)
cl.enqueue_copy(CL.cl_queue[0], data, a, is_blocking=True)
weights.append(data.tobytes())
elif isinstance(a, cl.Image):
assert a.format == cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT), "wrong type"
needs_load = a in self.buffers_to_save
row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64
size = row_pitch * a.shape[1]
# this is *2 if float16 and *4 if float32
buf = cl.Buffer(CL.cl_ctxs[0], cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
# zero out the buffer
cl.enqueue_copy(CL.cl_queue[0], buf, b'\x00'*buf.size, is_blocking=True)
CLProgram("from_image_strided", compile_gpu("""
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 l;
l.y = get_global_id(1);
l.x = get_global_id(0);
out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
}
"""), argdtypes=(None, None, np.int32))(a, buf, row_pitch//(4*(2 if FLOAT16 else 4)), global_size=a.shape)
# multiple of 32 isn't enough
jdat['objects'].append({
"id": ptr, "needs_load": needs_load, "size": size, "arg_type": "image2d_t",
"width": a.shape[0], "height": a.shape[1], "row_pitch": row_pitch, "float32": not FLOAT16,
})
if needs_load:
data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32)
cl.enqueue_copy(CL.cl_queue[0], data, buf, is_blocking=True)
if FLOAT16: data = data.astype(np.float16)
weights.append(data.tobytes())
else:
raise Exception("unknown object", a)
#print(jdat['objects'][-1])
saved_objs.add(ptr)
targs.append(ptr)
args_size.append(8)
else:
raise Exception("idk this type")
# save the kernel itself
jdat['kernels'].append({
"name": prg.name,
"work_dim": len(args[0]),
"global_work_size": args[0],
# TODO: C++ thneed requires a local_work_size, so we fill it with ones
"local_work_size": [1 for _ in args[0]] if args[1] is None else args[1],
"num_args": len(args)-2,
"args": targs,
"args_size": args_size
})
jdat['outputs'] = [{
"buffer_id": struct.pack("Q", x.global_id).decode("latin_1"),
"size": x.size,
} for x in self.outputs]
jdat['inputs'] = [{
"buffer_id": struct.pack("Q", v.global_id).decode("latin_1"),
"size": v.size,
"name": k
} for k,v in self.inputs.items()][::-1]
print(f"saving thneed to {output_fn}")
with open(output_fn, "wb") as f:
j = json.dumps(jdat, ensure_ascii=False).encode('latin_1')
f.write(struct.pack("I", len(j)))
f.write(j)
f.write(b''.join(weights))
f.write(b''.join(binaries))
def run(self):
events = []
st = time.monotonic()
for prg, args in self.cl_cache:
events.append(prg.clprgs[0](CL.cl_queue[0], *args))
mt = time.monotonic()
CL.synchronize()
et = time.monotonic() - st
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")
if DEBUGCL >= 2:
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
print(f"{i:3d} {prg.name:25s} " + "queued @ %5.2f ms, submit @ %5.2fms, start @ %5.2f ms, end @ %5.2f ms" % tuple((x*OSX_TIMING_RATIO - st*1e9)/1e6 for x in [e.profile.queued, e.profile.submit, e.profile.start, e.profile.end]))
if DEBUGCL >= 1:
total_runtime = 0
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:25s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
if hasattr(prg, 'prg') and ((DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3):
print(prg.prg)
total_runtime += runtime
print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms")
return total_runtime/1e9
return et

View File

@@ -0,0 +1,205 @@
# type: ignore
import pickle, hashlib, zipfile, io, requests, struct, tempfile, platform, concurrent.futures
import numpy as np
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict
from typing import Union
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
from tinygrad.helpers import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import Device
from tinygrad.shape.view import strides_for_shape
OSX = platform.system() == "Darwin"
WINDOWS = platform.system() == "Windows"
def temp(x:str) -> str: return (Path(tempfile.gettempdir()) / x).as_posix()
def fetch(url):
if url.startswith("/") or url.startswith("."):
with open(url, "rb") as f:
return f.read()
fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest())
download_file(url, fp, skip_if_exists=not getenv("NOCACHE"))
with open(fp, "rb") as f:
return f.read()
def fetch_as_file(url):
if url.startswith("/") or url.startswith("."):
with open(url, "rb") as f:
return f.read()
fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest())
download_file(url, fp, skip_if_exists=not getenv("NOCACHE"))
return fp
def download_file(url, fp, skip_if_exists=True):
if skip_if_exists and Path(fp).is_file() and Path(fp).stat().st_size > 0:
return
r = requests.get(url, stream=True)
assert r.status_code == 200
progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url)
(path := Path(fp).parent).mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
for chunk in r.iter_content(chunk_size=16384):
progress_bar.update(f.write(chunk))
f.close()
Path(f.name).rename(fp)
def my_unpickle(fb0):
key_prelookup = defaultdict(list)
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
ident, storage_type, obj_key, location, obj_size = storage[0:5]
assert ident == 'storage'
assert prod(size) <= (obj_size - storage_offset)
if storage_type not in [np.float16, np.float32]:
if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {size}")
ret = None
else:
ret = Tensor.empty(*size, dtype=dtypes.from_np(storage_type))
key_prelookup[obj_key].append((storage_type, obj_size, ret, size, stride, storage_offset))
return ret
def _rebuild_parameter(*args):
#print(args)
pass
class Dummy: pass
class MyPickle(pickle.Unpickler):
def find_class(self, module, name):
#print(module, name)
if name == 'FloatStorage': return np.float32
if name == 'LongStorage': return np.int64
if name == 'IntStorage': return np.int32
if name == 'HalfStorage': return np.float16
if module == "torch._utils":
if name == "_rebuild_tensor_v2": return _rebuild_tensor_v2
if name == "_rebuild_parameter": return _rebuild_parameter
else:
if module.startswith('pytorch_lightning'): return Dummy
try:
return super().find_class(module, name)
except Exception:
return Dummy
def persistent_load(self, pid):
return pid
return MyPickle(fb0).load(), key_prelookup
def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset, mmap_allowed=False):
bytes_size = np.dtype(dtype).itemsize
if t is None:
myfile.seek(prod(shape) * bytes_size, 1)
return
bytes_offset = 0
if storage_offset is not None:
bytes_offset = storage_offset * bytes_size
myfile.seek(bytes_offset)
assert t.shape == shape or shape == tuple(), f"shape mismatch {t.shape} != {shape}"
assert t.dtype.np == dtype and t.dtype.itemsize == bytes_size
if any(s != 1 and st1 != st2 for s, st1, st2 in zip(shape, strides_for_shape(shape), strides)):
# slow path
buffer_size = sum(strides[i]*t.dtype.itemsize * (shape[i] - 1) for i in range(len(shape)))
buffer_size += t.dtype.itemsize
np_array = np.frombuffer(myfile.read(buffer_size), t.dtype.np)
np_array = np.lib.stride_tricks.as_strided(
np_array, shape=shape, strides=[i*t.dtype.itemsize for i in strides])
lna = t.lazydata.op.arg
lna.fxn = lambda _: np_array
t.realize()
return
# ["METAL", "CLANG", "LLVM"] support readinto for more speed
# ["GPU", "CUDA"] use _mmap since they have to copy in to the GPU anyway
# this needs real APIs
if t.device in ["METAL", "CLANG", "LLVM"]:
del t.lazydata.op
t.lazydata.realized = Device[t.lazydata.device].buffer(prod(t.shape), dtype=t.dtype)
myfile.readinto(t.lazydata.realized._buffer())
else:
def _mmap(lna):
assert myfile._compress_type == 0, "compressed data can't be mmaped"
return np.memmap(myfile._fileobj._file, dtype=lna.dtype, mode='r', offset=myfile._orig_compress_start + bytes_offset, shape=lna.shape)
def _read(lna):
ret = np.empty(lna.shape, dtype=lna.dtype)
myfile.readinto(ret.data)
return ret
if mmap_allowed and not OSX and t.device in ["GPU", "CUDA"]: t.lazydata.op.arg.fxn = _mmap
else: t.lazydata.op.arg.fxn = _read
t.realize()
def fake_torch_load_zipped(fb0, load_weights=True, multithreaded=True):
if Device.DEFAULT in ["TORCH", "GPU", "CUDA"]: multithreaded = False # multithreaded doesn't work with CUDA or TORCH. for GPU it's a wash with _mmap
with zipfile.ZipFile(fb0, 'r') as myzip:
base_name = myzip.namelist()[0].split('/', 1)[0]
with myzip.open(f'{base_name}/data.pkl') as myfile:
ret = my_unpickle(myfile)
if load_weights:
def load_weight(k, vv):
with myzip.open(f'{base_name}/data/{k}') as myfile:
for v in vv:
load_single_weight(v[2], myfile, v[3], v[4], v[0], v[5], mmap_allowed=True)
if multithreaded:
# 2 seems fastest
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
futures = {executor.submit(load_weight, k, v):k for k,v in ret[1].items()}
for future in (t:=tqdm(concurrent.futures.as_completed(futures), total=len(futures))):
if future.exception() is not None: raise future.exception()
k = futures[future]
t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
else:
for k,v in (t := tqdm(ret[1].items())):
t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
load_weight(k,v)
return ret[0]
def fake_torch_load(b0):
# convert it to a file
fb0 = io.BytesIO(b0)
if b0[0:2] == b"\x50\x4b":
return fake_torch_load_zipped(fb0)
# skip three junk pickles
pickle.load(fb0)
pickle.load(fb0)
pickle.load(fb0)
ret, key_prelookup = my_unpickle(fb0)
# create key_lookup
key_lookup = pickle.load(fb0)
key_real = [None] * len(key_lookup)
for k,v in key_prelookup.items():
assert len(v) == 1
key_real[key_lookup.index(k)] = v[0]
# read in the actual data
for storage_type, obj_size, tensor, np_shape, np_strides, storage_offset in key_real:
ll = struct.unpack("Q", fb0.read(8))[0]
assert ll == obj_size, f"size mismatch {ll} != {obj_size}"
assert storage_offset == 0, "not implemented"
load_single_weight(tensor, fb0, np_shape, np_strides, storage_type, None)
return ret
def get_child(parent, key):
obj = parent
for k in key.split('.'):
if k.isnumeric():
obj = obj[int(k)]
elif isinstance(obj, dict):
obj = obj[k]
else:
obj = getattr(obj, k)
return obj

View File

@@ -0,0 +1,166 @@
#!/usr/bin/env python3
import os, sys, io, pathlib
sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
if "OPT" not in os.environ: os.environ["OPT"] = "99"
os.environ["PREREALIZE"] = "0"
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
import onnx
from typing import Tuple, List
from extra.utils import fetch
from extra.onnx import get_run_onnx
from tinygrad.graph import print_tree, log_schedule_item
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, DEBUG, getenv, ImageDType, GRAPH
from tinygrad.realize import run_schedule
from tinygrad.ops import LoadOps, Device, ScheduleItem
from tinygrad.features.image import fix_schedule_for_images
Device.DEFAULT = "GPU"
def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
Tensor.no_grad = True
Tensor.training = False
# load the model
onnx_model = onnx.load(io.BytesIO(onnx_data))
run_onnx = get_run_onnx(onnx_model)
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
# run the model
inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()}
ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
schedule = ret.lazydata.schedule()
# filter schedule that don't depend on the inputs
input_lb = [x.lazydata.base for x in inputs.values()]
depends = set(input_lb)
for si in schedule:
if any(b in depends for b in si.inputs):
depends.add(si.out)
# run all kernels that don't depend on the inputs
# NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized
schedule, schedule_independent = partition(schedule, lambda si: si.out in depends)
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
# confirm no loadops in the (non independent) schedule except for the ones that load the input buffers
assert all(si.ast.op not in LoadOps or si.out in input_lb for si in schedule), "has loadops, can't compile to Thneed"
return schedule, schedule_independent, inputs
def schedule_to_thneed(schedule, output_fn):
from extra.thneed import Thneed
# transform to CL.CACHE
used_ops = 0
cl_cache = []
for si in schedule:
prg = Device["GPU"].method_cache[si.ast]
args = (si.out,) + si.inputs
# pass these to thneed
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
setattr(prg.clprg, 'prg', prg.prg)
global_size = prg.global_size + [1]*(3-len(prg.global_size))
local_size = prg.local_size + [1]*(3-len(prg.local_size))
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x.realized._buf for x in args]]))
used_ops += prg.op_estimate
from extra.thneed import Thneed
input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()}
t = Thneed(cl_cache, {k:v._buf for k,v in input_rawbuffers.items()})
# save thneed (before run)
t.save(output_fn)
print(f"buffers to save: {len(t.buffers_to_save)}, inputs: {list(t.inputs.keys())}, outputs: {t.outputs}")
runtime = t.run()
print(f"network using {used_ops/1e9:.2f} GOPS with runtime {runtime*1e3:.2f} ms that's {used_ops/runtime*1e-9:.2f} GFLOPS")
def thneed_test_onnx(onnx_data, output_fn):
import onnx
import pyopencl as cl
from tinygrad.runtime.ops_gpu import CL
import numpy as np
from extra.thneed import Thneed
onnx_model = onnx.load(io.BytesIO(onnx_data))
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
inputs = {k:Tensor.randn(*shp, requires_grad=False)*8 for k,shp in input_shapes.items()}
new_np_inputs = {k:v.realize().numpy() for k,v in inputs.items()}
if getenv("ORT"):
# test with onnxruntime
import onnxruntime as ort
onnx_session = ort.InferenceSession(onnx_data)
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_np_inputs.items()})
new_torch_out = onnx_output[0]
else:
# test with torch
from test.models.test_onnx import run_onnx_torch
new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
if output_fn is None:
# non thneed
run_onnx = get_run_onnx(onnx_model)
new_tinygrad_out = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).numpy()
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
print("classic self-test passed!")
else:
# load thneed and try that
nt = Thneed()
nt.load(output_fn)
# inputs
for k,v in nt.inputs.items():
cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True)
nt.run()
new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(new_torch_out.shape)
cl.enqueue_copy(CL.cl_queue[0], new_thneed_out, nt.outputs[0], is_blocking=True)
# compare torch to thneed
np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2)
print("thneed self-test passed!")
if __name__ == "__main__":
onnx_data = fetch(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL)
# quick test for ONNX issues
#thneed_test_onnx(onnx_data, None)
#exit(0)
schedule, schedule_independent, inputs = get_schedule(onnx_data)
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps)
print(f"{len(schedule_input)} inputs")
run_schedule(schedule_independent, disable_logging=True)
run_schedule(schedule_input)
with Context(DEBUG=2, BEAM=getenv("LATEBEAM")):
schedule = fix_schedule_for_images(schedule)
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
if GRAPH:
for si in schedule_input: log_schedule_item(si)
for si in schedule: log_schedule_item(si)
GlobalCounters.reset()
run_schedule(schedule[:])
output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed"
schedule_to_thneed(schedule, output_fn)
FLOAT16 = getenv("FLOAT16", 0)
if FLOAT16 == 0:
try:
thneed_test_onnx(onnx_data, output_fn)
except ModuleNotFoundError as e:
print(f"TEST NOT HAPPENING {e}")

View File

@@ -0,0 +1,585 @@
from __future__ import annotations
import os, math, itertools
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, Device, Compiled
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int, ansilen, getenv, prod, DEBUG
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import View, strides_for_shape
from dataclasses import dataclass
from enum import Enum, auto
class OptOps(Enum):
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto() # noqa: E702
def __lt__(self, x:OptOps): return self.value < x.value
@dataclass(frozen=True, order=True)
class Opt:
op: OptOps
axis: Optional[int] = None
amt: Optional[int] = None
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
@dataclass(frozen=True)
class TensorCore:
device: str
dims: List[int]
dtype_in: DType
dtype_out: DType
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
upcast_dim: int # which TC dim to upcast
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim
thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim
arch: Optional[str] = None
def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>"
tensor_cores: Dict[str, List[TensorCore]] = {
"METAL": [
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"),
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"),
],
"HIP": [
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]),
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]),
]
}
class LocalBuffer(NamedTuple):
name: str
size: int
dtype: DType = dtypes.float32
realized: None = None
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
class LinearizerOptions(NamedTuple):
device: str = ""
# TODO: make this generic with a list of supported types
supports_float4: bool = True
supports_float4_alu: bool = True
has_local: bool = True
has_shared: bool = True
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
global_max: Optional[List[int]] = None
local_max: Optional[List[int]] = None
class Kernel:
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None):
self.opts = opts if opts else (cast(Compiled, Device[Device.DEFAULT]).linearizer_opts if isinstance(Device[Device.DEFAULT], Compiled) else LinearizerOptions())
self.ast = ast
# fetch lazyop info
self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
# there's only allowed to be one reduceop
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]
assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
self.reduceop = reduceops[0] if reduceops else None
# create new shapetrackers inside this kernel, we will permute them
self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in BufferOps])
# get earlybufs, before the one reduce op
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in BufferOps] if self.reduceop else []
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
# create the (permuted) shapetrackers
self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)]
# move all reduce axes to the end
reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
self.reshape_and_permute(None, permute)
# parameters for optimization
self.applied_opts: List[Opt] = []
self.group_for_reduce: List[int] = []
self.upcasted: int = 0
self.local_dims: int = 0
self.local_alias: Dict[int, LocalBuffer] = {}
self.tensor_core: Optional[TensorCore] = None
self.dont_use_locals: bool = False
# group simplifies
self.simplify_ones()
self.simplify_merge_adjacent()
# cache
self.applied_opts_cache: Optional[List[Opt]] = None
def copy(self):
ret = type(self).__new__(type(self))
# base linearizer params
ret.opts, ret.ast = self.opts, self.ast
# things downstream of the AST
# NOTE: we copy bufs for local buffers and sts for optimizations
ret.info, ret.reduceop, ret.bufs, ret.earlybufs, ret.full_buf_index, ret.sts = \
self.info, self.reduceop, self.bufs[:], self.earlybufs, self.full_buf_index, self.sts[:]
# parameters for optimizations
ret.applied_opts, ret.group_for_reduce, ret.upcasted, ret.local_dims, ret.local_alias, ret.tensor_core, ret.dont_use_locals = \
self.applied_opts[:], self.group_for_reduce[:], self.upcasted, self.local_dims, self.local_alias.copy(), self.tensor_core, self.dont_use_locals
# uncached since linearize didn't run
ret.applied_opts_cache = None
return ret
@property
def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
def has_variable_shape(self) -> bool:
for b in self.bufs:
if not isinstance(b, LocalBuffer) and not all_int(b.st.views[-1].shape): return True
return False
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]
def upcasted_axis(self, i):
return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:],
self.sts[i].real_strides()[self.shape_len-self.upcasted:],
[x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
# TODO: is there a better way to write this?
def acc_offsets(self, i):
if self.upcasted == 0: return [0]
upcasted_i = self.upcasted_axis(i)
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
def get_upcast_dim(self, i) -> List[int]:
should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1]
@property
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)
@property
def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
@property
def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
@property
def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.shape_len-self.upcasted]
@property
def shape_len(self) -> int: return len(self.sts[0].shape)
@property
def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]]
@property
def global_dims(self) -> int: return self.first_reduce-self.local_dims
# there's eight chunks of the shape
# blue -- global dims
# cyan -- local dims (warp ones first)
# *** self.first_reduce
# green -- reduce-local dims
# white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes)
# red -- reduce loops
# *** self.upcasted
# purple -- reduce upcasted
# yellow -- normal upcasted dimensions
def colors(self) -> List[str]:
# first non local non reduce dims are global (blue)
colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
colors += ["cyan"] * self.local_dims
# between first_reduce and first_reduce + group_for_reduce, they are either upcast mid reduce (white), or late upcasted (green)
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))]
# between first_reduce + group_for_reduce and upcasted, they are reduce (red)
colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce)))
# upcasted dimensions are reduce (magenta) or normal (yellow)
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
assert len(colors) == self.shape_len, "colors size mismatch"
return colors
def colored_shape(self, pad=None, dense=False) -> str:
ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors()))
if pad: ret += ' '*(pad-ansilen(ret))
return ret
# ******************** base simplifiers ********************
# apply reshape and permute to all shapetrackers
def reshape_and_permute(self, new_shape_fxn, axis):
new_sts = []
for st in self.sts:
if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape)))
if axis is not None: st = st.permute(tuple(axis))
new_sts.append(st)
self.sts = new_sts
# drops the final dimension
def upcast(self):
assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1"
self.upcasted += 1
# axis : the axis to pull from
# amount : the amount to take
# top : if you want to pull that amount from the top
# insert_before : place to insert the new stuff
def shift_to(self, axis, amount, top=False, insert_before=None):
if insert_before is None: insert_before = self.shape_len
move_axis = axis if top else axis+1
if move_axis < insert_before: insert_before += 1
self.reshape_and_permute(
lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]),
[i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
# ******************** complex simplifiers ********************
def simplify_ones(self) -> bool:
# remove places where the shape is all ones
# TODO: this should be factored in to multi shape stride
if self.shape_len == 0: return False
all_ones = [s==1 for s in self.full_shape]
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
return any(all_ones)
def simplify_merge_adjacent(self):
if self.shape_len == 0: return
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
if isinstance(self.bufs[0].dtype, ImageDType):
base_shape = self.bufs[0].dtype.shape
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
special_strides: Tuple[int, ...] = tuple()
for i,g in enumerate(shape_idx_groups):
shape_piece = tuple(self.output_shape[x] for x in g)
assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
special_strides += strides_for_shape(shape_piece)
# adding the fake image shape
shapes.append(self.output_shape)
strides.append(special_strides)
# merge dimensions if we can, multi get_shape_strides
# TODO: does this always preserve the reduce dimension, NO
# TODO: move this into shapetracker, with tests!
rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
for i in range(1, len(shapes[0])):
can_merge = []
for j in range(len(shapes)):
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0)))
# more can merge than this
mergeable = all(can_merge) and i != self.first_reduce
for j in range(len(shapes)):
if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
else: rets[j].append((shapes[j][i], strides[j][i]))
# do the reshapes
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
# ******************** GPU simplifiers ********************
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
new_shape,dims = list(x), len(x)
for i in range(dims):
next_idx = (i + 1) % dims
while new_shape[i] > max_size[i]:
new_shape[i] = new_shape[i] // 2
if (new_shape[next_idx] <= max_size[next_idx]):
new_shape[next_idx] = new_shape[next_idx] * 2
else:
next_idx = (next_idx + 1) % dims
new_shape[next_idx] = new_shape[next_idx] * 2
return tuple(new_shape)
def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
# Check the global allocation limit, current the global_size will be flipped during codegen
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
global_dims = self.first_reduce-self.local_dims
if global_dims > 0:
if global_max:
tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}"
for i in range(global_dims-1):
if i < len(global_max) and self.full_shape[i] > global_max[i]:
order = list(range(len(self.full_shape)))
order[i], order[global_dims-1] = order[global_dims-1], order[i]
self.reshape_and_permute(None, order)
if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
def alias_buffer(self, i, pattern):
assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
bst = 1
real_strides = self.sts[i].real_strides()
shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
for j,p in enumerate(pattern):
if priority == p and real_strides[j] != 0:
stride[j] = bst
bst *= shp[j]
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size()))
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])
# ******************** high level optimizers ********************
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None):
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores:
for tc in tensor_cores[self.opts.device]:
if not((tc.arch is None or tc.arch == os.uname().machine) and isinstance(self.reduceop.src[0], LazyOp)): continue
has_cast = tc.dtype_in != tc.dtype_out
if has_cast and not(isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
if not(isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL): continue
if not(isinstance(mul_op.src[0], LazyOp) and mul_op.src[0].op == BufferOps.MEM and mul_op.src[0].arg.dtype == tc.dtype_in): continue
if not(isinstance(mul_op.src[1], LazyOp) and mul_op.src[1].op == BufferOps.MEM and mul_op.src[1].arg.dtype == tc.dtype_in): continue
buf0, buf1 = self.bufs.index(cast(MemBuffer, mul_op.src[0].arg)), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg))
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0]
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0]
if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0] # TODO: select axis in smart way
s0_exists, s1_exists = True, True
assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0
def fix(needed, ax):
nonlocal s0, s1, s0_exists, s1_exists
if not needed: return
if s0_exists and ax == s0:
if s1_exists and s0 < s1: s1 -= 1
s0_exists = False
elif s1_exists and ax == s1:
if s0_exists and s1 < s0: s0 -= 1
s1_exists = False
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
for (tc_dim, tc_amt) in tc.threads:
fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
# assert tensor core and prevent extra_opts from altering the key shape structure
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
if extra_opts is not None:
for opt in extra_opts:
self.apply_opt(opt)
else:
# hand-coded TC opts
if s1_exists:
s1_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s1]%upc == 0][0]
if s1_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1)
if s0_exists:
s0_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s0]%upc == 0][0]
if s0_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0)
if self.tensor_core and s0_exists:
for upc in [4,2]:
if self.full_shape[s0] % upc == 0:
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
break
# alias buffer
alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2)
self.alias_buffer(buf0, alias_pattern)
self.alias_buffer(buf1, alias_pattern)
return True
return False
def apply_opt(self, opt:Opt):
assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals"
self.applied_opts.append(opt)
if opt.axis is not None:
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0))
else:
axis = -1
if opt.amt is not None:
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
else:
amt = -1
if opt.op == OptOps.LOCAL: # cyan
assert axis < self.first_reduce, "can't local a reduce"
assert not(self.tensor_core), "can't local with tensor cores"
self.shift_to(axis, amt, insert_before=self.first_reduce)
self.local_dims += 1
elif opt.op == OptOps.LASTLOCAL: # cyan
assert axis < self.first_reduce, "can't local a reduce"
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
self.local_dims += 1
elif opt.op == OptOps.GROUP: # green
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
assert not(self.tensor_core), "can't group with tensor cores"
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(amt)
elif opt.op == OptOps.GROUPTOP: # green
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
assert not(self.tensor_core), "can't group with tensor cores"
self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(amt)
elif opt.op == OptOps.UNROLL: # purple
assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted"
assert amt <= 32, "don't unroll more than 32"
self.shift_to(axis, amt, insert_before=None)
self.upcast()
elif opt.op == OptOps.UPCAST: # yellow
assert axis < self.first_reduce, "upcast is for non-reduce"
assert amt <= 8, "don't upcast more than 8"
self.shift_to(axis, amt, insert_before=None)
self.upcast()
elif opt.op == OptOps.UPCASTMID: # white
assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce"
axes = self.sts[0].unit_stride_axes()
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
assert axes[0] == axis, "wrong axis"
assert amt == 4, "don't upcast mid anything but 4"
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(amt)
elif opt.op == OptOps.NOLOCALS:
assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals"
assert not self.dont_use_locals, "already not using locals"
self.dont_use_locals = True
return self.simplify_ones()
def required_optimizations(self, early_only=False):
for buf_index,buf in enumerate(self.bufs):
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
if unit_stride_axes_mul_4[0] < self.first_reduce:
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
else:
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
def hand_coded_optimizations(self):
# if there's images in the earlybufs, we have to make an axis the 4 loading one
self.required_optimizations(early_only=True)
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \
self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM:
buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg)
buf1 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[1]).arg)
buf0_strides = self.sts[buf0].real_strides()
buf1_strides = self.sts[buf1].real_strides()
def has_expanded_axis(s, st): return any(x > 1 and y == 0 for x,y in zip(s,st))
if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)):
for global_idx in range(self.global_dims):
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
if MV_THREADS_PER_ROW > 1:
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
if MV_BLOCKSIZE > 1:
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
if MV_ROWS_PER_THREAD > 1:
self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
return
if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]):
# are we grouping? (requires local shape support)
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
break
# are we upcasting in mid reduce? (only for images)
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
axes = self.sts[0].unit_stride_axes()
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
if self.sts[0].shape[axes[0]]%4 == 0:
self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
# now do everything required
self.required_optimizations()
# no more opt if we are grouping
if self.group_for_reduce: return
# **** below this line need to be optional and benchmarked ****
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
# to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below
# expression and run test/test_ops.py with IMAGE=2
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
# this can be made much smarter
to_upcast: List[int] = []
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
for axis in range(self.first_reduce):
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
# for now skip upcasting here if there is a symbolic axis
if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis)
for axis in to_upcast[::-1]:
self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
# potentially do more upcasts of non reduce axes based on a heuristic
upcasted_axis = set()
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
xb_choices = []
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)):
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
if xb_choices:
xb_choices = sorted(xb_choices)
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
upcasted_axis.add(xb_choices[0][2])
else:
break
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64):
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
# if it's small, upcast a second reduce dimension too
if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
else:
for splits in [4]:
if self.full_unupcasted_shape[-1]%splits == 0:
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits))
break
# if nothing at all is upcasted and it's easy to, do an upcast
# TODO: this is breaking the tests
for splits in [4]:
if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0:
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits))
# **** local groups ****
if self.opts.has_local:
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce:
self.apply_opt(Opt(OptOps.NOLOCALS))
else:
# prioritize making expand axes local
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))]
to_local: List[Tuple[int, int]] = []
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
local_size = prod(sz for _, sz in to_local)
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None)
if local_sz is not None: to_local.append((axis, local_sz))
deleted_shape = 0
for axis, local_sz in sorted(to_local[:3]):
axis = axis - deleted_shape
will_delete_shape = local_sz == self.full_shape[axis]
self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
if will_delete_shape: deleted_shape += 1

View File

@@ -0,0 +1,441 @@
from __future__ import annotations
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, Dict, Union, Sequence, Final, Set
import itertools, math, functools
from collections import defaultdict
from enum import Enum, auto
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename
from tinygrad.codegen.kernel import LocalBuffer, Kernel
from tinygrad.lazy import vars_from_ast
from tinygrad.features.image import to_image_idx
# bottom ones are asm only
class UOps(Enum):
LOOP = auto(); IF = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
class UOp(NamedTuple):
uop: UOps
dtype: Optional[DType]
vin: Tuple[UOp, ...]
arg: Any
def __repr__(self): return f"{self.num:4d} {str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.num for x in self.vin]):32s} {self.arg}"
#def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str(self.vin):32s} {self.arg}"
# UOps are unique
num: int
def __hash__(self): return self.num
def __eq__(self, x): return self.num == x.num
def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)]
if maxdim != 0 and len(local_dims) > maxdim:
dd = local_idxs[maxdim-1]
nli = []
for s in local_dims[maxdim-1:][::-1]:
nli.append(dd % s)
dd //= s
local_idxs = local_idxs[0:maxdim-1] + nli[::-1]
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
class Linearizer(Kernel):
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
return self.uop(UOps.ALU, dtype, (a, render_b), op)
# NOTE: the consts have to be be cached for deduping of downstream uops to work
def const(self, b:Union[int,float], dtype=dtypes.int32) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b)
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT, dtype=dtypes.bool),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def global_load(self, i:int, idxs:Sequence[Node], acc=None) -> List[UOp]:
buf = self.bufs[i]
const = buf.val if isinstance(buf, ConstBuffer) else acc
def rename_var(v: VariableOrNum, expr: str): return v if isinstance(v, NumNode) else Variable(expr, v.min, v.max)
amt, dim = 1, None
upcast_dim = self.get_upcast_dim(i)
if len(upcast_dim) == 1 and len(float4_expand := idxs[upcast_dim[0]].expand()) in [4,2]:
dim, amt = upcast_dim[0], len(float4_expand)
expand_vars = tuple([rename_var(idx.expand_idx(), f"_uidx{j}") for j, idx in enumerate(idxs)])
fake_idxs = [idx.substitute({idx.expand_idx(): ev}) for idx, ev in zip(idxs, expand_vars)]
if dim is not None:
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs[:dim] + [float4_expand[0]] + fake_idxs[dim+1:])
if (g_idx // amt * amt).render() != g_idx.render():
(g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None
else:
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs)
localtype = dtypes.float32 if amt == 1 else dtypes._float4 if amt == 4 else dtypes._float2
e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars)
ret = []
invalid_value = 0 if dtypes.is_int(buf.dtype) else 0.0
for idx, valid, rep_idx in zip(e_idxs, e_valids, Node.iter_idxs(expand_vars)):
this_const, idx, valid = (invalid_value, Variable.num(0), Variable.num(1)) if valid.max == 0 else (const, idx, valid)
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}"
if key not in self.load_cache:
if acc is not None:
assert valid.min == 1
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False)
elif this_const is not None:
self.load_cache[key] = self.const(this_const, localtype)
if valid.min == 0 and valid.max == 1:
valid_rendered = valid.render(self.render_ops, self)
self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE)
else:
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
if isinstance(buf.dtype, ImageDType):
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uop(UOps.CAST, dtypes._int2, (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)))
else:
rendered_idx = idx.render(self.render_ops, self)
if valid.min == 0:
valid_rendered = valid.render(self.render_ops, self)
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)))
else:
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx))
ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
return ret
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> None:
buf = self.bufs[i]
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
expanded_nodes = [idx.expand() for idx in idxs]
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
store_offset = dict(zip(_idxs, store))
# float4 grouping
upcast_dim = self.get_upcast_dim(i)
if len(upcast_dim) == 1 and len(expanded_nodes[upcast_dim[0]]) in [2,4]:
grouped_store_offset = defaultdict(list)
for k in store_offset:
_idx = k[:upcast_dim[0]] + (expanded_nodes[upcast_dim[0]][0],) + k[upcast_dim[0]+1:]
grouped_store_offset[_idx].append(store_offset[k])
store_offset_new = {}
for k,out_tokens in grouped_store_offset.items():
amt = len(out_tokens)
idx, valid = self.sts[i].expr_idxs(k)
assert idx.render() == ((idx//amt)*amt).render(), "float4 stores are always aligned"
assert valid.min == 1, "stores are always valid"
store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens))
store_offset = store_offset_new
for idx, var in store_offset.items():
idx, valid = self.sts[i].expr_idxs(idx)
if isinstance(buf.dtype, ImageDType):
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
else:
rendered_idx = idx.render(self.render_ops, self)
self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var))
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
def linearize(self):
# no new opts and we already ran? skip relinearizing
if self.applied_opts == self.applied_opts_cache: return self
# save backups
sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduce[:], self.upcasted
# global uop cache
self.saved_exprs: Dict[Tuple, UOp] = dict()
# limit dims if we need to
if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
# uops
self.uops: List[UOp] = []
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
self.loop_uops: Dict[str, UOp] = {}
# add global buffers
for i,buf in enumerate(self.bufs):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype))
# add var vals
for var in sorted(vars_from_ast(self.ast), key=lambda k: k.key):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
# define local buffers
for lb in self.local_alias.values():
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))
# add a local buffer for multistage reduce. # TODO: use local alias
if self.group_for_reduce:
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size())))
# kernel name (before late upcast)
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) if isinstance(x, int) else sym_rename(x) for x in self.full_shape])
self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
# name the function something unique
Linearizer.kernel_cnt[self.function_name] += 1
suffix = f"{'n'+str(Linearizer.kernel_cnt[self.function_name]-1)}" if Linearizer.kernel_cnt[self.function_name] > 1 else ""
self.function_name, self.display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
# define indexes
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+len(self.group_for_reduce)], 3 if self.opts.has_local else 0)
full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]]
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
# global and local loops
def render_loop(xx:List[Variable]):
self.loop_uops.update({x.expr:self.uop(UOps.LOOP, dtypes.int32, (
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None})
def end_loop(xx:List[Variable]):
for x in xx[::-1]:
if not isinstance(x, NumNode) and x.expr is not None:
loop_uop = self.loop_uops[x.expr]
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,))
# set global/local size
self.global_size: Optional[List[int]] = None
self.local_size: Optional[List[int]] = None
if self.dont_use_locals:
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
elif self.opts.has_local:
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
self.global_size += [1]*(3-len(self.global_size))
self.local_size += [1]*(3-len(self.local_size))
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
else:
render_loop(loop_global_idxs+loop_local_idxs)
# parse AST
loaded_buffers = {}
acc = []
self.load_cache: Dict[str, UOp] = {}
if_gate: Optional[UOp] = None
# reduce op
fake_reduce_idxs: List[Variable] = []
if self.reduceop is not None:
# define indexes
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)]
fake_reduce_idxs = [x*0 for x in reduce_idxs]
# define accumulator
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
if self.tensor_core:
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
replace_idxs = []
for alias in aliases:
full_var, full_var_sz = Variable.num(0), 1
if alias[0] != 0:
for i in alias:
next_var = local_idxs[-i] if i > 0 else Variable(None, 0, local_size-1)
full_var += next_var * full_var_sz
full_var_sz *= next_var.max+1
replace_idxs.append(full_var)
return replace_idxs
replace_acc_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[2], self.tensor_core.thread_local_aliases[2])
for n in range(len(self.tensor_core.threads)):
local_idxs[self.local_dims-len(self.tensor_core.threads)+n] = replace_acc_idxs[n] # replace locals
for n in range(len(replace_acc_idxs)-len(self.tensor_core.threads)):
upcast_idxs[n] = replace_acc_idxs[len(self.tensor_core.threads)+n] # replace upcasts
# reduce loop
render_loop(reduce_idxs)
# barrier for fast GEMM
if self.tensor_core: self.uop(UOps.BARRIER, None, (), cachable=False)
# compute local aliases
locals_to_store = []
for i in self.local_alias:
localbuf_idx = self.bufs.index(self.local_alias[i])
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
if self.tensor_core:
min_alias_idx = min(self.local_alias.keys())
replace_input_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[i-min_alias_idx], self.tensor_core.thread_local_aliases[i-min_alias_idx])
for n in range(len(self.tensor_core.threads)):
buf_idxs[self.first_reduce-len(self.tensor_core.threads)+n] = replace_input_idxs[n] # replace locals
for n in range(len(replace_input_idxs)-len(self.tensor_core.threads)):
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(self.tensor_core.threads)+n] # replace upcasts
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs)
ll = self.global_load(i, buf_idxs)
locals_to_store.append((localbuf_idx, buf_idxs, ll))
# copy in any global buffers
if self.tensor_core:
wmma_sz = self.tensor_core.thread_local_sizes
# calculate the number of local accumulator reduces and render WMMAs: this is bad... this needs to come from someplace else
nx, ny, nacc = (len(locals_to_store[0][2])//wmma_sz[0]), (len(locals_to_store[1][2])//wmma_sz[1]), (len(acc)//wmma_sz[2])
acc_reds = math.isqrt((nx*ny)//nacc)
i, bx, by = 0, nx//acc_reds, ny//acc_reds
for y in range(by):
for x in range(bx):
for j in range(acc_reds):
self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]]+locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]]+acc[i:i+wmma_sz[2]]), (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
i += wmma_sz[2]
else:
if locals_to_store:
self.uop(UOps.BARRIER, None, (), cachable=False)
for i, idxs, ll in locals_to_store: self.global_store(i, idxs, ll)
self.uop(UOps.BARRIER, None, (), cachable=False)
# load earlybufs
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
# run early AST (with reduce)
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True)
# end the reduce loop
end_loop(reduce_idxs)
self.load_cache.clear()
# end the local loop, do the local reduce
if self.group_for_reduce:
fake_global_idxs = [x*0 for x in global_idxs]
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
self.uop(UOps.BARRIER, None, (), cachable=False)
end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if?
if self.opts.has_local:
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
if_gate = self.uop(UOps.IF, None, (if_cond,), cachable=False)
# create new late reduce local loops and replace local_idxs that have been used
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
# if any group_for_reduce items aren't reduces, upcast them here
for j in self.upcast_in_mid_reduce_axes:
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
self.upcast()
self.group_for_reduce.pop()
local_idxs = local_idxs[:-1]
end_local_idxs = end_local_idxs[:-1]
# regenerate upcast_idxs
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
# NOTE: this structure is the same as the reduce op above
# define late accumulator
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
# late reduce loop
render_loop(end_local_idxs)
# load localbufs
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True) # type: ignore
# end the late reduce loop
end_loop(end_local_idxs)
self.load_cache.clear()
# load latebufs
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
# run late AST
val = self.ast_parse(self.ast, acc, None, loaded_buffers)
# store
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
# end the global (and maybe local) loop
if if_gate: self.uop(UOps.END, None, (if_gate,))
end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs)
# (recursively) remove childless uops
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.WMMA, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
while 1:
has_child: Set[UOp] = set()
for ru in self.uops:
for vu in ru.vin:
has_child.add(vu)
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS]
if len(nu) == len(self.uops): break
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
self.uops = nu
# restore backups
self.sts, self.group_for_reduce, self.upcasted = sts_backup, gfr_backup, upc_backup
# set cache and return
self.applied_opts_cache = self.applied_opts[:]
return self
def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True) -> UOp:
key = (uop, dtype, vin, arg)
if uop == UOps.PHI and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self phi is noop
if uop == UOps.CAST and all(x.uop == UOps.GEP for x in vin) and all_same([x.vin[0] for x in vin]) and all(x.arg == i for i,x in enumerate(vin)): return vin[0].vin[0]
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype)
if uop == UOps.ALU:
# rewrites. NOTE: the rewritten NEG op is still around...
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable)
# constant folding
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype)
# zero folding
for x in [0,1]:
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops)))
if DEBUG >= 5: print(self.uops[-1])
if cachable: self.saved_exprs[key] = self.uops[-1]
return self.uops[-1]
def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False) -> List[UOp]:
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
if x.op in BufferOps: return loaded_buffers[x.arg]
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, offs, loaded_buffers) # cast isn't an ALU op
if x.op in ReduceOps and not do_reduce:
assert offs is None, "not available if we aren't doing reduce"
return acc
# MULACC fusion. TODO: this is copied from Interpreted
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL:
x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
values = [self.ast_parse(v, acc, offs, loaded_buffers) for v in x.src]
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
if x.op in ops:
ret = []
for idx, val, off in zip([[i] for i in range(len(values[0]))], zip(*values), offs):
new_val = self.uop(UOps.ALU, dtypes.float32, val+(acc[off],), ops[x.op])
# NOTE: we could apply the phi node to only the last change, but this breaks CLANG with nested max(x,y)
acc[off] = self.uop(UOps.PHI, dtypes.float32, (acc[off], new_val))
ret.append((idx, acc[off]))
else:
ret = [(idx, self.uop(UOps.ALU, dtypes.float32, val, x.op)) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values))]
ordered_ret: List[Optional[UOp]] = [None]*len(values[0])
# scatter
for i,j in ret:
for k in i:
ordered_ret[k] = j
assert all(isinstance(x, UOp) for x in ordered_ret), "some tokens didn't get scattered?"
return cast(List[UOp], ordered_ret)

View File

@@ -0,0 +1,204 @@
from typing import List, Tuple, Dict, Any
from tinygrad.helpers import ImageDType, prod, IMAGE, getenv, dtypes, DEBUG, flatten
# *** image Tensor function replacements ***
from tinygrad.lazy import get_single_root
def image_dot(self, w):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
n1, n2 = len(self.shape), len(w.shape)
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = self.shape[0:-2] + (cout,-1)
if len(self.shape) > 1:
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
return image_conv2d(cx, cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
rcout = cout//groups
x, w = self, weight.reshape(groups, rcout, cin, H, W)
# hack for non multiples of 4 on cin
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
added_input_channels = 4 - (cin % 4)
w = w.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(w.shape))))
x = x.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(x.shape))))
cin = cin + added_input_channels
x = x.reshape(bs, groups*cin, iy, ix)
# hack for non multiples of 4 on rcout
added_output_channels = 0
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
added_output_channels = 4 - (rcout % 4)
rcout += added_output_channels
cout = groups * rcout
w = w.slice(tuple((0, rcout) if i == 1 else (0, s) for i,s in enumerate(w.shape)))
# packed (note: flipping bs and iy would make the auto-padding work)
x = x.permute(0,2,3,1)
cin_last = iy == 1 and ix == 1
if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
x, w = x.contiguous(), w.contiguous()
if getenv("PREREALIZE", 1) and get_single_root(w.lazydata).realized: w.realize()
# expand out
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
# padding
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
# prepare input
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
oy, ox = x.shape[4:6]
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W)
# prepare weights
w = w.permute(0,4,2,5,1,3)
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)).expand(x.shape)
# the conv! (+ the bias)
ret = x*w
if IMAGE >= 2: ret = ret.cast(base_image_type((bs*oy, ox*cout//4, 4)))
ret = ret.sum((-4, -3, -2, -1))
# undo hack for non multiples of 4 on C.rcout
if added_output_channels != 0:
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
rcout -= added_output_channels
cout = groups * rcout
# NCHW output
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
# *** schedules with images need to be fixed to be valid ***
import dataclasses
from tinygrad.ops import ScheduleItem, BufferOps, LazyOp, UnaryOps, LoadOps, MemBuffer, get_lazyop_info
def fix_schedule_for_images(schedule:List[ScheduleItem]):
# this is the fundamental fix, find unwritable or unreadable images and convert them to normal float32 (TODO: should it be float16?)
replace_inputs = {}
for i, si in enumerate(schedule):
if isinstance(si.out.dtype, ImageDType) and (prod(si.out.shape) != prod(si.out.dtype.shape) or not any(si.out.shape[x]%4 == 0 for x in si.out.st.unit_stride_axes())):
if DEBUG >= 1: print(f"{i:3d}: rewrite output, output shape {prod(si.out.shape)}, image dtype {si.out.dtype} prod {prod(si.out.dtype.shape)}")
si.out.dtype = dtypes.float32
for b in si.ast.get_lazyops():
if b.op != BufferOps.MEM: continue
# TODO: unit_stride axes will fail if there's a mask, even if the mask is divisble by four. this is too aggressive
if isinstance(si.inputs[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())):
if DEBUG >= 1: print(f"{i:3d}: rewrite input, image dtype {si.inputs[b.arg.idx-1].dtype}, {b.arg.st.views}")
if si.inputs[b.arg.idx-1].realized:
# have to copy it
replace_inputs[si.inputs[b.arg.idx-1]] = si.inputs[b.arg.idx-1].cast(dtypes.float32)
else:
# change it before it's created
si.inputs[b.arg.idx-1].dtype = dtypes.float32
# now fix up the schedule to reflect the new dtypes
fixed_schedule:List[ScheduleItem] = []
for i,si in enumerate(schedule):
ast = si.ast
inputs = si.inputs
# replace inputs with casted versions
if any(x in replace_inputs for x in inputs):
fixed_schedule += flatten([replace_inputs[x].schedule() for x in inputs if x in replace_inputs])
inputs = tuple(replace_inputs.get(x, x) for x in inputs)
# fix input dtypes to match what they actually are
replacements = {}
for b in si.ast.get_lazyops():
if b.op != BufferOps.MEM: continue
if b.arg.dtype != inputs[b.arg.idx-1].dtype:
replacements[b] = LazyOp(BufferOps.MEM, (), MemBuffer(b.arg.idx, inputs[b.arg.idx-1].dtype, b.arg.st))
if replacements: ast = ast.map_buffers(replacements)
# fix the ops to create the output dtype
if ast.op not in LoadOps:
info = get_lazyop_info(ast)
if info.dtype != si.out.dtype:
if DEBUG >= 3: print(f"{i:3d}: info.dtype {info.dtype} != {si.out.dtype} -> {si.out.dtype}")
ast = LazyOp(UnaryOps.CAST, (ast,), (si.out.dtype, False))
# put this in the fixed schedule
fixed_schedule.append(dataclasses.replace(si, ast=ast, inputs=inputs))
return fixed_schedule
# *** images have weird indexing requirements ***
from tinygrad.shape.symbolic import Node, AndNode, Variable, NumNode, SumNode, LtNode
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
idx = (idxy // 4) % base_shape[1]
idy = (idxy // (4 * base_shape[1]))
if valid.min == 0 and isinstance(idxy, SumNode):
nodes = valid.nodes if isinstance(valid, AndNode) else [valid]
val_dict: Dict[Node, Any] = {}
idxy_flat_var = [(i, i.vars()[0]) for i in idxy.flat_components if not isinstance(i, NumNode)]
for node in nodes:
assert isinstance(node, LtNode)
node_flat, node_vars = node.a.flat_components if isinstance(node.a, SumNode) else [node.a], node.vars()
same_sym = [i for (i, var) in idxy_flat_var if var in node_vars]
if len(same_sym) == 0: continue
first, second = sorted(same_sym)[0], sorted(node_flat)[0]
f_b = 1 if isinstance(first, Variable) else first.b
s_b = 1 if isinstance(second, Variable) else second.b
sig = -1 if s_b < 0 else 1
key_node = sig*node.a
if key_node not in val_dict: val_dict[key_node] = [key_node.min, key_node.max, abs(f_b//s_b)]
val_dict[key_node][(sig + 1)//2] = sig*(node.b - 1)
fakes = {}
for cnt, (key_node, (mnn, mxn, multip)) in enumerate(val_dict.items()):
fake_var = Variable("fake_" + str(cnt), mnn, mxn)
fakes[fake_var] = key_node
idxy += multip*(fake_var - key_node)
idx = (idxy // 4) % base_shape[1]
idy = (idxy // (4 * base_shape[1]))
fake_rep = {fake: node for fake, node in fakes.items()}
idx = idx.substitute(fake_rep)
idy = idy.substitute(fake_rep)
idy_vars, idx_vars, ones = set(idy.vars()), set(idx.vars()), []
for node in nodes:
node_vars = set(node.vars())
if not node_vars & (idx_vars | idy_vars): continue #There is simplified NumNode which can not go outside the bounds
# NOTE: Why does only idy is problematic? and not the idx
if idy_vars == node_vars or idy_vars & node_vars == set(): ones.append(node)
valid = Variable.ands([i for i in nodes if i not in ones])
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
return (idx, idy), valid

View File

@@ -0,0 +1,151 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, random
from tinygrad.lazy import vars_from_ast
from tinygrad.ops import Device, Compiled, MemBuffer
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.runtime.lib import RawBuffer
from collections import defaultdict
from tinygrad.tensor import Tensor
from tinygrad.codegen.kernel import Opt, OptOps
actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7]] for axis in range(6)])
actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)])
actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)])
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
actions += [
Opt(op=OptOps.LOCAL, axis=0, amt=32),
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
Opt(op=OptOps.NOLOCALS),
]
# returns time in seconds
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
try:
lin.linearize()
prg = cast(Compiled, Device[Device.DEFAULT]).to_program(lin)
real_global_size = prg.global_size
if allow_test_size and prg.global_size:
test_global_size = prg.global_size[:]
while prod(test_global_size) > max_global_size:
for j in range(2,-1,-1):
if test_global_size[j] > 16:
test_global_size[j] //= 2
break
factor = prod(prg.global_size) / prod(test_global_size)
prg.global_size = test_global_size
#print(real_global_size, test_global_size, factor)
else:
factor = 1
# TODO: this is super broken for var_vals
# TODO: this is copied from prg.__call__
global_size, local_size = prg.launch_dims(var_vals)
if global_size is not None and local_size is None:
local_size = prg.optimize_local_size(global_size, rawbufs)
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
tms = []
for _ in range(cnt):
if clear_l2:
# TODO: this is too small for many L2 caches
with Context(DEBUG=0): Tensor.rand(1024,1024).realize()
lra = prg.runtime_args.copy()
if global_size: lra['global_size'] = global_size
if local_size: lra['local_size'] = local_size
tms.append(prg.clprg(*rawbufs, *var_vals.values(), **lra, wait=True)*factor)
prg.global_size = real_global_size
except Exception:
if DEBUG >= 4:
import traceback
traceback.print_exc()
print("FAILED")
print(lin.ast)
print(lin.applied_opts)
tms = [float('inf')]
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
return min(tms)
# get (scrap) buffers for timing the linearizer
def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
for x in lin.membufs: bufsts[x.idx].append(x)
rawbufs:List[Optional[RawBuffer]] = [None]*len(bufsts)
for k,lx in bufsts.items():
rawbufs[k] = cast(Compiled, Device[Device.DEFAULT]).buffer(prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype)
assert all(r is not None for r in rawbufs)
return cast(List[RawBuffer], rawbufs)
# get dictionary of all possible actions
def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
acted_lins = {0:lin} if include_0 else {}
for i,a in enumerate(actions):
if a.axis is not None and a.axis >= lin.shape_len: continue
if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
lin2 = lin.copy()
try:
lin2.apply_opt(a)
up, lcl = 1, 1
for s,c in zip(lin2.full_shape, lin2.colors()):
if c in {"magenta", "yellow"}: up *= s
if c in {"cyan", "green", "white"}: lcl *= s
if up > 256 or lcl > 256: continue
acted_lins[i+1] = lin2
except Exception:
pass
return acted_lins
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size}
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
ret = lin.copy()
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
return ret
# init the BEAM with the base linearizer
beam: List[Tuple[Linearizer, float]] = [(lin, time_linearizer(lin, rawbufs, allow_test_size=allow_test_size))]
# NOTE: real uops use a weird compare method that's only valid inside a linearizer
def tuplize_uops(uops): return tuple([(x.uop, x.dtype, tuple(x.num for x in x.vin), x.arg) for x in uops])
seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)}
while 1:
acted_lins = lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
# dedup with uops (TODO: double linearize not needed)
acted_lins_dedup = []
for lin in acted_lins:
tuops = tuplize_uops(lin.linearize().uops)
if tuops in seen_uops:
#print(seen_uops[tuops], lin.applied_opts)
continue
seen_uops[tuops] = tuple(lin.applied_opts)
acted_lins_dedup.append(lin)
acted_lins = acted_lins_dedup
# time linearizers
timed_lins: List[Tuple[Linearizer, float]] = [(v,time_linearizer(v,rawbufs,allow_test_size=allow_test_size)) for v in acted_lins]
opts = sorted(timed_lins, key=lambda x: x[1])
if len(opts) == 0 or beam[0][1] <= opts[0][1]: break # we didn't get faster
# keep the BEAM best
beam = opts[:amt]
if DEBUG >= 2: print(f"{opts[0][1]*1e6:12.2f} us from {len(lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape())
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
if DEBUG >= 3: print(beam[0][0].applied_opts)
return beam[0][0]
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
test_rawbuffers = [type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
MAX_WORKGROUP = clprg.max_work_group_size() if hasattr(clprg, 'max_work_group_size') else 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
def try_exec(local_size):
try:
return clprg(*test_rawbuffers, global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True)
except Exception:
return float('inf')
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]

View File

@@ -0,0 +1,117 @@
import os, atexit, functools
try:
import networkx as nx # type: ignore
except ImportError:
nx = None # graph won't work
from collections import defaultdict
from typing import Dict, List
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup
from tinygrad.codegen.linearizer import UOps
# **** debugging and graphing ****
G = nx.DiGraph() if nx is not None else None
cnts: Dict[OpType, int] = defaultdict(int)
if DEBUG >= 2:
def print_globalcounters():
if GlobalCounters.time_sum_s == 0: return
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s",
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms")
atexit.register(print_globalcounters)
if GRAPH:
def save_graph_exit():
for k,v in cnts.items(): print(k, v)
print("saving", G)
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
# -Gnslimit=100 can make it finish, but you won't like results
os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
atexit.register(save_graph_exit)
node_count = 0
def nm(x):
global node_count
if not hasattr(x, 'node_id'):
setattr(x, 'node_id', node_count)
node_count += 1
return x.node_id
def get_sop(op: List[Op]):
op = [x for x in op if x not in BufferOps]
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
if len(op) <= 6: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
return str(len(op))
def str_dtype(dtyp):
ret = str(dtyp)[7:]
return "" if ret == 'float' else f"\n{ret}"
@functools.lru_cache(None)
def add_st_node(nmx, nmo, label, st):
global node_count
inter_node = node_count
node_count += 1
G.add_node(inter_node, style='filled', fillcolor="#80ff8080", color="black", label=f"{st.shape}\n{st.real_strides()}" + (f"\n{st.real_offset()}" if st.real_offset() != 0 else ""))
G.add_edge(nmx, inter_node, color='#00000060')
G.add_edge(inter_node, nmo, label=label, color='#00000060')
logops = open(getenv("LOGOPS", ""),"a") if getenv("LOGOPS", "") else None
def log_schedule_item(si: ScheduleItem):
if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n")
show_graph = bool(GRAPH)
if not DEBUG and not show_graph: return
if si.ast.op == LoadOps.CONTIGUOUS: setattr(si.out, 'node_id', nm(si.inputs[0].base))
if si.ast.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return
op: List[Op] = [x.op for x in si.ast.get_lazyops()]
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps, BufferOps]
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
cnts[optype] += 1
if show_graph:
assert si.out.base == si.out, "all outputs based"
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'}
# get inputs for shapetrackers
input_to_st = defaultdict(list)
for lo in si.ast.get_lazyops():
if lo.op != BufferOps.MEM: continue
input_to_st[si.inputs[lo.arg.idx-1]].append(lo.arg.st)
# add them to the graph, potentially with a movement op seperating them
for x in input_to_st:
for st in dedup(input_to_st[x]):
if st.contiguous:
G.add_edge(nm(x), nm(si.out), label=get_sop(op), color='#00000060')
else:
add_st_node(nm(x), nm(si.out), get_sop(op), st)
if 'label' not in G.nodes[nm(x)]:
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(si.out.dtype)
if nm(si.out) not in G.nodes: G.add_node(nm(si.out))
G.nodes[nm(si.out)]['label'] = (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps else "")
G.nodes[nm(si.out)]['fillcolor'] = top_colors[optype]
G.nodes[nm(si.out)]['color'] = 'black'
G.nodes[nm(si.out)]['style'] = 'filled'
def _tree(lazydata, prefix=""):
if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")
if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
childs = [_tree(c) for c in lazydata.src[:]]
for c in childs[:-1]: lines += [f"{c[0]}"] + [f"{l}" for l in c[1:]]
return lines + [""+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))]))
def graph_uops(uops):
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0"}
G = nx.DiGraph()
for u in uops:
G.add_node(u.num, label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff"))
for v in u.vin: G.add_edge(v.num, u.num)
GRAPHPATH = "/tmp/uops"
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
os.system(f'dot -Grankdir=LR -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')

View File

@@ -0,0 +1,208 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3
import numpy as np
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
from typing_extensions import TypeGuard
T = TypeVar("T")
# NOTE: it returns int 1 if x is empty regardless of the type of x
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.__mul__, x, 1)
# NOTE: helpers is not allowed to import from anything else in tinygrad
OSX = platform.system() == "Darwin"
CI = os.getenv("CI", "") != ""
def dedup(x): return list(dict.fromkeys(x)) # retains list order
def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same(items): return all(x == items[0] for x in items)
def all_int(t: Tuple[Any, ...]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
def colored(st, color, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line
def ansistrip(s): return re.sub('\x1b\\[(K|.*?m)', '', s)
def ansilen(s): return len(ansistrip(s))
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def flatten(l:Union[List, Iterator]): return [item for sublist in l for item in sublist]
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
def merge_dicts(ds:Iterable[Dict]) -> Dict:
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
return {k:v for d in ds for k,v in d.items()}
def partition(lst, fxn):
a: list[Any] = []
b: list[Any] = []
for s in lst: (a if fxn(s) else b).append(s)
return a,b
@functools.lru_cache(maxsize=None)
def getenv(key, default=0): return type(default)(os.getenv(key, default))
class Context(contextlib.ContextDecorator):
stack: ClassVar[List[dict[str, int]]] = [{}]
def __init__(self, **kwargs): self.kwargs = kwargs
def __enter__(self):
Context.stack[-1] = {k:o.value for k,o in ContextVar._cache.items()} # Store current state.
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v # Update to new temporary state.
Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later.
def __exit__(self, *args):
for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, ContextVar._cache[k].value)
class ContextVar:
_cache: ClassVar[Dict[str, ContextVar]] = {}
value: int
def __new__(cls, key, default_value):
if key in ContextVar._cache: return ContextVar._cache[key]
instance = ContextVar._cache[key] = super().__new__(cls)
instance.value = getenv(key, default_value)
return instance
def __bool__(self): return bool(self.value)
def __ge__(self, x): return self.value >= x
def __gt__(self, x): return self.value > x
def __lt__(self, x): return self.value < x
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
GRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
class Timing(contextlib.ContextDecorator):
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
def __enter__(self): self.st = time.perf_counter_ns()
def __exit__(self, exc_type, exc_val, exc_tb):
self.et = time.perf_counter_ns() - self.st
if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
# **** tinygrad now supports dtypes! *****
class DType(NamedTuple):
priority: int # this determines when things get upcasted
itemsize: int
name: str
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
sz: int = 1
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}"
# dependent typing?
class ImageDType(DType):
def __new__(cls, priority, itemsize, name, np, shape):
return super().__new__(cls, priority, itemsize, name, np)
def __init__(self, priority, itemsize, name, np, shape):
self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
super().__init__()
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
# TODO: fix this to not need these
def __hash__(self): return hash((super().__hash__(), self.shape))
def __eq__(self, x): return super().__eq__(x) and self.shape == x.shape
def __ne__(self, x): return super().__ne__(x) or self.shape != x.shape
class PtrDType(DType):
def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz)
def __repr__(self): return f"ptr.{super().__repr__()}"
class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float2, dtypes._float4)
@staticmethod
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
float16: Final[DType] = DType(0, 2, "half", np.float16)
half = float16
float32: Final[DType] = DType(4, 4, "float", np.float32)
float = float32
float64: Final[DType] = DType(0, 8, "double", np.float64)
double = float64
int8: Final[DType] = DType(0, 1, "char", np.int8)
int16: Final[DType] = DType(1, 2, "short", np.int16)
int32: Final[DType] = DType(2, 4, "int", np.int32)
int64: Final[DType] = DType(3, 8, "long", np.int64)
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16)
uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32)
uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64)
# NOTE: bfloat16 isn't supported in numpy
bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
# NOTE: these are internal dtypes, should probably check for that
_int2: Final[DType] = DType(2, 4*2, "int2", None, 2)
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp)
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
class GlobalCounters:
global_ops: ClassVar[int] = 0
global_mem: ClassVar[int] = 0
time_sum_s: ClassVar[float] = 0.0
kernel_count: ClassVar[int] = 0
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
# *** universal database cache ***
CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache")
CACHELEVEL = getenv("CACHELEVEL", 2)
VERSION = 6
_db_connection = None
def db_connection():
global _db_connection
if _db_connection is None:
_db_connection = sqlite3.connect(CACHEDB)
if DEBUG >= 5: _db_connection.set_trace_callback(print)
if diskcache_get("meta", "version") != VERSION:
print("cache is out of date, clearing it")
os.unlink(CACHEDB)
_db_connection = sqlite3.connect(CACHEDB)
if DEBUG >= 5: _db_connection.set_trace_callback(print)
diskcache_put("meta", "version", VERSION)
return _db_connection
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
if isinstance(key, (str,int)): key = {"key": key}
try:
res = db_connection().cursor().execute(f"SELECT val FROM {table} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
except sqlite3.OperationalError:
return None # table doesn't exist
if (val:=res.fetchone()) is not None:
return pickle.loads(val[0])
return None
_db_tables = set()
def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
if isinstance(key, (str,int)): key = {"key": key}
conn = db_connection()
cur = conn.cursor()
if table not in _db_tables:
TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
cur.execute(f"CREATE TABLE IF NOT EXISTS {table} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
_db_tables.add(table)
cur.execute(f"REPLACE INTO {table} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), ))
conn.commit()
cur.close()
return val
def diskcache(func):
def wrapper(*args, **kwargs) -> bytes:
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
if (ret:=diskcache_get(table, key)): return ret
return diskcache_put(table, key, func(*args, **kwargs))
setattr(wrapper, "__wrapped__", func)
return wrapper

View File

@@ -0,0 +1,77 @@
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
from collections import defaultdict
import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts
from tinygrad.ops import RawBuffer, Device
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]
class TinyJit:
def __init__(self, fxn:Callable):
self.fxn: Callable = fxn
self.cnt: int = 0
self.jit_cache: List[Tuple[Any, List[Optional[RawBuffer]], Dict[Variable, int]]] = []
self.ret: Any = None
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], ShapeTracker, DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_shapetracker, expected_type)
self.updatable_entries: Dict[int, List[int]] = defaultdict(list) # (kernel_number) -> list(argument id). These are buffers from input + variables.
# add support for instance methods
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT.split(":")[0] not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
assert len(input_rawbuffers) != 0, "no inputs to JIT"
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
if self.cnt >= 2:
try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"]
except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor])
if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key))
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"
# NOTE: if we pass jit_ctx instead of using reshape to update the var_vals, we cannot compare the shapetracker directly
if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].unbind() == expected_st, f"ShapeTracker mismatch in JIT, {input_rawbuffers[input_name][1].unbind()} != {expected_st}"
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
for j in self.updatable_entries.keys():
for k in self.jit_cache[j][2].keys():
try: self.jit_cache[j][2][k] = var_vals[k]
except KeyError: pass
for prg, pargs, variables in self.jit_cache: prg(pargs, variables, jit=True)
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
elif self.cnt == 1:
CacheCollector.start()
self.ret = self.fxn(*args, **kwargs)
self.jit_cache = CacheCollector.finish()
assert len(self.jit_cache) != 0, "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
# get the inputs for replacement
for j_,cache in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]]
for i,a in enumerate(cache[1]):
if a in [v[0] for v in input_rawbuffers.values()]:
self.input_replace[(j_,i)] = [(k, v[1].unbind(), v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
self.updatable_entries[j_].append(i)
for i in range(len(cache[2])): self.updatable_entries[j_].append(len(cache[1])+i)
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
elif self.cnt == 0:
self.ret = self.fxn(*args, **kwargs)
self.cnt += 1
return self.ret
class _CacheCollector:
def __init__(self): self.cache: Optional[List[Tuple[Callable, List[Any], Dict[Any,Any]]]] = None
def start(self): self.cache = []
def add(self, prg, rawbufs, var_vals):
if self.cache is None: return
self.cache.append((prg, rawbufs, var_vals))
def finish(self):
if self.cache is None: return []
ret = self.cache
self.cache = None
return ret
CacheCollector = _CacheCollector()

View File

@@ -0,0 +1,347 @@
from __future__ import annotations
import sys, operator, math, functools
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping
from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts, all_int
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.runtime.lib import RawBuffer
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
# lazy can recurse a lot
sys.setrecursionlimit(10000)
OPT = getenv("OPT", 2)
LAZYCACHE = getenv("LAZYCACHE", 1)
# TODO: movement ops that only change shape are really nops. treat them as such
REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
PUSH_RESHAPES = OPT>=4
# **** ast fixing functions ****
def _ast_reduceops(op:LazyOp) -> LazyOp:
# TODO: this can also corealize a binary op after the reduce, not just before
src = op.src[0]
if not src.realized:
assert isinstance(src.op, LazyOp), "if not src.realized, then src.op must be a LazyOp"
if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1: src = src.op
return LazyOp(op.op, (src,), op.arg)
# this supports late merging an upstream Reduce op and even an Elementwise op above that
def _ast_binaryops(op:LazyOp, shape: Tuple[sint, ...]) -> LazyOp:
real_srcs: Dict[LazyBuffer, Optional[Union[LazyOp, LazyBuffer]]] = {x:None for x in op.buffers}
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
intermediate_shape: Tuple[sint, ...] = shape
if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and psrcs:
psrc = psrcs[0] # NOTE: right now we can't handle multiple, as we'd have to check for loop
if psrc[1].optype == ReduceOps:
top = _ast_reduceops(psrc[1].op)
real_srcs[psrc[0]] = top
real_srcs.update({x:x for x in top.buffers}) # the reduce op buffers are not modified
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
if psrc[0].shape != psrc[1].shape:
intermediate_shape = psrc[1].shape
assert psrc[0].shape == shape, f"shape mismatch {psrc[0].shape} != {shape}"
# reshape all the late ops into the output shape
# NOTE: these RESHAPEs will return self if they don't change the shape
for x in real_srcs.keys():
if real_srcs[x] is None: real_srcs[x] = x.reshape(intermediate_shape)
# NOTE: cast the type to remove the Optional
ast = op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer]], real_srcs))
return LazyOp(MovementOps.RESHAPE, (ast, ), shape) if intermediate_shape != shape else ast
def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
replacements:Dict[LazyBuffer, LazyOp] = {}
base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()])
for x in op.buffers:
st = x.st.simplify().unbind()
if x.base in base_bufs:
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st))
elif not x.realized and x.base.op.op == LoadOps.CONST:
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st))
else:
raise NotImplementedError(f"not handled {x}")
return (op.src[0] if op.op == MovementOps.RESHAPE else op).map_buffers(replacements), base_bufs
# **** lazy operations ****
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(LazyBuffer, root.op.src[0])) if getattr(root, 'op', None) and len(root.op.src) == 1 and isinstance(root.op.src[0], LazyBuffer) else root
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
def vars_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], []))
lazycache: WeakValueDictionary = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None):
# fromcpu aren't cached
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base)
# wop is the deduping key. i feel this used to compare more deeply
wop = (device, dtype, optype, ref(op), ref(base) if base else None)
if wop in lazycache:
for x in op.buffers: x.children.add(lazycache[wop])
return lazycache[wop]
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, base=base)
return ret
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
class LazyBuffer:
__deletable__ = ('op',)
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None):
self.st: ShapeTracker = st
self.device, self.shape, self.optype, self._dtype = device, self.st.shape, optype, dtype
self._realized: Optional[RawBuffer] = src
self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
self.children: WeakSet = WeakSet()
self.views: WeakSet = WeakSet()
# NOTE: op should be read only after construction of LazyBuffer. it is now with schedule
if op is not None:
self.op: LazyOp = op
for x in op.buffers: x.children.add(self)
assert optype != MovementOps or (base is not None and base.optype != MovementOps), "MovementOps must be based"
self._base = base
if base: base.views.add(self)
else: assert st.contiguous, "unbased LazyBuffers must be contiguous"
@property
def base(self): return self._base if self._base is not None else self
def is_unrealized_const(self): return not self.realized and self.base.op.op == LoadOps.CONST
@property
def realized(self): return self.base._realized
@realized.setter
def realized(self, val):
assert self._base is None, "no setting realized of based LazyBuffers"
self._realized = val
@property
def dtype(self): return self.base._dtype
@dtype.setter
def dtype(self, val):
assert self._base is None, "no setting dtype of based LazyBuffers"
self._dtype = val
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if hasattr(self, 'op') else self._realized} st={self.st}>"
@property
def key(self):
if self.realized: return (self.dtype, self.realized.key, self.st)
return (self.dtype, self.op.op, self.st)
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
@property
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]): return real_srcs.get(self, self)
def get_lazyops(self) -> List[LazyOp]: return []
# *** scheduling ***
def schedule(self, seen=None) -> List[ScheduleItem]:
if seen is None: seen = set()
if self in seen or self.realized or self.is_unrealized_const(): return []
seen.add(self)
if self.base != self: return self.base.schedule(seen)
# rewrite unbased CONTIGUOUS into UnaryOps.NOOP
op = self.op if self.op.op != LoadOps.CONTIGUOUS else LazyOp(UnaryOps.NOOP, self.op.src)
if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape)
elif self.optype is ReduceOps: op = _ast_reduceops(op)
# schedule the past
ret = []
for x in op.buffers: ret += x.schedule(seen)
var_vals = dict(sorted(merge_dicts([self.st.var_vals] + [buf.st.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
# run the ast and log the op
op, base_bufs = _replace_bufferops(op)
return ret + [ScheduleItem(op, self, tuple(base_bufs), {k:var_vals[k] for k in vars_from_ast(op)})]
# *** creation/special ops ***
@staticmethod
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
# create a constant with the shape and dtype of self
def const(self, val:Union[float, int]) -> LazyBuffer:
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
def copy_to_device(self, device:str) -> LazyBuffer:
# back off a FROM if it's a double FROM
if not self.realized and self.op.op == LoadOps.FROM and cast(LazyBuffer, self.op.src[0]).device == device: return cast(LazyBuffer, self.op.src[0])
return LazyBuffer.loadop(LoadOps.FROM, self.shape, self.dtype, device, src=self.contiguous())
def contiguous(self:LazyBuffer) -> LazyBuffer:
if not self.realized and self.op.op in LoadOps and self.op.op != LoadOps.CONST: return self # all LoadOps are already contiguous (except CONST)
if self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const():
# this will turn into nothing, it's based and a copy
# TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops
return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, base=self.base)
# real contiguous, this will turn into a UnaryOps.NOOP
return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self)
@staticmethod
def fromCPU(x: np.ndarray) -> LazyBuffer:
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
def cast(self, dtype:DType, bitcast:bool=False):
return self.e(UnaryOps.CAST, arg=(dtype, bitcast))
# *** elementwise ops ***
def e(self:LazyBuffer, op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
# srcs includes self
srcs = (self,)+srcs
# if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
# get outputs now
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(Tuple[DType, bool], arg)[0]
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
new_srcs: List[LazyBuffer] = []
for x in srcs:
if not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1:
x.op.src[0].children.discard(x)
new_srcs.append(cast(LazyBuffer, x.op.src[0]))
else:
new_srcs.append(x)
return new_srcs[0].e(op, *new_srcs[1:], arg=arg).contiguous()
if MERGE_ELEMENTWISE_OPS:
# remove the buffers from any (childless) BinaryOps that feed into this
_srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
# TODO: needs general merge limiting
if out_device != "WEBGPU" or len(dedup([x.base for _src in _srcs for x in _src.buffers if not x.is_unrealized_const()])) < 7: srcs = _srcs # type: ignore
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype)
# *** reduce ops ***
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape)
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype)
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if not all_int(self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
# *** movement ops ***
def _movement_op(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and not self.realized and self.optype == BinaryOps and not self.children:
if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and (self.op.op in UnaryOps or PUSH_RESHAPES)):
return self.op.replace_with_movement_ops([(op, arg)])
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root
root = get_movementroot(self)
if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape):
return root.reshape(st.shape)
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base)
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.RESHAPE:
assert isinstance(self.op.src[0], LazyBuffer)
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].reshape(arg)
return self._movement_op(self.st.reshape(arg), MovementOps.RESHAPE, arg)
def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
if all(b == 0 and e == 0 for b,e in arg): return self
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
return self._movement_op(self.st.pad(arg), MovementOps.PAD, arg)
def expand(self: LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.EXPAND: return self.op.src[0].expand(arg)
return self._movement_op(self.st.expand(arg), MovementOps.EXPAND, arg)
def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if arg == tuple(range(len(self.shape))): return self
if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute(tuple([self.op.arg[i] for i in arg]))
if SHUFFLE_MOVEMENT_OPS and not self.realized:
if PUSH_PERMUTES and self.optype == ReduceOps:
# reduceops have one buffer input, permute it
narg = tuple([self.op.arg[a] for a in arg])
src, rop = self.op.src[0], self.op.op
src.children.discard(self)
del self # TODO: why doesn't this delete remove it from the children
return src.permute(arg).r(cast(ReduceOps, rop), narg)
# move permutes before expands (always, this is safe)
if self.op.op == MovementOps.EXPAND:
return self.op.src[0].permute(arg).expand(tuple([self.op.arg[a] for a in arg]))
# move permutes before reshapes if we can
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(self.st.permute(arg).shape)
return self._movement_op(self.st.permute(arg), MovementOps.PERMUTE, arg)
def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
return self._movement_op(self.st.shrink(arg), MovementOps.SHRINK, arg)
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if all(a == 1 for a in arg): return self
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg)
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
y = self
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
return y
def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
new_srcs = []
for x in srcs:
mops: List[Tuple[MovementOps, Any]] = []
bx = x
# backwalk all the movement ops. don't push PAD or EXPAND
while not bx.realized and bx.optype is MovementOps and bx.op.op is not MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) and len(bx.children) <= 1:
assert isinstance(bx.op.op, MovementOps)
mops.append((bx.op.op, bx.op.arg))
assert isinstance(bx.op.src[0], LazyBuffer)
bx = bx.op.src[0]
# NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0
if mops and not bx.realized and bx.optype is BinaryOps and len(bx.children) <= 1 and (all(y[0] is not MovementOps.PAD for y in mops) or all(y.op not in UNSAFE_PAD_OPS for y in bx.op.get_lazyops())):
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
else:
new_srcs.append(x)
return tuple(new_srcs)
MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
MovementOps.RESHAPE: LazyBuffer.reshape,
MovementOps.EXPAND: LazyBuffer.expand,
MovementOps.SHRINK: LazyBuffer.shrink,
MovementOps.PERMUTE: LazyBuffer.permute,
MovementOps.PAD: LazyBuffer.pad,
MovementOps.STRIDE: LazyBuffer.stride,
}

View File

@@ -0,0 +1,211 @@
import math
from typing import Tuple, Optional, cast
from tinygrad.helpers import argsort, DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.symbolic import sint
class Contiguous(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
class ContiguousBackward(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
class Cast(Function):
def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
self.input_dtype, self.bitcast = x.dtype, bitcast
return x.cast(dtype, bitcast)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.cast(self.input_dtype, self.bitcast)
# ************* unary ops *************
class Zero(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.const(0)
def backward(self, grad:LazyBuffer) -> LazyBuffer: return grad.const(0)
class Neg(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
def backward(self, grad:LazyBuffer) -> LazyBuffer: return grad.e(UnaryOps.NEG)
class Sin(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.e(UnaryOps.SIN)
def backward(self, grad:LazyBuffer) -> LazyBuffer:
return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad)
# NOTE: maximum(x, 0) behaves differently where x=0
class Relu(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.e(BinaryOps.MAX, x.const(0))
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output)
class Log(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.e(BinaryOps.DIV, self.x)
class Exp(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.ret.e(BinaryOps.MUL, grad_output)
class Sqrt(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.e(UnaryOps.SQRT)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.e(BinaryOps.DIV, self.ret.e(BinaryOps.MUL, self.ret.const(2)))
# NOTE: the implicit derivative of sigmoid is not stable
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
# TODO: have the backend automatically find this
class Sigmoid(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)))
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output)
# ************* binary ops *************
class Less(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
return x.e(BinaryOps.CMPLT, y)
class Add(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
return x.e(BinaryOps.ADD, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output if self.needs_input_grad[0] else None, \
grad_output if self.needs_input_grad[1] else None
class Sub(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
return x.e(BinaryOps.SUB, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output if self.needs_input_grad[0] else None, \
grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None
class Mul(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y = x, y
return x.e(BinaryOps.MUL, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
class Div(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y = x, y
return x.e(BinaryOps.DIV, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
# ************* ternary ops *************
class Where(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
self.x = x
return x.e(TernaryOps.WHERE, y, z)
def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
return None, \
self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
# ************* reduce ops *************
class Sum(Function):
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.r(ReduceOps.SUM, new_shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.expand(self.input_shape)
class Max(Function):
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
# ************* movement ops *************
# NOTE: this is sum in reverse
class Expand(Function):
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.expand(shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.r(ReduceOps.SUM, self.input_shape)
class Reshape(Function):
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.reshape(shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.reshape(self.input_shape)
class Permute(Function):
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
self.input_order = order
return x.permute(order)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.permute(argsort(self.input_order))
class Pad(Function):
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
return x.pad(arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.shrink(self.narg)
class Shrink(Function):
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
return x.shrink(arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
assert all(isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg), "symbolic shrink does not support backward"
# need this cast because mypy cannot narrow the type even with assert
return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg))
class Flip(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
return x.stride(self.arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.stride(self.arg)

View File

@@ -0,0 +1,128 @@
import math
from typing import Optional, Union, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, all_int
class BatchNorm2d:
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
else: self.weight, self.bias = None, None
self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
def __call__(self, x:Tensor):
if Tensor.training:
# This requires two full memory accesses to x
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
batch_mean = x.mean(axis=(0,2,3))
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
batch_var = (y*y).mean(axis=(0,2,3))
batch_invstd = batch_var.add(self.eps).pow(-0.5)
# NOTE: wow, this is done all throughout training in most PyTorch models
if self.track_running_stats:
self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape) - y.shape[1]) * batch_var.detach() )
self.num_batches_tracked += 1
else:
batch_mean = self.running_mean
# NOTE: this can be precomputed for static inference. we expand it here so it fuses
batch_invstd = self.running_var.reshape(1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
# TODO: these Conv lines are terrible
def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias)
class Conv2d:
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
self.weight = self.initialize_weight(out_channels, in_channels, groups)
assert all_int(self.weight.shape), "does not support symbolic shape"
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
def __call__(self, x:Tensor):
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
class ConvTranspose2d(Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.output_padding = output_padding
def __call__(self, x:Tensor):
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
class Linear:
def __init__(self, in_features, out_features, bias=True):
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
# TODO: remove this once we can represent Tensor with int shape in typing
assert isinstance(self.weight.shape[1], int), "does not support symbolic shape"
bound = 1 / math.sqrt(self.weight.shape[1])
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
def __call__(self, x:Tensor):
return x.linear(self.weight.transpose(), self.bias)
class GroupNorm:
def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
def __call__(self, x:Tensor):
# reshape for layernorm to work as group norm
# subtract mean and divide stddev
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
if self.weight is None or self.bias is None: return x
# elementwise_affine on channels
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
class InstanceNorm:
def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
self.num_features, self.eps = num_features, eps
self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None
def __call__(self, x:Tensor):
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
if self.weight is None or self.bias is None: return x
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
class LayerNorm:
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
def __call__(self, x:Tensor):
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
x = x.layernorm(eps=self.eps, axis=self.axis)
if not self.elementwise_affine: return x
return x * self.weight + self.bias
class LayerNorm2d(LayerNorm):
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
class Embedding:
def __init__(self, vocab_size:int, embed_size:int):
self.vocab_size = vocab_size
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
def __call__(self, idx:Tensor) -> Tensor:
if not hasattr(self, 'vocab_counter'): self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size)
return (self.vocab_counter == idx.unsqueeze(2)).expand(*idx.shape, self.vocab_size) @ self.weight

View File

@@ -0,0 +1,68 @@
# sorted in order of increasing complexity
from typing import List
from tinygrad.helpers import dedup
from tinygrad.tensor import Tensor
class Optimizer:
def __init__(self, params: List[Tensor], lr: float):
# if it's None, but being put into an optimizer, set it to True
for x in params:
if x.requires_grad is None: x.requires_grad = True
self.params: List[Tensor] = dedup([x for x in params if x.requires_grad])
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
self.lr = Tensor([lr], requires_grad=False).contiguous()
def zero_grad(self):
for param in self.params: param.grad = None
def realize(self, extra=None):
# NOTE: in extra is too late for most of the params due to issues with assign
Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers)
class SGD(Optimizer):
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
super().__init__(params, lr)
self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
# https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
def step(self) -> None:
for i, t in enumerate(self.params):
assert t.grad is not None
g = t.grad.realize() + self.wd * t.detach()
if self.momentum:
self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign(t.detach() - g * self.lr)
self.realize(self.b)
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): return LAMB(params, lr, b1, b2, eps, wd, adam=True)
def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
class LAMB(Optimizer):
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False):
super().__init__(params, lr)
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], requires_grad=False).realize()
self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
def step(self) -> None:
self.t.assign(self.t + 1).realize()
for i, t in enumerate(self.params):
assert t.grad is not None
g = t.grad.realize()
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize()
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize()
m_hat = self.m[i] / (1.0 - self.b1**self.t)
v_hat = self.v[i] / (1.0 - self.b2**self.t)
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
if not self.adam:
r1 = t.detach().square().sum().sqrt()
r2 = up.square().sum().sqrt()
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
else:
r = 1.0
t.assign(t.detach() - self.lr * r * up)
self.realize([self.t] + self.m + self.v)

View File

@@ -0,0 +1,124 @@
import os, json, pathlib, zipfile, pickle
from tqdm import tqdm
from typing import Dict, Union, List, Optional, Any, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI
from tinygrad.shape.view import strides_for_shape
from tinygrad.ops import Device
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
return (t, json_len, json.loads(t[8:8+json_len].numpy().tobytes()))
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
t, json_len, metadata = safe_load_metadata(fn)
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"}
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
headers, offset = {}, 0
if metadata: headers['__metadata__'] = metadata
for k,v in tensors.items():
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
offset += v.nbytes()
j = json.dumps(headers, separators=(',', ':'))
j += "\x20"*((8-len(j)%8)%8)
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:1].cast(dtypes.int64).assign([len(j)])
t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8, device="cpu"))
for k,v in safe_load(t).items(): v.assign(tensors[k])
# state dict
from collections import OrderedDict
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
state_dict = {}
if isinstance(obj, (list, tuple)):
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
elif isinstance(obj, dict):
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
return state_dict
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
def load_state_dict(model, state_dict, strict=True, verbose=True):
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
model_state_dict = get_state_dict(model)
if DEBUG >= 1 and len(state_dict) > len(model_state_dict): print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}")
if k not in state_dict and not strict:
if DEBUG >= 1: print(f"WARNING: not loading {k}")
continue
v.assign(state_dict[k].to(v.device)).realize()
# torch support!
def torch_load(fn:str):
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
offsets: Dict[str, int] = {}
lens: Dict[str, int] = {}
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
lens[storage[2]] = storage[4] * storage[1].itemsize
if storage[2] not in offsets: return None
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
# convert bfloat16 -> float16 using LLVM for Llama 2
# upstream LLaMA also does this conversion:
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
if storage[1] == dtypes.bfloat16:
ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half()
#ret = ret.to("LLVM").half().to(Device.DEFAULT)
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
if DEBUG >= 2: print(f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}")
# TODO: find a nice way to support all shapetracker on disktensors
ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes)
return ret.reshape(size)
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
class Dummy: pass
class TorchPickle(pickle.Unpickler):
def find_class(self, module, name):
module_root = module.split(".")[0]
if module_root not in whitelist:
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
return Dummy
return intercept[name] if module_root == "torch" else super().find_class(module, name)
def persistent_load(self, pid): return pid
if tuple(t[0:2].numpy()) == (0x50, 0x4b):
myzip = zipfile.ZipFile(fn, 'r')
base_name = myzip.namelist()[0].split('/', 1)[0]
for n in myzip.namelist():
if n.startswith(f'{base_name}/data/'):
with myzip.open(n) as myfile:
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
with myzip.open(f'{base_name}/data.pkl') as myfile:
return TorchPickle(myfile).load()
else:
with open(fn, "rb") as f:
pkl = TorchPickle(f)
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), f.tell(), pkl.load(), pkl.load(), f.tell()
for i in ids:
offsets[i] = base_offset + 8
base_offset += 8 + lens[i]
f.seek(rwd)
return TorchPickle(f).load()

View File

@@ -0,0 +1,300 @@
from __future__ import annotations
import importlib, inspect, functools, pathlib
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT
from tinygrad.runtime.lib import RawBuffer
from tinygrad.shape.symbolic import Variable, sym_infer
from dataclasses import dataclass
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
# NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702
# Ops below this line are not allowed in ASTs
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.lazy import LazyBuffer
@dataclass(frozen=True)
class MemBuffer:
idx: int
dtype: DType
st: ShapeTracker
@dataclass(frozen=True)
class ConstBuffer:
val: Any
dtype: DType
st: ShapeTracker
@dataclass(frozen=True)
class ScheduleItem:
ast: LazyOp
out: LazyBuffer
inputs: Tuple[LazyBuffer, ...]
var_vals: Dict[Variable, int]
@dataclass(frozen=True)
class LazyOp:
op: Op
src: Tuple[Union[LazyOp, LazyBuffer], ...]
arg: Any = None
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
@property
def buffers(self):
buffers: Tuple[Union[LazyOp, LazyBuffer], ...] = ()
try: # NOTE: the linearizer's key function maps the buffers to ints, and LOCAL_BUFFER is used. we don't care about buffers in these cases
for x in self.src: buffers += x.buffers
except AttributeError: buffers = ()
return buffers
@property
def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg))
def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) if y not in real_srcs else real_srcs[y] for y in self.src]), self.arg)
def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()]
def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
assert self.op in BinaryOps or self.op in UnaryOps or self.op in TernaryOps
srcs = [z.replace_with_movement_ops(ops) for z in self.src]
return srcs[0].e(self.op, *srcs[1:], arg=self.arg) # type: ignore
@property
def st(self): raise NotImplementedError
@property
def realized(self): raise NotImplementedError
@property
def children(self): raise NotImplementedError
# movement ops
def reshape(self, _): raise NotImplementedError
def pad(self, _): raise NotImplementedError
def expand(self, _): raise NotImplementedError
def permute(self, _): raise NotImplementedError
def shrink(self, _): raise NotImplementedError
def stride(self, _): raise NotImplementedError
# **************** Device ****************
class _Device:
def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]:
x = x.split(":")[0].upper()
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
@functools.cached_property
def DEFAULT(self) -> str:
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None)
if device_from_env: return device_from_env
for device in ["METAL", "CUDA", "GPU"]:
try:
if self[device]: return device
except Exception: pass
return "CPU"
Device = _Device()
# **************** for Interpreted Buffers ****************
class Interpreted:
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_underlying=None):
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
self.synchronize = lambda: None
self.codegen = None
self.method_cache: Dict[LazyOp, Callable] = {}
def interpret_ast(self:Interpreted, ast:LazyOp) -> Callable:
tglob: Dict[str, Any] = {}
lines: List[str] = []
f = self.fxn_for_op
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
@functools.lru_cache(None)
def _interpret_ast(ast:LazyOp) -> str:
if TernaryOps.MULACC in f and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if MovementOps.AS_STRIDED in f and ast.op in BufferOps:
tmp = f"{gstr(f[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(f[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(f[mop], mop)}({tmp}, {gstr(arg)})"
else:
inp = [_interpret_ast(src) for src in ast.src]
tmp = f"{gstr(f[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})" if self.from_underlying else f" return {ret}"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return tglob['run']
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, **kwargs):
if ast not in self.method_cache: self.method_cache[ast] = self.interpret_ast(ast)
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None)
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op:
ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.fxn_for_op[BufferOps.MEM](ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
# TODO: is this used?
if output is not None and output.output_buffer is not None:
assert output.output_buffer.dtype == ret.dtype
output.output_buffer._buf = ret._buf
return output.output_buffer
return ret
@dataclass
class FlopCounter:
shape: Tuple[int, ...]
dtype: DType
flops: int
mem: Dict[int, int]
@property
def mem_estimate(self): return sum(self.mem.values()) + self.dtype.itemsize*prod(self.shape)
def consume_flops(self):
self.flops, ret = 0, self.flops
return ret
InterpretedFlopCounter = Interpreted(FlopCounter, {
BufferOps.MEM: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}), BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST},
**{op:lambda self,y: FlopCounter(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps},
**{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps},
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})})
@functools.lru_cache(None)
def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.exec_ast(ast)
# **************** for Compiled Buffers ****************
class ASTRunner:
def __init__(self, name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
def build(self, compiler, runtime):
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
self.clprg = runtime(self.name, self.lib)
return self
def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]:
from tinygrad.jit import CacheCollector
CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
return self(rawbufs, var_vals, force_wait=force_wait)
def launch_dims(self, var_vals):
global_size = ([sym_infer(sz, var_vals) for sz in self.global_size] + [1]*(3-len(self.global_size))) if self.global_size is not None else self.global_size
local_size = ([sym_infer(sz, var_vals) for sz in self.local_size] + [1]*(3-len(self.local_size))) if self.local_size is not None else self.local_size
return global_size, local_size
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
if var_vals is None: var_vals = {}
global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None:
# TODO: this is copied from get_program
from tinygrad.features.search import optimize_local_size
local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
lra = self.runtime_args.copy()
if global_size: lra['global_size'] = global_size
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
if et := self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(37-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += mem_estimate
return et
class Compiled:
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize
self.method_cache: Dict[LazyOp, ASTRunner] = {}
def to_program(self, k):
k.linearize()
src, runtime_args = self.renderer(k.function_name, k.uops)
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
op_estimate=k.info.flops, mem_estimate=k.info.mem_estimate,
display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime)
def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
# check if we can reuse the output buffer
# if it's aliased, don't use it
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
output.realized = output.output_buffer
if output.realized:
for i,a in enumerate(inputs):
# TODO: if this is contiguous it's fine
if a.realized == output.realized:
if any(not x.arg.st.contiguous for x in ast.get_lazyops() if x.op == BufferOps.MEM and x.arg.idx == i+1):
output.realized = None
break
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if not output.realized:
output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
# all the rawbuffers
rawbuffers = [output.realized] + [x.realized for x in inputs]
# extract real vars used in ast
from tinygrad.lazy import vars_from_ast
ast_vars = vars_from_ast(ast)
assert all(v.val is None for v in ast_vars), f"ast contains bound Variable {ast_vars}"
# compilation time
def get_program():
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, self.linearizer_opts)
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1 and not vars_from_ast(ast):
lins = [(("tc" if used_tensor_cores else "hc"), k)]
# allocate a scratch buffer if output buffer is also input
test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
kb = Linearizer(ast, self.linearizer_opts)
kb.required_optimizations()
from tinygrad.features.search import beam_search, time_linearizer
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
else:
k.required_optimizations()
return self.to_program(k)
if getenv("ENABLE_METHOD_CACHE", 1):
if ast not in self.method_cache: self.method_cache[ast] = get_program()
prg = self.method_cache[ast]
else:
prg = get_program()
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in ast_vars})
return output.realized

View File

@@ -0,0 +1,74 @@
from typing import List, cast, Dict, Callable
import numpy as np
from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, Device, BufferOps
from tinygrad.graph import log_schedule_item, print_tree
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE
from tinygrad.runtime.lib import RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.features.image import fix_schedule_for_images
def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
# HACK: images can be not usable due to shape
if IMAGE >= 2: schedule = fix_schedule_for_images(schedule)
# NOTE: if you for loop the schedule it's slow because nothing frees
while len(schedule):
si = schedule.pop(0)
if not disable_logging: log_schedule_item(si)
assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized"
if DEBUG >= 3: print_tree(si.ast)
if si.ast.op in LoadOps:
# confirm the LoadOps are contiguous and in order
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
else:
si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args())
del si.out.op
for v in si.out.views: del v.op
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
assert si.out.realized.dtype == si.out.dtype, "realized dtype is incorrect"
# *** zero op LoadOps ***
def _realize_empty(buffer: LazyBuffer) -> None:
assert all_int(buffer.shape), "does not support symbolic shape"
if DEBUG >= 2: print(f"*** empty {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
def _realize_rand(buffer: LazyBuffer) -> None:
assert all_int(buffer.shape), "does not support symbolic shape"
if DEBUG >= 2: print(f"*** rand {buffer.device} seed {buffer.op.arg:<10d} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
rng = np.random.default_rng(buffer.op.arg)
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args())
# *** one op LoadOps ***
def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None:
assert src.realized.size == buffer.st.size(), f"size mismatch on FROM {src.realized.size} != {buffer.st.size()}"
assert src.st.contiguous and buffer.st.contiguous, "all must be contiguous for from"
if DEBUG >= 2: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size:<16d} shape {str(buffer.shape):23s} dtype {src.realized.dtype}")
# TODO: make this generic
if isinstance(src.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
assert all_int(buffer.shape), "does not support symbolic shape"
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
src.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer())
elif isinstance(src.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and getenv("P2P", 0) >= 1:
buffer.realized = cast(RawBufferTransfer, Device[buffer.device].buffer).transfer(src.realized, buffer.shape, buffer.dtype, **buffer._device_extra_args())
else:
# TODO: schedule this as FROM to go to CPU, and a FROM to go to device
buffer.realized = Device[buffer.device].buffer.fromCPU(src.realized.toCPU(), **buffer._device_extra_args())
# *** n op LoadOps ***
def _realize_custom(buffer: LazyBuffer, *inputs: LazyBuffer) -> None:
if DEBUG >= 2: print(f"*** custom {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
buffer.realized = buffer.op.arg(buffer, *inputs)
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
LoadOps.EMPTY: _realize_empty,
LoadOps.RAND: _realize_rand,
LoadOps.FROM: _realize_from,
LoadOps.CUSTOM: _realize_custom,
}

View File

@@ -0,0 +1,212 @@
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict
import math
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import ImageDType, dtypes, prod, DType, strip_parens
class CStyleLanguage(NamedTuple):
size_prefix: str = "int"
generic_var_prefix: str = ""
kernel_prefix: str = ""
buffer_prefix: str = ""
buffer_suffix: str = ""
smem_align: str = ""
smem_prefix: str = ""
smem_prefix_for_cast: bool = True
arg_int_prefix: str = ""
barrier: str = ""
xid: List[str] = []
gid: List[str] = []
lid: List[str] = []
global_max: List[int] = []
local_max: List[int] = []
extra_args: List[str] = []
float4: Optional[str] = None
half_prekernel: Optional[str] = None
uses_vload: bool = False
external_local_bufs: bool = False
uses_ptr_arithmetic: bool = False
launch_bounds: bool = False
code_for_op: Dict = {
UnaryOps.NEG: lambda x: f"(-{x})",
UnaryOps.EXP2: lambda x: f"exp2({x})",
UnaryOps.LOG2: lambda x: f"log2({x})",
UnaryOps.SIN: lambda x: f"sin({x})",
UnaryOps.SQRT: lambda x: f"sqrt({x})",
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})",
TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})"
}
# returns a str expression of the casted xs with the given type
def render_cast(self, x:List[str], var_dtype:DType) -> str:
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert self.float4 is not None, "cast is not supported on this platform"
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
if var_dtype == dtypes._int2: return f"{self.float4.replace('float4', 'int2')}({','.join(x)})"
raise NotImplementedError(f"no cast for {var_dtype}")
# returns a str expression of the const with the given type
def render_const(self, x:Union[float,int], var_dtype) -> str:
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
else: val = f"{x}f" if dtypes.is_float(var_dtype) and isinstance(x, float) else f"{int(x)}"
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert output_dtype == dtypes._float4, f"images must be float4, getting {output_dtype}"
return f"read_imagef({buf_name}, smp, {idx})"
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
if output_dtype.sz > 1:
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))"
else:
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
def render_local(self, name:str, size:int):
return self.smem_align + self.smem_prefix + f"float {name}[{size}];"
def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str:
return f"for (int {expr} = {_min}; {expr} < {_max}; ++{expr}) {{"
def render_if(self, cond: str):
return f"if ({cond}) {{"
def render_conditional(self, cond: str, x:str, y:str) -> str:
return f"({cond})?({x}):{y}"
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
self.arg_int_prefix if dtype == dtypes._arg_int32 else
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
return prg
# returns a str statement that does the store
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes._float4, "images must be float4"
return f"write_imagef({buf_name}, {idx}, {var_name});"
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
if var_dtype.sz > 1:
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
local_size: List[int] = []
kernel,prekernel,bufs = [],[],[]
#pend_close = None
depth = 1
def kk(s): kernel.append(" "*depth+s)
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
def ssa(u, prefix="t"):
nonlocal c, r
c[prefix] += 1
r[u]=f"{prefix}{c[prefix]-1}"
return r[u]
child_count: DefaultDict[UOp, int] = defaultdict(int)
for ru in uops:
for v in ru.vin:
child_count[v] += 1
for u in uops:
uop,dtype,vin,args,_ = u
if uop == UOps.LOOP:
kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]]))
depth += 1
elif uop == UOps.IF:
kk(lang.render_if(r[vin[0]]))
depth += 1
elif uop == UOps.BARRIER:
kk(lang.barrier)
elif uop == UOps.END:
depth -= 1
kk("}")
elif uop == UOps.WMMA:
if args[0] == "METAL":
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
kk("{ simdgroup_float8x8 a,b,c;")
kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};")
kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};")
kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};")
kk("simdgroup_multiply_accumulate(c, a, b, c);")
kk(f"{r[vin[4]]} = c.thread_elements()[0]; {r[vin[5]]} = c.thread_elements()[1]; }}")
elif args[0] == "HIP":
kk("{")
kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[0:16]])} }};")
kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[16:32]])} }};")
kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[32:]])} }};")
kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);")
for i in range(8): kk(f"{r[vin[32+i]]} = c_frag[{i}];")
kk("}")
else:
raise NotImplementedError(f"WMMA not implemented for {args}")
elif uop == UOps.ALU:
assert dtype is not None
# remove parens if ALU types are the same. TODO: can do more here
if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}:
val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]])
else:
val = lang.code_for_op[args](*[r[x] for x in vin])
assert child_count[u] != 0, f"childless ALU op found {u}"
if child_count[u] <= 1 or dtypes.is_int(dtype): # fix index rendering issue
r[u] = val
else:
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};")
elif uop == UOps.DEFINE_ACC:
assert dtype is not None
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
elif uop == UOps.SPECIAL:
xid = lang.gid if args[1].startswith("g") else (lang.xid if args[1].startswith("i") else lang.lid)
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */")
if args[1].startswith("l"): local_size.append(args[2])
r[u] = args[1]
elif uop == UOps.CONST:
r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
elif uop == UOps.LOAD:
assert dtype is not None
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)
if len(vin) > 2: val = lang.render_conditional(r[vin[2]], val, r[vin[3]])
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};")
elif uop == UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop == UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL))
elif uop == UOps.CAST and dtype is not None and dtype.sz > 1:
val = lang.render_cast([r[x] for x in vin], dtype)
if child_count[u] <= 1: r[u] = val
else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};")
elif uop == UOps.DEFINE_LOCAL:
if lang.external_local_bufs:
prekernel.append(lang.render_local(args[0], args[1]))
else:
kk(lang.render_local(args[0], args[1]))
r[u] = args[0]
elif uop == UOps.DEFINE_GLOBAL:
bufs.append(args)
r[u] = args[0]
elif uop == UOps.GEP:
r[u] = f"({r[vin[0]]}).{'xyzw'[args]}"
else:
raise RuntimeError(f"failed to render {uop}")
return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {}

View File

@@ -0,0 +1,23 @@
import functools
from tinygrad.helpers import dtypes
from tinygrad.ops import TernaryOps
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint64: "ulong" }
class OpenCLLanguage(CStyleLanguage):
kernel_prefix = "__kernel "
buffer_prefix = "__global "
smem_align = "__attribute__ ((aligned (16))) "
smem_prefix = "__local "
arg_int_prefix = "const int"
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable"
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
float4 = "(float4)"
gid = [f'get_group_id({i})' for i in range(3)]
lid = [f'get_local_id({i})' for i in range(3)]
xid = [f'get_global_id({i})' for i in range(3)]
uses_vload = True
# NOTE: mad is used so the loads aren't reordered into the math on 845
code_for_op = {**CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a,b,c: f"mad({a},{b},{c})"}
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())

View File

@@ -0,0 +1,111 @@
import ctypes
import numpy as np
from collections import defaultdict, deque
from typing import TypeVar, Type, Any, Dict, Deque, Tuple
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, ImageDType
_T = TypeVar("_T")
class RawBuffer: # pylint: disable=abstract-method
def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
self.size: int = size
self.dtype: DType = dtype
self._buf = buf if buf is not None else (allocator.alloc(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
self._memsz: int = size*dtype.itemsize
self._allocator = allocator
self._device = kwargs.get('device', None)
GlobalCounters.mem_used += self._memsz
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf)
def __repr__(self): return f"buffer<{self.size}, {self.dtype}, {id(self)}>"
@property
def key(self): return (self.size, self.dtype)
# NOTE: this interface allows for 0 copy
@classmethod
def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented")
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
class RawBufferCopyIn(RawBuffer):
def _copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
@classmethod
def fromCPU(cls, x:np.ndarray, **kwargs):
ret = cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
if x.size > 0: ret._copyin(x)
return ret
class RawBufferMapped(RawBufferCopyIn):
def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
# NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
def buffer_view(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size) # type: ignore
def toCPU(self) -> np.ndarray: return self.buffer_view().copy() # Need a copy, since jit will write to the same buffer.
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.buffer_view(), x.reshape(-1))
# this one is simple enough that i moved it out of the runtimes
class RawMallocBuffer(RawBufferMapped):
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64, dtypes.int16: ctypes.c_int16, dtypes.uint16: ctypes.c_uint16}[dtype] * size)())
def _buffer(self): return memoryview(self._buf)
class RawBufferCopyInOut(RawBufferCopyIn):
def _copyout(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
def toCPU(self) -> np.ndarray:
x: np.ndarray = np.empty(self.size, dtype=self.dtype.np)
if x.size > 0: self._copyout(x)
return x
class RawBufferTransfer(RawBuffer):
def _transfer(self, x) -> None: raise NotImplementedError("must be implemented")
@classmethod
def transfer(cls, x, shape, dtype, **kwargs):
ret = cls(prod(shape), dtype, **kwargs)
ret._transfer(x)
return ret
class LRUAllocator:
def __init__(self, dev_memsz=(4<<30)):
self.epoch = 0
self.free_space: Dict[Any, int] = defaultdict(lambda: dev_memsz)
self.buffer_info: Dict[Any, Tuple[int, DType, str]] = dict()
self.cached_buffers: Dict[Tuple[int, ...], Deque[Tuple[Any, int]]] = defaultdict(deque) # Cached buffer storage, splitted by type and size, newest first.
self.aging_order: Dict[Any, Deque[Tuple[Tuple[int, ...], int]]] = defaultdict(deque) # Keys of cached_buffers, ordered from oldest to newest updates.
def _cache_reuse_buffer(self, rawbufs: Deque[Tuple[Any, int]]): # The newest cached buffer is reused.
GlobalCounters.mem_cached -= self._underlying_buf_memsz(rawbufs[0][0])
return rawbufs.popleft()[0]
def ensure_has_free_space(self, size, dtype, device):
while len(self.aging_order[device]) and (self.free_space[device]-size*dtype.itemsize) < 0: # When OOM removing lru buffers.
bucket, epoch = self.aging_order[device].popleft()
if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
def _alloc_buffer(self, size, dtype, device, **kwargs):
self.ensure_has_free_space(size, dtype, device)
self.free_space[device] -= size*dtype.itemsize
newbuf = self._do_alloc(max(1, size), dtype, device, **kwargs)
self.buffer_info[newbuf] = (size, dtype, device)
return newbuf
def _free_buffer(self, buf_to_free):
self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free)
GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free)
self.buffer_info.pop(buf_to_free)
self._do_free(buf_to_free)
def alloc(self, size, dtype, device='0', **kwargs):
rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None)
return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs)
def free(self, buf): # free() just caches buffer. It might be freed later when OOM during allocation.
self.epoch += 1
size, dtype, device = self.buffer_info[buf]
self.cached_buffers[self._cached_bufkey(size, dtype, device)].appendleft((buf, self.epoch))
self.aging_order[device].append((self._cached_bufkey(size, dtype, device), self.epoch))
GlobalCounters.mem_cached += self._underlying_buf_memsz(buf)
def _underlying_buf_memsz(self, buf): return self.buffer_info[buf][0] * self.buffer_info[buf][1].itemsize
def _cached_bufkey(self, size, dtype, device) -> Tuple[int, ...]: return (device, size, dtype, dtype.shape) if isinstance(dtype, ImageDType) else (device, size, dtype) # Provides a key for reusing device buffers with identical keys.
def _do_alloc(self, size, dtype, device, **kwargs): raise NotImplementedError("must be implemented")
def _do_free(self, buf): pass

View File

@@ -0,0 +1,52 @@
import numpy as np
import operator
from typing import Callable, Dict, Tuple, Optional
from tinygrad.helpers import dtypes, DType
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
from tinygrad.runtime.lib import RawBuffer
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
base_fxn_for_op: Dict[Op, Callable] = {
BufferOps.MEM: lambda x: x._buf, UnaryOps.NEG: operator.neg, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
}
def match_types(x, y):
up = x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
return x.astype(up, copy=False), y.astype(up, copy=False)
def einsum_mulacc(einsum, get_strides, expand):
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
def axes_slice(strides): return [i for i,s in enumerate(strides) if s != 0], tuple([slice(None) if s != 0 else 0 for i,s in enumerate(strides)])
def mulacc(a, b, new_shape):
(a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b))
out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices])
return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
return mulacc
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])),
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
TernaryOps.WHERE: np.where,
}}
class RawNumpyBuffer(RawBuffer):
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None): super().__init__(size, dtype, buf if buf is not None else np.empty([size], dtype.np))
@classmethod
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
def toCPU(self): return self._buf
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, from_underlying=RawNumpyBuffer.fromCPU)

View File

@@ -0,0 +1,41 @@
import os, mmap
from typing import Optional
from typing import Callable, Dict, Tuple
from tinygrad.helpers import prod, DType
from tinygrad.runtime.lib import RawBufferMapped
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps
class RawDiskBuffer(RawBufferMapped):
def __init__(self, size, dtype:DType, device:Optional[str]=None, buf=None, shape=None, offset=0): # pylint: disable=super-init-not-called
self.shape = (size, ) if shape is None else shape
self.offset = offset # this is an offset in bytes
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
if device is not None:
f = open(device, "a+b")
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
buf = [f, mmap.mmap(f.fileno(), size * dtype.itemsize), 1]
else:
buf[2] += 1
# NOTE: we don't call super since disk tensors don't use RAM
self.size, self.dtype, self._buf = size, dtype, buf
def __del__(self):
self._buf[2] -= 1
if self._buf[2] == 0: self._buf[0].close()
def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset)
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
def shrink(self, arg):
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
offset = arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize
size = (arg[0][1]-arg[0][0]) * prod(self.shape[1:])
return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+self.shape[1:])
def as_strided(self, arg):
return RawDiskBuffer(prod(arg[0]), self.dtype, buf=self._buf, offset=self.offset+arg[2]*self.dtype.itemsize, shape=arg[0])
def _buffer(self): return memoryview(self._buf[1])[self.offset:self.offset+self.size*self.dtype.itemsize]
def readinto(self, buf):
self._buf[0].seek(self.offset)
self._buf[0].readinto(buf)
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op, from_underlying=lambda x:x)

View File

@@ -0,0 +1,110 @@
from __future__ import annotations
import os
os.environ['PYOPENCL_NO_CACHE'] = '1'
import pathlib
import numpy as np
import pyopencl as cl # type: ignore
from typing import Optional, List, Tuple
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport, diskcache
from tinygrad.ops import Compiled
from tinygrad.renderer.opencl import OpenCLRenderer
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
from tinygrad.codegen.kernel import LinearizerOptions
OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
# TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait()
ROCM_LLVM_PATH = pathlib.Path("/opt/rocm/llvm/bin")
#ROCM_LLVM_PATH = pathlib.Path(__file__).parents[3] / "extra/rocm/build/llvm-project/bin"
if DEBUG >= 5:
early_exec = fromimport("extra.helpers", "enable_early_exec")()
class CLAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs):
if isinstance(dtype, ImageDType):
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
else:
buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize)
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
return buf
class _CL:
def __init__(self):
cl_platforms = cl.get_platforms()
platform_devices: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl_platforms] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl_platforms]) if y]
self.devices = [device for device in platform_devices[getenv('CL_PLATFORM', 0)] if device.name not in getenv('CL_EXCLUDE', "").split(",")]
self.cl_platform = self.devices[0].platform
def post_init(self, device=None):
self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in self.devices] if device is None else [cl.Context(devices=[self.devices[device]])]
if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}")
self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs]
self.cl_allocator = CLAllocator(CL.cl_ctxs[0].devices[0].get_info(cl.device_info.GLOBAL_MEM_SIZE))
def synchronize(self):
for q in self.cl_queue: q.finish()
CL = _CL()
if not getenv("DELAYED_RUNTIME_INIT", False): CL.post_init()
class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device})
def _copyin(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements=['C', 'A']), is_blocking=False)
def _copyout(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
CL.cl_allocator.ensure_has_free_space(self.size, self.dtype, self._device)
buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data)
mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False)
with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([self.event] if hasattr(self, "event") else []))
def _transfer(self, x):
if "gfx" in CL.cl_ctxs[x._buf.device].devices[0].name:
cl.enqueue_copy_buffer_p2p_amd(CL.cl_platform, CL.cl_queue[x._buf.device], x._buf, self._buf, x.size * x.dtype.itemsize).wait()
else: raise NotImplementedError("p2p transfer between devices not implemented on non-amd")
@diskcache
def compile_gpu(prg:str) -> bytes:
clprg = cl.Program(CL.cl_ctxs[0], prg)
clprg.build()
return clprg.get_info(cl.program_info.BINARIES)[0]
class CLProgram:
def __init__(self, name:str, prg:bytes, argdtypes=None, options=None):
self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) for ctx in CL.cl_ctxs] # type: ignore
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
self.clprgs = [clprg.__getattr__(name) for clprg in self._clprgs]
if DEBUG >= 5 and not OSX:
if 'Adreno' in CL.cl_ctxs[0].devices[0].name:
fromimport('disassemblers.adreno', 'disasm')(prg)
elif CL.cl_ctxs[0].devices[0].name.startswith('gfx'):
asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], prg))
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
else:
# print the PTX for NVIDIA. TODO: probably broken for everything else
print(prg.decode('utf-8'))
if argdtypes is not None: self.set_argdtypes(argdtypes)
def set_argdtypes(self, argdtypes): self.argdtypes, _ = argdtypes, [clprg.set_scalar_arg_dtypes(argdtypes) for clprg in self.clprgs]
@staticmethod
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Optional[Tuple[int,int,int]]=None, wait=False) -> Optional[float]:
if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if x.__class__ is CLBuffer else np.int32 for x in bufs))
cl_bufs, wait_for = [], []
for x in bufs:
if x.__class__ is CLBuffer:
cl_bufs.append(x._buf)
if hasattr(x, "event"): wait_for.append(x.event)
else: cl_bufs.append(x)
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [int(g*l) for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=wait_for)
if wait:
e.wait()
try:
return ((e.profile.end - e.profile.start) * OSX_TIMING_RATIO) * 1e-9
except cl.RuntimeError: # no profiling info available
return None
return None
GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), OpenCLRenderer, compile_gpu, CLProgram, CL.synchronize)

View File

@@ -0,0 +1,221 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import functools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, cast
from tinygrad.ops import MovementOps
from tinygrad.helpers import prod, DEBUG, dedup
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
from tinygrad.shape.view import View
@functools.lru_cache(maxsize=None)
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
assert len(shape) == len(strides)
ret = [(shape[0], strides[0])] if shape else []
for i in range(1, len(shape)):
if ret[-1][1] == shape[i]*strides[i] or ret[-1][0] == 1:
ret[-1] = (ret[-1][0] * shape[i], strides[i])
elif shape[i] == 1:
continue
else:
ret.append((shape[i], strides[i]))
return tuple(ret)
def expr_node_mask(view:View, idx, valid=None) -> Node:
expr = [valid] if valid is not None else []
if view.mask is not None:
acc = 1
for ns,(x,y) in reversed(list(zip(view.shape, view.mask))):
if x != 0 or y != ns:
base = ((idx//acc) % ns)
expr += [base >= x, base < y]
acc *= ns
return Variable.ands(expr)
# generate an expression if you have a single idx variable
def expr_node(view:View, idx=None) -> Node:
if idx is None: idx = Variable('idx', 0, prod(view.shape)-1)
ret: List[Node] = [Variable.num(view.offset) if isinstance(view.offset, int) else view.offset] if view.offset else []
acc = 1
for d,s in reversed(to_shape_strides(view.shape, view.strides)):
ret.append(((idx//acc)%d)*s)
acc *= d
return Variable.sum(ret)
# generate an expression if you have a variable or expression for each index
def expr_idxs(view:View, idxs) -> Node:
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
return Variable.sum([Variable.num(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0])
@functools.lru_cache(maxsize=None)
def merge_views(vm2:View, vm1:View) -> Optional[View]:
if vm2.mask: return None # this isn't supported yet
mst = ShapeTracker((vm2, vm1))
strides = mst.real_strides()
if None in strides: return None
return View.create(vm1.shape, cast(Tuple[sint, ...], strides), mst.real_offset(), vm1.mask)
@functools.lru_cache(maxsize=None)
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
assert len(idxs) == len(shape), "need an idx for all dimensions"
acc = 1
ret = []
for tidx,d in reversed(list(zip(idxs, shape))):
ret.append(tidx * acc)
acc *= d
return Variable.sum(ret)
@dataclass(frozen=True)
class ShapeTracker:
views: Tuple[View, ...]
def __post_init__(self): assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views"
@staticmethod
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
@property
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
@property
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
# this is the real size (ish)
def size(self): return self.views[-1].size()
def vars(self) -> List[Variable]: return dedup(functools.reduce(operator.add, [v.vars() for v in self.views], []))
@property
def var_vals(self) -> Dict[Variable, int]:
ret:Dict[Variable, int] = {}
for v in self.vars():
var, val = v.unbind()
assert var not in ret or ret[var] == val, f"{var} has conflicted values {val} and {ret[var]}"
ret[var] = val
return ret
def unbind(self) -> ShapeTracker: return ShapeTracker(tuple(v.unbind() for v in self.views))
def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
to_apply:List[Tuple[MovementOps, Tuple]] = []
for v in self.views:
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
# first, we apply the offset
# then, we make it the correct shape
# then, we apply permutations
# TODO: don't use as_strided
to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset)))
# then, we apply pre expand pads
if v.mask is not None:
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
if any(x != (0,0) for x in pre_expand_pads):
to_apply.append((MovementOps.PAD, pre_expand_pads))
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
# then, we do any expands
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
# lastly, we apply post expand pads
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
return to_apply
# these are multiview strides, value is None if it's not a simple strided dimension
# TODO: this can be shared code between simplify and merge_views
def real_offset(self) -> sint:
real_offset, _ = self.expr_node(Variable('zero', 0, 0))
return real_offset.b if isinstance(real_offset, NumNode) else real_offset
# NOTE: if a stride is not always valid, it will be None
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
idx, valid = self.expr_idxs(idxs)
ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
ret[idxs.index(this_dim.a)] = this_dim.b
elif isinstance(this_dim, Variable) and this_dim in idxs:
ret[idxs.index(this_dim)] = 1
idx_vars, valid_vars = idx.vars(), valid.vars()
for i,tidx in enumerate(idxs):
if tidx in valid_vars and not ignore_valid: ret[i] = None
elif tidx not in idx_vars: ret[i] = 0
return tuple(ret)
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def _expr_idx(self, idx, valid) -> Tuple[Node, Node]:
for v in reversed(self.views[0:-1]):
if valid.max == 0: return Variable.num(-1), valid
valid = expr_node_mask(v, idx, valid)
idx = expr_node(v, idx)
return idx, valid
def simplify(self) -> ShapeTracker:
if len(self.views) >= 2:
new_view = merge_views(self.views[-2], self.views[-1])
if new_view:
if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
return self
def expr_idxs(self, idxs=None):
if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
idx = expr_idxs(self.views[-1], tuple(idxs))
valid = expr_node_mask(self.views[-1], idxs_to_idx(self.views[-1].shape, tuple(idxs)))
return self._expr_idx(idx, valid)
def expr_node(self, idx='idx'):
if idx.__class__ is str: idx = Variable(idx, 0, prod(self.shape)-1)
return self._expr_idx(expr_node(self.views[-1], idx), expr_node_mask(self.views[-1], idx))
def axis_is_masked(self, axis) -> bool:
_, valid = self.expr_idxs()
return f'idx{axis}' in [v.expr for v in valid.vars()]
# *** under this line are the movement ops ***
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
def permute(self, axis: Tuple[int, ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
def stride(self, mul: Tuple[int, ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
new_view = self.views[-1].reshape(new_shape)
if new_view is None:
extra_view = View.create(new_shape)
# last chance to merge. TODO: move into View
if (merged_view := merge_views(self.views[-1], extra_view)) is not None:
return ShapeTracker(self.views[0:-1] + (merged_view,))
return ShapeTracker(self.views + (extra_view, ))
return ShapeTracker(self.views[0:-1] + (new_view,))
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
# TODO: if we remove movementops from lazy.py we can delete this
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
# Pre-allocate all groups.
axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
# Index for new_shape and axis_groups.
i: int = 0
old_shape_i: int = 0
while old_shape_i < len(old_shape):
# 1s exist in new_shape only will lead to empty axes group creations.
if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
if i < len(new_shape) - 1: i += 1
else:
axis_groups[i].append(old_shape_i)
axis_group_size = prod([old_shape[x] for x in axis_groups[i]])
# Move to next axes group if total size of all dimensions match.
if axis_group_size == new_shape[i]:
if i < len(new_shape) - 1: i += 1
elif axis_group_size > new_shape[i]: return None
old_shape_i += 1
return axis_groups

View File

@@ -0,0 +1,352 @@
from __future__ import annotations
from abc import abstractmethod
import functools
from math import gcd
from itertools import product
from tinygrad.helpers import partition
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
class Node:
b: Union[Node, int]
min: int
max: int
def render(self, ops=None, ctx=None) -> Any:
if ops is None: ops = render_python
assert self.__class__ in (Variable, NumNode) or self.min != self.max
return ops[type(self)](self, ops, ctx)
def vars(self): return []
def expand_idx(self) -> VariableOrNum: return next((v for v in self.vars() if v.expr is None), NumNode(0))
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
# expand increments earlier variables faster than later variables (as specified in the argument)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def expand(self, idxs:Optional[Tuple[VariableOrNum, ...]]=None) -> List[Node]:
if idxs is None: idxs = (self.expand_idx(),)
return [self.substitute(dict(zip(idxs, (NumNode(x) for x in rep)))) for rep in Node.iter_idxs(idxs)]
@staticmethod
def iter_idxs(idxs:Tuple[VariableOrNum, ...]) -> Iterator[Tuple[int,...]]:
yield from (x[::-1] for x in product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
# substitute Variables with the values in var_vals
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__)
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
@functools.cached_property
def hash(self) -> int: return hash(self.key)
def __repr__(self): return self.render(ctx="REPR")
def __str__(self): return "<"+self.key+">"
def __hash__(self): return self.hash
def __bool__(self): return not (self.max == self.min == 0)
def __eq__(self, other:object) -> bool:
if not isinstance(other, Node): return NotImplemented
return self.key == other.key
def __neg__(self): return self*-1
def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
def __radd__(self, b:int): return self+b
def __sub__(self, b:Union[Node,int]): return self+-b
def __rsub__(self, b:int): return -self+b
def __le__(self, b:Union[Node,int]): return self < (b+1)
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b))
def __mul__(self, b:Union[Node, int]):
if b == 0: return NumNode(0)
if b == 1: return self
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
def __rmul__(self, b:int): return self*b
# *** complex ops ***
def __rfloordiv__(self, b:int):
if self.min > b >= 0: return NumNode(0)
if isinstance(self, NumNode): return NumNode(b // self.b)
raise RuntimeError(f"not supported: {b} // {self}")
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
if isinstance(b, Node):
if b.__class__ is NumNode: return self // b.b
if self == b: return NumNode(1)
if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
raise RuntimeError(f"not supported: {self} // {b}")
assert b != 0
if b < 0: return (self//-b)*-1
if b == 1: return self
# the numerator of div is not allowed to be negative
if self.min < 0:
offset = self.min//b
# factor out an "offset" to make the numerator positive. don't allowing factoring again
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
return create_node(DivNode(self, b))
def __rmod__(self, b:int):
if self.min > b >= 0: return NumNode(b)
if isinstance(self, NumNode): return NumNode(b % self.b)
raise RuntimeError(f"not supported: {b} % {self}")
def __mod__(self, b:Union[Node,int]):
if isinstance(b, Node):
if b.__class__ is NumNode: return self % b.b
if self == b: return NumNode(0)
if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node
raise RuntimeError(f"not supported: {self} % {b}")
assert b > 0
if b == 1: return NumNode(0)
if self.min >= 0 and self.max < b: return self
if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
if self.min < 0: return (self - ((self.min//b)*b)) % b
return create_node(ModNode(self, b))
@staticmethod
def num(num:int) -> NumNode: return NumNode(num)
@staticmethod
def factorize(nodes:List[Node]) -> List[Node]:
mul_groups: Dict[Node, int] = {}
for x in nodes:
a,b = (x.a,x.b) if isinstance(x, MulNode) else (x,1)
mul_groups[a] = mul_groups.get(a, 0) + b
return [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
@staticmethod
def sum(nodes:List[Node]) -> Node:
nodes = [x for x in nodes if x.max or x.min]
if not nodes: return NumNode(0)
if len(nodes) == 1: return nodes[0]
new_nodes: List[Node] = []
num_node_sum = 0
for node in SumNode(nodes).flat_components:
if node.__class__ is NumNode: num_node_sum += node.b
else: new_nodes.append(node)
if len(new_nodes) > 1 and len(set([x.a if isinstance(x, MulNode) else x for x in new_nodes])) < len(new_nodes):
new_nodes = Node.factorize(new_nodes)
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
@staticmethod
def ands(nodes:List[Node]) -> Node:
if not nodes: return NumNode(1)
if len(nodes) == 1: return nodes[0]
if any(not x for x in nodes): return NumNode(0)
# filter 1s
nodes = [x for x in nodes if x.min != x.max]
return create_rednode(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
# 4 basic node types
class Variable(Node):
def __new__(cls, expr:Optional[str], nmin:int, nmax:int):
assert nmin >= 0 and nmin <= nmax
if nmin == nmax: return NumNode(nmin)
return super().__new__(cls)
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
self.val:Optional[int] = None
def bind(self, val):
assert self.val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
self.val = val
return self
def unbind(self) -> Tuple[Variable, int]:
assert self.val is not None, f"cannot unbind {self}"
return Variable(self.expr, self.min, self.max), self.val
def vars(self): return [self]
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return var_vals[self] if self in var_vals else self
class NumNode(Node):
def __init__(self, num:int):
assert isinstance(num, int), f"{num} is not an int"
self.b:int = num
self.min, self.max = num, num
def bind(self, val):
assert self.b == val, f"cannot bind {val} to {self}"
return self
def __eq__(self, other): return self.b == other
def __hash__(self): return self.hash # needed with __eq__ override
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self
def create_node(ret:Node):
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
if ret.min == ret.max: return NumNode(ret.min)
return ret
class OpNode(Node):
def __init__(self, a:Node, b:Union[Node, int]):
self.a, self.b = a, b
self.min, self.max = self.get_bounds()
def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else [])
@abstractmethod
def get_bounds(self) -> Tuple[int, int]: pass
class LtNode(OpNode):
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
def get_bounds(self) -> Tuple[int, int]:
if isinstance(self.b, int):
return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
class MulNode(OpNode):
def __lt__(self, b: Union[Node, int]):
if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1: return Node.__lt__(self, b)
sgn = 1 if self.b > 0 else -1
return Node.__lt__(self.a*sgn, (b + abs(self.b) - 1)//abs(self.b))
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
if self.b % b == 0: return self.a*(self.b//b)
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
return Node.__floordiv__(self, b, factoring_allowed)
def __mod__(self, b: Union[Node, int]):
a = (self.a * (self.b%b))
return Node.__mod__(a, b)
def get_bounds(self) -> Tuple[int, int]:
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
class DivNode(OpNode):
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0 and isinstance(self.b, int)
return self.a.min//self.b, self.a.max//self.b
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) // self.b
class ModNode(OpNode):
def __mod__(self, b: Union[Node, int]):
if isinstance(b, Node) or isinstance(self.b, Node): return Node.__mod__(self, b)
return self.a % b if gcd(self.b, b) == b else Node.__mod__(self, b)
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
return Node.__floordiv__(self, b, factoring_allowed)
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0 and isinstance(self.b, int)
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b)
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) % self.b
class RedNode(Node):
def __init__(self, nodes:List[Node]): self.nodes = nodes
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
class SumNode(RedNode):
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
fully_divided: List[Node] = []
rest: List[Node] = []
if isinstance(b, SumNode):
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
if isinstance(b, Node):
for x in self.flat_components:
if x % b == 0: fully_divided.append(x // b)
else: rest.append(x)
if (sum_fully_divided:=create_rednode(SumNode, fully_divided)) != 0: return sum_fully_divided + create_rednode(SumNode, rest) // b
return Node.__floordiv__(self, b, False)
if b == 1: return self
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
fully_divided, rest = [], []
_gcd = b
divisor = 1
for x in self.flat_components:
if x.__class__ in (NumNode, MulNode):
if x.b%b == 0: fully_divided.append(x//b)
else:
rest.append(x)
_gcd = gcd(_gcd, x.b)
if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
else:
rest.append(x)
_gcd = 1
if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd)
if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __mod__(self, b: Union[Node, int]):
if isinstance(b, SumNode):
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
new_nodes: List[Node] = []
for x in self.nodes:
if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b))
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
else: new_nodes.append(x)
return Node.__mod__(Node.sum(new_nodes), b)
def __lt__(self, b:Union[Node,int]):
lhs: Node = self
if isinstance(b, int):
new_sum = []
for x in self.nodes:
# TODO: should we just force the last one to always be the number
if isinstance(x, NumNode): b -= x.b
else: new_sum.append(x)
lhs = Node.sum(new_sum)
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
if muls:
# NOTE: gcd in python 3.8 takes exactly 2 args
mul_gcd = b
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell x.b is int here
all_others = Variable.sum(others)
if all_others.min >= 0 and all_others.max < mul_gcd:
lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
return Node.__lt__(lhs, b)
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Variable.sum([node.substitute(var_vals) for node in self.nodes])
@property
def flat_components(self): # recursively expand sumnode components
new_nodes = []
for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x])
return new_nodes
class AndNode(RedNode):
def __floordiv__(self, b: Union[Node, int], _=True): return Variable.ands([x//b for x in self.nodes])
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
subed = []
for node in self.nodes:
if not (sub:=node.substitute(var_vals)): return NumNode(0)
subed.append(sub)
return Variable.ands(subed)
def create_rednode(typ:Type[RedNode], nodes:List[Node]):
ret = typ(nodes)
if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes]))
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
return create_node(ret)
@functools.lru_cache(maxsize=None)
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
if isinstance(a, (int, float)): return a
ret = a.substitute({k:Variable.num(v) for k, v in var_vals.items()})
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
return ret.b
# symbolic int
sint = Union[Node, int]
VariableOrNum = Union[Variable, NumNode]
render_python: Dict[Type, Callable] = {
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self.val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"),
NumNode: lambda self,ops,ctx: f"{self.b}",
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
}

View File

@@ -0,0 +1,132 @@
from __future__ import annotations
import functools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, cast
from tinygrad.helpers import prod, all_int, dedup
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, is_sym_int, sint
@functools.lru_cache(maxsize=None)
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
@functools.lru_cache(maxsize=None)
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
strides = [1] if shape else []
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
return filter_strides(shape, tuple(strides))
@dataclass(frozen=True)
class View:
shape:Tuple[sint, ...]
strides:Tuple[sint, ...]
offset:sint
mask:Optional[Tuple[Tuple[sint, sint], ...]]
contiguous:bool
@staticmethod
@functools.lru_cache(maxsize=None)
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
strides = filter_strides(shape, strides) if strides else strides_for_shape(shape)
contiguous = offset == 0 and mask is None and all(s1 == s2 for s1,s2 in zip(strides, strides_for_shape(shape)))
return View(shape, strides, offset, mask, contiguous)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def size(self): return prod([s.max if isinstance(s, Node) else s for s,st in zip(self.shape, self.strides) if st != 0])
def vars(self) -> List[Variable]:
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
return dedup(functools.reduce(operator.add, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], []))
def unbind(self) -> View:
unbound_vars:Dict[VariableOrNum,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None}
new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
return View.create(new_shape, new_strides, new_offset, new_mask)
# MovementOps live here now
def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
if self.mask:
# move the old mask
nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.mask, arg)])
# merge the masks if we have two
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
shape = [y-x for x,y in arg]
return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View:
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
if any(b or e for b, e in arg):
zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
return self.__unsafe_resize(zvarg, mask=mask)
return self
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape)
return self.__unsafe_resize(arg)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def expand(self, new_shape: Tuple[sint, ...]) -> View:
assert len(new_shape) == len(self.shape)
assert all(is_sym_int(x) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
# NOTE: can the mask ever be (0,0)?
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
return View.create(new_shape, self.strides, self.offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def permute(self, axis: Tuple[int, ...]) -> View:
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
return View.create(tuple([self.shape[a] for a in axis]), tuple([self.strides[a] for a in axis]), self.offset, tuple([self.mask[a] for a in axis]) if self.mask is not None else None)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def stride(self, mul: Tuple[int, ...]) -> View:
# except for the negative case, you can build this from the others. invertible in the negative case
assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
strides = tuple([z*m for z,m in zip(self.strides, mul)])
new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
return View.create(new_shape, strides, self.offset + offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
if self.shape == new_shape: return self
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
# check for the same size
if all_int(self.shape):
if all_int(new_shape):
assert prod(self.shape) == prod(new_shape), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
else:
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
assert prod(self.shape) == prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
# after the asserts, it's okay to check contiguous
if self.contiguous: return View.create(new_shape)
# check if this is adding or removing 1s (only)
# NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
if [x for x in self.shape if x != 1] == [x for x in new_shape if x != 1]:
new_strides: List[sint] = [y for x,y in zip(self.shape, self.strides) if x != 1]
new_strides_tuple: Tuple[sint, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape])
new_mask_tuple: Optional[Tuple[Tuple[sint, sint], ...]] = None
if self.mask:
for x,y in zip(self.shape, self.mask):
if x == 1 and y != (0, 1):
new_mask_tuple = ((0,0),) * len(new_shape)
break
else:
new_mask: List[Tuple[sint, sint]] = [y for x,y in zip(self.shape, self.mask) if x != 1]
new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape])
return View.create(new_shape, new_strides_tuple, self.offset, new_mask_tuple)
# TODO: bring the merge_views logic here for more caching
return None

View File

@@ -0,0 +1,790 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math
from collections import defaultdict
from functools import partialmethod, reduce
from itertools import accumulate
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import Device, LoadOps
from tinygrad.shape.symbolic import sint
from tinygrad.realize import run_schedule
# An instantiation of the Function is the Context
class Function:
def __init__(self, device:str, *tensors:Tensor):
self.device = device
self.needs_input_grad = [t.requires_grad for t in tensors]
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
if self.requires_grad: self.parents = tensors
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
@classmethod
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
ctx = fxn(x[0].device, *x)
ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
return ret
import tinygrad.mlops as mlops
# **** start with two base classes, Tensor and Function ****
class Tensor:
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',)
training: ClassVar[bool] = False
class train:
def __init__(self, val=True): self.val = val
def __enter__(self):
self.prev = Tensor.training
Tensor.training = self.val
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev
no_grad: ClassVar[bool] = False
default_type: ClassVar[DType] = dtypes.float32
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = Device.canonicalize(device)
# tensors have gradients, buffers do not
self.grad: Optional[Tensor] = None
# NOTE: this can be in three states. False and None: no gradient, True: gradient
# None (the default) will be updated to True if it's put in an optimizer
self.requires_grad: Optional[bool] = requires_grad
# internal variables used for autograd graph construction
self._ctx: Optional[Function] = None
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, (int, float)):
data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
elif data.__class__ is list:
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np))
elif isinstance(data, np.ndarray):
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
if data.shape == ():
data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
else:
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
else: raise RuntimeError(f"can't create Tensor from {data}")
# data is a LazyBuffer, but it might be on the wrong device
self.lazydata = data if data.device == device else data.copy_to_device(device)
def __repr__(self):
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
# Python has a non moving GC, so this should be okay
def __hash__(self): return id(self)
@property
def device(self) -> str: return self.lazydata.device
@property
def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape
@property
def dtype(self) -> DType: return self.lazydata.dtype
# ***** data handlers ****
@staticmethod
def corealize(lst:Iterable[Tensor]):
seen:Set[LazyBuffer] = set()
sched = []
for t in lst: sched += t.lazydata.schedule(seen)
run_schedule(sched)
def realize(self) -> Tensor:
run_schedule(self.lazydata.schedule())
return self
def assign(self, x) -> Tensor:
# TODO: this is a hack for writing to DISK
if self.device.startswith("DISK"):
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
self.contiguous().realize().lazydata.realized._copyin(x.numpy()) # type: ignore
return self
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
assert not x.requires_grad # self requires_grad is okay?
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
self.lazydata = x.lazydata
return self
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
def numpy(self) -> np.ndarray:
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape)
# TODO: if things are realized this won't work
def to_(self, device:str):
assert self.lazydata.realized is None
self.lazydata.device = device
if self.grad: self.grad.to_(device)
def to(self, device:str) -> Tensor:
ret = Tensor(self.lazydata, device)
if self.grad: ret.grad = self.grad.to(device)
return ret
# ***** creation llop entrypoint *****
@staticmethod
def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
@staticmethod
def empty(*shape, **kwargs):
assert all_int(shape), f"cannot create with symbolic shape {shape}"
return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape)
_seed: int = int(time.time())
@staticmethod
def manual_seed(seed=0): Tensor._seed = seed
@staticmethod
def rand(*shape, **kwargs):
assert all_int(shape), f"cannot create with symbolic shape {shape}"
Tensor._seed += 1
return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
# ***** creation helper functions *****
@staticmethod
def full(shape:Tuple[sint, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
@staticmethod
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)
@staticmethod
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
@staticmethod
def arange(start, stop=None, step=1, **kwargs):
if stop is None: stop, start = start, 0
return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
@staticmethod
def eye(dim:int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
def full_like(self, fill_value, **kwargs):
return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
# ***** rng hlops *****
@staticmethod
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
src = Tensor.rand(2, *shape, **kwargs)
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
@staticmethod
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
dtype = kwargs.pop("dtype", Tensor.default_type)
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(shape)**-0.5)
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
# ***** toposort and backward pass *****
def deepwalk(self):
def _deepwalk(node, visited, nodes):
visited.add(node)
if getattr(node, "_ctx", None):
for i in node._ctx.parents:
if i not in visited: _deepwalk(i, visited, nodes)
nodes.append(node)
return nodes
return _deepwalk(self, set(), [])
def backward(self):
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
# this is "implicit gradient creation"
self.grad = Tensor(1, device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()):
assert (t0.grad is not None)
grads = t0._ctx.backward(t0.grad.lazydata)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
for t, g in zip(t0._ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
del t0._ctx
# ***** movement mlops *****
def reshape(self, shape, *args) -> Tensor:
new_shape = argfix(shape, *args)
assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}"
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]))
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)
# ***** movement hlops *****
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
# - A slice i:j returns the elements with indices in [i, j)
# - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
# - Negative values for i and j are taken relative to the end of the sequence
# - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence
# - Indexing with None on a given axis will add a new dimension of size one before that axis
# - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends).
# - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len).
# - Strides > 1 and < 0 are now allowed!:
# - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional)
# - Idea of stride < 0 support:
# - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
# - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink):
# - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s].
# - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
# is possible.
# - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s].
# - Fancy indexing and combined indexing is supported
# - Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing
# - Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively
# - The first iteration will expand the dim of self while consecutive iterations will reduce the dim
# - There's a special case where a permute is needed at the end:
# - if first Tensor passed in (expand dims) is not at dim 0
# - and following Tensors does not follow consecutively to the end of fancy indexing's dims
def __getitem__(self, val): # val: Union[int, slice, Tensor, None, Ellipsis, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]]
def normalize_int(e, i, dim_sz):
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
orig_slices = list(val) if isinstance(val, tuple) else [val]
count = defaultdict(list)
for i,v in enumerate(orig_slices): count[type(v)].append(i)
if (num_slices := len(count[int]) + len(count[slice]) + len(count[Tensor])) > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
if len(ellipsis_found := count[type(Ellipsis)]) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
ellipsis_idx = ellipsis_found[0] if ellipsis_found else len(orig_slices)
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
valid_slices = [v for v in orig_slices if v is not None]
valid_slices = [v if isinstance(v, slice) else slice(y_ := normalize_int(v, i, dim_sz), y_+1) if isinstance(v, int) else slice(None) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides))
sliced_tensor = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0])
new_shape = sliced_tensor.shape
if any(abs(s) != 1 for s in strides):
strides = tuple(abs(s) for s in strides)
# Pad: add pad at the end: [dim_sz] -> [dim_sz_padded]
padded_tensor = sliced_tensor.pad(tuple((0, s-(dim_sz % s) if dim_sz % s != 0 else 0) for s, dim_sz in zip(strides, sliced_tensor.shape)))
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
reshaped_tensor = padded_tensor.reshape(flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides)))
new_shape = reshaped_tensor.shape[::2]
# Shrink: do [:, 0]
sliced_tensor = reshaped_tensor.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in new_shape)))
final_shape, it_shape, dim, tensors, dim_collapsed = [], iter(new_shape), [], [], 0
for i,s in enumerate(orig_slices):
if s is None: final_shape.append(1)
else: # s is int or slice or Tensor
dim_shape = next(it_shape)
if isinstance(s, int):
dim_collapsed += 1
else:
assert isinstance(dim_shape, int), f"does not support symbolic shape {dim_shape}"
final_shape.append(dim_shape)
if isinstance(s, Tensor):
tensors.append(s)
dim.append(i-dim_collapsed)
ret = sliced_tensor.reshape(tuple(final_shape))
if tensors: # Fancy/tensor indexing
# normalize idx
# TODO: first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm
idx = [t.sign().contiguous().__neg__().contiguous().relu() * ret.shape[d] + t for d,t in zip(dim, tensors)]
max_dim = max(i.ndim for i in idx)
# compute sum_dim, arange, and idx
sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(dim)]
arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, dim))]
first_idx = [idx[0].reshape(*[1]*dim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - dim[0] - 1))]
rest_idx = [i.reshape(*[1]*dim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - dim[0] - n)) for n,i in enumerate(idx[1:], 1)]
idx = first_idx + rest_idx
ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:])
# iteratively fancy index
for a,i,sd in zip(arange, idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
# special permute case
if dim[0] != 0 and len(dim) != 1 and dim != list(range(dim[0], dim[-1]+1)):
ret_dims = list(range(ret.ndim))
ret = ret.permute(ret_dims[dim[0]:dim[0]+max_dim] + ret_dims[:dim[0]] + ret_dims[dim[0]+max_dim:])
return ret
def __setitem__(self,s,v): return self.__getitem__(s).assign(v)
# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
def gather(self: Tensor, idx: Tensor, dim: int):
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
if dim < 0: dim += self.ndim
idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
permarg = list(range(self.ndim))
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
catargs = [self, *args]
assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated"
shapes = [s.shape[dim] for s in catargs]
shape_cumsum = [0, *accumulate(shapes)]
slc = [[(0, 0) for _ in self.shape] for _ in catargs]
for shp,k,s in zip(shapes, shape_cumsum[:-1], slc):
s[dim] = (k, shape_cumsum[-1] - k - shp)
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
@staticmethod
def stack(tensors, dim=0):
first = tensors[0].unsqueeze(dim)
unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors[1:]]
# checks for shapes and number of dimensions delegated to cat
return first.cat(*unsqueezed_tensors, dim=dim)
def repeat(self, repeats):
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
new_shape = [x for b in base_shape for x in [1, b]]
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
final_shape = [r*s for r,s in zip(repeats, base_shape)]
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
def chunk(self, num:int, dim:int) -> List[Tensor]:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
return [self[tuple(sl)] for sl in slice_params]
def squeeze(self, dim=None):
if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1])
if dim <= 0 and self.ndim == 0: return self # This is to match PyTorch behavior
if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})")
if dim < 0: dim += self.ndim
return self if self.shape[dim] != 1 else self.reshape(*[size for idx, size in enumerate(self.shape) if idx != dim])
def unsqueeze(self, dim):
if dim < 0: dim = len(self.shape) + dim + 1
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
# (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(self, padding:Union[List[int], Tuple[int, ...]], value:float=0):
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
@property
def T(self) -> Tensor: return self.transpose()
def transpose(self, ax1=1, ax2=0) -> Tensor:
order = list(range(len(self.shape)))
order[ax1], order[ax2] = order[ax2], order[ax1]
return self.permute(order)
def flatten(self, start_dim=0): return self.reshape(shape=self.shape[:start_dim] + (-1,))
# ***** reduce ops *****
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if axis.__class__ is int else list(axis)) # type: ignore
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
shape = [s for i,s in enumerate(self.shape) if i not in axis_]
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
return ret if keepdim else ret.reshape(shape=shape)
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
def mean(self, axis=None, keepdim=False):
assert all_int(self.shape), "does not support symbolic shape"
out = self.sum(axis=axis, keepdim=keepdim)
return out.mul(prod(out.shape)/prod(self.shape))
def std(self, axis=None, keepdim=False, correction=1):
assert all_int(self.shape), "does not support symbolic shape"
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
return square_sum.div(prod(self.shape)/prod(square_sum.shape)-correction).sqrt()
def _softmax(self, axis):
m = self - self.max(axis=axis, keepdim=True)
e = m.exp()
return m, e, e.sum(axis=axis, keepdim=True)
def softmax(self, axis=-1):
_, e, ss = self._softmax(axis)
return e.div(ss)
def log_softmax(self, axis=-1):
m, _, ss = self._softmax(axis)
return m - ss.log()
def argmax(self, axis=None, keepdim=False):
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
return prod(self.shape) - idx.max() - 1
axis = axis + len(self.shape) if axis < 0 else axis
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
# ***** processing ops *****
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)])
# slide by dilation
xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
xup = xup.slice(slc_prefix + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_)))
# handle stride, and permute to move reduce to the end
xup = xup.reshape(*prefix, *flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))
xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_)))
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))])
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_))))
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
# NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing)
stride = make_pair(stride, len(HW))
if any(s>1 for s in stride):
x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:]))
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])
x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding)
wino = int(getenv("WINO", "0"))
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}"
padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1])
# conv2d is a pooling op (with padding)
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not Tensor.wino:
# normal conv
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat])
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order almost doubles compilation time
# todo: stride == dilation
# use padding to round up to 4x4 output tiles
d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # (bs, cin_, tyx, HWI)
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))).contiguous_backward() # move HW to the front: # (HWI, bs, cin_, tyx)
tyx = d.shape[-len(HWI):] # dim of tiling
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
# compute 6x6 winograd tiles: GgGt, BtdB
gfactors = apply_matrix(winograd_G, g).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx) # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW))) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx])) # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
def dot(self, w:Tensor) -> Tensor:
n1, n2 = len(self.shape), len(w.shape)
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
return (x*w).sum(-1)
def cumsum(self, axis:int=0) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-1,0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
# ***** mlops (unary) *****
def __neg__(self): return mlops.Neg.apply(self)
def contiguous(self): return mlops.Contiguous.apply(self)
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
def log(self): return mlops.Log.apply(self)
def log2(self): return mlops.Log.apply(self)/math.log(2)
def exp(self): return mlops.Exp.apply(self)
def exp2(self): return mlops.Exp.apply(self*math.log(2))
def relu(self): return mlops.Relu.apply(self)
def sigmoid(self): return mlops.Sigmoid.apply(self)
def sin(self): return mlops.Sin.apply(self)
def sqrt(self): return mlops.Sqrt.apply(self)
def rsqrt(self): return (1/self).sqrt()
def cos(self): return ((math.pi/2)-self).sin()
def tan(self): return self.sin() / self.cos()
@staticmethod
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)
# ***** math functions (unary) *****
def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype)
def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b)
def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b)
def square(self): return self*self
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
def abs(self): return self.relu() + (-self).relu()
def sign(self): return self / (self.abs() + 1e-10)
def reciprocal(self): return 1.0/self
# ***** activation functions (unary) *****
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
def swish(self): return self * self.sigmoid()
def silu(self): return self.swish() # The SiLU function is also known as the swish function.
def relu6(self): return self.relu() - (self-6).relu()
def hardswish(self): return self * (self+3).relu6() * (1/6)
def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def hardtanh(self, min_val=-1, max_val=1): return self.clip(min_val, max_val)
def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
def quick_gelu(self): return self * (self * 1.702).sigmoid()
def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
def mish(self): return self * self.softplus().tanh()
def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
def softsign(self): return self / (1 + self.abs())
# ***** broadcasted binary mlops *****
def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tensor, Tensor]:
x: Tensor = self
if not isinstance(y, Tensor):
y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32)
if reverse: x, y = y, x
if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
shape_delta = len(xshape) - len(yshape)
if shape_delta > 0: y = y.reshape((1,) * shape_delta + yshape)
elif shape_delta < 0: x = x.reshape((1,) * -shape_delta + xshape)
if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
shape_ret = tuple([max(x, y) for x, y in zip(xshape, yshape)])
if xshape != shape_ret: x = x.expand(shape_ret)
if yshape != shape_ret: y = y.expand(shape_ret)
return (x, y)
def _to_float(self, x:Union[Tensor, float]):
return x.lazydata.op.arg if isinstance(x, Tensor) and not x.lazydata.realized and x.lazydata.op.op == LoadOps.CONST and not x.requires_grad \
and x.lazydata.st.contiguous and self._broadcasted(x)[0].shape == self.shape else x
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor:
x = self._to_float(x)
return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor:
x = self._to_float(x)
return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else (-self if reverse else self)
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor:
x = self._to_float(x)
if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self)
if x.__class__ is not Tensor and x == -1.0: return -self
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor:
x = self._to_float(x)
return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x)
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
x = self._to_float(x)
if x.__class__ is not Tensor and not reverse:
# simple pow identities
if x < 0: return self.reciprocal().pow(-x)
if x == 3.0: return self*self*self
if x == 2.0: return self*self
if x == 1.0: return self
if x == 0.5: return self.sqrt()
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos()
# we only need to correct the sign if the base is negative
base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2
# we need 0 to be positive so we need to correct base_sign when the base is 0
base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x))))))
# inject nan if the base is negative and the power is not an integer
to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign
inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
def maximum(self, x:Union[Tensor, float]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
x_,y = self._broadcasted(input_)
x,z = x_._broadcasted(other)
return mlops.Where.apply(x, *y._broadcasted(z))
# ***** binary op wrappers (18 wasted lines to make the typechecker happy) *****
# NOTE: __pow__ and friends are broken in mypyc with the ** operator
def __add__(self, x) -> Tensor: return self.add(x)
def __sub__(self, x) -> Tensor: return self.sub(x)
def __mul__(self, x) -> Tensor: return self.mul(x)
def __pow__(self, x) -> Tensor: return self.pow(x)
def __truediv__(self, x) -> Tensor: return self.div(x)
def __matmul__(self, x) -> Tensor: return self.matmul(x)
def __radd__(self, x) -> Tensor: return self.add(x, True)
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
def __rmul__(self, x) -> Tensor: return self.mul(x, True)
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
def __ge__(self, x) -> Tensor: return 1.0-(self<x)
def __le__(self, x) -> Tensor: return 1.0-(self>x)
def __ne__(self, x) -> Tensor: return (self<x) + (self>x) # type: ignore
def __eq__(self, x) -> Tensor: return 1.0-(self != x) # type: ignore
# ***** functional nn ops *****
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
return x.add(bias) if bias is not None else x
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return reduce(lambda x,f: f(x), ll, self)
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
y = (self - self.mean(axis, keepdim=True))
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
x = (self - mean.reshape(shape=[1, -1, 1, 1]))
if weight: x = x * weight.reshape(shape=[1, -1, 1, 1])
ret = x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd)
return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret
def dropout(self, p=0.5) -> Tensor:
if not Tensor.training or p == 0: return self
mask = (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p).cast(dtypes.bool)
return self * mask * (1/(1.0 - p))
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
# NOTE: it works if key, value have symbolic shape
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
def binary_crossentropy(self, y:Tensor) -> Tensor:
return (-y*self.log() - (1-y)*(1-self).log()).mean()
def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
return (self.maximum(0) - y * self + (1 + self.abs().__neg__().exp()).log()).mean()
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
loss_mask = Y != ignore_index
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
return self.log_softmax().mul(y).sum() / loss_mask.sum()
# ***** cast ops *****
def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
def bitcast(self, dtype:DType) -> Tensor:
assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
def float(self) -> Tensor: return self.cast(dtypes.float32)
def half(self) -> Tensor: return self.cast(dtypes.float16)
# ***** convenience stuff *****
@property
def ndim(self) -> int: return len(self.shape)
def numel(self) -> sint: return prod(self.shape)
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
# register functions to move between devices
for device in Device._buffers:
setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
setattr(Tensor, f"{device.lower()}_", partialmethod(Tensor.to_, device))
if IMAGE:
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
from tinygrad.features.image import image_conv2d, image_dot
setattr(Tensor, "conv2d", image_conv2d)
setattr(Tensor, "dot", image_dot)