openpilot v0.9.6 release
date: 2024-01-12T10:13:37 master commit: ba792d576a49a0899b88a753fa1c52956bedf9e6
This commit is contained in:
209
tinygrad_repo/extra/onnx.py
Normal file
209
tinygrad_repo/extra/onnx.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from __future__ import annotations
|
||||
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
|
||||
import importlib
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
|
||||
from typing import List,Dict
|
||||
from onnx.onnx_pb import AttributeProto, ModelProto, TensorProto, TypeProto
|
||||
try:
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
except ImportError:
|
||||
# for onnx < 1.13
|
||||
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
||||
tensor_dtype_to_np_dtype = lambda x: TENSOR_TYPE_TO_NP_TYPE[x]
|
||||
|
||||
# global numpy cache for parameters
|
||||
numpy_cache = {}
|
||||
def safe_numpy(t) -> np.ndarray:
|
||||
if not isinstance(t, Tensor): return t
|
||||
global numpy_cache
|
||||
if t not in numpy_cache:
|
||||
if DEBUG >= 3: print("numpy cache miss", t)
|
||||
tmp = t.numpy()
|
||||
numpy_cache[t] = tmp if len(tmp.shape) else tmp.reshape(1)
|
||||
assert len(numpy_cache[t].shape) > 0
|
||||
return numpy_cache[t]
|
||||
|
||||
onnx_ops = importlib.import_module('extra.onnx_ops')
|
||||
|
||||
ONNXLIMIT = getenv("ONNXLIMIT", -1)
|
||||
|
||||
def get_run_onnx(onnx_model: ModelProto):
|
||||
def type_parse(type_proto: TypeProto):
|
||||
ret = []
|
||||
while True:
|
||||
attr = type_proto.WhichOneof('value')
|
||||
if attr == 'tensor_type':
|
||||
if "dim_value" not in getattr(type_proto, attr).shape.dim.__dir__(): return () # variable type, unable to determine shape
|
||||
elif not ret:
|
||||
return tuple([x.dim_value for x in getattr(type_proto, attr).shape.dim])
|
||||
else:
|
||||
ret.extend([(x.dim_value,) for x in getattr(type_proto, attr).shape.dim])
|
||||
return tuple(ret)
|
||||
elif attr == 'sequence_type':
|
||||
type_proto = getattr(type_proto, attr).elem_type
|
||||
ret.append(1)
|
||||
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
|
||||
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
|
||||
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
|
||||
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
|
||||
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
|
||||
|
||||
def buffer_parse(inp: TensorProto) -> Tensor:
|
||||
if inp.data_type in (1,10,6,7):
|
||||
# TODO: this is shared with below
|
||||
if len(inp.float_data) > 0:
|
||||
ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int64_data) > 0:
|
||||
ret = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int32_data) > 0:
|
||||
ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False)
|
||||
else:
|
||||
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
|
||||
else:
|
||||
raise Exception(f"bad data type {inp.name} {inp.dims} {inp.data_type}")
|
||||
return ret
|
||||
|
||||
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
|
||||
# TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list
|
||||
if a.type == AttributeProto.FLOAT: return float(a.f)
|
||||
elif a.type == AttributeProto.INT: return int(a.i)
|
||||
elif a.type == AttributeProto.STRING: return a.s.decode("utf-8")
|
||||
elif a.type == AttributeProto.TENSOR: return buffer_parse(a.t) # TENSOR
|
||||
elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats)
|
||||
elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints)
|
||||
elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings)
|
||||
elif a.type == AttributeProto.GRAPH: raise Exception(f"graph not implemented: {a.g}")
|
||||
else: raise Exception(f"can't parse {a.type} {a}")
|
||||
def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): return {x.name:attribute_parse(x) for x in a}
|
||||
|
||||
tensors: Dict[str, Tensor] = {}
|
||||
|
||||
# get weights and biases
|
||||
for inp in onnx_model.graph.initializer:
|
||||
if len(inp.raw_data) > 0:
|
||||
tensors[inp.name] = buffer_parse(inp)
|
||||
elif len(inp.float_data) > 0:
|
||||
tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int64_data) > 0:
|
||||
tensors[inp.name] = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.raw_data) == 0:
|
||||
tensors[inp.name] = Tensor(np.array([], dtype=np.float32), requires_grad=False)
|
||||
else:
|
||||
print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
|
||||
print(inp)
|
||||
raise Exception("no data")
|
||||
|
||||
# preparse the attributes
|
||||
attribute_dict = {}
|
||||
domain = ""
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
attribute_dict[num] = attribute_to_dict(n.attribute)
|
||||
if n.domain: domain = n.domain
|
||||
|
||||
onnx_model_version = onnx_model.opset_import[0].version
|
||||
|
||||
def run_onnx(inputs={}, debug=0):
|
||||
debug = getenv("DEBUGONNX") or debug
|
||||
input_tensors: Dict[str,Tensor] = {}
|
||||
intermediate_tensors: Dict[str,Tensor] = {}
|
||||
output_tensor_names = [x.name for x in onnx_model.graph.output]
|
||||
|
||||
# get inputs
|
||||
for inp in onnx_model.graph.input:
|
||||
if inp.name in tensors: continue
|
||||
shape = type_parse(inp.type)
|
||||
if inp.name in inputs:
|
||||
if isinstance(inputs[inp.name], Tensor):
|
||||
input_tensors[inp.name] = inputs[inp.name]
|
||||
elif isinstance(inputs[inp.name], list):
|
||||
input_tensors[inp.name] = [Tensor(i, requires_grad=False) for i in inputs[inp.name]]
|
||||
elif domain == "ai.onnx.preview.training": # not sure if in real use the domain is "ai.onnx.preview.training"
|
||||
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=True) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops
|
||||
else:
|
||||
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=False)
|
||||
if shape: # if only input_tensor is not variable type
|
||||
input_shape = input_tensors[inp.name].shape if isinstance(input_tensors[inp.name], Tensor) else (1, *[i.shape for i in input_tensors[inp.name]])
|
||||
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
|
||||
else:
|
||||
raise Exception(f"no data for {inp.name} with shape {shape}")
|
||||
|
||||
def fetch_tensor(x: str):
|
||||
if x in tensors: return tensors[x]
|
||||
if x in intermediate_tensors: return intermediate_tensors[x]
|
||||
if x != str(): return input_tensors[x]
|
||||
return None
|
||||
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
inp: List[Tensor] = []
|
||||
if debug >= 3: print("inputs:")
|
||||
for x in n.input:
|
||||
t = fetch_tensor(x)
|
||||
if debug >= 3: print(f"\t{x} - {t}")
|
||||
inp.append(t)
|
||||
opt: Dict = attribute_dict[num]
|
||||
if debug >= 1: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
|
||||
# some ops live here because they require some local variables
|
||||
if n.op_type == "Split": # have to use n.output for cases when num_outputs is absent
|
||||
axis = opt.get("axis", 0)
|
||||
split = None if len(inp) == 1 else [int(x) for x in safe_numpy(inp[1])]
|
||||
if split is None:
|
||||
split = [inp[0].shape[axis] // len(n.output)] * len(n.output)
|
||||
for i in range(inp[0].shape[axis] % len(n.output)):
|
||||
split[i] += 1
|
||||
i, ret = 0, []
|
||||
arg = [(0,x) for x in inp[0].shape]
|
||||
for s in split:
|
||||
arg[axis] = (i,i+s)
|
||||
ret.append(inp[0].shrink(arg=tuple(arg)))
|
||||
i = i+s
|
||||
ret = tuple(ret)
|
||||
elif n.op_type == "Slice": # need to check onnx_model_version
|
||||
if onnx_model_version < 10:
|
||||
axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim
|
||||
else:
|
||||
starts, ends = inp[1:3]
|
||||
axes = safe_numpy(Tensor.arange(inp[0].ndim, dtype=dtypes.int32) if len(inp) <= 3 else inp[3]).tolist()
|
||||
steps = safe_numpy(inp[4]) if len(inp) > 4 else [1]*inp[0].ndim
|
||||
starts, ends = safe_numpy(starts.ceil().cast(dtypes.int32)).tolist(), safe_numpy(ends.ceil().cast(dtypes.int32)).tolist()
|
||||
arg = [(0,x,1) for x in inp[0].shape]
|
||||
for i, axis in enumerate(axes):
|
||||
axis = int(axis) + inp[0].ndim if axis < 0 else int(axis)
|
||||
starts[i], ends[i] = starts[i] + inp[0].shape[axis] if starts[i] < 0 else starts[i], ends[i] + inp[0].shape[axis] if ends[i] < 0 else ends[i]
|
||||
starts[i], ends[i] = max(0, min(starts[i], inp[0].shape[axis])), max(0, min(ends[i], inp[0].shape[axis]))
|
||||
if starts[i] > ends[i] and steps[i] >= 0: steps[i] = -steps[i]
|
||||
arg[axis] = (starts[i], ends[i], steps[i])
|
||||
new_shape = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in arg)
|
||||
if any(s==e for s,e in new_shape): ret = inp[0].shrink(new_shape)
|
||||
else: ret = inp[0].__getitem__(tuple([slice(s,e,st) for s,e,st in arg]))
|
||||
elif n.op_type == "Gradient": # need to call backward on intermediate_tensors
|
||||
assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"
|
||||
y = opt["y"]
|
||||
intermediate_tensors[y].backward()
|
||||
ret = tuple([t.grad for t in inp])
|
||||
elif hasattr(onnx_ops, n.op_type):
|
||||
fxn = getattr(onnx_ops, n.op_type)
|
||||
if isinstance(fxn, dict):
|
||||
for k in sorted(fxn.keys()):
|
||||
if k <= onnx_model_version:
|
||||
real_fxn = fxn[k]
|
||||
else:
|
||||
real_fxn = fxn
|
||||
ret = real_fxn(*inp, **opt)
|
||||
else:
|
||||
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
||||
raise Exception(f"op_type {n.op_type} not supported")
|
||||
if not isinstance(ret, tuple): ret = (ret, )
|
||||
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
|
||||
if debug >= 2: print([x.shape if isinstance(x, Tensor) else None for x in ret])
|
||||
if debug >= 2: print("outputs:")
|
||||
for i in range(len(n.output)):
|
||||
if debug >= 2: print(f"\t{n.output[i]} - {ret[i]}")
|
||||
intermediate_tensors[n.output[i]] = ret[i]
|
||||
if num == ONNXLIMIT:
|
||||
output_tensor_names = n.output
|
||||
break
|
||||
|
||||
return {outp:intermediate_tensors[outp] for outp in output_tensor_names}
|
||||
return run_onnx
|
||||
720
tinygrad_repo/extra/onnx_ops.py
Normal file
720
tinygrad_repo/extra/onnx_ops.py
Normal file
@@ -0,0 +1,720 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod, dtypes, ImageDType
|
||||
from extra.onnx import safe_numpy
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from onnx.onnx_pb import TensorProto
|
||||
import os
|
||||
import numpy as np
|
||||
import functools
|
||||
from typing import Union, Tuple, Optional, List, Any
|
||||
import math
|
||||
|
||||
# **************** Free Ops ****************
|
||||
|
||||
def Identity(input: Tensor): return input
|
||||
def Neg(input: Tensor): return -input
|
||||
def Add(input: Tensor, other: Tensor, broadcast=None): return input + other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input + other).cast(input.dtype)
|
||||
def Sub(input: Union[Tensor, Any], other: Tensor): return input - other # some test has input as int
|
||||
def Mul(input: Tensor, other: Tensor): return (input * other) if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input * other).cast(input.dtype)
|
||||
# in openpilot, due to SHUFFLE_PAD_OPS issues, we are spending an extra kernel
|
||||
def Div(input: Tensor, other: Tensor): return input / other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else input.div(other).floor()
|
||||
def Pow(input: Tensor, other: Tensor): return (input.float() ** other.float()).cast(input.dtype)
|
||||
def Reciprocal(input: Tensor): return input.reciprocal()
|
||||
def Sqrt(input: Tensor): return input.sqrt()
|
||||
def Sign(input: Tensor): return input.sign()
|
||||
def Abs(input: Tensor): return input.abs()
|
||||
def Exp(input: Tensor): return input.exp()
|
||||
def Log(input: Tensor): return input.log()
|
||||
def Mish(input: Tensor): return input.mish()
|
||||
def Sin(x: Tensor): return x.sin()
|
||||
def Cos(x: Tensor): return x.cos()
|
||||
def Tan(x: Tensor): return x.tan()
|
||||
def Relu(input: Tensor): return input.relu()
|
||||
def Sigmoid(input: Tensor): return input.sigmoid()
|
||||
def Tanh(input: Tensor): return input.tanh()
|
||||
def MatMul(input: Tensor, other: Tensor): return input.matmul(other)
|
||||
def Floor(x:Tensor): return x.floor()
|
||||
def Ceil(x:Tensor): return x.ceil()
|
||||
def Less(x:Tensor,y:Tensor): return (x<y).cast(dtypes.bool)
|
||||
def LessOrEqual(x:Tensor,y:Tensor): return (x<=y).cast(dtypes.bool)
|
||||
def Greater(x:Tensor,y:Tensor): return (x>y).cast(dtypes.bool)
|
||||
def GreaterOrEqual(x:Tensor,y:Tensor): return (x>=y).cast(dtypes.bool)
|
||||
def Equal(x:Tensor,y:Tensor): return (x==y).cast(dtypes.bool)
|
||||
def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
|
||||
def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
|
||||
def Sum(*data_0): return functools.reduce(Tensor.__add__, data_0)
|
||||
def Mean(*data_0): return functools.reduce(Tensor.__add__, data_0) / len(data_0)
|
||||
def Where(condition:Tensor,X:Tensor,Y:Tensor): return condition.where(X, Y).cast(X.dtype)
|
||||
def Cast(input: Tensor, to): return input.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to)))
|
||||
|
||||
# **************** Simple Ops ****************
|
||||
|
||||
def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None):
|
||||
if value: return value
|
||||
elif value_float: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
|
||||
elif value_floats: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
|
||||
elif value_int: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
|
||||
elif value_ints: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
|
||||
elif value_string or value_strings: raise NotImplementedError(f'value_string or value_strings not implemented for Constant op')
|
||||
|
||||
def Softsign(input: Tensor): return input / (1+input.abs())
|
||||
def Cosh(x): return (math.e ** x + math.e ** -x) / 2
|
||||
def Sinh(x): return (math.e ** x - math.e ** -x) / 2
|
||||
def Tanh(x): return x.tanh()
|
||||
|
||||
def HardSigmoid(input: Tensor, alpha=0.2, beta=0.5): return (alpha*input + beta).clip(0, 1)
|
||||
def HardSwish(input: Tensor): return input * HardSigmoid(input, 1/6, 0.5)
|
||||
def Celu(X: Tensor, alpha=1.0): return X.relu() - (-alpha*(X/alpha).exp()+1).relu()
|
||||
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
|
||||
def Softplus(X: Tensor): return X.softplus()
|
||||
def PRelu(X:Tensor, slope:Tensor):
|
||||
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
|
||||
return X.clip(0, float("inf")) + X.clip(float("-inf"), 0) * slope
|
||||
def LeakyRelu(X: Tensor, alpha=0.01): return X.leakyrelu(alpha)
|
||||
def ThresholdedRelu(X: Tensor, alpha=1.0): return (X-alpha).relu() + (X-alpha).relu().sign() * alpha
|
||||
def Softmax_1(input: Tensor, axis=1): return input.softmax(axis)
|
||||
def Softmax_13(input: Tensor, axis=-1): return input.softmax(axis)
|
||||
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
|
||||
def LogSoftmax(input: Tensor, axis=-1): return input.log_softmax(axis)
|
||||
def Clip(input: Tensor, min=None, max=None): return input.clip(float('-inf') if min is None else min, float('inf') if max is None else max)
|
||||
|
||||
# NOTE ReduceProd would require a new llop
|
||||
def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,)) else ([] if noop_with_empty_axes else None)
|
||||
def ReduceMax(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceMin(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceMean(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceSumSquare(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceL1(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.abs().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).sqrt()
|
||||
def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
|
||||
def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.exp().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
|
||||
|
||||
def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
||||
def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
||||
def OptionalHasElement(x: Tensor=None): return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool)
|
||||
def OptionalGetElement(x: Tensor=None): return x if x is not None else Tensor([], dtype=dtypes.float32)
|
||||
|
||||
def Tile(input: Tensor, repeats): return input.repeat([int(x) for x in safe_numpy(repeats)])
|
||||
def Range(start: Tensor, limit, delta): return Tensor.arange(start=int(safe_numpy(start)), stop=int(safe_numpy(limit)), step=int(safe_numpy(delta))).cast(dtype=start.dtype)
|
||||
def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64) # TODO: really?
|
||||
def Size(data: Tensor): return prod(data if isinstance(data, list) else data.shape)
|
||||
def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)
|
||||
def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))])
|
||||
def Shrink(input: Tensor, bias=0.0, lambd=0.5): return (input < -lambd)*(input+bias) + (input > lambd)*(input-bias)
|
||||
def And(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.zeros(*x.shape)).cast(dtypes.bool)
|
||||
def Or(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
def Xor(x:Tensor, y:Tensor): return Where((x==y), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
def Not(x:Tensor): return Where((x==1), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
|
||||
def Asin(x): return Atan(x / Tensor.sqrt(1 - x * x))
|
||||
def Asinh(x): return Tensor.log(x + Tensor.sqrt(x * x + 1))
|
||||
def Acosh(x): return Tensor.log(x + Tensor.sqrt(x * x - 1))
|
||||
def Atanh(x): return 0.5 * Tensor.log((1 + x)/(1 - x))
|
||||
def Acos(x: Tensor):
|
||||
negate = (x < 0)
|
||||
x = x.abs()
|
||||
ret = ((((-0.0187293 * x) + 0.0742610)*x - 0.2121144) * x + 1.5707288) * Tensor.sqrt(1.0 - x)
|
||||
ret = ret - 2 * negate * ret
|
||||
return negate * 3.14159265358979 + ret
|
||||
def Atan(y: Tensor):
|
||||
x = Tensor.ones(y.shape)
|
||||
t3 = x
|
||||
t1 = y.abs()
|
||||
t0 = (t3 > t1).where(t3, t1)
|
||||
t1 = (t3 < t1).where(t3, t1)
|
||||
t3 = t1 / t0
|
||||
t4 = t3 * t3
|
||||
t0 = ((((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4 - 0.332994597) * t4 + 0.999995630
|
||||
t3 = t0 * t3
|
||||
t3 = (y.abs() > x.abs()).where(1.570796327 - t3, t3)
|
||||
return (y < 0).where(-t3, t3)
|
||||
|
||||
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
|
||||
k = int(k.numpy().item()) if k != 0 else 0 # onnx passes k as a tensor int64 with one element, default is 0
|
||||
return x.triu(k) if upper else x.tril(k)
|
||||
|
||||
def Squeeze(input: Tensor, axes):
|
||||
if isinstance(axes, Tensor): axes = safe_numpy(axes)
|
||||
axes = [int(x) if x >= 0 else int(x+input.ndim) for x in axes]
|
||||
return input.reshape([s for i,s in enumerate(input.shape) if i not in axes])
|
||||
def Unsqueeze(data: Tensor, axes):
|
||||
axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)]
|
||||
new_shape = [1] * (len(data.shape) + len(axes))
|
||||
ptr = iter(data.shape)
|
||||
for i in range(len(new_shape)):
|
||||
if i not in axes:
|
||||
new_shape[i] = next(ptr)
|
||||
return data.reshape(new_shape)
|
||||
|
||||
def Binarizer(input, threshold=0.0): return input > threshold
|
||||
|
||||
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
|
||||
axis = axis + x.ndim if axis < 0 else axis
|
||||
m = x == (x.max(axis=axis, keepdim=keepdims) if keepdims else x.max(axis=axis, keepdim=keepdims).unsqueeze(axis))
|
||||
c = Tensor.arange(x.shape[axis]).reshape(*[1]*(axis), x.shape[axis], *[1]*(x.ndim - axis-1)) * m
|
||||
return c.max(axis=axis,keepdim=keepdims).cast(dtypes.int64)
|
||||
def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
||||
|
||||
def Elu(input: Tensor, alpha=1.0): return input.elu(alpha=alpha)
|
||||
def Concat(*inputs: List[Tensor], axis): return inputs[0].cat(*inputs[1:], dim=axis)
|
||||
def Transpose(input: Tensor, perm=None): return input.permute(order=list(range(len(input.shape))[::-1]) if perm is None else perm)
|
||||
|
||||
# NOTE: since we only have one type, this is valid!
|
||||
def CastLike(input, target_type):
|
||||
assert isinstance(target_type, Tensor), "can only CastLike Tensor"
|
||||
return input
|
||||
|
||||
def ConstantOfShape(input, value:Tensor=None):
|
||||
if value is None: value=Tensor([0.0])
|
||||
shape = [int(x) for x in safe_numpy(input)]
|
||||
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
|
||||
|
||||
# TODO: abstract out the broadcast logic in tensor
|
||||
def Expand(input: Tensor, shape):
|
||||
x_shape, y_shape = input.shape, [int(x) for x in safe_numpy(shape)]
|
||||
# copied from _broadcasted
|
||||
x_shape, y_shape = [([1]*(max(len(x_shape), len(y_shape))-len(t_shape)) + list(t_shape)) for t_shape in [x_shape, y_shape]]
|
||||
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x_shape, y_shape))
|
||||
return input.reshape(x_shape).expand(shape_ret)
|
||||
|
||||
# **************** Complex Ops ****************
|
||||
|
||||
def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
|
||||
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
|
||||
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1]))
|
||||
return ret
|
||||
|
||||
# works with Tensors.ndim != 4
|
||||
def _batchnorm(self:Tensor, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor):
|
||||
shape = [1, -1] + [1] * (self.ndim-2)
|
||||
x = (self - mean.reshape(shape=shape))
|
||||
if weight: x = x * weight.reshape(shape=shape)
|
||||
ret = x.mul(invstd.reshape(shape=shape) if len(invstd.shape) == 1 else invstd)
|
||||
return (ret + bias.reshape(shape=shape)) if bias else ret
|
||||
|
||||
# TODO: this is copied from tinygrad/nn/__init__.py
|
||||
# spatial is from opset 7 and has since been removed
|
||||
def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0, spatial=1, is_test=0):
|
||||
if training_mode:
|
||||
x_detached = X.detach()
|
||||
current_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
current_var = (y*y).mean(axis=(0,2,3))
|
||||
current_invstd = current_var.add(epsilon).pow(-0.5)
|
||||
|
||||
running_mean = input_mean * momentum + current_mean * (1 - momentum)
|
||||
running_var = input_var * momentum + current_var * (1 - momentum)
|
||||
|
||||
return _batchnorm(X, scale, B, current_mean, current_invstd), running_mean, running_var
|
||||
else:
|
||||
invstd = (input_var + epsilon)**-0.5
|
||||
return _batchnorm(X, scale, B, input_mean, invstd)
|
||||
|
||||
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
|
||||
axis = tuple(range(2, len(x.shape)))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).pow(-0.5)
|
||||
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
|
||||
|
||||
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
axis = tuple(i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape)))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).sqrt().reciprocal()
|
||||
|
||||
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
|
||||
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
|
||||
|
||||
# onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
|
||||
# numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...)
|
||||
def _format_padding(onnx_pads, ndims=None, axes=None):
|
||||
if ndims and len(onnx_pads)//2 != ndims: onnx_pads = onnx_pads * ndims # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2
|
||||
if ndims is None: ndims = len(onnx_pads) // 2
|
||||
if axes is None: axes = list(range(ndims))
|
||||
num_axes = len(axes)
|
||||
np_pads = [(0,0)] * ndims
|
||||
for i in range(num_axes):
|
||||
np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes])
|
||||
return np_pads
|
||||
|
||||
def _padding(X: Tensor, pads=None, auto_pad="NOTSET", axes=None, constant_value=0., strides=None, kernel_shape=None, dilations=None):
|
||||
if auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
|
||||
if pads is None: return X
|
||||
pads = _format_padding(pads, ndims=len(X.shape), axes=axes)
|
||||
return X.pad(tuple(pads), value=constant_value)
|
||||
|
||||
def _auto_pad(X, auto_pad, strides, kernel_shape, dilations):
|
||||
strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape)
|
||||
dilations = [1]*len(kernel_shape) if dilations == 1 else dilations
|
||||
pad_shape = [(math.ceil(sh/st)-1)*st+((ks-1)*di+1)-sh for sh, st, ks, di in zip(X.shape[-len(strides):], strides, kernel_shape, dilations)]
|
||||
if auto_pad == "SAME_UPPER": return [pad_shape[0]//2, pad_shape[1]//2, pad_shape[0]-pad_shape[0]//2, pad_shape[1]-pad_shape[1]//2]
|
||||
elif auto_pad == "SAME_LOWER": return [pad_shape[0]-pad_shape[0]//2, pad_shape[1]-pad_shape[1]//2, pad_shape[0]//2, pad_shape[1]//2]
|
||||
else: raise NotImplementedError(f"auto_pad={auto_pad} not implemented, yet")
|
||||
|
||||
def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=None, axes: Tensor=None, mode="constant", value: float=0.):
|
||||
constant_value = value if constant_value is None else float(safe_numpy(constant_value)[0])
|
||||
seq_pads = list(pads) if isinstance(pads, tuple) else safe_numpy(pads)
|
||||
seq_pads = [math.ceil(i) for i in seq_pads]
|
||||
seq_axes = safe_numpy(axes).astype(np.int32).tolist() if axes is not None else None
|
||||
base_shape = x.shape
|
||||
pads = _format_padding(seq_pads, ndims=len(x.shape), axes=seq_axes)
|
||||
if mode == "wrap":
|
||||
repeat_args = [math.ceil(dim[0]/sh) + math.ceil(dim[1]/sh) + 1 for dim, sh in zip(pads, base_shape)]
|
||||
new_shape = [s*r for s,r in zip(base_shape, repeat_args)]
|
||||
shrink_args = [(sh-dim[0]%sh if dim[0]%sh != 0 else 0, nsh-(sh-dim[1]%sh) if dim[1]%sh != 0 else nsh) for dim, sh, nsh in zip(pads, base_shape, new_shape)]
|
||||
return x.repeat(tuple(repeat_args)).shrink(tuple(shrink_args))
|
||||
elif mode == "reflect":
|
||||
for i,s in enumerate(x.shape):
|
||||
if pads[i] == (0,0): continue
|
||||
elif pads[i][0] and not pads[i][1]:
|
||||
x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (s-pads[i][0]-1, s_-1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (0,s) for i_ in range(x.ndim)])) + \
|
||||
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
|
||||
elif not pads[i][0] and pads[i][1]:
|
||||
x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (1, pads[i][1]+1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (s,0) for i_ in range(x.ndim)])) + \
|
||||
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
|
||||
else:
|
||||
x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (s-pads[i][0]-1, s_-1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (0,s+pads[i][1]) for i_ in range(x.ndim)])) + \
|
||||
x.flip(i).shrink(tuple([(0,s_) if i_ != i else (1, pads[i][1]+1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \
|
||||
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
|
||||
return x
|
||||
elif mode == "edge":
|
||||
for i,s in enumerate(x.shape):
|
||||
if pads[i] == (0,0): continue
|
||||
elif pads[i][0] and not pads[i][1]:
|
||||
x = x.shrink(tuple([(0,s_) if i_ != i else (0,1) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (0,s) for i_ in range(x.ndim)])) + \
|
||||
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
|
||||
elif not pads[i][0] and pads[i][1]:
|
||||
x = x.shrink(tuple([(0,s_) if i_ != i else (s_-1, s_) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \
|
||||
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
|
||||
else:
|
||||
x = x.shrink(tuple([(0,s_) if i_ != i else (0,1) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (0,s+pads[i][1]) for i_ in range(x.ndim)])) + \
|
||||
x.shrink(tuple([(0,s_) if i_ != i else (s_-1, s_) for i_,s_ in enumerate(x.shape)])).expand([pads[i][1] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \
|
||||
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
|
||||
return x
|
||||
elif mode == "constant":
|
||||
return _padding(x, seq_pads, axes=seq_axes, constant_value=constant_value)
|
||||
|
||||
def AveragePool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_pad=0, dilations=1, pads=None, strides=1):
|
||||
if dilations != 1: raise NotImplementedError(f"dilations != 1 not supported, dilations:{dilations}")
|
||||
pixel_axes = tuple(range(len(X.shape)))[-2:]
|
||||
if ceil_mode: auto_pad = "SAME_UPPER"
|
||||
padding_included = _padding(X, pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations).avg_pool2d(kernel_shape, stride=strides)
|
||||
if count_include_pad:
|
||||
return padding_included
|
||||
else:
|
||||
div = _padding(Tensor.ones(*X.shape), pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations).avg_pool2d(kernel_shape, stride=strides)
|
||||
return padding_included / div
|
||||
|
||||
def MaxPool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1):
|
||||
if ceil_mode: auto_pad = "SAME_UPPER"
|
||||
ret = _padding(X, pads, auto_pad, constant_value=-np.inf, axes=tuple(range(len(X.shape)))[-len(kernel_shape):], strides=strides, kernel_shape=kernel_shape, dilations=dilations)
|
||||
ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations)
|
||||
ret_len, X_len = ret.numel(), X.numel()
|
||||
indices = ((ret.flatten().unsqueeze(1).expand(ret_len, X_len) == X.flatten().reshape(1, X_len).expand(ret_len, X_len)) * Tensor.arange(X_len).reshape(1, X_len).expand(ret_len, X_len)).sum(1).reshape(ret.shape).cast(dtypes.int64)
|
||||
if storage_order: indices = indices.transpose(indices.ndim-2, indices.ndim-1)
|
||||
return ret, indices
|
||||
|
||||
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: Tensor=None, kernel_shape=None, pads=None, strides=None):
|
||||
out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)]
|
||||
outlength = prod(out_sh)
|
||||
xI = xI.flatten().unsqueeze(1).expand(prod(xT.shape), outlength)
|
||||
arange = Tensor.arange(outlength, requires_grad=False).reshape(1, outlength).expand(xI.shape)
|
||||
xT = xT.flatten().unsqueeze(1).expand(prod(xT.shape), outlength)
|
||||
ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh)
|
||||
if outshape is not None:
|
||||
outshape = safe_numpy(outshape).tolist()
|
||||
if outshape != ret.shape:
|
||||
diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]]
|
||||
pad_args = [diff[0]//2, diff[1]//2, diff[0]-diff[0]//2, diff[1]-diff[1]//2]
|
||||
ret = ret.pad2d((pad_args[1], pad_args[3], pad_args[0], pad_args[2]))
|
||||
return ret
|
||||
|
||||
def Conv(X: Tensor, W: Tensor, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1):
|
||||
if auto_pad != "NOTSET": padding = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
|
||||
else: padding = [p for ps in zip(pads[:len(pads)//2][::-1], pads[len(pads)//2:][::-1]) for p in ps] if pads is not None else 0 # reorder padding
|
||||
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=padding)
|
||||
|
||||
def ConvTranspose(X: Tensor, W: Tensor, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, output_shape=None, output_padding=0, strides=1):
|
||||
if not kernel_shape: kernel_shape = W.shape
|
||||
if pads is None and auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
|
||||
elif pads is None and auto_pad == "NOTSET": pads = [0,0] * (X.ndim - 2)
|
||||
strides_ = [1]*(W.ndim-1) + [strides] if isinstance(strides, int) else [1]*(W.ndim-len(strides)) + list(strides)
|
||||
dilations_ = [1]*(W.ndim-1) + [dilations] if isinstance(dilations, int) else [1]*(W.ndim-len(dilations)) + list(dilations)
|
||||
if output_shape and not output_padding:
|
||||
out_sh = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides_, X.shape, kernel_shape, dilations_))]
|
||||
output_padding = [os - rs for os, rs in zip(output_shape, out_sh[-len(output_shape):])]
|
||||
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads if pads is not None else 0, output_padding=output_padding)
|
||||
|
||||
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
||||
def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None):
|
||||
if isinstance(ratio, Tensor) and not ratio.shape: ratio = safe_numpy(ratio) # ratio and tensor is passed in as Tensor with shape: ()
|
||||
if isinstance(training_mode, Tensor) and not training_mode.shape: training_mode = safe_numpy(training_mode)
|
||||
if not training_mode: return data, Tensor.ones(*data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
|
||||
rng = np.random.RandomState(seed)
|
||||
ratio = ratio.lazydata.realize().toCPU()[0] if isinstance(ratio, Tensor) else ratio
|
||||
mask = Tensor((rng.random(data.shape) >= ratio), requires_grad=False, device=data.device)
|
||||
return data * mask * (1/(1.0 - ratio)), mask
|
||||
|
||||
def LRN(input: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0):
|
||||
bs, c, iy, ix = input.shape
|
||||
return input / input.mul(input).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta)
|
||||
|
||||
def MeanVarianceNormalization(input: Tensor, axis=(0, 2, 3)):
|
||||
data_mean = input.mean(axis=axis, keepdim=True)
|
||||
std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt()
|
||||
return (input - data_mean) / (std + 1e-9)
|
||||
|
||||
def NegativeLogLikelihoodLoss(input: Tensor, target: Tensor, weight=None, ignore_index=None, reduction="mean"):
|
||||
target = target.cast(dtypes.float32)
|
||||
N, C, i_shape = input.shape[0], input.shape[1], input.shape
|
||||
t_shape = target.shape
|
||||
if len(input.shape) != 3:
|
||||
input = input.reshape((N, C, -1))
|
||||
target = target.reshape((N, -1))
|
||||
if weight is not None:
|
||||
mask = target.unsqueeze(-1) == Tensor.arange(C).repeat((N, 1, 1))
|
||||
weight = (mask * weight).sum(axis=-1)
|
||||
if ignore_index is not None:
|
||||
cond = target == ignore_index
|
||||
weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1)
|
||||
mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(input.shape) -2))
|
||||
loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight)
|
||||
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
|
||||
elif reduction == "sum": return loss.sum()
|
||||
return loss.reshape(t_shape) if len(i_shape) != 3 else loss
|
||||
|
||||
def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean"):
|
||||
N, C, *s_dimensions = scores.shape
|
||||
if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels)
|
||||
mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(1, C, *[1]*len(s_dimensions))
|
||||
y = scores.log_softmax(axis=1)
|
||||
if weights is not None: weights = weights.__getitem__(tuple([labels, *[slice(None)]*(weights.ndim-1)]))
|
||||
loss = (mask * -y).sum(1) if weights is None else (mask * -y).sum(1) * weights
|
||||
if reduction == "mean": loss = loss.sum() / (loss == 0).where(0, 1).sum() if weights is None else loss.sum() / weights.sum()
|
||||
elif reduction == "sum": loss = loss.sum()
|
||||
return loss, y
|
||||
|
||||
def ArrayFeatureExtractor(input: Tensor, indices: Tensor): return input.__getitem__(tuple([slice(None) if i != (input.ndim-1) else indices for i in range(input.ndim)]))
|
||||
def Gather(input: Tensor, indices: Tensor, axis=0):
|
||||
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
|
||||
input_sh = list(input.shape)
|
||||
ret_shape = input_sh[:axis] + list(indices.shape) + input_sh[axis+1:]
|
||||
if indices.ndim > 1: indices = indices.flatten()
|
||||
indices = [int(safe_numpy(indices))] if indices.shape == () else [input_sh[axis]+int(x) if x<0 else int(x) for x in safe_numpy(indices)]
|
||||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(input_sh)] for i in indices]
|
||||
return input.shrink(arg=tuple(args[0])).cat(*[input.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
|
||||
else: # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
||||
return input.__getitem__(tuple([slice(None) if i != axis else indices for i in range(input.ndim)]))
|
||||
|
||||
def GatherElements(input: Tensor, indices: Tensor, axis):
|
||||
indices = indices.sign().contiguous().__neg__().contiguous().relu() * input.shape[axis] + indices
|
||||
return input.gather(indices, axis)
|
||||
|
||||
def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
|
||||
def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0)
|
||||
assert n <= 1, f"n:{n} shouldn't be larger than 1"
|
||||
b = x.cast(dtypes.int32).contiguous().cast(x.dtype)
|
||||
b = (b >= 0).where(b+n, b-n)
|
||||
if equidistant_case == "round_down":
|
||||
return (x > b).where(b+1-n, b-n)
|
||||
elif equidistant_case == "round_up":
|
||||
return (x >= b).where(b+1-n, b-n)
|
||||
elif equidistant_case == "round_to_even":
|
||||
x_ceil_fraction = x.ceil()/2
|
||||
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
|
||||
x = (_and(x == b, cond_ceil_even)).where(x+1-n, x)
|
||||
x = (x > b).where(b+1-n, b-n)
|
||||
return x
|
||||
|
||||
def Round(X:Tensor): return _round(X, 0.5, "round_to_even")
|
||||
|
||||
def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel', cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch', mode='nearest', nearest_mode='round_prefer_floor'):
|
||||
def _nearest_gather(X: Tensor, x_out, y_out): return X[:,:,y_out,:][:,:,:,x_out]
|
||||
def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len):
|
||||
if nearest_mode == "round_prefer_floor": ret = _round(x_resized, 0.5, "round_down")
|
||||
elif nearest_mode == "round_prefer_ceil": ret = _round(x_resized, 0.5, "round_up")
|
||||
elif nearest_mode == "floor": ret = x_resized.floor()
|
||||
elif nearest_mode == "ceil": ret = x_resized.ceil()
|
||||
return ret.clip(0, x_len-1)
|
||||
def _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi=None):
|
||||
if coordinate_transformation_mode == "half_pixel":
|
||||
x_out = (x_out + 0.5)/Tensor(scales_lol[-1]) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
|
||||
y_out = (y_out + 0.5)/Tensor(scales_lol[-2]) - 0.5
|
||||
elif coordinate_transformation_mode == "align_corners":
|
||||
x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1)
|
||||
y_out = y_out * (X.shape[-2] - 1) / (output_shape[-2] - 1)
|
||||
elif coordinate_transformation_mode == "asymmetric":
|
||||
x_out = x_out/scales_lol[-1]
|
||||
y_out = y_out/scales_lol[-2]
|
||||
elif coordinate_transformation_mode == "half_pixel_symmetric":
|
||||
x_out = X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1]) + (x_out + 0.5) / scales_lol[-1] - 0.5
|
||||
y_out = X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2]) + (y_out + 0.5) / scales_lol[-2] - 0.5
|
||||
elif coordinate_transformation_mode == "pytorch_half_pixel":
|
||||
x_out = (x_out + 0.5)/scales_lol[-1] - 0.5 if output_shape[-1] > 1 else Tensor([0])
|
||||
y_out = (y_out + 0.5)/scales_lol[-2] - 0.5 if output_shape[-2] > 1 else Tensor([0])
|
||||
elif coordinate_transformation_mode == "tf_crop_and_resize":
|
||||
x_out = roi[-1][0] * (X.shape[-1] - 1) + x_out * ((roi[-1][1] - roi[-1][0]) * (X.shape[-1] - 1) / (output_shape[-1] - 1)) if output_shape[-1] > 1 else Tensor([0.5 * (roi[-1][0] + roi[-1][1]) * (X.shape[-1] - 1)])
|
||||
y_out = roi[-2][0] * (X.shape[-2] - 1) + y_out * ((roi[-2][1] - roi[-2][0]) * (X.shape[-2] - 1) / (output_shape[-2] - 1)) if output_shape[-2] > 1 else Tensor([0.5 * (roi[-2][0] + roi[-2][1]) * (X.shape[-2] - 1)])
|
||||
return x_out.clip(0, X.shape[-1]-1), y_out.clip(0, X.shape[-2]-1)
|
||||
if roi is not None:
|
||||
roi = safe_numpy(roi)
|
||||
roi = [(st,ed) for st, ed in zip(roi[:len(roi)//2], roi[len(roi)//2:])]
|
||||
roi_ = [(1,1)] * 4
|
||||
if axes is not None:
|
||||
for a,r in zip(axes, roi):
|
||||
roi_[a] = r
|
||||
roi = roi_
|
||||
if scales is not None:
|
||||
scales = safe_numpy(scales).tolist()
|
||||
if axes is not None:
|
||||
scales_ = [1]*X.ndim
|
||||
for a,s in zip(axes, scales):
|
||||
scales_[a] = s
|
||||
scales = scales_
|
||||
elif sizes is not None:
|
||||
sizes = [int(i) for i in safe_numpy(sizes)]
|
||||
scales = []
|
||||
if axes is not None:
|
||||
sizes_ = [1]*X.ndim
|
||||
for a,s in zip(axes, sizes):
|
||||
sizes_[a] = s
|
||||
scales.append(s/X.shape[a])
|
||||
sizes = sizes_
|
||||
else: scales = [si/xs for xs, si in zip(X.shape, sizes)]
|
||||
if keep_aspect_ratio_policy == "not_larger":
|
||||
scale = min(scales)
|
||||
sizes = _round(Tensor(list(X.shape[-2:]))*scale, 0.5, "round_up")
|
||||
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
|
||||
elif keep_aspect_ratio_policy == "not_smaller":
|
||||
scale = max(scales)
|
||||
sizes = _round(Tensor(list(X.shape[-2:]))*scale, 0.5, "round_up")
|
||||
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
|
||||
output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
|
||||
output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
|
||||
scales_lol = [os/xs for xs, os in zip(X.shape, output_shape)]
|
||||
x_out = Tensor.arange(output_shape[-1])
|
||||
y_out = Tensor.arange(output_shape[-2])
|
||||
if mode == "nearest":
|
||||
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi)
|
||||
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])
|
||||
y_out = _nearest_mode(y_out, nearest_mode, X.shape[-1])
|
||||
return _nearest_gather(X, x_out, y_out)
|
||||
elif mode == "linear":
|
||||
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape_, scales, roi)
|
||||
ret = []
|
||||
for y in safe_numpy(y_out):
|
||||
for x in safe_numpy(x_out):
|
||||
x_floor, y_floor = int(x), int(y)
|
||||
y_shrink = (0, X.shape[2]) if X.shape[2] == 1 else (y_floor, y_floor+2) if y != y_floor else (y_floor, y_floor+1)
|
||||
x_shrink = (x_floor, x_floor+2) if x != x_floor else (x_floor, x_floor+1)
|
||||
shrink_args = ((0, X.shape[0]), (0, X.shape[1]), y_shrink, x_shrink)
|
||||
corners = safe_numpy(X.shrink(shrink_args))
|
||||
x1, x2, y1, y2 = x_floor, x_floor+1, y_floor, y_floor+1
|
||||
if x == x_floor and y == y_floor: # TODO https://en.wikipedia.org/wiki/Bilinear_interpolation#Weighted_mean maybe do weighted mean?
|
||||
ret.append(corners[0,0,0,0])
|
||||
elif x == x_floor:
|
||||
ret.append((corners[0,0,0,0] * (y2 - y) + corners[0,0,1,0] * (y - y1)) / (y2 - y1))
|
||||
elif y == y_floor:
|
||||
ret.append((corners[0,0,0,0] * (x2 - x) + corners[0,0,0,1] * (x - x1)) / (x2 - x1))
|
||||
else:
|
||||
ret.append((corners[0,0,0,0] * (x2 - x) * (y2 - y) + corners[0,0,0,1] * (x - x1) * (y2 - y) + corners[0,0,1,0] * (x2 - x) * (y - y1) + corners[0,0,1,1] * (x - x1) * (y - y1)) / ((x2 - x1) * (y2 - y1)))
|
||||
return Tensor(ret).reshape(output_shape)
|
||||
elif mode == "cubic":
|
||||
raise Exception("cubic interpolation is not implemented")
|
||||
|
||||
def CenterCropPad(input: Tensor, shape: Tensor, axes=None):
|
||||
if not axes: axes = list(range(input.ndim))
|
||||
shrink_arg = [(0,i) for i in input.shape]
|
||||
pad_arg = [(0,0) for _ in range(input.ndim)]
|
||||
shape = safe_numpy(shape).tolist()
|
||||
for s, x in zip(shape, axes):
|
||||
if s < input.shape[x]: shrink_arg[x] = (input.shape[x]//2 - s//2, input.shape[x]//2 + s//2) if s%2 == 0 else (input.shape[x]//2 - s//2 - 1, input.shape[x]//2 + s//2)
|
||||
elif s > input.shape[x]: pad_arg[x] = ((s - input.shape[x])//2, (s - input.shape[x])//2) if (s - input.shape[x])% 2 == 0 else ((s - input.shape[x])//2, (s - input.shape[x])//2 + 1)
|
||||
return input.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
|
||||
|
||||
def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
|
||||
depth = int(safe_numpy(depth).item())
|
||||
indices, rank = (indices < 0).where(indices+depth, indices), len(indices.shape)
|
||||
if axis < 0: axis += rank + 1
|
||||
ls, rs = indices.shape[0:axis], indices.shape[axis: rank]
|
||||
cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs))
|
||||
return cond.where(values[1], values[0]).cast(values.dtype)
|
||||
|
||||
def Erf(x: Tensor):
|
||||
sign = x.sign()
|
||||
x = x.abs()
|
||||
t = 1.0 / (1.0 + 0.3275911 * x)
|
||||
term1 = 0.254829592 * t
|
||||
term2 = -0.284496736 * t ** 2
|
||||
term3 = 1.421413741 * t ** 3
|
||||
term4 = -1.453152027 * t ** 4
|
||||
term5 = 1.061405429 * t ** 5
|
||||
y = (term1 + term2 + term3 + term4 + term5)
|
||||
return sign * (1.0 - y * Tensor.exp(-x * x))
|
||||
|
||||
def Compress(inp: Tensor, condition: Tensor, axis=None):
|
||||
if axis is None:
|
||||
inp = inp.flatten()
|
||||
axis = 0
|
||||
|
||||
axis = axis + inp.ndim if axis < 0 else axis
|
||||
|
||||
con_np = safe_numpy(condition)
|
||||
con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor
|
||||
return inp.__getitem__(tuple([slice(None) if i != axis else con for i in range(inp.ndim)]))
|
||||
|
||||
type_map = {TensorProto.DOUBLE: dtypes.double, TensorProto.FLOAT: dtypes.float32}
|
||||
def EyeLike(x: Tensor, dtype=None, k=0):
|
||||
if dtype is None: dtype = x.dtype
|
||||
else: dtype = type_map[dtype]
|
||||
shape = x.shape
|
||||
dim = min(x.shape)
|
||||
if shape[0] == shape[1]: return Tensor.eye(dim=dim, dtype=dtype)
|
||||
else:
|
||||
diff = (shape[0]-dim, shape[1]-dim)
|
||||
padarg = tuple([(d, d) if d == 0 else (k, d-k) for d in diff])
|
||||
return Tensor.eye(dim=dim, dtype=dtype).pad(padarg)
|
||||
|
||||
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode)
|
||||
|
||||
# Needs work
|
||||
def IsInf(x,detect_negative=1,detect_positive=1):
|
||||
ret = (x == float("inf"))*detect_positive + (x == float("-inf"))*detect_negative + Tensor.zeros(*x.shape)
|
||||
return ret.cast(dtypes.bool)
|
||||
|
||||
# Needs work
|
||||
def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point=0, axis=1):
|
||||
axis = axis + x.ndim if axis < 0 else axis
|
||||
x_sc = x_scale.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim))
|
||||
x_zer = x_zero_point.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim)) if isinstance(x_zero_point, Tensor) else x_zero_point
|
||||
return (x - x_zer) * x_sc
|
||||
|
||||
# Needs work
|
||||
def IsNaN(x):
|
||||
return (x < float("-inf")).cast(dtypes.bool)
|
||||
|
||||
# **************** com.microsoft Ops ****************
|
||||
|
||||
def SkipLayerNormalization(input:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon=None):
|
||||
if epsilon is None: epsilon=1e-12
|
||||
x = input + skip + bias
|
||||
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
|
||||
|
||||
def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
|
||||
x = x + bias
|
||||
return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x ** 3).tanh())
|
||||
|
||||
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):
|
||||
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
|
||||
assert (segment_ids is None) is (segment_embedding is None)
|
||||
assert (mask is None) is (mask_index_type is None)
|
||||
assert mask is None, "functionality not supported yet" # TODO
|
||||
input_shape = input_ids.shape
|
||||
bsz, seq_length = input_shape[0], input_shape[1]
|
||||
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
|
||||
vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None)
|
||||
|
||||
def embedding(x:Tensor, vocab_size, weight:Tensor)->Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor
|
||||
vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).reshape(1, 1, vocab_size).expand(*x.shape, vocab_size)
|
||||
return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight
|
||||
|
||||
# bert embedding layer
|
||||
if epsilon is None: epsilon = 1e-12
|
||||
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
|
||||
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
|
||||
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
|
||||
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
|
||||
|
||||
embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res
|
||||
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
|
||||
return out, None, embedding_sum
|
||||
|
||||
def Attention(input:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None, relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary=None, mask_filter_value=None, num_heads=None, past_present_share_buffer=None, qkv_hidden_sizes=None, scale=None, unidirectional=None):
|
||||
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
|
||||
assert num_heads is not None # required
|
||||
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
|
||||
assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, "functionality not supported yet" # TODO strange params
|
||||
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
|
||||
|
||||
if unidirectional: # gpt-style
|
||||
assert hidden_size == v_hidden_size
|
||||
xqkv = input.linear(weights, bias)
|
||||
xq, xk, xv = [xqkv.slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
|
||||
else: # bert-style
|
||||
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
|
||||
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
|
||||
xq, xk, xv = [input.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
|
||||
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
|
||||
|
||||
if past is not None:
|
||||
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
|
||||
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
|
||||
|
||||
def attn(query, key, value, attn_mask):
|
||||
query_length, key_length = query.shape[-2], key.shape[-2]
|
||||
cdim = max(query_length, key_length) + 1
|
||||
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
|
||||
# This is where Tensor.scaled_dot_product_attention differs:
|
||||
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False).cast(dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length].cast(dtypes.bool)
|
||||
return (Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask).softmax(-1) @ value
|
||||
|
||||
bsz, _, seq_len, _ = xq.shape
|
||||
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
return out, present
|
||||
|
||||
# **************** ai.onnx.preview.training Ops ****************
|
||||
|
||||
# TODO not entirely sure these optimizers are correct
|
||||
def Adagrad(R, T, *inputs, decay_factor=0.0, epsilon=0.0, norm_coefficient=0.0):
|
||||
groups = len(inputs) // 3
|
||||
grouped_inputs = [inputs[i::groups] for i in range(groups)]
|
||||
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
|
||||
r = R / (1 + T * decay_factor)
|
||||
ret = []
|
||||
for input in grouped_inputs:
|
||||
X, G, H = input
|
||||
X.grad = norm_coefficient * X + G
|
||||
X.grad.requires_grad, H.requires_grad = False, False # TODO manually turning off requires_grad, see TODO under (domain == "ai.onnx.preview.training") in onnx.py
|
||||
H.assign(H.detach() + X.grad * X.grad).realize()
|
||||
H_adaptive = H.sqrt() + epsilon
|
||||
X.assign(X.detach() - r * X.grad / H_adaptive)
|
||||
ret.extend([X, H])
|
||||
ret = ret[::2] + ret[1::2]
|
||||
return tuple(ret)
|
||||
|
||||
def Momentum(R, T, *inputs, alpha, beta, mode, norm_coefficient):
|
||||
groups = len(inputs) // 3
|
||||
grouped_inputs = [inputs[i::groups] for i in range(groups)]
|
||||
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
|
||||
beta_adjusted = beta if T > 0 else 1
|
||||
ret = []
|
||||
for input in grouped_inputs:
|
||||
X, G, V = input
|
||||
X.grad = (norm_coefficient * X + G).realize()
|
||||
X.grad.requires_grad, V.requires_grad = False, False
|
||||
V.assign(alpha * V + beta_adjusted * X.grad).realize()
|
||||
if mode == "standard": X.assign(X.detach() - R * V).realize()
|
||||
elif mode == "nesterov": X.assign(X.detach() - R * (X.grad + alpha + V)).realize()
|
||||
ret.extend([X, V])
|
||||
ret = ret[::2] + ret[1::2]
|
||||
return tuple(ret)
|
||||
|
||||
# copied from tinygrad/nn/optim.py: LAMB with some edits
|
||||
def Adam(R, T, *inputs, alpha=0.9, beta=0.999, epsilon=0.0, norm_coefficient=0.0, norm_coefficient_post=0.0):
|
||||
groups = len(inputs) // 4
|
||||
grouped_inputs = [inputs[i::groups] for i in range(groups)]
|
||||
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
|
||||
ret = []
|
||||
for input in grouped_inputs:
|
||||
X, G, V, H = input
|
||||
X.grad = (norm_coefficient * X + G).realize()
|
||||
V.requires_grad, H.requires_grad, X.grad.requires_grad = False, False, False
|
||||
V.assign(alpha * V + (1.0 - alpha) * X.grad).realize()
|
||||
H.assign(beta * H + (1.0 - beta) * (X.grad * X.grad)).realize()
|
||||
up = (V / (1.0 - alpha**T)) / ((H / (1.0 - beta**T)).sqrt() + epsilon) if T > 0 else V / (H.sqrt() + epsilon)
|
||||
X.assign(X.detach() - R * up).realize()
|
||||
X = (1 - norm_coefficient_post) * X
|
||||
ret.extend([X, V, H])
|
||||
ret = ret[::3] + ret[1::3] + ret[2::3]
|
||||
return tuple(ret)
|
||||
285
tinygrad_repo/extra/thneed.py
Normal file
285
tinygrad_repo/extra/thneed.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# this can be constructed from a cl_cache or loaded from a thneed file
|
||||
import time
|
||||
import struct
|
||||
import json
|
||||
import traceback
|
||||
import numpy as np
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
from collections import defaultdict
|
||||
import pyopencl as cl
|
||||
from tinygrad.runtime.ops_gpu import CL, OSX_TIMING_RATIO
|
||||
|
||||
DEBUGCL = getenv("DEBUGCL", 0)
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
|
||||
class Thneed:
|
||||
def __init__(self, cl_cache=[], inputs={}):
|
||||
self.cl_cache, self.inputs = cl_cache[:], inputs
|
||||
self.gobj = 0
|
||||
|
||||
# build graph
|
||||
# NOTE: if CLCACHE=1, this is wrong!
|
||||
nodes = defaultdict(lambda: {'in_edges': [], 'out_edges': []})
|
||||
for _, args in self.cl_cache:
|
||||
# output is always the first parameter
|
||||
for a in args[3:]:
|
||||
nodes[a]['out_edges'].append(args[2])
|
||||
nodes[args[2]]['in_edges'].append(a)
|
||||
|
||||
# get buffers to save
|
||||
self.buffers_to_save = set()
|
||||
self.outputs = []
|
||||
for n in nodes.keys():
|
||||
if len(nodes[n]['in_edges']) == 0:
|
||||
self.buffers_to_save.add(n)
|
||||
if len(nodes[n]['out_edges']) == 0:
|
||||
self.outputs.append(n)
|
||||
|
||||
fake_inputs = []
|
||||
for k,n in self.inputs.items():
|
||||
if n in self.buffers_to_save:
|
||||
self.buffers_to_save.remove(n)
|
||||
else:
|
||||
print(f"WARNING: {k} was not a used input, removing it")
|
||||
fake_inputs.append(k)
|
||||
for k in fake_inputs:
|
||||
del self.inputs[k]
|
||||
|
||||
def load(self, input_fn):
|
||||
float32 = not FLOAT16
|
||||
|
||||
mf = cl.mem_flags
|
||||
image_fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT if float32 else cl.channel_type.HALF_FLOAT)
|
||||
image_fmt_32 = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT)
|
||||
|
||||
with open(input_fn, "rb") as f:
|
||||
json_len = struct.unpack("I", f.read(4))[0]
|
||||
jdat = json.loads(f.read(json_len).decode('latin_1'))
|
||||
weights = f.read()
|
||||
|
||||
# load in the buffers
|
||||
bufs = {'\x00\x00\x00\x00\x00\x00\x00\x00': None}
|
||||
bufs_loaded = {}
|
||||
ptr = 0
|
||||
for o in jdat['objects']:
|
||||
#print(o)
|
||||
if o['needs_load']:
|
||||
nptr = ptr + o['size']
|
||||
o['data'] = weights[ptr:nptr]
|
||||
ptr = nptr
|
||||
|
||||
if o['arg_type'] == "image2d_t" or o['arg_type'] == "image1d_t":
|
||||
tfmt = image_fmt_32 if 'float32' in o and o['float32'] else image_fmt
|
||||
if o['arg_type'] == "image2d_t":
|
||||
if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]:
|
||||
# hack: use a image1d since we can back that with a buffer
|
||||
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
|
||||
else:
|
||||
# buffer isn't supported in image2d, copy buffer into image
|
||||
if 'buffer_id' in o and bufs_loaded[o['buffer_id']]:
|
||||
arr = np.zeros(bufs[o['buffer_id']].size // 2, dtype=np.float16)
|
||||
cl.enqueue_copy(CL.cl_queue[0], arr, bufs[o['buffer_id']])
|
||||
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
|
||||
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr)
|
||||
elif o['needs_load']:
|
||||
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
|
||||
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data'])
|
||||
else:
|
||||
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
|
||||
if o['arg_type'] == "image1d_t":
|
||||
assert not o['needs_load']
|
||||
assert not bufs_loaded[o['buffer_id']]
|
||||
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
|
||||
else:
|
||||
if 'data' in o:
|
||||
buf = cl.Buffer(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
|
||||
else:
|
||||
# zero out buffers
|
||||
buf = cl.Buffer(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
|
||||
|
||||
bufs[o['id']] = buf
|
||||
bufs_loaded[o['id']] = 'data' in o
|
||||
# if it's loaded, it's saved
|
||||
if 'data' in o:
|
||||
self.buffers_to_save.add(buf)
|
||||
|
||||
# load binaries
|
||||
prgs = {}
|
||||
for o in jdat['binaries']:
|
||||
nptr = ptr + o['length']
|
||||
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr])
|
||||
ptr = nptr
|
||||
|
||||
# populate the cl_cache
|
||||
for i,k in enumerate(jdat['kernels']):
|
||||
kernel = prgs[k['name']]
|
||||
aaa = []
|
||||
for j,(a,sz) in enumerate(zip(k['args'], k['args_size'])):
|
||||
if len(a) == 0:
|
||||
aa = cl.LocalMemory(sz)
|
||||
elif len(a) == 4:
|
||||
a = a.encode('latin_1')
|
||||
aa = np.uint32(struct.unpack("I", a)[0])
|
||||
elif len(a) == 2:
|
||||
a = a.encode('latin_1')
|
||||
aa = np.uint16(struct.unpack("H", a)[0])
|
||||
elif len(a) == 8:
|
||||
#print(i,j,struct.unpack("Q", a.encode('latin_1'))[0])
|
||||
aa = bufs[a]
|
||||
aaa.append(aa)
|
||||
self.cl_cache.append((kernel, [k['global_work_size'], k['local_work_size'], *aaa]))
|
||||
|
||||
if DEBUG >= 1: print(f"thneed: total bufs loaded: {len(bufs.keys())}")
|
||||
|
||||
# load inputs
|
||||
for k in jdat['inputs']:
|
||||
self.inputs[k['name']] = bufs[k['buffer_id']]
|
||||
|
||||
# load outputs
|
||||
for k in jdat['outputs']:
|
||||
self.outputs.append(bufs[k['buffer_id']])
|
||||
|
||||
|
||||
def save(self, output_fn):
|
||||
# this is the struct that will be saved
|
||||
jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []}
|
||||
|
||||
# build the pieces of this struct
|
||||
weights = []
|
||||
binaries = []
|
||||
saved_objs = set()
|
||||
saved_binaries = set()
|
||||
for prg, args in self.cl_cache:
|
||||
# get binaries for saving
|
||||
if prg.name not in saved_binaries:
|
||||
binary = prg.clprograms[0].get_info(cl.program_info.BINARIES)
|
||||
assert len(binary) == 1
|
||||
jdat['binaries'].append({"name":prg.name, "length":len(binary[0])})
|
||||
binaries.append(binary[0])
|
||||
saved_binaries.add(prg.name)
|
||||
|
||||
# get the args from the kernel, some need the data saved
|
||||
targs, args_size = [], []
|
||||
argdtypes = prg.argdtypes if prg.argdtypes is not None else [None]*(len(args)-2)
|
||||
for a,d in zip(args[2:], argdtypes):
|
||||
if d == np.int16:
|
||||
targs.append(struct.pack("H", a).decode("latin_1"))
|
||||
args_size.append(2)
|
||||
elif d == np.int32:
|
||||
targs.append(struct.pack("I", a).decode("latin_1"))
|
||||
args_size.append(4)
|
||||
elif isinstance(a, cl.LocalMemory):
|
||||
targs.append("")
|
||||
args_size.append(a.size)
|
||||
elif d is None:
|
||||
if getattr(a, "global_id", None) is None:
|
||||
setattr(a, "global_id", self.gobj)
|
||||
self.gobj += 1
|
||||
ptr = struct.pack("Q", a.global_id).decode("latin_1")
|
||||
if ptr not in saved_objs:
|
||||
if isinstance(a, cl.Buffer):
|
||||
needs_load = a in self.buffers_to_save
|
||||
jdat['objects'].append({
|
||||
"id": ptr, "arg_type": "float*", "needs_load": needs_load, "size": a.size,
|
||||
})
|
||||
if needs_load:
|
||||
data = np.empty(a.size//4, dtype=np.float32)
|
||||
cl.enqueue_copy(CL.cl_queue[0], data, a, is_blocking=True)
|
||||
weights.append(data.tobytes())
|
||||
elif isinstance(a, cl.Image):
|
||||
assert a.format == cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT), "wrong type"
|
||||
needs_load = a in self.buffers_to_save
|
||||
row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64
|
||||
size = row_pitch * a.shape[1]
|
||||
# this is *2 if float16 and *4 if float32
|
||||
buf = cl.Buffer(CL.cl_ctxs[0], cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
|
||||
|
||||
# zero out the buffer
|
||||
cl.enqueue_copy(CL.cl_queue[0], buf, b'\x00'*buf.size, is_blocking=True)
|
||||
|
||||
CLProgram("from_image_strided", compile_gpu("""
|
||||
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
int2 l;
|
||||
l.y = get_global_id(1);
|
||||
l.x = get_global_id(0);
|
||||
out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
|
||||
}
|
||||
"""), argdtypes=(None, None, np.int32))(a, buf, row_pitch//(4*(2 if FLOAT16 else 4)), global_size=a.shape)
|
||||
|
||||
# multiple of 32 isn't enough
|
||||
jdat['objects'].append({
|
||||
"id": ptr, "needs_load": needs_load, "size": size, "arg_type": "image2d_t",
|
||||
"width": a.shape[0], "height": a.shape[1], "row_pitch": row_pitch, "float32": not FLOAT16,
|
||||
})
|
||||
|
||||
if needs_load:
|
||||
data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32)
|
||||
cl.enqueue_copy(CL.cl_queue[0], data, buf, is_blocking=True)
|
||||
if FLOAT16: data = data.astype(np.float16)
|
||||
weights.append(data.tobytes())
|
||||
else:
|
||||
raise Exception("unknown object", a)
|
||||
#print(jdat['objects'][-1])
|
||||
saved_objs.add(ptr)
|
||||
targs.append(ptr)
|
||||
args_size.append(8)
|
||||
else:
|
||||
raise Exception("idk this type")
|
||||
|
||||
# save the kernel itself
|
||||
jdat['kernels'].append({
|
||||
"name": prg.name,
|
||||
"work_dim": len(args[0]),
|
||||
"global_work_size": args[0],
|
||||
# TODO: C++ thneed requires a local_work_size, so we fill it with ones
|
||||
"local_work_size": [1 for _ in args[0]] if args[1] is None else args[1],
|
||||
"num_args": len(args)-2,
|
||||
"args": targs,
|
||||
"args_size": args_size
|
||||
})
|
||||
|
||||
jdat['outputs'] = [{
|
||||
"buffer_id": struct.pack("Q", x.global_id).decode("latin_1"),
|
||||
"size": x.size,
|
||||
} for x in self.outputs]
|
||||
|
||||
jdat['inputs'] = [{
|
||||
"buffer_id": struct.pack("Q", v.global_id).decode("latin_1"),
|
||||
"size": v.size,
|
||||
"name": k
|
||||
} for k,v in self.inputs.items()][::-1]
|
||||
|
||||
print(f"saving thneed to {output_fn}")
|
||||
with open(output_fn, "wb") as f:
|
||||
j = json.dumps(jdat, ensure_ascii=False).encode('latin_1')
|
||||
f.write(struct.pack("I", len(j)))
|
||||
f.write(j)
|
||||
f.write(b''.join(weights))
|
||||
f.write(b''.join(binaries))
|
||||
|
||||
def run(self):
|
||||
events = []
|
||||
st = time.monotonic()
|
||||
for prg, args in self.cl_cache:
|
||||
events.append(prg.clprgs[0](CL.cl_queue[0], *args))
|
||||
mt = time.monotonic()
|
||||
CL.synchronize()
|
||||
et = time.monotonic() - st
|
||||
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")
|
||||
|
||||
if DEBUGCL >= 2:
|
||||
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
|
||||
print(f"{i:3d} {prg.name:25s} " + "queued @ %5.2f ms, submit @ %5.2fms, start @ %5.2f ms, end @ %5.2f ms" % tuple((x*OSX_TIMING_RATIO - st*1e9)/1e6 for x in [e.profile.queued, e.profile.submit, e.profile.start, e.profile.end]))
|
||||
if DEBUGCL >= 1:
|
||||
total_runtime = 0
|
||||
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
|
||||
runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
|
||||
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:25s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
|
||||
if hasattr(prg, 'prg') and ((DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3):
|
||||
print(prg.prg)
|
||||
total_runtime += runtime
|
||||
print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms")
|
||||
return total_runtime/1e9
|
||||
return et
|
||||
205
tinygrad_repo/extra/utils.py
Normal file
205
tinygrad_repo/extra/utils.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# type: ignore
|
||||
import pickle, hashlib, zipfile, io, requests, struct, tempfile, platform, concurrent.futures
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
from typing import Union
|
||||
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
OSX = platform.system() == "Darwin"
|
||||
WINDOWS = platform.system() == "Windows"
|
||||
|
||||
def temp(x:str) -> str: return (Path(tempfile.gettempdir()) / x).as_posix()
|
||||
|
||||
def fetch(url):
|
||||
if url.startswith("/") or url.startswith("."):
|
||||
with open(url, "rb") as f:
|
||||
return f.read()
|
||||
fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
download_file(url, fp, skip_if_exists=not getenv("NOCACHE"))
|
||||
with open(fp, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def fetch_as_file(url):
|
||||
if url.startswith("/") or url.startswith("."):
|
||||
with open(url, "rb") as f:
|
||||
return f.read()
|
||||
fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
download_file(url, fp, skip_if_exists=not getenv("NOCACHE"))
|
||||
return fp
|
||||
|
||||
def download_file(url, fp, skip_if_exists=True):
|
||||
if skip_if_exists and Path(fp).is_file() and Path(fp).stat().st_size > 0:
|
||||
return
|
||||
r = requests.get(url, stream=True)
|
||||
assert r.status_code == 200
|
||||
progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url)
|
||||
(path := Path(fp).parent).mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
||||
for chunk in r.iter_content(chunk_size=16384):
|
||||
progress_bar.update(f.write(chunk))
|
||||
f.close()
|
||||
Path(f.name).rename(fp)
|
||||
|
||||
def my_unpickle(fb0):
|
||||
key_prelookup = defaultdict(list)
|
||||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
|
||||
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
||||
ident, storage_type, obj_key, location, obj_size = storage[0:5]
|
||||
assert ident == 'storage'
|
||||
assert prod(size) <= (obj_size - storage_offset)
|
||||
|
||||
if storage_type not in [np.float16, np.float32]:
|
||||
if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {size}")
|
||||
ret = None
|
||||
else:
|
||||
ret = Tensor.empty(*size, dtype=dtypes.from_np(storage_type))
|
||||
key_prelookup[obj_key].append((storage_type, obj_size, ret, size, stride, storage_offset))
|
||||
return ret
|
||||
|
||||
def _rebuild_parameter(*args):
|
||||
#print(args)
|
||||
pass
|
||||
|
||||
class Dummy: pass
|
||||
|
||||
class MyPickle(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#print(module, name)
|
||||
if name == 'FloatStorage': return np.float32
|
||||
if name == 'LongStorage': return np.int64
|
||||
if name == 'IntStorage': return np.int32
|
||||
if name == 'HalfStorage': return np.float16
|
||||
if module == "torch._utils":
|
||||
if name == "_rebuild_tensor_v2": return _rebuild_tensor_v2
|
||||
if name == "_rebuild_parameter": return _rebuild_parameter
|
||||
else:
|
||||
if module.startswith('pytorch_lightning'): return Dummy
|
||||
try:
|
||||
return super().find_class(module, name)
|
||||
except Exception:
|
||||
return Dummy
|
||||
|
||||
def persistent_load(self, pid):
|
||||
return pid
|
||||
|
||||
return MyPickle(fb0).load(), key_prelookup
|
||||
|
||||
def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset, mmap_allowed=False):
|
||||
bytes_size = np.dtype(dtype).itemsize
|
||||
if t is None:
|
||||
myfile.seek(prod(shape) * bytes_size, 1)
|
||||
return
|
||||
|
||||
bytes_offset = 0
|
||||
if storage_offset is not None:
|
||||
bytes_offset = storage_offset * bytes_size
|
||||
myfile.seek(bytes_offset)
|
||||
|
||||
assert t.shape == shape or shape == tuple(), f"shape mismatch {t.shape} != {shape}"
|
||||
assert t.dtype.np == dtype and t.dtype.itemsize == bytes_size
|
||||
if any(s != 1 and st1 != st2 for s, st1, st2 in zip(shape, strides_for_shape(shape), strides)):
|
||||
# slow path
|
||||
buffer_size = sum(strides[i]*t.dtype.itemsize * (shape[i] - 1) for i in range(len(shape)))
|
||||
buffer_size += t.dtype.itemsize
|
||||
np_array = np.frombuffer(myfile.read(buffer_size), t.dtype.np)
|
||||
|
||||
np_array = np.lib.stride_tricks.as_strided(
|
||||
np_array, shape=shape, strides=[i*t.dtype.itemsize for i in strides])
|
||||
|
||||
lna = t.lazydata.op.arg
|
||||
lna.fxn = lambda _: np_array
|
||||
t.realize()
|
||||
return
|
||||
|
||||
# ["METAL", "CLANG", "LLVM"] support readinto for more speed
|
||||
# ["GPU", "CUDA"] use _mmap since they have to copy in to the GPU anyway
|
||||
# this needs real APIs
|
||||
if t.device in ["METAL", "CLANG", "LLVM"]:
|
||||
del t.lazydata.op
|
||||
t.lazydata.realized = Device[t.lazydata.device].buffer(prod(t.shape), dtype=t.dtype)
|
||||
myfile.readinto(t.lazydata.realized._buffer())
|
||||
else:
|
||||
def _mmap(lna):
|
||||
assert myfile._compress_type == 0, "compressed data can't be mmaped"
|
||||
return np.memmap(myfile._fileobj._file, dtype=lna.dtype, mode='r', offset=myfile._orig_compress_start + bytes_offset, shape=lna.shape)
|
||||
def _read(lna):
|
||||
ret = np.empty(lna.shape, dtype=lna.dtype)
|
||||
myfile.readinto(ret.data)
|
||||
return ret
|
||||
if mmap_allowed and not OSX and t.device in ["GPU", "CUDA"]: t.lazydata.op.arg.fxn = _mmap
|
||||
else: t.lazydata.op.arg.fxn = _read
|
||||
t.realize()
|
||||
|
||||
def fake_torch_load_zipped(fb0, load_weights=True, multithreaded=True):
|
||||
if Device.DEFAULT in ["TORCH", "GPU", "CUDA"]: multithreaded = False # multithreaded doesn't work with CUDA or TORCH. for GPU it's a wash with _mmap
|
||||
with zipfile.ZipFile(fb0, 'r') as myzip:
|
||||
base_name = myzip.namelist()[0].split('/', 1)[0]
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||||
ret = my_unpickle(myfile)
|
||||
if load_weights:
|
||||
def load_weight(k, vv):
|
||||
with myzip.open(f'{base_name}/data/{k}') as myfile:
|
||||
for v in vv:
|
||||
load_single_weight(v[2], myfile, v[3], v[4], v[0], v[5], mmap_allowed=True)
|
||||
if multithreaded:
|
||||
# 2 seems fastest
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
futures = {executor.submit(load_weight, k, v):k for k,v in ret[1].items()}
|
||||
for future in (t:=tqdm(concurrent.futures.as_completed(futures), total=len(futures))):
|
||||
if future.exception() is not None: raise future.exception()
|
||||
k = futures[future]
|
||||
t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
||||
else:
|
||||
for k,v in (t := tqdm(ret[1].items())):
|
||||
t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
||||
load_weight(k,v)
|
||||
return ret[0]
|
||||
|
||||
def fake_torch_load(b0):
|
||||
|
||||
# convert it to a file
|
||||
fb0 = io.BytesIO(b0)
|
||||
|
||||
if b0[0:2] == b"\x50\x4b":
|
||||
return fake_torch_load_zipped(fb0)
|
||||
|
||||
# skip three junk pickles
|
||||
pickle.load(fb0)
|
||||
pickle.load(fb0)
|
||||
pickle.load(fb0)
|
||||
|
||||
ret, key_prelookup = my_unpickle(fb0)
|
||||
|
||||
# create key_lookup
|
||||
key_lookup = pickle.load(fb0)
|
||||
key_real = [None] * len(key_lookup)
|
||||
for k,v in key_prelookup.items():
|
||||
assert len(v) == 1
|
||||
key_real[key_lookup.index(k)] = v[0]
|
||||
|
||||
# read in the actual data
|
||||
for storage_type, obj_size, tensor, np_shape, np_strides, storage_offset in key_real:
|
||||
ll = struct.unpack("Q", fb0.read(8))[0]
|
||||
assert ll == obj_size, f"size mismatch {ll} != {obj_size}"
|
||||
assert storage_offset == 0, "not implemented"
|
||||
load_single_weight(tensor, fb0, np_shape, np_strides, storage_type, None)
|
||||
|
||||
return ret
|
||||
|
||||
def get_child(parent, key):
|
||||
obj = parent
|
||||
for k in key.split('.'):
|
||||
if k.isnumeric():
|
||||
obj = obj[int(k)]
|
||||
elif isinstance(obj, dict):
|
||||
obj = obj[k]
|
||||
else:
|
||||
obj = getattr(obj, k)
|
||||
return obj
|
||||
166
tinygrad_repo/openpilot/compile2.py
Normal file
166
tinygrad_repo/openpilot/compile2.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python3
|
||||
import os, sys, io, pathlib
|
||||
sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
|
||||
|
||||
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
|
||||
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
|
||||
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
|
||||
if "OPT" not in os.environ: os.environ["OPT"] = "99"
|
||||
os.environ["PREREALIZE"] = "0"
|
||||
|
||||
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
|
||||
|
||||
import onnx
|
||||
from typing import Tuple, List
|
||||
from extra.utils import fetch
|
||||
from extra.onnx import get_run_onnx
|
||||
from tinygrad.graph import print_tree, log_schedule_item
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, DEBUG, getenv, ImageDType, GRAPH
|
||||
from tinygrad.realize import run_schedule
|
||||
from tinygrad.ops import LoadOps, Device, ScheduleItem
|
||||
from tinygrad.features.image import fix_schedule_for_images
|
||||
Device.DEFAULT = "GPU"
|
||||
|
||||
def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
|
||||
Tensor.no_grad = True
|
||||
Tensor.training = False
|
||||
|
||||
# load the model
|
||||
onnx_model = onnx.load(io.BytesIO(onnx_data))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
|
||||
# run the model
|
||||
inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()}
|
||||
ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
|
||||
schedule = ret.lazydata.schedule()
|
||||
|
||||
# filter schedule that don't depend on the inputs
|
||||
input_lb = [x.lazydata.base for x in inputs.values()]
|
||||
depends = set(input_lb)
|
||||
for si in schedule:
|
||||
if any(b in depends for b in si.inputs):
|
||||
depends.add(si.out)
|
||||
|
||||
# run all kernels that don't depend on the inputs
|
||||
# NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized
|
||||
schedule, schedule_independent = partition(schedule, lambda si: si.out in depends)
|
||||
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
|
||||
|
||||
# confirm no loadops in the (non independent) schedule except for the ones that load the input buffers
|
||||
assert all(si.ast.op not in LoadOps or si.out in input_lb for si in schedule), "has loadops, can't compile to Thneed"
|
||||
return schedule, schedule_independent, inputs
|
||||
|
||||
def schedule_to_thneed(schedule, output_fn):
|
||||
from extra.thneed import Thneed
|
||||
|
||||
# transform to CL.CACHE
|
||||
used_ops = 0
|
||||
cl_cache = []
|
||||
for si in schedule:
|
||||
prg = Device["GPU"].method_cache[si.ast]
|
||||
args = (si.out,) + si.inputs
|
||||
|
||||
# pass these to thneed
|
||||
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
|
||||
setattr(prg.clprg, 'prg', prg.prg)
|
||||
|
||||
global_size = prg.global_size + [1]*(3-len(prg.global_size))
|
||||
local_size = prg.local_size + [1]*(3-len(prg.local_size))
|
||||
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x.realized._buf for x in args]]))
|
||||
used_ops += prg.op_estimate
|
||||
|
||||
from extra.thneed import Thneed
|
||||
input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()}
|
||||
t = Thneed(cl_cache, {k:v._buf for k,v in input_rawbuffers.items()})
|
||||
|
||||
# save thneed (before run)
|
||||
t.save(output_fn)
|
||||
|
||||
print(f"buffers to save: {len(t.buffers_to_save)}, inputs: {list(t.inputs.keys())}, outputs: {t.outputs}")
|
||||
runtime = t.run()
|
||||
print(f"network using {used_ops/1e9:.2f} GOPS with runtime {runtime*1e3:.2f} ms that's {used_ops/runtime*1e-9:.2f} GFLOPS")
|
||||
|
||||
def thneed_test_onnx(onnx_data, output_fn):
|
||||
import onnx
|
||||
import pyopencl as cl
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
import numpy as np
|
||||
from extra.thneed import Thneed
|
||||
onnx_model = onnx.load(io.BytesIO(onnx_data))
|
||||
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
inputs = {k:Tensor.randn(*shp, requires_grad=False)*8 for k,shp in input_shapes.items()}
|
||||
new_np_inputs = {k:v.realize().numpy() for k,v in inputs.items()}
|
||||
|
||||
if getenv("ORT"):
|
||||
# test with onnxruntime
|
||||
import onnxruntime as ort
|
||||
onnx_session = ort.InferenceSession(onnx_data)
|
||||
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_np_inputs.items()})
|
||||
new_torch_out = onnx_output[0]
|
||||
else:
|
||||
# test with torch
|
||||
from test.models.test_onnx import run_onnx_torch
|
||||
new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
|
||||
|
||||
if output_fn is None:
|
||||
# non thneed
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
new_tinygrad_out = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).numpy()
|
||||
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
print("classic self-test passed!")
|
||||
else:
|
||||
# load thneed and try that
|
||||
nt = Thneed()
|
||||
nt.load(output_fn)
|
||||
|
||||
# inputs
|
||||
for k,v in nt.inputs.items():
|
||||
cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True)
|
||||
|
||||
nt.run()
|
||||
new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(new_torch_out.shape)
|
||||
cl.enqueue_copy(CL.cl_queue[0], new_thneed_out, nt.outputs[0], is_blocking=True)
|
||||
|
||||
# compare torch to thneed
|
||||
np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2)
|
||||
print("thneed self-test passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
onnx_data = fetch(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL)
|
||||
|
||||
# quick test for ONNX issues
|
||||
#thneed_test_onnx(onnx_data, None)
|
||||
#exit(0)
|
||||
|
||||
schedule, schedule_independent, inputs = get_schedule(onnx_data)
|
||||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps)
|
||||
print(f"{len(schedule_input)} inputs")
|
||||
|
||||
run_schedule(schedule_independent, disable_logging=True)
|
||||
run_schedule(schedule_input)
|
||||
with Context(DEBUG=2, BEAM=getenv("LATEBEAM")):
|
||||
schedule = fix_schedule_for_images(schedule)
|
||||
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
|
||||
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
|
||||
|
||||
if GRAPH:
|
||||
for si in schedule_input: log_schedule_item(si)
|
||||
for si in schedule: log_schedule_item(si)
|
||||
|
||||
GlobalCounters.reset()
|
||||
run_schedule(schedule[:])
|
||||
|
||||
output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed"
|
||||
schedule_to_thneed(schedule, output_fn)
|
||||
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
if FLOAT16 == 0:
|
||||
try:
|
||||
thneed_test_onnx(onnx_data, output_fn)
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"TEST NOT HAPPENING {e}")
|
||||
|
||||
|
||||
585
tinygrad_repo/tinygrad/codegen/kernel.py
Normal file
585
tinygrad_repo/tinygrad/codegen/kernel.py
Normal file
@@ -0,0 +1,585 @@
|
||||
from __future__ import annotations
|
||||
import os, math, itertools
|
||||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, Device, Compiled
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int, ansilen, getenv, prod, DEBUG
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum):
|
||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class Opt:
|
||||
op: OptOps
|
||||
axis: Optional[int] = None
|
||||
amt: Optional[int] = None
|
||||
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorCore:
|
||||
device: str
|
||||
dims: List[int]
|
||||
dtype_in: DType
|
||||
dtype_out: DType
|
||||
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
|
||||
upcast_dim: int # which TC dim to upcast
|
||||
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim
|
||||
thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim
|
||||
arch: Optional[str] = None
|
||||
def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>"
|
||||
|
||||
tensor_cores: Dict[str, List[TensorCore]] = {
|
||||
"METAL": [
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"),
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"),
|
||||
],
|
||||
"HIP": [
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]),
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]),
|
||||
]
|
||||
}
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
name: str
|
||||
size: int
|
||||
dtype: DType = dtypes.float32
|
||||
realized: None = None
|
||||
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
|
||||
|
||||
class LinearizerOptions(NamedTuple):
|
||||
device: str = ""
|
||||
# TODO: make this generic with a list of supported types
|
||||
supports_float4: bool = True
|
||||
supports_float4_alu: bool = True
|
||||
has_local: bool = True
|
||||
has_shared: bool = True
|
||||
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
|
||||
global_max: Optional[List[int]] = None
|
||||
local_max: Optional[List[int]] = None
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None):
|
||||
self.opts = opts if opts else (cast(Compiled, Device[Device.DEFAULT]).linearizer_opts if isinstance(Device[Device.DEFAULT], Compiled) else LinearizerOptions())
|
||||
self.ast = ast
|
||||
|
||||
# fetch lazyop info
|
||||
self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
|
||||
|
||||
# there's only allowed to be one reduceop
|
||||
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]
|
||||
assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
|
||||
self.reduceop = reduceops[0] if reduceops else None
|
||||
|
||||
# create new shapetrackers inside this kernel, we will permute them
|
||||
self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in BufferOps])
|
||||
|
||||
# get earlybufs, before the one reduce op
|
||||
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in BufferOps] if self.reduceop else []
|
||||
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
|
||||
|
||||
# create the (permuted) shapetrackers
|
||||
self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)]
|
||||
|
||||
# move all reduce axes to the end
|
||||
reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
|
||||
permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
|
||||
self.reshape_and_permute(None, permute)
|
||||
|
||||
# parameters for optimization
|
||||
self.applied_opts: List[Opt] = []
|
||||
self.group_for_reduce: List[int] = []
|
||||
self.upcasted: int = 0
|
||||
self.local_dims: int = 0
|
||||
self.local_alias: Dict[int, LocalBuffer] = {}
|
||||
self.tensor_core: Optional[TensorCore] = None
|
||||
self.dont_use_locals: bool = False
|
||||
|
||||
# group simplifies
|
||||
self.simplify_ones()
|
||||
self.simplify_merge_adjacent()
|
||||
|
||||
# cache
|
||||
self.applied_opts_cache: Optional[List[Opt]] = None
|
||||
|
||||
def copy(self):
|
||||
ret = type(self).__new__(type(self))
|
||||
|
||||
# base linearizer params
|
||||
ret.opts, ret.ast = self.opts, self.ast
|
||||
|
||||
# things downstream of the AST
|
||||
# NOTE: we copy bufs for local buffers and sts for optimizations
|
||||
ret.info, ret.reduceop, ret.bufs, ret.earlybufs, ret.full_buf_index, ret.sts = \
|
||||
self.info, self.reduceop, self.bufs[:], self.earlybufs, self.full_buf_index, self.sts[:]
|
||||
|
||||
# parameters for optimizations
|
||||
ret.applied_opts, ret.group_for_reduce, ret.upcasted, ret.local_dims, ret.local_alias, ret.tensor_core, ret.dont_use_locals = \
|
||||
self.applied_opts[:], self.group_for_reduce[:], self.upcasted, self.local_dims, self.local_alias.copy(), self.tensor_core, self.dont_use_locals
|
||||
|
||||
# uncached since linearize didn't run
|
||||
ret.applied_opts_cache = None
|
||||
|
||||
return ret
|
||||
|
||||
@property
|
||||
def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
|
||||
|
||||
def has_variable_shape(self) -> bool:
|
||||
for b in self.bufs:
|
||||
if not isinstance(b, LocalBuffer) and not all_int(b.st.views[-1].shape): return True
|
||||
return False
|
||||
|
||||
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
|
||||
def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]
|
||||
|
||||
def upcasted_axis(self, i):
|
||||
return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:],
|
||||
self.sts[i].real_strides()[self.shape_len-self.upcasted:],
|
||||
[x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
|
||||
|
||||
# TODO: is there a better way to write this?
|
||||
def acc_offsets(self, i):
|
||||
if self.upcasted == 0: return [0]
|
||||
upcasted_i = self.upcasted_axis(i)
|
||||
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
|
||||
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
|
||||
|
||||
def get_upcast_dim(self, i) -> List[int]:
|
||||
should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
|
||||
return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1]
|
||||
|
||||
@property
|
||||
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)
|
||||
|
||||
@property
|
||||
def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
|
||||
|
||||
@property
|
||||
def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
|
||||
|
||||
@property
|
||||
def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.shape_len-self.upcasted]
|
||||
|
||||
@property
|
||||
def shape_len(self) -> int: return len(self.sts[0].shape)
|
||||
|
||||
@property
|
||||
def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]]
|
||||
|
||||
@property
|
||||
def global_dims(self) -> int: return self.first_reduce-self.local_dims
|
||||
|
||||
# there's eight chunks of the shape
|
||||
# blue -- global dims
|
||||
# cyan -- local dims (warp ones first)
|
||||
# *** self.first_reduce
|
||||
# green -- reduce-local dims
|
||||
# white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes)
|
||||
# red -- reduce loops
|
||||
# *** self.upcasted
|
||||
# purple -- reduce upcasted
|
||||
# yellow -- normal upcasted dimensions
|
||||
def colors(self) -> List[str]:
|
||||
# first non local non reduce dims are global (blue)
|
||||
colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
|
||||
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
|
||||
colors += ["cyan"] * self.local_dims
|
||||
# between first_reduce and first_reduce + group_for_reduce, they are either upcast mid reduce (white), or late upcasted (green)
|
||||
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))]
|
||||
# between first_reduce + group_for_reduce and upcasted, they are reduce (red)
|
||||
colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce)))
|
||||
# upcasted dimensions are reduce (magenta) or normal (yellow)
|
||||
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
|
||||
assert len(colors) == self.shape_len, "colors size mismatch"
|
||||
return colors
|
||||
|
||||
def colored_shape(self, pad=None, dense=False) -> str:
|
||||
ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors()))
|
||||
if pad: ret += ' '*(pad-ansilen(ret))
|
||||
return ret
|
||||
|
||||
# ******************** base simplifiers ********************
|
||||
|
||||
# apply reshape and permute to all shapetrackers
|
||||
def reshape_and_permute(self, new_shape_fxn, axis):
|
||||
new_sts = []
|
||||
for st in self.sts:
|
||||
if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape)))
|
||||
if axis is not None: st = st.permute(tuple(axis))
|
||||
new_sts.append(st)
|
||||
self.sts = new_sts
|
||||
|
||||
# drops the final dimension
|
||||
def upcast(self):
|
||||
assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1"
|
||||
self.upcasted += 1
|
||||
|
||||
# axis : the axis to pull from
|
||||
# amount : the amount to take
|
||||
# top : if you want to pull that amount from the top
|
||||
# insert_before : place to insert the new stuff
|
||||
def shift_to(self, axis, amount, top=False, insert_before=None):
|
||||
if insert_before is None: insert_before = self.shape_len
|
||||
move_axis = axis if top else axis+1
|
||||
if move_axis < insert_before: insert_before += 1
|
||||
self.reshape_and_permute(
|
||||
lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]),
|
||||
[i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
|
||||
|
||||
# ******************** complex simplifiers ********************
|
||||
|
||||
def simplify_ones(self) -> bool:
|
||||
# remove places where the shape is all ones
|
||||
# TODO: this should be factored in to multi shape stride
|
||||
if self.shape_len == 0: return False
|
||||
all_ones = [s==1 for s in self.full_shape]
|
||||
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
|
||||
self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
|
||||
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
||||
return any(all_ones)
|
||||
|
||||
def simplify_merge_adjacent(self):
|
||||
if self.shape_len == 0: return
|
||||
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
||||
|
||||
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
|
||||
if isinstance(self.bufs[0].dtype, ImageDType):
|
||||
base_shape = self.bufs[0].dtype.shape
|
||||
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
||||
special_strides: Tuple[int, ...] = tuple()
|
||||
for i,g in enumerate(shape_idx_groups):
|
||||
shape_piece = tuple(self.output_shape[x] for x in g)
|
||||
assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
|
||||
special_strides += strides_for_shape(shape_piece)
|
||||
# adding the fake image shape
|
||||
shapes.append(self.output_shape)
|
||||
strides.append(special_strides)
|
||||
|
||||
# merge dimensions if we can, multi get_shape_strides
|
||||
# TODO: does this always preserve the reduce dimension, NO
|
||||
# TODO: move this into shapetracker, with tests!
|
||||
rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
|
||||
for i in range(1, len(shapes[0])):
|
||||
can_merge = []
|
||||
for j in range(len(shapes)):
|
||||
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
||||
can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0)))
|
||||
# more can merge than this
|
||||
mergeable = all(can_merge) and i != self.first_reduce
|
||||
for j in range(len(shapes)):
|
||||
if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
|
||||
else: rets[j].append((shapes[j][i], strides[j][i]))
|
||||
|
||||
# do the reshapes
|
||||
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
|
||||
# ******************** GPU simplifiers ********************
|
||||
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
|
||||
new_shape,dims = list(x), len(x)
|
||||
for i in range(dims):
|
||||
next_idx = (i + 1) % dims
|
||||
while new_shape[i] > max_size[i]:
|
||||
new_shape[i] = new_shape[i] // 2
|
||||
if (new_shape[next_idx] <= max_size[next_idx]):
|
||||
new_shape[next_idx] = new_shape[next_idx] * 2
|
||||
else:
|
||||
next_idx = (next_idx + 1) % dims
|
||||
new_shape[next_idx] = new_shape[next_idx] * 2
|
||||
return tuple(new_shape)
|
||||
|
||||
def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
|
||||
# Check the global allocation limit, current the global_size will be flipped during codegen
|
||||
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
|
||||
global_dims = self.first_reduce-self.local_dims
|
||||
if global_dims > 0:
|
||||
if global_max:
|
||||
tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
|
||||
if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
|
||||
assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}"
|
||||
for i in range(global_dims-1):
|
||||
if i < len(global_max) and self.full_shape[i] > global_max[i]:
|
||||
order = list(range(len(self.full_shape)))
|
||||
order[i], order[global_dims-1] = order[global_dims-1], order[i]
|
||||
self.reshape_and_permute(None, order)
|
||||
if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
|
||||
|
||||
def alias_buffer(self, i, pattern):
|
||||
assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
|
||||
|
||||
bst = 1
|
||||
real_strides = self.sts[i].real_strides()
|
||||
shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
|
||||
for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
|
||||
for j,p in enumerate(pattern):
|
||||
if priority == p and real_strides[j] != 0:
|
||||
stride[j] = bst
|
||||
bst *= shp[j]
|
||||
|
||||
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
|
||||
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size()))
|
||||
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
|
||||
self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])
|
||||
|
||||
# ******************** high level optimizers ********************
|
||||
|
||||
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None):
|
||||
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores:
|
||||
for tc in tensor_cores[self.opts.device]:
|
||||
if not((tc.arch is None or tc.arch == os.uname().machine) and isinstance(self.reduceop.src[0], LazyOp)): continue
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
|
||||
if has_cast and not(isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
|
||||
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
|
||||
|
||||
if not(isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL): continue
|
||||
if not(isinstance(mul_op.src[0], LazyOp) and mul_op.src[0].op == BufferOps.MEM and mul_op.src[0].arg.dtype == tc.dtype_in): continue
|
||||
if not(isinstance(mul_op.src[1], LazyOp) and mul_op.src[1].op == BufferOps.MEM and mul_op.src[1].arg.dtype == tc.dtype_in): continue
|
||||
buf0, buf1 = self.bufs.index(cast(MemBuffer, mul_op.src[0].arg)), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg))
|
||||
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0]
|
||||
|
||||
if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue
|
||||
|
||||
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
||||
|
||||
s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0] # TODO: select axis in smart way
|
||||
s0_exists, s1_exists = True, True
|
||||
assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0
|
||||
def fix(needed, ax):
|
||||
nonlocal s0, s1, s0_exists, s1_exists
|
||||
if not needed: return
|
||||
if s0_exists and ax == s0:
|
||||
if s1_exists and s0 < s1: s1 -= 1
|
||||
s0_exists = False
|
||||
elif s1_exists and ax == s1:
|
||||
if s0_exists and s1 < s0: s0 -= 1
|
||||
s1_exists = False
|
||||
|
||||
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
|
||||
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
|
||||
self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
|
||||
for (tc_dim, tc_amt) in tc.threads:
|
||||
fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
|
||||
|
||||
# assert tensor core and prevent extra_opts from altering the key shape structure
|
||||
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
||||
|
||||
if extra_opts is not None:
|
||||
for opt in extra_opts:
|
||||
self.apply_opt(opt)
|
||||
else:
|
||||
# hand-coded TC opts
|
||||
if s1_exists:
|
||||
s1_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s1]%upc == 0][0]
|
||||
if s1_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1)
|
||||
if s0_exists:
|
||||
s0_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s0]%upc == 0][0]
|
||||
if s0_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0)
|
||||
if self.tensor_core and s0_exists:
|
||||
for upc in [4,2]:
|
||||
if self.full_shape[s0] % upc == 0:
|
||||
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
|
||||
break
|
||||
|
||||
# alias buffer
|
||||
alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2)
|
||||
self.alias_buffer(buf0, alias_pattern)
|
||||
self.alias_buffer(buf1, alias_pattern)
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_opt(self, opt:Opt):
|
||||
assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals"
|
||||
self.applied_opts.append(opt)
|
||||
if opt.axis is not None:
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0))
|
||||
else:
|
||||
axis = -1
|
||||
if opt.amt is not None:
|
||||
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
||||
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
||||
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
|
||||
else:
|
||||
amt = -1
|
||||
if opt.op == OptOps.LOCAL: # cyan
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
assert not(self.tensor_core), "can't local with tensor cores"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.LASTLOCAL: # cyan
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.GROUP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert not(self.tensor_core), "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.GROUPTOP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert not(self.tensor_core), "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.UNROLL: # purple
|
||||
assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted"
|
||||
assert amt <= 32, "don't unroll more than 32"
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCAST: # yellow
|
||||
assert axis < self.first_reduce, "upcast is for non-reduce"
|
||||
assert amt <= 8, "don't upcast more than 8"
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCASTMID: # white
|
||||
assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce"
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
assert axes[0] == axis, "wrong axis"
|
||||
assert amt == 4, "don't upcast mid anything but 4"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.NOLOCALS:
|
||||
assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals"
|
||||
assert not self.dont_use_locals, "already not using locals"
|
||||
self.dont_use_locals = True
|
||||
return self.simplify_ones()
|
||||
|
||||
def required_optimizations(self, early_only=False):
|
||||
for buf_index,buf in enumerate(self.bufs):
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
||||
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
|
||||
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
||||
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
|
||||
if unit_stride_axes_mul_4[0] < self.first_reduce:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
||||
else:
|
||||
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
|
||||
|
||||
def hand_coded_optimizations(self):
|
||||
# if there's images in the earlybufs, we have to make an axis the 4 loading one
|
||||
self.required_optimizations(early_only=True)
|
||||
|
||||
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
||||
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
||||
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||||
self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
||||
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \
|
||||
self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM:
|
||||
buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg)
|
||||
buf1 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[1]).arg)
|
||||
buf0_strides = self.sts[buf0].real_strides()
|
||||
buf1_strides = self.sts[buf1].real_strides()
|
||||
def has_expanded_axis(s, st): return any(x > 1 and y == 0 for x,y in zip(s,st))
|
||||
if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)):
|
||||
for global_idx in range(self.global_dims):
|
||||
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
|
||||
if MV_THREADS_PER_ROW > 1:
|
||||
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||||
if MV_BLOCKSIZE > 1:
|
||||
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||||
if MV_ROWS_PER_THREAD > 1:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
||||
return
|
||||
|
||||
if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]):
|
||||
# are we grouping? (requires local shape support)
|
||||
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
||||
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
||||
self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
||||
break
|
||||
|
||||
# are we upcasting in mid reduce? (only for images)
|
||||
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
if self.sts[0].shape[axes[0]]%4 == 0:
|
||||
self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
|
||||
|
||||
# now do everything required
|
||||
self.required_optimizations()
|
||||
|
||||
# no more opt if we are grouping
|
||||
if self.group_for_reduce: return
|
||||
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
|
||||
# to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below
|
||||
# expression and run test/test_ops.py with IMAGE=2
|
||||
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
||||
# this can be made much smarter
|
||||
to_upcast: List[int] = []
|
||||
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
||||
for axis in range(self.first_reduce):
|
||||
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
||||
# for now skip upcasting here if there is a symbolic axis
|
||||
if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
|
||||
prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
|
||||
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||||
to_upcast.append(axis)
|
||||
for axis in to_upcast[::-1]:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
||||
|
||||
# potentially do more upcasts of non reduce axes based on a heuristic
|
||||
upcasted_axis = set()
|
||||
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
|
||||
xb_choices = []
|
||||
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
||||
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)):
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
|
||||
if xb_choices:
|
||||
xb_choices = sorted(xb_choices)
|
||||
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
|
||||
self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
|
||||
upcasted_axis.add(xb_choices[0][2])
|
||||
else:
|
||||
break
|
||||
|
||||
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64):
|
||||
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
||||
else:
|
||||
for splits in [4]:
|
||||
if self.full_unupcasted_shape[-1]%splits == 0:
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits))
|
||||
break
|
||||
|
||||
# if nothing at all is upcasted and it's easy to, do an upcast
|
||||
# TODO: this is breaking the tests
|
||||
for splits in [4]:
|
||||
if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits))
|
||||
|
||||
# **** local groups ****
|
||||
|
||||
if self.opts.has_local:
|
||||
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce:
|
||||
self.apply_opt(Opt(OptOps.NOLOCALS))
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))]
|
||||
to_local: List[Tuple[int, int]] = []
|
||||
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||||
local_size = prod(sz for _, sz in to_local)
|
||||
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None)
|
||||
if local_sz is not None: to_local.append((axis, local_sz))
|
||||
deleted_shape = 0
|
||||
for axis, local_sz in sorted(to_local[:3]):
|
||||
axis = axis - deleted_shape
|
||||
will_delete_shape = local_sz == self.full_shape[axis]
|
||||
self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
|
||||
if will_delete_shape: deleted_shape += 1
|
||||
441
tinygrad_repo/tinygrad/codegen/linearizer.py
Normal file
441
tinygrad_repo/tinygrad/codegen/linearizer.py
Normal file
@@ -0,0 +1,441 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, Dict, Union, Sequence, Final, Set
|
||||
import itertools, math, functools
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
|
||||
from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename
|
||||
from tinygrad.codegen.kernel import LocalBuffer, Kernel
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.features.image import to_image_idx
|
||||
|
||||
# bottom ones are asm only
|
||||
class UOps(Enum):
|
||||
LOOP = auto(); IF = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
|
||||
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
|
||||
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
|
||||
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
|
||||
|
||||
class UOp(NamedTuple):
|
||||
uop: UOps
|
||||
dtype: Optional[DType]
|
||||
vin: Tuple[UOp, ...]
|
||||
arg: Any
|
||||
def __repr__(self): return f"{self.num:4d} {str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.num for x in self.vin]):32s} {self.arg}"
|
||||
#def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str(self.vin):32s} {self.arg}"
|
||||
|
||||
# UOps are unique
|
||||
num: int
|
||||
def __hash__(self): return self.num
|
||||
def __eq__(self, x): return self.num == x.num
|
||||
|
||||
def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
|
||||
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)]
|
||||
if maxdim != 0 and len(local_dims) > maxdim:
|
||||
dd = local_idxs[maxdim-1]
|
||||
nli = []
|
||||
for s in local_dims[maxdim-1:][::-1]:
|
||||
nli.append(dd % s)
|
||||
dd //= s
|
||||
local_idxs = local_idxs[0:maxdim-1] + nli[::-1]
|
||||
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
|
||||
|
||||
class Linearizer(Kernel):
|
||||
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
|
||||
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
||||
return self.uop(UOps.ALU, dtype, (a, render_b), op)
|
||||
|
||||
# NOTE: the consts have to be be cached for deduping of downstream uops to work
|
||||
def const(self, b:Union[int,float], dtype=dtypes.int32) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b)
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
||||
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
|
||||
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
|
||||
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT, dtype=dtypes.bool),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def global_load(self, i:int, idxs:Sequence[Node], acc=None) -> List[UOp]:
|
||||
buf = self.bufs[i]
|
||||
const = buf.val if isinstance(buf, ConstBuffer) else acc
|
||||
|
||||
def rename_var(v: VariableOrNum, expr: str): return v if isinstance(v, NumNode) else Variable(expr, v.min, v.max)
|
||||
|
||||
amt, dim = 1, None
|
||||
upcast_dim = self.get_upcast_dim(i)
|
||||
if len(upcast_dim) == 1 and len(float4_expand := idxs[upcast_dim[0]].expand()) in [4,2]:
|
||||
dim, amt = upcast_dim[0], len(float4_expand)
|
||||
|
||||
expand_vars = tuple([rename_var(idx.expand_idx(), f"_uidx{j}") for j, idx in enumerate(idxs)])
|
||||
fake_idxs = [idx.substitute({idx.expand_idx(): ev}) for idx, ev in zip(idxs, expand_vars)]
|
||||
if dim is not None:
|
||||
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs[:dim] + [float4_expand[0]] + fake_idxs[dim+1:])
|
||||
if (g_idx // amt * amt).render() != g_idx.render():
|
||||
(g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None
|
||||
else:
|
||||
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs)
|
||||
localtype = dtypes.float32 if amt == 1 else dtypes._float4 if amt == 4 else dtypes._float2
|
||||
|
||||
e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars)
|
||||
|
||||
ret = []
|
||||
invalid_value = 0 if dtypes.is_int(buf.dtype) else 0.0
|
||||
for idx, valid, rep_idx in zip(e_idxs, e_valids, Node.iter_idxs(expand_vars)):
|
||||
this_const, idx, valid = (invalid_value, Variable.num(0), Variable.num(1)) if valid.max == 0 else (const, idx, valid)
|
||||
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}"
|
||||
if key not in self.load_cache:
|
||||
if acc is not None:
|
||||
assert valid.min == 1
|
||||
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False)
|
||||
elif this_const is not None:
|
||||
self.load_cache[key] = self.const(this_const, localtype)
|
||||
if valid.min == 0 and valid.max == 1:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE)
|
||||
else:
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
|
||||
if valid.min == 0:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)))
|
||||
else:
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx))
|
||||
ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
||||
return ret
|
||||
|
||||
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> None:
|
||||
buf = self.bufs[i]
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
|
||||
expanded_nodes = [idx.expand() for idx in idxs]
|
||||
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
|
||||
store_offset = dict(zip(_idxs, store))
|
||||
|
||||
# float4 grouping
|
||||
upcast_dim = self.get_upcast_dim(i)
|
||||
if len(upcast_dim) == 1 and len(expanded_nodes[upcast_dim[0]]) in [2,4]:
|
||||
grouped_store_offset = defaultdict(list)
|
||||
for k in store_offset:
|
||||
_idx = k[:upcast_dim[0]] + (expanded_nodes[upcast_dim[0]][0],) + k[upcast_dim[0]+1:]
|
||||
grouped_store_offset[_idx].append(store_offset[k])
|
||||
store_offset_new = {}
|
||||
for k,out_tokens in grouped_store_offset.items():
|
||||
amt = len(out_tokens)
|
||||
idx, valid = self.sts[i].expr_idxs(k)
|
||||
assert idx.render() == ((idx//amt)*amt).render(), "float4 stores are always aligned"
|
||||
assert valid.min == 1, "stores are always valid"
|
||||
store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens))
|
||||
store_offset = store_offset_new
|
||||
|
||||
for idx, var in store_offset.items():
|
||||
idx, valid = self.sts[i].expr_idxs(idx)
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var))
|
||||
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
def linearize(self):
|
||||
# no new opts and we already ran? skip relinearizing
|
||||
if self.applied_opts == self.applied_opts_cache: return self
|
||||
|
||||
# save backups
|
||||
sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduce[:], self.upcasted
|
||||
|
||||
# global uop cache
|
||||
self.saved_exprs: Dict[Tuple, UOp] = dict()
|
||||
|
||||
# limit dims if we need to
|
||||
if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
|
||||
|
||||
# uops
|
||||
self.uops: List[UOp] = []
|
||||
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
|
||||
self.loop_uops: Dict[str, UOp] = {}
|
||||
|
||||
# add global buffers
|
||||
for i,buf in enumerate(self.bufs):
|
||||
if isinstance(buf, MemBuffer):
|
||||
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype))
|
||||
# add var vals
|
||||
for var in sorted(vars_from_ast(self.ast), key=lambda k: k.key):
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
|
||||
# define local buffers
|
||||
for lb in self.local_alias.values():
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))
|
||||
# add a local buffer for multistage reduce. # TODO: use local alias
|
||||
if self.group_for_reduce:
|
||||
# TODO: the strides of this can be controlled
|
||||
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
|
||||
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size())))
|
||||
|
||||
# kernel name (before late upcast)
|
||||
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) if isinstance(x, int) else sym_rename(x) for x in self.full_shape])
|
||||
self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
|
||||
# name the function something unique
|
||||
Linearizer.kernel_cnt[self.function_name] += 1
|
||||
suffix = f"{'n'+str(Linearizer.kernel_cnt[self.function_name]-1)}" if Linearizer.kernel_cnt[self.function_name] > 1 else ""
|
||||
self.function_name, self.display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
|
||||
|
||||
# define indexes
|
||||
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
|
||||
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+len(self.group_for_reduce)], 3 if self.opts.has_local else 0)
|
||||
full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]]
|
||||
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
||||
|
||||
# global and local loops
|
||||
def render_loop(xx:List[Variable]):
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.LOOP, dtypes.int32, (
|
||||
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
||||
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None})
|
||||
def end_loop(xx:List[Variable]):
|
||||
for x in xx[::-1]:
|
||||
if not isinstance(x, NumNode) and x.expr is not None:
|
||||
loop_uop = self.loop_uops[x.expr]
|
||||
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,))
|
||||
|
||||
# set global/local size
|
||||
self.global_size: Optional[List[int]] = None
|
||||
self.local_size: Optional[List[int]] = None
|
||||
if self.dont_use_locals:
|
||||
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||
elif self.opts.has_local:
|
||||
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
|
||||
self.global_size += [1]*(3-len(self.global_size))
|
||||
self.local_size += [1]*(3-len(self.local_size))
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
||||
else:
|
||||
render_loop(loop_global_idxs+loop_local_idxs)
|
||||
|
||||
# parse AST
|
||||
loaded_buffers = {}
|
||||
acc = []
|
||||
self.load_cache: Dict[str, UOp] = {}
|
||||
if_gate: Optional[UOp] = None
|
||||
|
||||
# reduce op
|
||||
fake_reduce_idxs: List[Variable] = []
|
||||
if self.reduceop is not None:
|
||||
# define indexes
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)]
|
||||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
|
||||
# define accumulator
|
||||
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
|
||||
if self.tensor_core:
|
||||
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
|
||||
replace_idxs = []
|
||||
for alias in aliases:
|
||||
full_var, full_var_sz = Variable.num(0), 1
|
||||
if alias[0] != 0:
|
||||
for i in alias:
|
||||
next_var = local_idxs[-i] if i > 0 else Variable(None, 0, local_size-1)
|
||||
full_var += next_var * full_var_sz
|
||||
full_var_sz *= next_var.max+1
|
||||
replace_idxs.append(full_var)
|
||||
return replace_idxs
|
||||
replace_acc_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[2], self.tensor_core.thread_local_aliases[2])
|
||||
for n in range(len(self.tensor_core.threads)):
|
||||
local_idxs[self.local_dims-len(self.tensor_core.threads)+n] = replace_acc_idxs[n] # replace locals
|
||||
for n in range(len(replace_acc_idxs)-len(self.tensor_core.threads)):
|
||||
upcast_idxs[n] = replace_acc_idxs[len(self.tensor_core.threads)+n] # replace upcasts
|
||||
|
||||
# reduce loop
|
||||
render_loop(reduce_idxs)
|
||||
|
||||
# barrier for fast GEMM
|
||||
if self.tensor_core: self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
|
||||
# compute local aliases
|
||||
locals_to_store = []
|
||||
for i in self.local_alias:
|
||||
localbuf_idx = self.bufs.index(self.local_alias[i])
|
||||
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
|
||||
if self.tensor_core:
|
||||
min_alias_idx = min(self.local_alias.keys())
|
||||
replace_input_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[i-min_alias_idx], self.tensor_core.thread_local_aliases[i-min_alias_idx])
|
||||
for n in range(len(self.tensor_core.threads)):
|
||||
buf_idxs[self.first_reduce-len(self.tensor_core.threads)+n] = replace_input_idxs[n] # replace locals
|
||||
for n in range(len(replace_input_idxs)-len(self.tensor_core.threads)):
|
||||
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(self.tensor_core.threads)+n] # replace upcasts
|
||||
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs)
|
||||
ll = self.global_load(i, buf_idxs)
|
||||
locals_to_store.append((localbuf_idx, buf_idxs, ll))
|
||||
|
||||
# copy in any global buffers
|
||||
if self.tensor_core:
|
||||
wmma_sz = self.tensor_core.thread_local_sizes
|
||||
# calculate the number of local accumulator reduces and render WMMAs: this is bad... this needs to come from someplace else
|
||||
nx, ny, nacc = (len(locals_to_store[0][2])//wmma_sz[0]), (len(locals_to_store[1][2])//wmma_sz[1]), (len(acc)//wmma_sz[2])
|
||||
acc_reds = math.isqrt((nx*ny)//nacc)
|
||||
i, bx, by = 0, nx//acc_reds, ny//acc_reds
|
||||
for y in range(by):
|
||||
for x in range(bx):
|
||||
for j in range(acc_reds):
|
||||
self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]]+locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]]+acc[i:i+wmma_sz[2]]), (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
|
||||
i += wmma_sz[2]
|
||||
else:
|
||||
if locals_to_store:
|
||||
self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
for i, idxs, ll in locals_to_store: self.global_store(i, idxs, ll)
|
||||
self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
|
||||
# load earlybufs
|
||||
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
|
||||
|
||||
# run early AST (with reduce)
|
||||
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True)
|
||||
|
||||
# end the reduce loop
|
||||
end_loop(reduce_idxs)
|
||||
self.load_cache.clear()
|
||||
|
||||
# end the local loop, do the local reduce
|
||||
if self.group_for_reduce:
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
||||
self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if?
|
||||
if self.opts.has_local:
|
||||
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
|
||||
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
||||
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
|
||||
if_gate = self.uop(UOps.IF, None, (if_cond,), cachable=False)
|
||||
|
||||
# create new late reduce local loops and replace local_idxs that have been used
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
|
||||
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
|
||||
|
||||
# if any group_for_reduce items aren't reduces, upcast them here
|
||||
for j in self.upcast_in_mid_reduce_axes:
|
||||
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
||||
self.upcast()
|
||||
self.group_for_reduce.pop()
|
||||
local_idxs = local_idxs[:-1]
|
||||
end_local_idxs = end_local_idxs[:-1]
|
||||
# regenerate upcast_idxs
|
||||
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
||||
|
||||
# NOTE: this structure is the same as the reduce op above
|
||||
|
||||
# define late accumulator
|
||||
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
|
||||
# late reduce loop
|
||||
render_loop(end_local_idxs)
|
||||
|
||||
# load localbufs
|
||||
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
|
||||
|
||||
# there's no AST here (and there's no shape for the reduce LazyOp)
|
||||
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True) # type: ignore
|
||||
|
||||
# end the late reduce loop
|
||||
end_loop(end_local_idxs)
|
||||
self.load_cache.clear()
|
||||
|
||||
# load latebufs
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
|
||||
|
||||
# run late AST
|
||||
val = self.ast_parse(self.ast, acc, None, loaded_buffers)
|
||||
|
||||
# store
|
||||
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
||||
|
||||
# end the global (and maybe local) loop
|
||||
if if_gate: self.uop(UOps.END, None, (if_gate,))
|
||||
end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs)
|
||||
|
||||
# (recursively) remove childless uops
|
||||
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.WMMA, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
||||
while 1:
|
||||
has_child: Set[UOp] = set()
|
||||
for ru in self.uops:
|
||||
for vu in ru.vin:
|
||||
has_child.add(vu)
|
||||
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS]
|
||||
if len(nu) == len(self.uops): break
|
||||
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
||||
self.uops = nu
|
||||
|
||||
# restore backups
|
||||
self.sts, self.group_for_reduce, self.upcasted = sts_backup, gfr_backup, upc_backup
|
||||
|
||||
# set cache and return
|
||||
self.applied_opts_cache = self.applied_opts[:]
|
||||
return self
|
||||
|
||||
def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True) -> UOp:
|
||||
key = (uop, dtype, vin, arg)
|
||||
if uop == UOps.PHI and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self phi is noop
|
||||
if uop == UOps.CAST and all(x.uop == UOps.GEP for x in vin) and all_same([x.vin[0] for x in vin]) and all(x.arg == i for i,x in enumerate(vin)): return vin[0].vin[0]
|
||||
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype)
|
||||
if uop == UOps.ALU:
|
||||
# rewrites. NOTE: the rewritten NEG op is still around...
|
||||
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable)
|
||||
# constant folding
|
||||
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype)
|
||||
# zero folding
|
||||
for x in [0,1]:
|
||||
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
||||
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
|
||||
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
|
||||
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
|
||||
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
||||
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
|
||||
self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops)))
|
||||
if DEBUG >= 5: print(self.uops[-1])
|
||||
if cachable: self.saved_exprs[key] = self.uops[-1]
|
||||
return self.uops[-1]
|
||||
|
||||
def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False) -> List[UOp]:
|
||||
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
|
||||
if x.op in BufferOps: return loaded_buffers[x.arg]
|
||||
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, offs, loaded_buffers) # cast isn't an ALU op
|
||||
if x.op in ReduceOps and not do_reduce:
|
||||
assert offs is None, "not available if we aren't doing reduce"
|
||||
return acc
|
||||
# MULACC fusion. TODO: this is copied from Interpreted
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL:
|
||||
x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
|
||||
x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
|
||||
values = [self.ast_parse(v, acc, offs, loaded_buffers) for v in x.src]
|
||||
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
||||
if x.op in ops:
|
||||
ret = []
|
||||
for idx, val, off in zip([[i] for i in range(len(values[0]))], zip(*values), offs):
|
||||
new_val = self.uop(UOps.ALU, dtypes.float32, val+(acc[off],), ops[x.op])
|
||||
# NOTE: we could apply the phi node to only the last change, but this breaks CLANG with nested max(x,y)
|
||||
acc[off] = self.uop(UOps.PHI, dtypes.float32, (acc[off], new_val))
|
||||
ret.append((idx, acc[off]))
|
||||
else:
|
||||
ret = [(idx, self.uop(UOps.ALU, dtypes.float32, val, x.op)) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values))]
|
||||
ordered_ret: List[Optional[UOp]] = [None]*len(values[0])
|
||||
# scatter
|
||||
for i,j in ret:
|
||||
for k in i:
|
||||
ordered_ret[k] = j
|
||||
assert all(isinstance(x, UOp) for x in ordered_ret), "some tokens didn't get scattered?"
|
||||
return cast(List[UOp], ordered_ret)
|
||||
204
tinygrad_repo/tinygrad/features/image.py
Normal file
204
tinygrad_repo/tinygrad/features/image.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from typing import List, Tuple, Dict, Any
|
||||
from tinygrad.helpers import ImageDType, prod, IMAGE, getenv, dtypes, DEBUG, flatten
|
||||
|
||||
# *** image Tensor function replacements ***
|
||||
|
||||
from tinygrad.lazy import get_single_root
|
||||
|
||||
def image_dot(self, w):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
||||
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
|
||||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = self.shape[0:-2] + (cout,-1)
|
||||
if len(self.shape) > 1:
|
||||
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
|
||||
else:
|
||||
order, out_shape_t = (0,), (cout, )
|
||||
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
|
||||
|
||||
# NOTE: with NHWC we can remove the transposes
|
||||
# bs x groups*cin x H x W
|
||||
cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
|
||||
# groups*cout x cin x H, W
|
||||
cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
|
||||
return image_conv2d(cx, cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
|
||||
|
||||
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
|
||||
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
||||
|
||||
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
||||
rcout = cout//groups
|
||||
x, w = self, weight.reshape(groups, rcout, cin, H, W)
|
||||
|
||||
# hack for non multiples of 4 on cin
|
||||
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
|
||||
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
|
||||
added_input_channels = 4 - (cin % 4)
|
||||
w = w.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(w.shape))))
|
||||
x = x.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(x.shape))))
|
||||
cin = cin + added_input_channels
|
||||
x = x.reshape(bs, groups*cin, iy, ix)
|
||||
|
||||
# hack for non multiples of 4 on rcout
|
||||
added_output_channels = 0
|
||||
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
|
||||
added_output_channels = 4 - (rcout % 4)
|
||||
rcout += added_output_channels
|
||||
cout = groups * rcout
|
||||
w = w.slice(tuple((0, rcout) if i == 1 else (0, s) for i,s in enumerate(w.shape)))
|
||||
|
||||
# packed (note: flipping bs and iy would make the auto-padding work)
|
||||
x = x.permute(0,2,3,1)
|
||||
cin_last = iy == 1 and ix == 1
|
||||
if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
|
||||
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
|
||||
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
|
||||
|
||||
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
||||
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
|
||||
x, w = x.contiguous(), w.contiguous()
|
||||
if getenv("PREREALIZE", 1) and get_single_root(w.lazydata).realized: w.realize()
|
||||
|
||||
# expand out
|
||||
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
|
||||
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
|
||||
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
||||
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
||||
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
||||
|
||||
# padding
|
||||
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
|
||||
x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
|
||||
|
||||
# prepare input
|
||||
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
||||
oy, ox = x.shape[4:6]
|
||||
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
|
||||
x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W)
|
||||
|
||||
# prepare weights
|
||||
w = w.permute(0,4,2,5,1,3)
|
||||
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)).expand(x.shape)
|
||||
|
||||
# the conv! (+ the bias)
|
||||
ret = x*w
|
||||
if IMAGE >= 2: ret = ret.cast(base_image_type((bs*oy, ox*cout//4, 4)))
|
||||
ret = ret.sum((-4, -3, -2, -1))
|
||||
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
if added_output_channels != 0:
|
||||
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
|
||||
rcout -= added_output_channels
|
||||
cout = groups * rcout
|
||||
|
||||
# NCHW output
|
||||
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
||||
|
||||
# *** schedules with images need to be fixed to be valid ***
|
||||
|
||||
import dataclasses
|
||||
from tinygrad.ops import ScheduleItem, BufferOps, LazyOp, UnaryOps, LoadOps, MemBuffer, get_lazyop_info
|
||||
|
||||
def fix_schedule_for_images(schedule:List[ScheduleItem]):
|
||||
# this is the fundamental fix, find unwritable or unreadable images and convert them to normal float32 (TODO: should it be float16?)
|
||||
replace_inputs = {}
|
||||
for i, si in enumerate(schedule):
|
||||
if isinstance(si.out.dtype, ImageDType) and (prod(si.out.shape) != prod(si.out.dtype.shape) or not any(si.out.shape[x]%4 == 0 for x in si.out.st.unit_stride_axes())):
|
||||
if DEBUG >= 1: print(f"{i:3d}: rewrite output, output shape {prod(si.out.shape)}, image dtype {si.out.dtype} prod {prod(si.out.dtype.shape)}")
|
||||
si.out.dtype = dtypes.float32
|
||||
for b in si.ast.get_lazyops():
|
||||
if b.op != BufferOps.MEM: continue
|
||||
# TODO: unit_stride axes will fail if there's a mask, even if the mask is divisble by four. this is too aggressive
|
||||
if isinstance(si.inputs[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())):
|
||||
if DEBUG >= 1: print(f"{i:3d}: rewrite input, image dtype {si.inputs[b.arg.idx-1].dtype}, {b.arg.st.views}")
|
||||
if si.inputs[b.arg.idx-1].realized:
|
||||
# have to copy it
|
||||
replace_inputs[si.inputs[b.arg.idx-1]] = si.inputs[b.arg.idx-1].cast(dtypes.float32)
|
||||
else:
|
||||
# change it before it's created
|
||||
si.inputs[b.arg.idx-1].dtype = dtypes.float32
|
||||
|
||||
# now fix up the schedule to reflect the new dtypes
|
||||
fixed_schedule:List[ScheduleItem] = []
|
||||
for i,si in enumerate(schedule):
|
||||
ast = si.ast
|
||||
inputs = si.inputs
|
||||
|
||||
# replace inputs with casted versions
|
||||
if any(x in replace_inputs for x in inputs):
|
||||
fixed_schedule += flatten([replace_inputs[x].schedule() for x in inputs if x in replace_inputs])
|
||||
inputs = tuple(replace_inputs.get(x, x) for x in inputs)
|
||||
|
||||
# fix input dtypes to match what they actually are
|
||||
replacements = {}
|
||||
for b in si.ast.get_lazyops():
|
||||
if b.op != BufferOps.MEM: continue
|
||||
if b.arg.dtype != inputs[b.arg.idx-1].dtype:
|
||||
replacements[b] = LazyOp(BufferOps.MEM, (), MemBuffer(b.arg.idx, inputs[b.arg.idx-1].dtype, b.arg.st))
|
||||
if replacements: ast = ast.map_buffers(replacements)
|
||||
|
||||
# fix the ops to create the output dtype
|
||||
if ast.op not in LoadOps:
|
||||
info = get_lazyop_info(ast)
|
||||
if info.dtype != si.out.dtype:
|
||||
if DEBUG >= 3: print(f"{i:3d}: info.dtype {info.dtype} != {si.out.dtype} -> {si.out.dtype}")
|
||||
ast = LazyOp(UnaryOps.CAST, (ast,), (si.out.dtype, False))
|
||||
|
||||
# put this in the fixed schedule
|
||||
fixed_schedule.append(dataclasses.replace(si, ast=ast, inputs=inputs))
|
||||
return fixed_schedule
|
||||
|
||||
# *** images have weird indexing requirements ***
|
||||
|
||||
from tinygrad.shape.symbolic import Node, AndNode, Variable, NumNode, SumNode, LtNode
|
||||
|
||||
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
|
||||
idx = (idxy // 4) % base_shape[1]
|
||||
idy = (idxy // (4 * base_shape[1]))
|
||||
|
||||
if valid.min == 0 and isinstance(idxy, SumNode):
|
||||
nodes = valid.nodes if isinstance(valid, AndNode) else [valid]
|
||||
val_dict: Dict[Node, Any] = {}
|
||||
idxy_flat_var = [(i, i.vars()[0]) for i in idxy.flat_components if not isinstance(i, NumNode)]
|
||||
|
||||
for node in nodes:
|
||||
assert isinstance(node, LtNode)
|
||||
node_flat, node_vars = node.a.flat_components if isinstance(node.a, SumNode) else [node.a], node.vars()
|
||||
same_sym = [i for (i, var) in idxy_flat_var if var in node_vars]
|
||||
if len(same_sym) == 0: continue
|
||||
first, second = sorted(same_sym)[0], sorted(node_flat)[0]
|
||||
f_b = 1 if isinstance(first, Variable) else first.b
|
||||
s_b = 1 if isinstance(second, Variable) else second.b
|
||||
sig = -1 if s_b < 0 else 1
|
||||
key_node = sig*node.a
|
||||
if key_node not in val_dict: val_dict[key_node] = [key_node.min, key_node.max, abs(f_b//s_b)]
|
||||
val_dict[key_node][(sig + 1)//2] = sig*(node.b - 1)
|
||||
|
||||
fakes = {}
|
||||
for cnt, (key_node, (mnn, mxn, multip)) in enumerate(val_dict.items()):
|
||||
fake_var = Variable("fake_" + str(cnt), mnn, mxn)
|
||||
fakes[fake_var] = key_node
|
||||
idxy += multip*(fake_var - key_node)
|
||||
|
||||
idx = (idxy // 4) % base_shape[1]
|
||||
idy = (idxy // (4 * base_shape[1]))
|
||||
|
||||
fake_rep = {fake: node for fake, node in fakes.items()}
|
||||
|
||||
idx = idx.substitute(fake_rep)
|
||||
idy = idy.substitute(fake_rep)
|
||||
|
||||
idy_vars, idx_vars, ones = set(idy.vars()), set(idx.vars()), []
|
||||
for node in nodes:
|
||||
node_vars = set(node.vars())
|
||||
if not node_vars & (idx_vars | idy_vars): continue #There is simplified NumNode which can not go outside the bounds
|
||||
# NOTE: Why does only idy is problematic? and not the idx
|
||||
if idy_vars == node_vars or idy_vars & node_vars == set(): ones.append(node)
|
||||
valid = Variable.ands([i for i in nodes if i not in ones])
|
||||
|
||||
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
|
||||
return (idx, idy), valid
|
||||
151
tinygrad_repo/tinygrad/features/search.py
Normal file
151
tinygrad_repo/tinygrad/features/search.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||
import itertools, random
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.ops import Device, Compiled, MemBuffer
|
||||
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from collections import defaultdict
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7]] for axis in range(6)])
|
||||
actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)])
|
||||
actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)])
|
||||
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
|
||||
actions += [
|
||||
Opt(op=OptOps.LOCAL, axis=0, amt=32),
|
||||
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
|
||||
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
|
||||
Opt(op=OptOps.NOLOCALS),
|
||||
]
|
||||
|
||||
# returns time in seconds
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
|
||||
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size}
|
||||
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
|
||||
try:
|
||||
lin.linearize()
|
||||
prg = cast(Compiled, Device[Device.DEFAULT]).to_program(lin)
|
||||
real_global_size = prg.global_size
|
||||
if allow_test_size and prg.global_size:
|
||||
test_global_size = prg.global_size[:]
|
||||
while prod(test_global_size) > max_global_size:
|
||||
for j in range(2,-1,-1):
|
||||
if test_global_size[j] > 16:
|
||||
test_global_size[j] //= 2
|
||||
break
|
||||
factor = prod(prg.global_size) / prod(test_global_size)
|
||||
prg.global_size = test_global_size
|
||||
#print(real_global_size, test_global_size, factor)
|
||||
else:
|
||||
factor = 1
|
||||
# TODO: this is super broken for var_vals
|
||||
# TODO: this is copied from prg.__call__
|
||||
global_size, local_size = prg.launch_dims(var_vals)
|
||||
if global_size is not None and local_size is None:
|
||||
local_size = prg.optimize_local_size(global_size, rawbufs)
|
||||
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||
tms = []
|
||||
for _ in range(cnt):
|
||||
if clear_l2:
|
||||
# TODO: this is too small for many L2 caches
|
||||
with Context(DEBUG=0): Tensor.rand(1024,1024).realize()
|
||||
lra = prg.runtime_args.copy()
|
||||
if global_size: lra['global_size'] = global_size
|
||||
if local_size: lra['local_size'] = local_size
|
||||
tms.append(prg.clprg(*rawbufs, *var_vals.values(), **lra, wait=True)*factor)
|
||||
prg.global_size = real_global_size
|
||||
except Exception:
|
||||
if DEBUG >= 4:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print("FAILED")
|
||||
print(lin.ast)
|
||||
print(lin.applied_opts)
|
||||
tms = [float('inf')]
|
||||
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
|
||||
return min(tms)
|
||||
|
||||
# get (scrap) buffers for timing the linearizer
|
||||
def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
|
||||
bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
|
||||
for x in lin.membufs: bufsts[x.idx].append(x)
|
||||
rawbufs:List[Optional[RawBuffer]] = [None]*len(bufsts)
|
||||
for k,lx in bufsts.items():
|
||||
rawbufs[k] = cast(Compiled, Device[Device.DEFAULT]).buffer(prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype)
|
||||
assert all(r is not None for r in rawbufs)
|
||||
return cast(List[RawBuffer], rawbufs)
|
||||
|
||||
# get dictionary of all possible actions
|
||||
def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
|
||||
acted_lins = {0:lin} if include_0 else {}
|
||||
for i,a in enumerate(actions):
|
||||
if a.axis is not None and a.axis >= lin.shape_len: continue
|
||||
if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
|
||||
lin2 = lin.copy()
|
||||
try:
|
||||
lin2.apply_opt(a)
|
||||
up, lcl = 1, 1
|
||||
for s,c in zip(lin2.full_shape, lin2.colors()):
|
||||
if c in {"magenta", "yellow"}: up *= s
|
||||
if c in {"cyan", "green", "white"}: lcl *= s
|
||||
if up > 256 or lcl > 256: continue
|
||||
acted_lins[i+1] = lin2
|
||||
except Exception:
|
||||
pass
|
||||
return acted_lins
|
||||
|
||||
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
|
||||
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size}
|
||||
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
|
||||
ret = lin.copy()
|
||||
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
||||
return ret
|
||||
|
||||
# init the BEAM with the base linearizer
|
||||
beam: List[Tuple[Linearizer, float]] = [(lin, time_linearizer(lin, rawbufs, allow_test_size=allow_test_size))]
|
||||
|
||||
# NOTE: real uops use a weird compare method that's only valid inside a linearizer
|
||||
def tuplize_uops(uops): return tuple([(x.uop, x.dtype, tuple(x.num for x in x.vin), x.arg) for x in uops])
|
||||
seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)}
|
||||
|
||||
while 1:
|
||||
acted_lins = lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
|
||||
|
||||
# dedup with uops (TODO: double linearize not needed)
|
||||
acted_lins_dedup = []
|
||||
for lin in acted_lins:
|
||||
tuops = tuplize_uops(lin.linearize().uops)
|
||||
if tuops in seen_uops:
|
||||
#print(seen_uops[tuops], lin.applied_opts)
|
||||
continue
|
||||
seen_uops[tuops] = tuple(lin.applied_opts)
|
||||
acted_lins_dedup.append(lin)
|
||||
acted_lins = acted_lins_dedup
|
||||
|
||||
# time linearizers
|
||||
timed_lins: List[Tuple[Linearizer, float]] = [(v,time_linearizer(v,rawbufs,allow_test_size=allow_test_size)) for v in acted_lins]
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
if len(opts) == 0 or beam[0][1] <= opts[0][1]: break # we didn't get faster
|
||||
|
||||
# keep the BEAM best
|
||||
beam = opts[:amt]
|
||||
if DEBUG >= 2: print(f"{opts[0][1]*1e6:12.2f} us from {len(lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape())
|
||||
|
||||
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
||||
if DEBUG >= 3: print(beam[0][0].applied_opts)
|
||||
return beam[0][0]
|
||||
|
||||
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
|
||||
test_rawbuffers = [type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
|
||||
MAX_WORKGROUP = clprg.max_work_group_size() if hasattr(clprg, 'max_work_group_size') else 1024
|
||||
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
||||
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
||||
def try_exec(local_size):
|
||||
try:
|
||||
return clprg(*test_rawbuffers, global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True)
|
||||
except Exception:
|
||||
return float('inf')
|
||||
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
|
||||
117
tinygrad_repo/tinygrad/graph.py
Normal file
117
tinygrad_repo/tinygrad/graph.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import os, atexit, functools
|
||||
try:
|
||||
import networkx as nx # type: ignore
|
||||
except ImportError:
|
||||
nx = None # graph won't work
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup
|
||||
from tinygrad.codegen.linearizer import UOps
|
||||
|
||||
# **** debugging and graphing ****
|
||||
|
||||
G = nx.DiGraph() if nx is not None else None
|
||||
cnts: Dict[OpType, int] = defaultdict(int)
|
||||
if DEBUG >= 2:
|
||||
def print_globalcounters():
|
||||
if GlobalCounters.time_sum_s == 0: return
|
||||
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s",
|
||||
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms")
|
||||
atexit.register(print_globalcounters)
|
||||
if GRAPH:
|
||||
def save_graph_exit():
|
||||
for k,v in cnts.items(): print(k, v)
|
||||
print("saving", G)
|
||||
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
|
||||
# -Gnslimit=100 can make it finish, but you won't like results
|
||||
os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
|
||||
atexit.register(save_graph_exit)
|
||||
|
||||
node_count = 0
|
||||
def nm(x):
|
||||
global node_count
|
||||
if not hasattr(x, 'node_id'):
|
||||
setattr(x, 'node_id', node_count)
|
||||
node_count += 1
|
||||
return x.node_id
|
||||
|
||||
def get_sop(op: List[Op]):
|
||||
op = [x for x in op if x not in BufferOps]
|
||||
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
|
||||
if len(op) <= 6: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
|
||||
return str(len(op))
|
||||
|
||||
def str_dtype(dtyp):
|
||||
ret = str(dtyp)[7:]
|
||||
return "" if ret == 'float' else f"\n{ret}"
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def add_st_node(nmx, nmo, label, st):
|
||||
global node_count
|
||||
inter_node = node_count
|
||||
node_count += 1
|
||||
G.add_node(inter_node, style='filled', fillcolor="#80ff8080", color="black", label=f"{st.shape}\n{st.real_strides()}" + (f"\n{st.real_offset()}" if st.real_offset() != 0 else ""))
|
||||
G.add_edge(nmx, inter_node, color='#00000060')
|
||||
G.add_edge(inter_node, nmo, label=label, color='#00000060')
|
||||
|
||||
logops = open(getenv("LOGOPS", ""),"a") if getenv("LOGOPS", "") else None
|
||||
def log_schedule_item(si: ScheduleItem):
|
||||
if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n")
|
||||
show_graph = bool(GRAPH)
|
||||
if not DEBUG and not show_graph: return
|
||||
if si.ast.op == LoadOps.CONTIGUOUS: setattr(si.out, 'node_id', nm(si.inputs[0].base))
|
||||
if si.ast.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return
|
||||
|
||||
op: List[Op] = [x.op for x in si.ast.get_lazyops()]
|
||||
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps, BufferOps]
|
||||
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
|
||||
cnts[optype] += 1
|
||||
if show_graph:
|
||||
assert si.out.base == si.out, "all outputs based"
|
||||
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'}
|
||||
|
||||
# get inputs for shapetrackers
|
||||
input_to_st = defaultdict(list)
|
||||
for lo in si.ast.get_lazyops():
|
||||
if lo.op != BufferOps.MEM: continue
|
||||
input_to_st[si.inputs[lo.arg.idx-1]].append(lo.arg.st)
|
||||
|
||||
# add them to the graph, potentially with a movement op seperating them
|
||||
for x in input_to_st:
|
||||
for st in dedup(input_to_st[x]):
|
||||
if st.contiguous:
|
||||
G.add_edge(nm(x), nm(si.out), label=get_sop(op), color='#00000060')
|
||||
else:
|
||||
add_st_node(nm(x), nm(si.out), get_sop(op), st)
|
||||
if 'label' not in G.nodes[nm(x)]:
|
||||
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(si.out.dtype)
|
||||
|
||||
if nm(si.out) not in G.nodes: G.add_node(nm(si.out))
|
||||
|
||||
G.nodes[nm(si.out)]['label'] = (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps else "")
|
||||
G.nodes[nm(si.out)]['fillcolor'] = top_colors[optype]
|
||||
G.nodes[nm(si.out)]['color'] = 'black'
|
||||
G.nodes[nm(si.out)]['style'] = 'filled'
|
||||
|
||||
def _tree(lazydata, prefix=""):
|
||||
if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")
|
||||
if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
|
||||
lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
|
||||
childs = [_tree(c) for c in lazydata.src[:]]
|
||||
for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
|
||||
return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
|
||||
|
||||
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))]))
|
||||
|
||||
def graph_uops(uops):
|
||||
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
|
||||
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
|
||||
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0"}
|
||||
G = nx.DiGraph()
|
||||
for u in uops:
|
||||
G.add_node(u.num, label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff"))
|
||||
for v in u.vin: G.add_edge(v.num, u.num)
|
||||
GRAPHPATH = "/tmp/uops"
|
||||
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
|
||||
os.system(f'dot -Grankdir=LR -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
|
||||
208
tinygrad_repo/tinygrad/helpers.py
Normal file
208
tinygrad_repo/tinygrad/helpers.py
Normal file
@@ -0,0 +1,208 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
T = TypeVar("T")
|
||||
# NOTE: it returns int 1 if x is empty regardless of the type of x
|
||||
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.__mul__, x, 1)
|
||||
|
||||
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
||||
OSX = platform.system() == "Darwin"
|
||||
CI = os.getenv("CI", "") != ""
|
||||
|
||||
def dedup(x): return list(dict.fromkeys(x)) # retains list order
|
||||
def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
|
||||
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
def all_same(items): return all(x == items[0] for x in items)
|
||||
def all_int(t: Tuple[Any, ...]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
||||
def colored(st, color, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line
|
||||
def ansistrip(s): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
||||
def ansilen(s): return len(ansistrip(s))
|
||||
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
||||
def flatten(l:Union[List, Iterator]): return [item for sublist in l for item in sublist]
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
||||
def merge_dicts(ds:Iterable[Dict]) -> Dict:
|
||||
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
||||
return {k:v for d in ds for k,v in d.items()}
|
||||
def partition(lst, fxn):
|
||||
a: list[Any] = []
|
||||
b: list[Any] = []
|
||||
for s in lst: (a if fxn(s) else b).append(s)
|
||||
return a,b
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def getenv(key, default=0): return type(default)(os.getenv(key, default))
|
||||
|
||||
class Context(contextlib.ContextDecorator):
|
||||
stack: ClassVar[List[dict[str, int]]] = [{}]
|
||||
def __init__(self, **kwargs): self.kwargs = kwargs
|
||||
def __enter__(self):
|
||||
Context.stack[-1] = {k:o.value for k,o in ContextVar._cache.items()} # Store current state.
|
||||
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v # Update to new temporary state.
|
||||
Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later.
|
||||
def __exit__(self, *args):
|
||||
for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, ContextVar._cache[k].value)
|
||||
|
||||
class ContextVar:
|
||||
_cache: ClassVar[Dict[str, ContextVar]] = {}
|
||||
value: int
|
||||
def __new__(cls, key, default_value):
|
||||
if key in ContextVar._cache: return ContextVar._cache[key]
|
||||
instance = ContextVar._cache[key] = super().__new__(cls)
|
||||
instance.value = getenv(key, default_value)
|
||||
return instance
|
||||
def __bool__(self): return bool(self.value)
|
||||
def __ge__(self, x): return self.value >= x
|
||||
def __gt__(self, x): return self.value > x
|
||||
def __lt__(self, x): return self.value < x
|
||||
|
||||
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
||||
GRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
|
||||
|
||||
class Timing(contextlib.ContextDecorator):
|
||||
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
|
||||
def __enter__(self): self.st = time.perf_counter_ns()
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.et = time.perf_counter_ns() - self.st
|
||||
if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
|
||||
|
||||
# **** tinygrad now supports dtypes! *****
|
||||
|
||||
class DType(NamedTuple):
|
||||
priority: int # this determines when things get upcasted
|
||||
itemsize: int
|
||||
name: str
|
||||
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
|
||||
sz: int = 1
|
||||
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}"
|
||||
|
||||
# dependent typing?
|
||||
class ImageDType(DType):
|
||||
def __new__(cls, priority, itemsize, name, np, shape):
|
||||
return super().__new__(cls, priority, itemsize, name, np)
|
||||
def __init__(self, priority, itemsize, name, np, shape):
|
||||
self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
|
||||
super().__init__()
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
# TODO: fix this to not need these
|
||||
def __hash__(self): return hash((super().__hash__(), self.shape))
|
||||
def __eq__(self, x): return super().__eq__(x) and self.shape == x.shape
|
||||
def __ne__(self, x): return super().__ne__(x) or self.shape != x.shape
|
||||
|
||||
class PtrDType(DType):
|
||||
def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz)
|
||||
def __repr__(self): return f"ptr.{super().__repr__()}"
|
||||
|
||||
class dtypes:
|
||||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float2, dtypes._float4)
|
||||
@staticmethod
|
||||
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
|
||||
@staticmethod
|
||||
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
||||
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
|
||||
float16: Final[DType] = DType(0, 2, "half", np.float16)
|
||||
half = float16
|
||||
float32: Final[DType] = DType(4, 4, "float", np.float32)
|
||||
float = float32
|
||||
float64: Final[DType] = DType(0, 8, "double", np.float64)
|
||||
double = float64
|
||||
int8: Final[DType] = DType(0, 1, "char", np.int8)
|
||||
int16: Final[DType] = DType(1, 2, "short", np.int16)
|
||||
int32: Final[DType] = DType(2, 4, "int", np.int32)
|
||||
int64: Final[DType] = DType(3, 8, "long", np.int64)
|
||||
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
|
||||
uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16)
|
||||
uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32)
|
||||
uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64)
|
||||
|
||||
# NOTE: bfloat16 isn't supported in numpy
|
||||
bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
|
||||
|
||||
# NOTE: these are internal dtypes, should probably check for that
|
||||
_int2: Final[DType] = DType(2, 4*2, "int2", None, 2)
|
||||
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
||||
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
|
||||
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
||||
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
|
||||
|
||||
# NOTE: these are image dtypes
|
||||
@staticmethod
|
||||
def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp)
|
||||
@staticmethod
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp)
|
||||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
|
||||
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
|
||||
|
||||
class GlobalCounters:
|
||||
global_ops: ClassVar[int] = 0
|
||||
global_mem: ClassVar[int] = 0
|
||||
time_sum_s: ClassVar[float] = 0.0
|
||||
kernel_count: ClassVar[int] = 0
|
||||
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
|
||||
mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
|
||||
@staticmethod
|
||||
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
|
||||
|
||||
# *** universal database cache ***
|
||||
|
||||
CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache")
|
||||
CACHELEVEL = getenv("CACHELEVEL", 2)
|
||||
|
||||
VERSION = 6
|
||||
_db_connection = None
|
||||
def db_connection():
|
||||
global _db_connection
|
||||
if _db_connection is None:
|
||||
_db_connection = sqlite3.connect(CACHEDB)
|
||||
if DEBUG >= 5: _db_connection.set_trace_callback(print)
|
||||
if diskcache_get("meta", "version") != VERSION:
|
||||
print("cache is out of date, clearing it")
|
||||
os.unlink(CACHEDB)
|
||||
_db_connection = sqlite3.connect(CACHEDB)
|
||||
if DEBUG >= 5: _db_connection.set_trace_callback(print)
|
||||
diskcache_put("meta", "version", VERSION)
|
||||
return _db_connection
|
||||
|
||||
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
||||
if isinstance(key, (str,int)): key = {"key": key}
|
||||
try:
|
||||
res = db_connection().cursor().execute(f"SELECT val FROM {table} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
||||
except sqlite3.OperationalError:
|
||||
return None # table doesn't exist
|
||||
if (val:=res.fetchone()) is not None:
|
||||
return pickle.loads(val[0])
|
||||
return None
|
||||
|
||||
_db_tables = set()
|
||||
def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
|
||||
if isinstance(key, (str,int)): key = {"key": key}
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
if table not in _db_tables:
|
||||
TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
|
||||
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
||||
cur.execute(f"CREATE TABLE IF NOT EXISTS {table} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
||||
_db_tables.add(table)
|
||||
cur.execute(f"REPLACE INTO {table} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), ))
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return val
|
||||
|
||||
def diskcache(func):
|
||||
def wrapper(*args, **kwargs) -> bytes:
|
||||
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
|
||||
if (ret:=diskcache_get(table, key)): return ret
|
||||
return diskcache_put(table, key, func(*args, **kwargs))
|
||||
setattr(wrapper, "__wrapped__", func)
|
||||
return wrapper
|
||||
77
tinygrad_repo/tinygrad/jit.py
Normal file
77
tinygrad_repo/tinygrad/jit.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
|
||||
from collections import defaultdict
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts
|
||||
from tinygrad.ops import RawBuffer, Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]
|
||||
|
||||
class TinyJit:
|
||||
def __init__(self, fxn:Callable):
|
||||
self.fxn: Callable = fxn
|
||||
self.cnt: int = 0
|
||||
self.jit_cache: List[Tuple[Any, List[Optional[RawBuffer]], Dict[Variable, int]]] = []
|
||||
self.ret: Any = None
|
||||
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], ShapeTracker, DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_shapetracker, expected_type)
|
||||
self.updatable_entries: Dict[int, List[int]] = defaultdict(list) # (kernel_number) -> list(argument id). These are buffers from input + variables.
|
||||
|
||||
# add support for instance methods
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if Device.DEFAULT.split(":")[0] not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
|
||||
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
|
||||
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
|
||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
if self.cnt >= 2:
|
||||
try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"]
|
||||
except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor])
|
||||
if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key))
|
||||
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
|
||||
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"
|
||||
# NOTE: if we pass jit_ctx instead of using reshape to update the var_vals, we cannot compare the shapetracker directly
|
||||
if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].unbind() == expected_st, f"ShapeTracker mismatch in JIT, {input_rawbuffers[input_name][1].unbind()} != {expected_st}"
|
||||
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
|
||||
for j in self.updatable_entries.keys():
|
||||
for k in self.jit_cache[j][2].keys():
|
||||
try: self.jit_cache[j][2][k] = var_vals[k]
|
||||
except KeyError: pass
|
||||
for prg, pargs, variables in self.jit_cache: prg(pargs, variables, jit=True)
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
|
||||
elif self.cnt == 1:
|
||||
CacheCollector.start()
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
self.jit_cache = CacheCollector.finish()
|
||||
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
|
||||
# get the inputs for replacement
|
||||
for j_,cache in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]]
|
||||
for i,a in enumerate(cache[1]):
|
||||
if a in [v[0] for v in input_rawbuffers.values()]:
|
||||
self.input_replace[(j_,i)] = [(k, v[1].unbind(), v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
|
||||
self.updatable_entries[j_].append(i)
|
||||
for i in range(len(cache[2])): self.updatable_entries[j_].append(len(cache[1])+i)
|
||||
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
|
||||
elif self.cnt == 0:
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
self.cnt += 1
|
||||
return self.ret
|
||||
|
||||
class _CacheCollector:
|
||||
def __init__(self): self.cache: Optional[List[Tuple[Callable, List[Any], Dict[Any,Any]]]] = None
|
||||
def start(self): self.cache = []
|
||||
def add(self, prg, rawbufs, var_vals):
|
||||
if self.cache is None: return
|
||||
self.cache.append((prg, rawbufs, var_vals))
|
||||
def finish(self):
|
||||
if self.cache is None: return []
|
||||
ret = self.cache
|
||||
self.cache = None
|
||||
return ret
|
||||
CacheCollector = _CacheCollector()
|
||||
347
tinygrad_repo/tinygrad/lazy.py
Normal file
347
tinygrad_repo/tinygrad/lazy.py
Normal file
@@ -0,0 +1,347 @@
|
||||
from __future__ import annotations
|
||||
import sys, operator, math, functools
|
||||
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping
|
||||
from weakref import ref, WeakSet, WeakValueDictionary
|
||||
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts, all_int
|
||||
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
|
||||
|
||||
# lazy can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
OPT = getenv("OPT", 2)
|
||||
LAZYCACHE = getenv("LAZYCACHE", 1)
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
|
||||
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2
|
||||
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
|
||||
PUSH_RESHAPES = OPT>=4
|
||||
|
||||
# **** ast fixing functions ****
|
||||
|
||||
def _ast_reduceops(op:LazyOp) -> LazyOp:
|
||||
# TODO: this can also corealize a binary op after the reduce, not just before
|
||||
src = op.src[0]
|
||||
if not src.realized:
|
||||
assert isinstance(src.op, LazyOp), "if not src.realized, then src.op must be a LazyOp"
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1: src = src.op
|
||||
return LazyOp(op.op, (src,), op.arg)
|
||||
|
||||
# this supports late merging an upstream Reduce op and even an Elementwise op above that
|
||||
def _ast_binaryops(op:LazyOp, shape: Tuple[sint, ...]) -> LazyOp:
|
||||
real_srcs: Dict[LazyBuffer, Optional[Union[LazyOp, LazyBuffer]]] = {x:None for x in op.buffers}
|
||||
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
|
||||
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
|
||||
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
intermediate_shape: Tuple[sint, ...] = shape
|
||||
if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and psrcs:
|
||||
psrc = psrcs[0] # NOTE: right now we can't handle multiple, as we'd have to check for loop
|
||||
if psrc[1].optype == ReduceOps:
|
||||
top = _ast_reduceops(psrc[1].op)
|
||||
real_srcs[psrc[0]] = top
|
||||
real_srcs.update({x:x for x in top.buffers}) # the reduce op buffers are not modified
|
||||
|
||||
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
|
||||
if psrc[0].shape != psrc[1].shape:
|
||||
intermediate_shape = psrc[1].shape
|
||||
assert psrc[0].shape == shape, f"shape mismatch {psrc[0].shape} != {shape}"
|
||||
|
||||
# reshape all the late ops into the output shape
|
||||
# NOTE: these RESHAPEs will return self if they don't change the shape
|
||||
for x in real_srcs.keys():
|
||||
if real_srcs[x] is None: real_srcs[x] = x.reshape(intermediate_shape)
|
||||
# NOTE: cast the type to remove the Optional
|
||||
ast = op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer]], real_srcs))
|
||||
return LazyOp(MovementOps.RESHAPE, (ast, ), shape) if intermediate_shape != shape else ast
|
||||
|
||||
def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
|
||||
replacements:Dict[LazyBuffer, LazyOp] = {}
|
||||
base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()])
|
||||
for x in op.buffers:
|
||||
st = x.st.simplify().unbind()
|
||||
if x.base in base_bufs:
|
||||
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st))
|
||||
elif not x.realized and x.base.op.op == LoadOps.CONST:
|
||||
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st))
|
||||
else:
|
||||
raise NotImplementedError(f"not handled {x}")
|
||||
return (op.src[0] if op.op == MovementOps.RESHAPE else op).map_buffers(replacements), base_bufs
|
||||
|
||||
# **** lazy operations ****
|
||||
|
||||
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(LazyBuffer, root.op.src[0])) if getattr(root, 'op', None) and len(root.op.src) == 1 and isinstance(root.op.src[0], LazyBuffer) else root
|
||||
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
|
||||
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
|
||||
|
||||
def vars_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], []))
|
||||
|
||||
lazycache: WeakValueDictionary = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None):
|
||||
# fromcpu aren't cached
|
||||
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base)
|
||||
|
||||
# wop is the deduping key. i feel this used to compare more deeply
|
||||
wop = (device, dtype, optype, ref(op), ref(base) if base else None)
|
||||
if wop in lazycache:
|
||||
for x in op.buffers: x.children.add(lazycache[wop])
|
||||
return lazycache[wop]
|
||||
|
||||
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, base=base)
|
||||
return ret
|
||||
|
||||
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
|
||||
|
||||
class LazyBuffer:
|
||||
__deletable__ = ('op',)
|
||||
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None):
|
||||
self.st: ShapeTracker = st
|
||||
self.device, self.shape, self.optype, self._dtype = device, self.st.shape, optype, dtype
|
||||
self._realized: Optional[RawBuffer] = src
|
||||
self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized
|
||||
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
|
||||
self.children: WeakSet = WeakSet()
|
||||
self.views: WeakSet = WeakSet()
|
||||
# NOTE: op should be read only after construction of LazyBuffer. it is now with schedule
|
||||
if op is not None:
|
||||
self.op: LazyOp = op
|
||||
for x in op.buffers: x.children.add(self)
|
||||
assert optype != MovementOps or (base is not None and base.optype != MovementOps), "MovementOps must be based"
|
||||
self._base = base
|
||||
if base: base.views.add(self)
|
||||
else: assert st.contiguous, "unbased LazyBuffers must be contiguous"
|
||||
|
||||
@property
|
||||
def base(self): return self._base if self._base is not None else self
|
||||
|
||||
def is_unrealized_const(self): return not self.realized and self.base.op.op == LoadOps.CONST
|
||||
|
||||
@property
|
||||
def realized(self): return self.base._realized
|
||||
@realized.setter
|
||||
def realized(self, val):
|
||||
assert self._base is None, "no setting realized of based LazyBuffers"
|
||||
self._realized = val
|
||||
@property
|
||||
def dtype(self): return self.base._dtype
|
||||
@dtype.setter
|
||||
def dtype(self, val):
|
||||
assert self._base is None, "no setting dtype of based LazyBuffers"
|
||||
self._dtype = val
|
||||
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if hasattr(self, 'op') else self._realized} st={self.st}>"
|
||||
@property
|
||||
def key(self):
|
||||
if self.realized: return (self.dtype, self.realized.key, self.st)
|
||||
return (self.dtype, self.op.op, self.st)
|
||||
|
||||
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
|
||||
|
||||
@property
|
||||
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
|
||||
def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]): return real_srcs.get(self, self)
|
||||
def get_lazyops(self) -> List[LazyOp]: return []
|
||||
|
||||
# *** scheduling ***
|
||||
|
||||
def schedule(self, seen=None) -> List[ScheduleItem]:
|
||||
if seen is None: seen = set()
|
||||
if self in seen or self.realized or self.is_unrealized_const(): return []
|
||||
seen.add(self)
|
||||
if self.base != self: return self.base.schedule(seen)
|
||||
|
||||
# rewrite unbased CONTIGUOUS into UnaryOps.NOOP
|
||||
op = self.op if self.op.op != LoadOps.CONTIGUOUS else LazyOp(UnaryOps.NOOP, self.op.src)
|
||||
|
||||
if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape)
|
||||
elif self.optype is ReduceOps: op = _ast_reduceops(op)
|
||||
|
||||
# schedule the past
|
||||
ret = []
|
||||
for x in op.buffers: ret += x.schedule(seen)
|
||||
|
||||
var_vals = dict(sorted(merge_dicts([self.st.var_vals] + [buf.st.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
|
||||
# run the ast and log the op
|
||||
op, base_bufs = _replace_bufferops(op)
|
||||
return ret + [ScheduleItem(op, self, tuple(base_bufs), {k:var_vals[k] for k in vars_from_ast(op)})]
|
||||
|
||||
# *** creation/special ops ***
|
||||
|
||||
@staticmethod
|
||||
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
|
||||
|
||||
# create a constant with the shape and dtype of self
|
||||
def const(self, val:Union[float, int]) -> LazyBuffer:
|
||||
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
|
||||
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
|
||||
|
||||
def copy_to_device(self, device:str) -> LazyBuffer:
|
||||
# back off a FROM if it's a double FROM
|
||||
if not self.realized and self.op.op == LoadOps.FROM and cast(LazyBuffer, self.op.src[0]).device == device: return cast(LazyBuffer, self.op.src[0])
|
||||
return LazyBuffer.loadop(LoadOps.FROM, self.shape, self.dtype, device, src=self.contiguous())
|
||||
|
||||
def contiguous(self:LazyBuffer) -> LazyBuffer:
|
||||
if not self.realized and self.op.op in LoadOps and self.op.op != LoadOps.CONST: return self # all LoadOps are already contiguous (except CONST)
|
||||
if self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const():
|
||||
# this will turn into nothing, it's based and a copy
|
||||
# TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, base=self.base)
|
||||
# real contiguous, this will turn into a UnaryOps.NOOP
|
||||
return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x: np.ndarray) -> LazyBuffer:
|
||||
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
|
||||
|
||||
def cast(self, dtype:DType, bitcast:bool=False):
|
||||
return self.e(UnaryOps.CAST, arg=(dtype, bitcast))
|
||||
|
||||
# *** elementwise ops ***
|
||||
|
||||
def e(self:LazyBuffer, op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
||||
# srcs includes self
|
||||
srcs = (self,)+srcs
|
||||
|
||||
# if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops
|
||||
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
|
||||
|
||||
# get outputs now
|
||||
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(Tuple[DType, bool], arg)[0]
|
||||
|
||||
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
|
||||
if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
|
||||
new_srcs: List[LazyBuffer] = []
|
||||
for x in srcs:
|
||||
if not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1:
|
||||
x.op.src[0].children.discard(x)
|
||||
new_srcs.append(cast(LazyBuffer, x.op.src[0]))
|
||||
else:
|
||||
new_srcs.append(x)
|
||||
return new_srcs[0].e(op, *new_srcs[1:], arg=arg).contiguous()
|
||||
|
||||
if MERGE_ELEMENTWISE_OPS:
|
||||
# remove the buffers from any (childless) BinaryOps that feed into this
|
||||
_srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
|
||||
# TODO: needs general merge limiting
|
||||
if out_device != "WEBGPU" or len(dedup([x.base for _src in _srcs for x in _src.buffers if not x.is_unrealized_const()])) < 7: srcs = _srcs # type: ignore
|
||||
|
||||
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype)
|
||||
|
||||
# *** reduce ops ***
|
||||
|
||||
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
||||
unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype)
|
||||
|
||||
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if not all_int(self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
|
||||
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
|
||||
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
|
||||
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
|
||||
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
|
||||
|
||||
# *** movement ops ***
|
||||
|
||||
def _movement_op(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and not self.realized and self.optype == BinaryOps and not self.children:
|
||||
if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and (self.op.op in UnaryOps or PUSH_RESHAPES)):
|
||||
return self.op.replace_with_movement_ops([(op, arg)])
|
||||
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
|
||||
# MovementOps aren't stacked any more, they each have one parent, find the root
|
||||
root = get_movementroot(self)
|
||||
if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape):
|
||||
return root.reshape(st.shape)
|
||||
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base)
|
||||
|
||||
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.RESHAPE:
|
||||
assert isinstance(self.op.src[0], LazyBuffer)
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
return self.op.src[0].reshape(arg)
|
||||
return self._movement_op(self.st.reshape(arg), MovementOps.RESHAPE, arg)
|
||||
|
||||
def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
if all(b == 0 and e == 0 for b,e in arg): return self
|
||||
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self._movement_op(self.st.pad(arg), MovementOps.PAD, arg)
|
||||
|
||||
def expand(self: LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.EXPAND: return self.op.src[0].expand(arg)
|
||||
return self._movement_op(self.st.expand(arg), MovementOps.EXPAND, arg)
|
||||
|
||||
def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if arg == tuple(range(len(self.shape))): return self
|
||||
if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute(tuple([self.op.arg[i] for i in arg]))
|
||||
if SHUFFLE_MOVEMENT_OPS and not self.realized:
|
||||
if PUSH_PERMUTES and self.optype == ReduceOps:
|
||||
# reduceops have one buffer input, permute it
|
||||
narg = tuple([self.op.arg[a] for a in arg])
|
||||
src, rop = self.op.src[0], self.op.op
|
||||
src.children.discard(self)
|
||||
del self # TODO: why doesn't this delete remove it from the children
|
||||
return src.permute(arg).r(cast(ReduceOps, rop), narg)
|
||||
|
||||
# move permutes before expands (always, this is safe)
|
||||
if self.op.op == MovementOps.EXPAND:
|
||||
return self.op.src[0].permute(arg).expand(tuple([self.op.arg[a] for a in arg]))
|
||||
|
||||
# move permutes before reshapes if we can
|
||||
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
|
||||
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(self.st.permute(arg).shape)
|
||||
return self._movement_op(self.st.permute(arg), MovementOps.PERMUTE, arg)
|
||||
|
||||
def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
|
||||
if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self
|
||||
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self._movement_op(self.st.shrink(arg), MovementOps.SHRINK, arg)
|
||||
|
||||
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if all(a == 1 for a in arg): return self
|
||||
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
|
||||
return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg)
|
||||
|
||||
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
|
||||
y = self
|
||||
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
|
||||
return y
|
||||
|
||||
def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
||||
new_srcs = []
|
||||
for x in srcs:
|
||||
mops: List[Tuple[MovementOps, Any]] = []
|
||||
bx = x
|
||||
# backwalk all the movement ops. don't push PAD or EXPAND
|
||||
while not bx.realized and bx.optype is MovementOps and bx.op.op is not MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) and len(bx.children) <= 1:
|
||||
assert isinstance(bx.op.op, MovementOps)
|
||||
mops.append((bx.op.op, bx.op.arg))
|
||||
assert isinstance(bx.op.src[0], LazyBuffer)
|
||||
bx = bx.op.src[0]
|
||||
# NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0
|
||||
if mops and not bx.realized and bx.optype is BinaryOps and len(bx.children) <= 1 and (all(y[0] is not MovementOps.PAD for y in mops) or all(y.op not in UNSAFE_PAD_OPS for y in bx.op.get_lazyops())):
|
||||
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
|
||||
else:
|
||||
new_srcs.append(x)
|
||||
return tuple(new_srcs)
|
||||
|
||||
MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
|
||||
MovementOps.RESHAPE: LazyBuffer.reshape,
|
||||
MovementOps.EXPAND: LazyBuffer.expand,
|
||||
MovementOps.SHRINK: LazyBuffer.shrink,
|
||||
MovementOps.PERMUTE: LazyBuffer.permute,
|
||||
MovementOps.PAD: LazyBuffer.pad,
|
||||
MovementOps.STRIDE: LazyBuffer.stride,
|
||||
}
|
||||
211
tinygrad_repo/tinygrad/mlops.py
Normal file
211
tinygrad_repo/tinygrad/mlops.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import math
|
||||
from typing import Tuple, Optional, cast
|
||||
from tinygrad.helpers import argsort, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.symbolic import sint
|
||||
|
||||
class Contiguous(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
|
||||
|
||||
class ContiguousBackward(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
|
||||
|
||||
class Cast(Function):
|
||||
def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
|
||||
self.input_dtype, self.bitcast = x.dtype, bitcast
|
||||
return x.cast(dtype, bitcast)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.cast(self.input_dtype, self.bitcast)
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
class Zero(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.const(0)
|
||||
def backward(self, grad:LazyBuffer) -> LazyBuffer: return grad.const(0)
|
||||
|
||||
class Neg(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
|
||||
def backward(self, grad:LazyBuffer) -> LazyBuffer: return grad.e(UnaryOps.NEG)
|
||||
|
||||
class Sin(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.e(UnaryOps.SIN)
|
||||
|
||||
def backward(self, grad:LazyBuffer) -> LazyBuffer:
|
||||
return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad)
|
||||
|
||||
# NOTE: maximum(x, 0) behaves differently where x=0
|
||||
class Relu(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.e(BinaryOps.MAX, x.const(0))
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output)
|
||||
|
||||
class Log(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.e(BinaryOps.DIV, self.x)
|
||||
|
||||
class Exp(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return self.ret.e(BinaryOps.MUL, grad_output)
|
||||
|
||||
class Sqrt(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.e(UnaryOps.SQRT)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.e(BinaryOps.DIV, self.ret.e(BinaryOps.MUL, self.ret.const(2)))
|
||||
|
||||
# NOTE: the implicit derivative of sigmoid is not stable
|
||||
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
|
||||
# TODO: have the backend automatically find this
|
||||
class Sigmoid(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)))
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output)
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
class Less(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.CMPLT, y)
|
||||
|
||||
class Add(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.ADD, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output if self.needs_input_grad[0] else None, \
|
||||
grad_output if self.needs_input_grad[1] else None
|
||||
|
||||
class Sub(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.SUB, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output if self.needs_input_grad[0] else None, \
|
||||
grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None
|
||||
|
||||
class Mul(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y = x, y
|
||||
return x.e(BinaryOps.MUL, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
|
||||
self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
||||
|
||||
class Div(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y = x, y
|
||||
return x.e(BinaryOps.DIV, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
|
||||
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
|
||||
|
||||
# ************* ternary ops *************
|
||||
|
||||
class Where(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.e(TernaryOps.WHERE, y, z)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return None, \
|
||||
self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
|
||||
self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.r(ReduceOps.SUM, new_shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.expand(self.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
|
||||
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
|
||||
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
# NOTE: this is sum in reverse
|
||||
class Expand(Function):
|
||||
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.expand(shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.r(ReduceOps.SUM, self.input_shape)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.reshape(shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.reshape(self.input_shape)
|
||||
|
||||
class Permute(Function):
|
||||
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_order = order
|
||||
return x.permute(order)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.permute(argsort(self.input_order))
|
||||
|
||||
class Pad(Function):
|
||||
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
|
||||
return x.pad(arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.shrink(self.narg)
|
||||
|
||||
class Shrink(Function):
|
||||
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
|
||||
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
|
||||
return x.shrink(arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
assert all(isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg), "symbolic shrink does not support backward"
|
||||
# need this cast because mypy cannot narrow the type even with assert
|
||||
return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg))
|
||||
|
||||
class Flip(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
|
||||
return x.stride(self.arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.stride(self.arg)
|
||||
128
tinygrad_repo/tinygrad/nn/__init__.py
Normal file
128
tinygrad_repo/tinygrad/nn/__init__.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import math
|
||||
from typing import Optional, Union, Tuple
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod, all_int
|
||||
|
||||
class BatchNorm2d:
|
||||
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
||||
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
||||
|
||||
if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
|
||||
else: self.weight, self.bias = None, None
|
||||
|
||||
self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
|
||||
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
if Tensor.training:
|
||||
# This requires two full memory accesses to x
|
||||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
||||
batch_mean = x.mean(axis=(0,2,3))
|
||||
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
batch_var = (y*y).mean(axis=(0,2,3))
|
||||
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
||||
|
||||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
if self.track_running_stats:
|
||||
self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
|
||||
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape) - y.shape[1]) * batch_var.detach() )
|
||||
self.num_batches_tracked += 1
|
||||
else:
|
||||
batch_mean = self.running_mean
|
||||
# NOTE: this can be precomputed for static inference. we expand it here so it fuses
|
||||
batch_invstd = self.running_var.reshape(1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
|
||||
|
||||
return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
|
||||
|
||||
# TODO: these Conv lines are terrible
|
||||
def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias)
|
||||
|
||||
class Conv2d:
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
|
||||
self.weight = self.initialize_weight(out_channels, in_channels, groups)
|
||||
assert all_int(self.weight.shape), "does not support symbolic shape"
|
||||
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
|
||||
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
|
||||
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
|
||||
|
||||
class ConvTranspose2d(Conv2d):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
||||
self.output_padding = output_padding
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
|
||||
class Linear:
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
|
||||
# TODO: remove this once we can represent Tensor with int shape in typing
|
||||
assert isinstance(self.weight.shape[1], int), "does not support symbolic shape"
|
||||
bound = 1 / math.sqrt(self.weight.shape[1])
|
||||
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
return x.linear(self.weight.transpose(), self.bias)
|
||||
|
||||
class GroupNorm:
|
||||
def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
|
||||
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
|
||||
self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
|
||||
self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
# reshape for layernorm to work as group norm
|
||||
# subtract mean and divide stddev
|
||||
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
|
||||
|
||||
if self.weight is None or self.bias is None: return x
|
||||
# elementwise_affine on channels
|
||||
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
|
||||
|
||||
class InstanceNorm:
|
||||
def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
|
||||
self.num_features, self.eps = num_features, eps
|
||||
self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
|
||||
self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
|
||||
if self.weight is None or self.bias is None: return x
|
||||
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
|
||||
|
||||
class LayerNorm:
|
||||
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
|
||||
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
||||
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
|
||||
x = x.layernorm(eps=self.eps, axis=self.axis)
|
||||
if not self.elementwise_affine: return x
|
||||
return x * self.weight + self.bias
|
||||
|
||||
class LayerNorm2d(LayerNorm):
|
||||
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
class Embedding:
|
||||
def __init__(self, vocab_size:int, embed_size:int):
|
||||
self.vocab_size = vocab_size
|
||||
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
|
||||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if not hasattr(self, 'vocab_counter'): self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size)
|
||||
return (self.vocab_counter == idx.unsqueeze(2)).expand(*idx.shape, self.vocab_size) @ self.weight
|
||||
68
tinygrad_repo/tinygrad/nn/optim.py
Normal file
68
tinygrad_repo/tinygrad/nn/optim.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# sorted in order of increasing complexity
|
||||
from typing import List
|
||||
from tinygrad.helpers import dedup
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, params: List[Tensor], lr: float):
|
||||
# if it's None, but being put into an optimizer, set it to True
|
||||
for x in params:
|
||||
if x.requires_grad is None: x.requires_grad = True
|
||||
|
||||
self.params: List[Tensor] = dedup([x for x in params if x.requires_grad])
|
||||
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
||||
self.lr = Tensor([lr], requires_grad=False).contiguous()
|
||||
|
||||
def zero_grad(self):
|
||||
for param in self.params: param.grad = None
|
||||
|
||||
def realize(self, extra=None):
|
||||
# NOTE: in extra is too late for most of the params due to issues with assign
|
||||
Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers)
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
|
||||
super().__init__(params, lr)
|
||||
self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov
|
||||
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
|
||||
|
||||
# https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
|
||||
def step(self) -> None:
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
g = t.grad.realize() + self.wd * t.detach()
|
||||
if self.momentum:
|
||||
self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
t.assign(t.detach() - g * self.lr)
|
||||
self.realize(self.b)
|
||||
|
||||
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
|
||||
def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): return LAMB(params, lr, b1, b2, eps, wd, adam=True)
|
||||
def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
|
||||
|
||||
class LAMB(Optimizer):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False):
|
||||
super().__init__(params, lr)
|
||||
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], requires_grad=False).realize()
|
||||
self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
|
||||
def step(self) -> None:
|
||||
self.t.assign(self.t + 1).realize()
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
g = t.grad.realize()
|
||||
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize()
|
||||
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize()
|
||||
m_hat = self.m[i] / (1.0 - self.b1**self.t)
|
||||
v_hat = self.v[i] / (1.0 - self.b2**self.t)
|
||||
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
||||
if not self.adam:
|
||||
r1 = t.detach().square().sum().sqrt()
|
||||
r2 = up.square().sum().sqrt()
|
||||
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
|
||||
else:
|
||||
r = 1.0
|
||||
t.assign(t.detach() - self.lr * r * up)
|
||||
self.realize([self.t] + self.m + self.v)
|
||||
124
tinygrad_repo/tinygrad/nn/state.py
Normal file
124
tinygrad_repo/tinygrad/nn/state.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import os, json, pathlib, zipfile, pickle
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Union, List, Optional, Any, Tuple
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.ops import Device
|
||||
|
||||
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
|
||||
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
||||
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
|
||||
return (t, json_len, json.loads(t[8:8+json_len].numpy().tobytes()))
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
t, json_len, metadata = safe_load_metadata(fn)
|
||||
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"}
|
||||
|
||||
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
||||
headers, offset = {}, 0
|
||||
if metadata: headers['__metadata__'] = metadata
|
||||
for k,v in tensors.items():
|
||||
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
|
||||
offset += v.nbytes()
|
||||
j = json.dumps(headers, separators=(',', ':'))
|
||||
j += "\x20"*((8-len(j)%8)%8)
|
||||
pathlib.Path(fn).unlink(missing_ok=True)
|
||||
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
t[0:1].cast(dtypes.int64).assign([len(j)])
|
||||
t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8, device="cpu"))
|
||||
for k,v in safe_load(t).items(): v.assign(tensors[k])
|
||||
|
||||
# state dict
|
||||
|
||||
from collections import OrderedDict
|
||||
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
|
||||
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
|
||||
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
|
||||
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
|
||||
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
|
||||
state_dict = {}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
|
||||
elif isinstance(obj, dict):
|
||||
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
|
||||
return state_dict
|
||||
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
|
||||
|
||||
def load_state_dict(model, state_dict, strict=True, verbose=True):
|
||||
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
|
||||
model_state_dict = get_state_dict(model)
|
||||
if DEBUG >= 1 and len(state_dict) > len(model_state_dict): print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
|
||||
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
|
||||
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}")
|
||||
if k not in state_dict and not strict:
|
||||
if DEBUG >= 1: print(f"WARNING: not loading {k}")
|
||||
continue
|
||||
v.assign(state_dict[k].to(v.device)).realize()
|
||||
|
||||
# torch support!
|
||||
|
||||
def torch_load(fn:str):
|
||||
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
|
||||
offsets: Dict[str, int] = {}
|
||||
lens: Dict[str, int] = {}
|
||||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
|
||||
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
||||
lens[storage[2]] = storage[4] * storage[1].itemsize
|
||||
if storage[2] not in offsets: return None
|
||||
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
||||
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
|
||||
# convert bfloat16 -> float16 using LLVM for Llama 2
|
||||
# upstream LLaMA also does this conversion:
|
||||
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
|
||||
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
|
||||
if storage[1] == dtypes.bfloat16:
|
||||
ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half()
|
||||
#ret = ret.to("LLVM").half().to(Device.DEFAULT)
|
||||
|
||||
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
||||
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
|
||||
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
|
||||
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
|
||||
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
||||
if DEBUG >= 2: print(f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}")
|
||||
# TODO: find a nice way to support all shapetracker on disktensors
|
||||
ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes)
|
||||
|
||||
return ret.reshape(size)
|
||||
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
|
||||
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
|
||||
class Dummy: pass
|
||||
class TorchPickle(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
module_root = module.split(".")[0]
|
||||
if module_root not in whitelist:
|
||||
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
|
||||
return Dummy
|
||||
return intercept[name] if module_root == "torch" else super().find_class(module, name)
|
||||
def persistent_load(self, pid): return pid
|
||||
|
||||
if tuple(t[0:2].numpy()) == (0x50, 0x4b):
|
||||
myzip = zipfile.ZipFile(fn, 'r')
|
||||
base_name = myzip.namelist()[0].split('/', 1)[0]
|
||||
for n in myzip.namelist():
|
||||
if n.startswith(f'{base_name}/data/'):
|
||||
with myzip.open(n) as myfile:
|
||||
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||||
return TorchPickle(myfile).load()
|
||||
else:
|
||||
with open(fn, "rb") as f:
|
||||
pkl = TorchPickle(f)
|
||||
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), f.tell(), pkl.load(), pkl.load(), f.tell()
|
||||
for i in ids:
|
||||
offsets[i] = base_offset + 8
|
||||
base_offset += 8 + lens[i]
|
||||
f.seek(rwd)
|
||||
return TorchPickle(f).load()
|
||||
300
tinygrad_repo/tinygrad/ops.py
Normal file
300
tinygrad_repo/tinygrad/ops.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
import importlib, inspect, functools, pathlib
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
from dataclasses import dataclass
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
|
||||
# NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
|
||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702
|
||||
# Ops below this line are not allowed in ASTs
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
|
||||
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemBuffer:
|
||||
idx: int
|
||||
dtype: DType
|
||||
st: ShapeTracker
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstBuffer:
|
||||
val: Any
|
||||
dtype: DType
|
||||
st: ShapeTracker
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: LazyOp
|
||||
out: LazyBuffer
|
||||
inputs: Tuple[LazyBuffer, ...]
|
||||
var_vals: Dict[Variable, int]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LazyOp:
|
||||
op: Op
|
||||
src: Tuple[Union[LazyOp, LazyBuffer], ...]
|
||||
arg: Any = None
|
||||
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
|
||||
@property
|
||||
def buffers(self):
|
||||
buffers: Tuple[Union[LazyOp, LazyBuffer], ...] = ()
|
||||
try: # NOTE: the linearizer's key function maps the buffers to ints, and LOCAL_BUFFER is used. we don't care about buffers in these cases
|
||||
for x in self.src: buffers += x.buffers
|
||||
except AttributeError: buffers = ()
|
||||
return buffers
|
||||
|
||||
@property
|
||||
def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg))
|
||||
|
||||
def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) if y not in real_srcs else real_srcs[y] for y in self.src]), self.arg)
|
||||
def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()]
|
||||
|
||||
def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
|
||||
assert self.op in BinaryOps or self.op in UnaryOps or self.op in TernaryOps
|
||||
srcs = [z.replace_with_movement_ops(ops) for z in self.src]
|
||||
return srcs[0].e(self.op, *srcs[1:], arg=self.arg) # type: ignore
|
||||
|
||||
@property
|
||||
def st(self): raise NotImplementedError
|
||||
@property
|
||||
def realized(self): raise NotImplementedError
|
||||
@property
|
||||
def children(self): raise NotImplementedError
|
||||
|
||||
# movement ops
|
||||
def reshape(self, _): raise NotImplementedError
|
||||
def pad(self, _): raise NotImplementedError
|
||||
def expand(self, _): raise NotImplementedError
|
||||
def permute(self, _): raise NotImplementedError
|
||||
def shrink(self, _): raise NotImplementedError
|
||||
def stride(self, _): raise NotImplementedError
|
||||
|
||||
# **************** Device ****************
|
||||
|
||||
class _Device:
|
||||
def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
||||
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]:
|
||||
x = x.split(":")[0].upper()
|
||||
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
|
||||
@functools.cached_property
|
||||
def DEFAULT(self) -> str:
|
||||
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None)
|
||||
if device_from_env: return device_from_env
|
||||
for device in ["METAL", "CUDA", "GPU"]:
|
||||
try:
|
||||
if self[device]: return device
|
||||
except Exception: pass
|
||||
return "CPU"
|
||||
Device = _Device()
|
||||
|
||||
# **************** for Interpreted Buffers ****************
|
||||
|
||||
class Interpreted:
|
||||
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_underlying=None):
|
||||
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
|
||||
self.synchronize = lambda: None
|
||||
self.codegen = None
|
||||
self.method_cache: Dict[LazyOp, Callable] = {}
|
||||
|
||||
def interpret_ast(self:Interpreted, ast:LazyOp) -> Callable:
|
||||
tglob: Dict[str, Any] = {}
|
||||
lines: List[str] = []
|
||||
f = self.fxn_for_op
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def gstr(x:Any, nm=None) -> str:
|
||||
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
|
||||
tglob[ret] = x
|
||||
return ret
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _interpret_ast(ast:LazyOp) -> str:
|
||||
if TernaryOps.MULACC in f and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
||||
|
||||
if MovementOps.AS_STRIDED in f and ast.op in BufferOps:
|
||||
tmp = f"{gstr(f[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(f[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
|
||||
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(f[mop], mop)}({tmp}, {gstr(arg)})"
|
||||
else:
|
||||
inp = [_interpret_ast(src) for src in ast.src]
|
||||
tmp = f"{gstr(f[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
|
||||
|
||||
ret = f"a{len(lines)}"
|
||||
lines.append(f" {ret} = {tmp}")
|
||||
return ret
|
||||
|
||||
ret = _interpret_ast(ast)
|
||||
src = '\n'.join(['def run(inputs):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})" if self.from_underlying else f" return {ret}"])
|
||||
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
|
||||
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
|
||||
return tglob['run']
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, **kwargs):
|
||||
if ast not in self.method_cache: self.method_cache[ast] = self.interpret_ast(ast)
|
||||
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None)
|
||||
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op:
|
||||
ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.fxn_for_op[BufferOps.MEM](ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
|
||||
# TODO: is this used?
|
||||
if output is not None and output.output_buffer is not None:
|
||||
assert output.output_buffer.dtype == ret.dtype
|
||||
output.output_buffer._buf = ret._buf
|
||||
return output.output_buffer
|
||||
return ret
|
||||
|
||||
@dataclass
|
||||
class FlopCounter:
|
||||
shape: Tuple[int, ...]
|
||||
dtype: DType
|
||||
flops: int
|
||||
mem: Dict[int, int]
|
||||
@property
|
||||
def mem_estimate(self): return sum(self.mem.values()) + self.dtype.itemsize*prod(self.shape)
|
||||
def consume_flops(self):
|
||||
self.flops, ret = 0, self.flops
|
||||
return ret
|
||||
InterpretedFlopCounter = Interpreted(FlopCounter, {
|
||||
BufferOps.MEM: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}), BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
|
||||
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
|
||||
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST},
|
||||
**{op:lambda self,y: FlopCounter(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps},
|
||||
**{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps},
|
||||
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})})
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.exec_ast(ast)
|
||||
|
||||
# **************** for Compiled Buffers ****************
|
||||
|
||||
class ASTRunner:
|
||||
def __init__(self, name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
|
||||
if DEBUG >= 4: print(prg)
|
||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
|
||||
|
||||
def build(self, compiler, runtime):
|
||||
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
|
||||
self.clprg = runtime(self.name, self.lib)
|
||||
return self
|
||||
|
||||
def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]:
|
||||
from tinygrad.jit import CacheCollector
|
||||
CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
|
||||
return self(rawbufs, var_vals, force_wait=force_wait)
|
||||
|
||||
def launch_dims(self, var_vals):
|
||||
global_size = ([sym_infer(sz, var_vals) for sz in self.global_size] + [1]*(3-len(self.global_size))) if self.global_size is not None else self.global_size
|
||||
local_size = ([sym_infer(sz, var_vals) for sz in self.local_size] + [1]*(3-len(self.local_size))) if self.local_size is not None else self.local_size
|
||||
return global_size, local_size
|
||||
|
||||
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
|
||||
if var_vals is None: var_vals = {}
|
||||
global_size, local_size = self.launch_dims(var_vals)
|
||||
if global_size is not None and local_size is None:
|
||||
# TODO: this is copied from get_program
|
||||
from tinygrad.features.search import optimize_local_size
|
||||
local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
|
||||
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||
lra = self.runtime_args.copy()
|
||||
if global_size: lra['global_size'] = global_size
|
||||
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
|
||||
if et := self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
mem_estimate = sym_infer(self.mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(37-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += op_estimate
|
||||
GlobalCounters.global_mem += mem_estimate
|
||||
return et
|
||||
|
||||
class Compiled:
|
||||
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None):
|
||||
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize
|
||||
self.method_cache: Dict[LazyOp, ASTRunner] = {}
|
||||
|
||||
def to_program(self, k):
|
||||
k.linearize()
|
||||
src, runtime_args = self.renderer(k.function_name, k.uops)
|
||||
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
|
||||
op_estimate=k.info.flops, mem_estimate=k.info.mem_estimate,
|
||||
display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime)
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
|
||||
# check if we can reuse the output buffer
|
||||
# if it's aliased, don't use it
|
||||
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
|
||||
output.realized = output.output_buffer
|
||||
if output.realized:
|
||||
for i,a in enumerate(inputs):
|
||||
# TODO: if this is contiguous it's fine
|
||||
if a.realized == output.realized:
|
||||
if any(not x.arg.st.contiguous for x in ast.get_lazyops() if x.op == BufferOps.MEM and x.arg.idx == i+1):
|
||||
output.realized = None
|
||||
break
|
||||
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
if not output.realized:
|
||||
output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
|
||||
|
||||
# all the rawbuffers
|
||||
rawbuffers = [output.realized] + [x.realized for x in inputs]
|
||||
|
||||
# extract real vars used in ast
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
ast_vars = vars_from_ast(ast)
|
||||
assert all(v.val is None for v in ast_vars), f"ast contains bound Variable {ast_vars}"
|
||||
|
||||
# compilation time
|
||||
def get_program():
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
k = Linearizer(ast, self.linearizer_opts)
|
||||
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
|
||||
if not NOOPT:
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
if BEAM >= 1 and not vars_from_ast(ast):
|
||||
lins = [(("tc" if used_tensor_cores else "hc"), k)]
|
||||
# allocate a scratch buffer if output buffer is also input
|
||||
test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
|
||||
kb = Linearizer(ast, self.linearizer_opts)
|
||||
kb.required_optimizations()
|
||||
from tinygrad.features.search import beam_search, time_linearizer
|
||||
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
|
||||
if used_tensor_cores:
|
||||
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
k = timed[0][1]
|
||||
else:
|
||||
k.required_optimizations()
|
||||
return self.to_program(k)
|
||||
|
||||
if getenv("ENABLE_METHOD_CACHE", 1):
|
||||
if ast not in self.method_cache: self.method_cache[ast] = get_program()
|
||||
prg = self.method_cache[ast]
|
||||
else:
|
||||
prg = get_program()
|
||||
|
||||
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
|
||||
|
||||
prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in ast_vars})
|
||||
return output.realized
|
||||
74
tinygrad_repo/tinygrad/realize.py
Normal file
74
tinygrad_repo/tinygrad/realize.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import List, cast, Dict, Callable
|
||||
import numpy as np
|
||||
from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, Device, BufferOps
|
||||
from tinygrad.graph import log_schedule_item, print_tree
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE
|
||||
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawBufferTransfer
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from tinygrad.features.image import fix_schedule_for_images
|
||||
|
||||
def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
|
||||
# HACK: images can be not usable due to shape
|
||||
if IMAGE >= 2: schedule = fix_schedule_for_images(schedule)
|
||||
|
||||
# NOTE: if you for loop the schedule it's slow because nothing frees
|
||||
while len(schedule):
|
||||
si = schedule.pop(0)
|
||||
if not disable_logging: log_schedule_item(si)
|
||||
assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized"
|
||||
if DEBUG >= 3: print_tree(si.ast)
|
||||
if si.ast.op in LoadOps:
|
||||
# confirm the LoadOps are contiguous and in order
|
||||
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
|
||||
else:
|
||||
si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args())
|
||||
del si.out.op
|
||||
for v in si.out.views: del v.op
|
||||
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
|
||||
assert si.out.realized.dtype == si.out.dtype, "realized dtype is incorrect"
|
||||
|
||||
# *** zero op LoadOps ***
|
||||
|
||||
def _realize_empty(buffer: LazyBuffer) -> None:
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
if DEBUG >= 2: print(f"*** empty {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
|
||||
def _realize_rand(buffer: LazyBuffer) -> None:
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
if DEBUG >= 2: print(f"*** rand {buffer.device} seed {buffer.op.arg:<10d} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
|
||||
rng = np.random.default_rng(buffer.op.arg)
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args())
|
||||
|
||||
# *** one op LoadOps ***
|
||||
|
||||
def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None:
|
||||
assert src.realized.size == buffer.st.size(), f"size mismatch on FROM {src.realized.size} != {buffer.st.size()}"
|
||||
assert src.st.contiguous and buffer.st.contiguous, "all must be contiguous for from"
|
||||
if DEBUG >= 2: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size:<16d} shape {str(buffer.shape):23s} dtype {src.realized.dtype}")
|
||||
# TODO: make this generic
|
||||
if isinstance(src.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
src.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer())
|
||||
elif isinstance(src.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and getenv("P2P", 0) >= 1:
|
||||
buffer.realized = cast(RawBufferTransfer, Device[buffer.device].buffer).transfer(src.realized, buffer.shape, buffer.dtype, **buffer._device_extra_args())
|
||||
else:
|
||||
# TODO: schedule this as FROM to go to CPU, and a FROM to go to device
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(src.realized.toCPU(), **buffer._device_extra_args())
|
||||
|
||||
# *** n op LoadOps ***
|
||||
|
||||
def _realize_custom(buffer: LazyBuffer, *inputs: LazyBuffer) -> None:
|
||||
if DEBUG >= 2: print(f"*** custom {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
|
||||
buffer.realized = buffer.op.arg(buffer, *inputs)
|
||||
|
||||
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
|
||||
LoadOps.EMPTY: _realize_empty,
|
||||
LoadOps.RAND: _realize_rand,
|
||||
LoadOps.FROM: _realize_from,
|
||||
LoadOps.CUSTOM: _realize_custom,
|
||||
}
|
||||
212
tinygrad_repo/tinygrad/renderer/cstyle.py
Normal file
212
tinygrad_repo/tinygrad/renderer/cstyle.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.helpers import ImageDType, dtypes, prod, DType, strip_parens
|
||||
|
||||
class CStyleLanguage(NamedTuple):
|
||||
size_prefix: str = "int"
|
||||
generic_var_prefix: str = ""
|
||||
kernel_prefix: str = ""
|
||||
buffer_prefix: str = ""
|
||||
buffer_suffix: str = ""
|
||||
smem_align: str = ""
|
||||
smem_prefix: str = ""
|
||||
smem_prefix_for_cast: bool = True
|
||||
arg_int_prefix: str = ""
|
||||
barrier: str = ""
|
||||
xid: List[str] = []
|
||||
gid: List[str] = []
|
||||
lid: List[str] = []
|
||||
global_max: List[int] = []
|
||||
local_max: List[int] = []
|
||||
extra_args: List[str] = []
|
||||
float4: Optional[str] = None
|
||||
half_prekernel: Optional[str] = None
|
||||
uses_vload: bool = False
|
||||
external_local_bufs: bool = False
|
||||
uses_ptr_arithmetic: bool = False
|
||||
launch_bounds: bool = False
|
||||
code_for_op: Dict = {
|
||||
UnaryOps.NEG: lambda x: f"(-{x})",
|
||||
UnaryOps.EXP2: lambda x: f"exp2({x})",
|
||||
UnaryOps.LOG2: lambda x: f"log2({x})",
|
||||
UnaryOps.SIN: lambda x: f"sin({x})",
|
||||
UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
||||
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
|
||||
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
|
||||
BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})",
|
||||
TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})"
|
||||
}
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
def render_cast(self, x:List[str], var_dtype:DType) -> str:
|
||||
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
|
||||
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
||||
assert self.float4 is not None, "cast is not supported on this platform"
|
||||
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
|
||||
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
|
||||
if var_dtype == dtypes._int2: return f"{self.float4.replace('float4', 'int2')}({','.join(x)})"
|
||||
raise NotImplementedError(f"no cast for {var_dtype}")
|
||||
|
||||
# returns a str expression of the const with the given type
|
||||
def render_const(self, x:Union[float,int], var_dtype) -> str:
|
||||
if math.isnan(x): val = "NAN"
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
||||
else: val = f"{x}f" if dtypes.is_float(var_dtype) and isinstance(x, float) else f"{int(x)}"
|
||||
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
|
||||
|
||||
# returns a str expression of the loaded value with the output type
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert output_dtype == dtypes._float4, f"images must be float4, getting {output_dtype}"
|
||||
return f"read_imagef({buf_name}, smp, {idx})"
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
|
||||
if output_dtype.sz > 1:
|
||||
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))"
|
||||
else:
|
||||
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
||||
|
||||
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
|
||||
|
||||
def render_local(self, name:str, size:int):
|
||||
return self.smem_align + self.smem_prefix + f"float {name}[{size}];"
|
||||
|
||||
def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str:
|
||||
return f"for (int {expr} = {_min}; {expr} < {_max}; ++{expr}) {{"
|
||||
|
||||
def render_if(self, cond: str):
|
||||
return f"if ({cond}) {{"
|
||||
|
||||
def render_conditional(self, cond: str, x:str, y:str) -> str:
|
||||
return f"({cond})?({x}):{y}"
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
|
||||
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
self.arg_int_prefix if dtype == dtypes._arg_int32 else
|
||||
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
|
||||
prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
|
||||
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
||||
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
||||
if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
|
||||
return prg
|
||||
|
||||
# returns a str statement that does the store
|
||||
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert var_dtype == dtypes._float4, "images must be float4"
|
||||
return f"write_imagef({buf_name}, {idx}, {var_name});"
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
|
||||
if var_dtype.sz > 1:
|
||||
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
||||
|
||||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
||||
local_size: List[int] = []
|
||||
kernel,prekernel,bufs = [],[],[]
|
||||
#pend_close = None
|
||||
depth = 1
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, str] = {}
|
||||
def ssa(u, prefix="t"):
|
||||
nonlocal c, r
|
||||
c[prefix] += 1
|
||||
r[u]=f"{prefix}{c[prefix]-1}"
|
||||
return r[u]
|
||||
|
||||
child_count: DefaultDict[UOp, int] = defaultdict(int)
|
||||
for ru in uops:
|
||||
for v in ru.vin:
|
||||
child_count[v] += 1
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args,_ = u
|
||||
if uop == UOps.LOOP:
|
||||
kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]]))
|
||||
depth += 1
|
||||
elif uop == UOps.IF:
|
||||
kk(lang.render_if(r[vin[0]]))
|
||||
depth += 1
|
||||
elif uop == UOps.BARRIER:
|
||||
kk(lang.barrier)
|
||||
elif uop == UOps.END:
|
||||
depth -= 1
|
||||
kk("}")
|
||||
elif uop == UOps.WMMA:
|
||||
if args[0] == "METAL":
|
||||
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
|
||||
kk("{ simdgroup_float8x8 a,b,c;")
|
||||
kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};")
|
||||
kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};")
|
||||
kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};")
|
||||
kk("simdgroup_multiply_accumulate(c, a, b, c);")
|
||||
kk(f"{r[vin[4]]} = c.thread_elements()[0]; {r[vin[5]]} = c.thread_elements()[1]; }}")
|
||||
elif args[0] == "HIP":
|
||||
kk("{")
|
||||
kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[0:16]])} }};")
|
||||
kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[16:32]])} }};")
|
||||
kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[32:]])} }};")
|
||||
kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);")
|
||||
for i in range(8): kk(f"{r[vin[32+i]]} = c_frag[{i}];")
|
||||
kk("}")
|
||||
else:
|
||||
raise NotImplementedError(f"WMMA not implemented for {args}")
|
||||
elif uop == UOps.ALU:
|
||||
assert dtype is not None
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}:
|
||||
val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]])
|
||||
else:
|
||||
val = lang.code_for_op[args](*[r[x] for x in vin])
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
if child_count[u] <= 1 or dtypes.is_int(dtype): # fix index rendering issue
|
||||
r[u] = val
|
||||
else:
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};")
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
assert dtype is not None
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
|
||||
elif uop == UOps.SPECIAL:
|
||||
xid = lang.gid if args[1].startswith("g") else (lang.xid if args[1].startswith("i") else lang.lid)
|
||||
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */")
|
||||
if args[1].startswith("l"): local_size.append(args[2])
|
||||
r[u] = args[1]
|
||||
elif uop == UOps.CONST:
|
||||
r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
|
||||
elif uop == UOps.LOAD:
|
||||
assert dtype is not None
|
||||
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)
|
||||
if len(vin) > 2: val = lang.render_conditional(r[vin[2]], val, r[vin[3]])
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};")
|
||||
elif uop == UOps.PHI:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop == UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL))
|
||||
elif uop == UOps.CAST and dtype is not None and dtype.sz > 1:
|
||||
val = lang.render_cast([r[x] for x in vin], dtype)
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
if lang.external_local_bufs:
|
||||
prekernel.append(lang.render_local(args[0], args[1]))
|
||||
else:
|
||||
kk(lang.render_local(args[0], args[1]))
|
||||
r[u] = args[0]
|
||||
elif uop == UOps.DEFINE_GLOBAL:
|
||||
bufs.append(args)
|
||||
r[u] = args[0]
|
||||
elif uop == UOps.GEP:
|
||||
r[u] = f"({r[vin[0]]}).{'xyzw'[args]}"
|
||||
else:
|
||||
raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {}
|
||||
23
tinygrad_repo/tinygrad/renderer/opencl.py
Normal file
23
tinygrad_repo/tinygrad/renderer/opencl.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import functools
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.ops import TernaryOps
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
|
||||
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint64: "ulong" }
|
||||
class OpenCLLanguage(CStyleLanguage):
|
||||
kernel_prefix = "__kernel "
|
||||
buffer_prefix = "__global "
|
||||
smem_align = "__attribute__ ((aligned (16))) "
|
||||
smem_prefix = "__local "
|
||||
arg_int_prefix = "const int"
|
||||
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable"
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
|
||||
float4 = "(float4)"
|
||||
gid = [f'get_group_id({i})' for i in range(3)]
|
||||
lid = [f'get_local_id({i})' for i in range(3)]
|
||||
xid = [f'get_global_id({i})' for i in range(3)]
|
||||
uses_vload = True
|
||||
# NOTE: mad is used so the loads aren't reordered into the math on 845
|
||||
code_for_op = {**CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a,b,c: f"mad({a},{b},{c})"}
|
||||
|
||||
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())
|
||||
111
tinygrad_repo/tinygrad/runtime/lib.py
Normal file
111
tinygrad_repo/tinygrad/runtime/lib.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from collections import defaultdict, deque
|
||||
from typing import TypeVar, Type, Any, Dict, Deque, Tuple
|
||||
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, ImageDType
|
||||
|
||||
_T = TypeVar("_T")
|
||||
class RawBuffer: # pylint: disable=abstract-method
|
||||
def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
|
||||
self.size: int = size
|
||||
self.dtype: DType = dtype
|
||||
self._buf = buf if buf is not None else (allocator.alloc(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
|
||||
self._memsz: int = size*dtype.itemsize
|
||||
self._allocator = allocator
|
||||
self._device = kwargs.get('device', None)
|
||||
GlobalCounters.mem_used += self._memsz
|
||||
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
|
||||
if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
|
||||
if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf)
|
||||
def __repr__(self): return f"buffer<{self.size}, {self.dtype}, {id(self)}>"
|
||||
@property
|
||||
def key(self): return (self.size, self.dtype)
|
||||
|
||||
# NOTE: this interface allows for 0 copy
|
||||
@classmethod
|
||||
def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented")
|
||||
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
|
||||
|
||||
class RawBufferCopyIn(RawBuffer):
|
||||
def _copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
|
||||
|
||||
@classmethod
|
||||
def fromCPU(cls, x:np.ndarray, **kwargs):
|
||||
ret = cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
|
||||
if x.size > 0: ret._copyin(x)
|
||||
return ret
|
||||
|
||||
class RawBufferMapped(RawBufferCopyIn):
|
||||
def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
|
||||
# NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
|
||||
def buffer_view(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size) # type: ignore
|
||||
def toCPU(self) -> np.ndarray: return self.buffer_view().copy() # Need a copy, since jit will write to the same buffer.
|
||||
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.buffer_view(), x.reshape(-1))
|
||||
|
||||
# this one is simple enough that i moved it out of the runtimes
|
||||
class RawMallocBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64, dtypes.int16: ctypes.c_int16, dtypes.uint16: ctypes.c_uint16}[dtype] * size)())
|
||||
def _buffer(self): return memoryview(self._buf)
|
||||
|
||||
class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
def _copyout(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
|
||||
|
||||
def toCPU(self) -> np.ndarray:
|
||||
x: np.ndarray = np.empty(self.size, dtype=self.dtype.np)
|
||||
if x.size > 0: self._copyout(x)
|
||||
return x
|
||||
|
||||
class RawBufferTransfer(RawBuffer):
|
||||
def _transfer(self, x) -> None: raise NotImplementedError("must be implemented")
|
||||
|
||||
@classmethod
|
||||
def transfer(cls, x, shape, dtype, **kwargs):
|
||||
ret = cls(prod(shape), dtype, **kwargs)
|
||||
ret._transfer(x)
|
||||
return ret
|
||||
|
||||
class LRUAllocator:
|
||||
def __init__(self, dev_memsz=(4<<30)):
|
||||
self.epoch = 0
|
||||
self.free_space: Dict[Any, int] = defaultdict(lambda: dev_memsz)
|
||||
self.buffer_info: Dict[Any, Tuple[int, DType, str]] = dict()
|
||||
self.cached_buffers: Dict[Tuple[int, ...], Deque[Tuple[Any, int]]] = defaultdict(deque) # Cached buffer storage, splitted by type and size, newest first.
|
||||
self.aging_order: Dict[Any, Deque[Tuple[Tuple[int, ...], int]]] = defaultdict(deque) # Keys of cached_buffers, ordered from oldest to newest updates.
|
||||
|
||||
def _cache_reuse_buffer(self, rawbufs: Deque[Tuple[Any, int]]): # The newest cached buffer is reused.
|
||||
GlobalCounters.mem_cached -= self._underlying_buf_memsz(rawbufs[0][0])
|
||||
return rawbufs.popleft()[0]
|
||||
|
||||
def ensure_has_free_space(self, size, dtype, device):
|
||||
while len(self.aging_order[device]) and (self.free_space[device]-size*dtype.itemsize) < 0: # When OOM removing lru buffers.
|
||||
bucket, epoch = self.aging_order[device].popleft()
|
||||
if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
|
||||
|
||||
def _alloc_buffer(self, size, dtype, device, **kwargs):
|
||||
self.ensure_has_free_space(size, dtype, device)
|
||||
self.free_space[device] -= size*dtype.itemsize
|
||||
newbuf = self._do_alloc(max(1, size), dtype, device, **kwargs)
|
||||
self.buffer_info[newbuf] = (size, dtype, device)
|
||||
return newbuf
|
||||
|
||||
def _free_buffer(self, buf_to_free):
|
||||
self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free)
|
||||
GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free)
|
||||
self.buffer_info.pop(buf_to_free)
|
||||
self._do_free(buf_to_free)
|
||||
|
||||
def alloc(self, size, dtype, device='0', **kwargs):
|
||||
rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None)
|
||||
return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs)
|
||||
|
||||
def free(self, buf): # free() just caches buffer. It might be freed later when OOM during allocation.
|
||||
self.epoch += 1
|
||||
size, dtype, device = self.buffer_info[buf]
|
||||
self.cached_buffers[self._cached_bufkey(size, dtype, device)].appendleft((buf, self.epoch))
|
||||
self.aging_order[device].append((self._cached_bufkey(size, dtype, device), self.epoch))
|
||||
GlobalCounters.mem_cached += self._underlying_buf_memsz(buf)
|
||||
|
||||
def _underlying_buf_memsz(self, buf): return self.buffer_info[buf][0] * self.buffer_info[buf][1].itemsize
|
||||
def _cached_bufkey(self, size, dtype, device) -> Tuple[int, ...]: return (device, size, dtype, dtype.shape) if isinstance(dtype, ImageDType) else (device, size, dtype) # Provides a key for reusing device buffers with identical keys.
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): raise NotImplementedError("must be implemented")
|
||||
def _do_free(self, buf): pass
|
||||
52
tinygrad_repo/tinygrad/runtime/ops_cpu.py
Normal file
52
tinygrad_repo/tinygrad/runtime/ops_cpu.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import numpy as np
|
||||
import operator
|
||||
from typing import Callable, Dict, Tuple, Optional
|
||||
from tinygrad.helpers import dtypes, DType
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
|
||||
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
|
||||
|
||||
base_fxn_for_op: Dict[Op, Callable] = {
|
||||
BufferOps.MEM: lambda x: x._buf, UnaryOps.NEG: operator.neg, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
|
||||
}
|
||||
|
||||
def match_types(x, y):
|
||||
up = x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
|
||||
return x.astype(up, copy=False), y.astype(up, copy=False)
|
||||
|
||||
def einsum_mulacc(einsum, get_strides, expand):
|
||||
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
|
||||
def axes_slice(strides): return [i for i,s in enumerate(strides) if s != 0], tuple([slice(None) if s != 0 else 0 for i,s in enumerate(strides)])
|
||||
def mulacc(a, b, new_shape):
|
||||
(a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b))
|
||||
out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
|
||||
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices])
|
||||
return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
|
||||
return mulacc
|
||||
|
||||
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
|
||||
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
||||
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
||||
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
|
||||
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])),
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
TernaryOps.WHERE: np.where,
|
||||
}}
|
||||
|
||||
class RawNumpyBuffer(RawBuffer):
|
||||
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None): super().__init__(size, dtype, buf if buf is not None else np.empty([size], dtype.np))
|
||||
@classmethod
|
||||
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
|
||||
def toCPU(self): return self._buf
|
||||
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, from_underlying=RawNumpyBuffer.fromCPU)
|
||||
41
tinygrad_repo/tinygrad/runtime/ops_disk.py
Normal file
41
tinygrad_repo/tinygrad/runtime/ops_disk.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import os, mmap
|
||||
from typing import Optional
|
||||
from typing import Callable, Dict, Tuple
|
||||
from tinygrad.helpers import prod, DType
|
||||
from tinygrad.runtime.lib import RawBufferMapped
|
||||
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps
|
||||
|
||||
class RawDiskBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype:DType, device:Optional[str]=None, buf=None, shape=None, offset=0): # pylint: disable=super-init-not-called
|
||||
self.shape = (size, ) if shape is None else shape
|
||||
self.offset = offset # this is an offset in bytes
|
||||
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
|
||||
if device is not None:
|
||||
f = open(device, "a+b")
|
||||
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
|
||||
buf = [f, mmap.mmap(f.fileno(), size * dtype.itemsize), 1]
|
||||
else:
|
||||
buf[2] += 1
|
||||
# NOTE: we don't call super since disk tensors don't use RAM
|
||||
self.size, self.dtype, self._buf = size, dtype, buf
|
||||
def __del__(self):
|
||||
self._buf[2] -= 1
|
||||
if self._buf[2] == 0: self._buf[0].close()
|
||||
def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset)
|
||||
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
|
||||
def shrink(self, arg):
|
||||
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
|
||||
offset = arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize
|
||||
size = (arg[0][1]-arg[0][0]) * prod(self.shape[1:])
|
||||
return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+self.shape[1:])
|
||||
|
||||
def as_strided(self, arg):
|
||||
return RawDiskBuffer(prod(arg[0]), self.dtype, buf=self._buf, offset=self.offset+arg[2]*self.dtype.itemsize, shape=arg[0])
|
||||
|
||||
def _buffer(self): return memoryview(self._buf[1])[self.offset:self.offset+self.size*self.dtype.itemsize]
|
||||
def readinto(self, buf):
|
||||
self._buf[0].seek(self.offset)
|
||||
self._buf[0].readinto(buf)
|
||||
|
||||
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
|
||||
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op, from_underlying=lambda x:x)
|
||||
110
tinygrad_repo/tinygrad/runtime/ops_gpu.py
Normal file
110
tinygrad_repo/tinygrad/runtime/ops_gpu.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
os.environ['PYOPENCL_NO_CACHE'] = '1'
|
||||
import pathlib
|
||||
import numpy as np
|
||||
import pyopencl as cl # type: ignore
|
||||
from typing import Optional, List, Tuple
|
||||
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport, diskcache
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.renderer.opencl import OpenCLRenderer
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
|
||||
OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
|
||||
|
||||
# TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait()
|
||||
ROCM_LLVM_PATH = pathlib.Path("/opt/rocm/llvm/bin")
|
||||
#ROCM_LLVM_PATH = pathlib.Path(__file__).parents[3] / "extra/rocm/build/llvm-project/bin"
|
||||
if DEBUG >= 5:
|
||||
early_exec = fromimport("extra.helpers", "enable_early_exec")()
|
||||
|
||||
class CLAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs):
|
||||
if isinstance(dtype, ImageDType):
|
||||
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
|
||||
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
|
||||
buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
|
||||
else:
|
||||
buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize)
|
||||
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
|
||||
return buf
|
||||
|
||||
class _CL:
|
||||
def __init__(self):
|
||||
cl_platforms = cl.get_platforms()
|
||||
platform_devices: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl_platforms] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl_platforms]) if y]
|
||||
self.devices = [device for device in platform_devices[getenv('CL_PLATFORM', 0)] if device.name not in getenv('CL_EXCLUDE', "").split(",")]
|
||||
self.cl_platform = self.devices[0].platform
|
||||
def post_init(self, device=None):
|
||||
self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in self.devices] if device is None else [cl.Context(devices=[self.devices[device]])]
|
||||
if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}")
|
||||
self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs]
|
||||
self.cl_allocator = CLAllocator(CL.cl_ctxs[0].devices[0].get_info(cl.device_info.GLOBAL_MEM_SIZE))
|
||||
def synchronize(self):
|
||||
for q in self.cl_queue: q.finish()
|
||||
CL = _CL()
|
||||
if not getenv("DELAYED_RUNTIME_INIT", False): CL.post_init()
|
||||
|
||||
class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
||||
def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device})
|
||||
def _copyin(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
|
||||
self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements=['C', 'A']), is_blocking=False)
|
||||
def _copyout(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
|
||||
CL.cl_allocator.ensure_has_free_space(self.size, self.dtype, self._device)
|
||||
buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data)
|
||||
mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False)
|
||||
with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([self.event] if hasattr(self, "event") else []))
|
||||
def _transfer(self, x):
|
||||
if "gfx" in CL.cl_ctxs[x._buf.device].devices[0].name:
|
||||
cl.enqueue_copy_buffer_p2p_amd(CL.cl_platform, CL.cl_queue[x._buf.device], x._buf, self._buf, x.size * x.dtype.itemsize).wait()
|
||||
else: raise NotImplementedError("p2p transfer between devices not implemented on non-amd")
|
||||
|
||||
@diskcache
|
||||
def compile_gpu(prg:str) -> bytes:
|
||||
clprg = cl.Program(CL.cl_ctxs[0], prg)
|
||||
clprg.build()
|
||||
return clprg.get_info(cl.program_info.BINARIES)[0]
|
||||
|
||||
class CLProgram:
|
||||
def __init__(self, name:str, prg:bytes, argdtypes=None, options=None):
|
||||
self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) for ctx in CL.cl_ctxs] # type: ignore
|
||||
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
|
||||
self.clprgs = [clprg.__getattr__(name) for clprg in self._clprgs]
|
||||
if DEBUG >= 5 and not OSX:
|
||||
if 'Adreno' in CL.cl_ctxs[0].devices[0].name:
|
||||
fromimport('disassemblers.adreno', 'disasm')(prg)
|
||||
elif CL.cl_ctxs[0].devices[0].name.startswith('gfx'):
|
||||
asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], prg))
|
||||
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
|
||||
else:
|
||||
# print the PTX for NVIDIA. TODO: probably broken for everything else
|
||||
print(prg.decode('utf-8'))
|
||||
if argdtypes is not None: self.set_argdtypes(argdtypes)
|
||||
|
||||
def set_argdtypes(self, argdtypes): self.argdtypes, _ = argdtypes, [clprg.set_scalar_arg_dtypes(argdtypes) for clprg in self.clprgs]
|
||||
|
||||
@staticmethod
|
||||
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
|
||||
|
||||
def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Optional[Tuple[int,int,int]]=None, wait=False) -> Optional[float]:
|
||||
if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if x.__class__ is CLBuffer else np.int32 for x in bufs))
|
||||
cl_bufs, wait_for = [], []
|
||||
for x in bufs:
|
||||
if x.__class__ is CLBuffer:
|
||||
cl_bufs.append(x._buf)
|
||||
if hasattr(x, "event"): wait_for.append(x.event)
|
||||
else: cl_bufs.append(x)
|
||||
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [int(g*l) for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=wait_for)
|
||||
if wait:
|
||||
e.wait()
|
||||
try:
|
||||
return ((e.profile.end - e.profile.start) * OSX_TIMING_RATIO) * 1e-9
|
||||
except cl.RuntimeError: # no profiling info available
|
||||
return None
|
||||
return None
|
||||
|
||||
GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), OpenCLRenderer, compile_gpu, CLProgram, CL.synchronize)
|
||||
221
tinygrad_repo/tinygrad/shape/shapetracker.py
Normal file
221
tinygrad_repo/tinygrad/shape/shapetracker.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import functools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, cast
|
||||
from tinygrad.ops import MovementOps
|
||||
from tinygrad.helpers import prod, DEBUG, dedup
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
|
||||
assert len(shape) == len(strides)
|
||||
ret = [(shape[0], strides[0])] if shape else []
|
||||
for i in range(1, len(shape)):
|
||||
if ret[-1][1] == shape[i]*strides[i] or ret[-1][0] == 1:
|
||||
ret[-1] = (ret[-1][0] * shape[i], strides[i])
|
||||
elif shape[i] == 1:
|
||||
continue
|
||||
else:
|
||||
ret.append((shape[i], strides[i]))
|
||||
return tuple(ret)
|
||||
|
||||
def expr_node_mask(view:View, idx, valid=None) -> Node:
|
||||
expr = [valid] if valid is not None else []
|
||||
if view.mask is not None:
|
||||
acc = 1
|
||||
for ns,(x,y) in reversed(list(zip(view.shape, view.mask))):
|
||||
if x != 0 or y != ns:
|
||||
base = ((idx//acc) % ns)
|
||||
expr += [base >= x, base < y]
|
||||
acc *= ns
|
||||
return Variable.ands(expr)
|
||||
|
||||
# generate an expression if you have a single idx variable
|
||||
def expr_node(view:View, idx=None) -> Node:
|
||||
if idx is None: idx = Variable('idx', 0, prod(view.shape)-1)
|
||||
ret: List[Node] = [Variable.num(view.offset) if isinstance(view.offset, int) else view.offset] if view.offset else []
|
||||
acc = 1
|
||||
for d,s in reversed(to_shape_strides(view.shape, view.strides)):
|
||||
ret.append(((idx//acc)%d)*s)
|
||||
acc *= d
|
||||
return Variable.sum(ret)
|
||||
|
||||
# generate an expression if you have a variable or expression for each index
|
||||
def expr_idxs(view:View, idxs) -> Node:
|
||||
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
|
||||
return Variable.sum([Variable.num(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0])
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
if vm2.mask: return None # this isn't supported yet
|
||||
mst = ShapeTracker((vm2, vm1))
|
||||
strides = mst.real_strides()
|
||||
if None in strides: return None
|
||||
return View.create(vm1.shape, cast(Tuple[sint, ...], strides), mst.real_offset(), vm1.mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
|
||||
assert len(idxs) == len(shape), "need an idx for all dimensions"
|
||||
acc = 1
|
||||
ret = []
|
||||
for tidx,d in reversed(list(zip(idxs, shape))):
|
||||
ret.append(tidx * acc)
|
||||
acc *= d
|
||||
return Variable.sum(ret)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
views: Tuple[View, ...]
|
||||
def __post_init__(self): assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views"
|
||||
|
||||
@staticmethod
|
||||
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
|
||||
|
||||
@property
|
||||
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
|
||||
|
||||
# this is the real size (ish)
|
||||
def size(self): return self.views[-1].size()
|
||||
|
||||
def vars(self) -> List[Variable]: return dedup(functools.reduce(operator.add, [v.vars() for v in self.views], []))
|
||||
|
||||
@property
|
||||
def var_vals(self) -> Dict[Variable, int]:
|
||||
ret:Dict[Variable, int] = {}
|
||||
for v in self.vars():
|
||||
var, val = v.unbind()
|
||||
assert var not in ret or ret[var] == val, f"{var} has conflicted values {val} and {ret[var]}"
|
||||
ret[var] = val
|
||||
return ret
|
||||
|
||||
def unbind(self) -> ShapeTracker: return ShapeTracker(tuple(v.unbind() for v in self.views))
|
||||
|
||||
def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
|
||||
to_apply:List[Tuple[MovementOps, Tuple]] = []
|
||||
for v in self.views:
|
||||
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
|
||||
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
|
||||
# first, we apply the offset
|
||||
# then, we make it the correct shape
|
||||
# then, we apply permutations
|
||||
# TODO: don't use as_strided
|
||||
to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset)))
|
||||
# then, we apply pre expand pads
|
||||
if v.mask is not None:
|
||||
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
||||
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
||||
if any(x != (0,0) for x in pre_expand_pads):
|
||||
to_apply.append((MovementOps.PAD, pre_expand_pads))
|
||||
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
|
||||
# then, we do any expands
|
||||
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
|
||||
# lastly, we apply post expand pads
|
||||
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
|
||||
return to_apply
|
||||
|
||||
# these are multiview strides, value is None if it's not a simple strided dimension
|
||||
# TODO: this can be shared code between simplify and merge_views
|
||||
def real_offset(self) -> sint:
|
||||
real_offset, _ = self.expr_node(Variable('zero', 0, 0))
|
||||
return real_offset.b if isinstance(real_offset, NumNode) else real_offset
|
||||
|
||||
# NOTE: if a stride is not always valid, it will be None
|
||||
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
|
||||
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
||||
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
||||
idx, valid = self.expr_idxs(idxs)
|
||||
ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
|
||||
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
||||
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
|
||||
ret[idxs.index(this_dim.a)] = this_dim.b
|
||||
elif isinstance(this_dim, Variable) and this_dim in idxs:
|
||||
ret[idxs.index(this_dim)] = 1
|
||||
idx_vars, valid_vars = idx.vars(), valid.vars()
|
||||
for i,tidx in enumerate(idxs):
|
||||
if tidx in valid_vars and not ignore_valid: ret[i] = None
|
||||
elif tidx not in idx_vars: ret[i] = 0
|
||||
return tuple(ret)
|
||||
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
def _expr_idx(self, idx, valid) -> Tuple[Node, Node]:
|
||||
for v in reversed(self.views[0:-1]):
|
||||
if valid.max == 0: return Variable.num(-1), valid
|
||||
valid = expr_node_mask(v, idx, valid)
|
||||
idx = expr_node(v, idx)
|
||||
return idx, valid
|
||||
|
||||
def simplify(self) -> ShapeTracker:
|
||||
if len(self.views) >= 2:
|
||||
new_view = merge_views(self.views[-2], self.views[-1])
|
||||
if new_view:
|
||||
if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
|
||||
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
|
||||
return self
|
||||
|
||||
def expr_idxs(self, idxs=None):
|
||||
if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
||||
idx = expr_idxs(self.views[-1], tuple(idxs))
|
||||
valid = expr_node_mask(self.views[-1], idxs_to_idx(self.views[-1].shape, tuple(idxs)))
|
||||
return self._expr_idx(idx, valid)
|
||||
|
||||
def expr_node(self, idx='idx'):
|
||||
if idx.__class__ is str: idx = Variable(idx, 0, prod(self.shape)-1)
|
||||
return self._expr_idx(expr_node(self.views[-1], idx), expr_node_mask(self.views[-1], idx))
|
||||
|
||||
def axis_is_masked(self, axis) -> bool:
|
||||
_, valid = self.expr_idxs()
|
||||
return f'idx{axis}' in [v.expr for v in valid.vars()]
|
||||
|
||||
# *** under this line are the movement ops ***
|
||||
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
|
||||
|
||||
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
|
||||
|
||||
def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
|
||||
|
||||
def permute(self, axis: Tuple[int, ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
|
||||
|
||||
def stride(self, mul: Tuple[int, ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
|
||||
|
||||
def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
|
||||
new_view = self.views[-1].reshape(new_shape)
|
||||
if new_view is None:
|
||||
extra_view = View.create(new_shape)
|
||||
# last chance to merge. TODO: move into View
|
||||
if (merged_view := merge_views(self.views[-1], extra_view)) is not None:
|
||||
return ShapeTracker(self.views[0:-1] + (merged_view,))
|
||||
return ShapeTracker(self.views + (extra_view, ))
|
||||
return ShapeTracker(self.views[0:-1] + (new_view,))
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
# TODO: if we remove movementops from lazy.py we can delete this
|
||||
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
||||
# Pre-allocate all groups.
|
||||
axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
|
||||
# Index for new_shape and axis_groups.
|
||||
i: int = 0
|
||||
old_shape_i: int = 0
|
||||
while old_shape_i < len(old_shape):
|
||||
# 1s exist in new_shape only will lead to empty axes group creations.
|
||||
if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
|
||||
if i < len(new_shape) - 1: i += 1
|
||||
else:
|
||||
axis_groups[i].append(old_shape_i)
|
||||
axis_group_size = prod([old_shape[x] for x in axis_groups[i]])
|
||||
# Move to next axes group if total size of all dimensions match.
|
||||
if axis_group_size == new_shape[i]:
|
||||
if i < len(new_shape) - 1: i += 1
|
||||
elif axis_group_size > new_shape[i]: return None
|
||||
old_shape_i += 1
|
||||
return axis_groups
|
||||
352
tinygrad_repo/tinygrad/shape/symbolic.py
Normal file
352
tinygrad_repo/tinygrad/shape/symbolic.py
Normal file
@@ -0,0 +1,352 @@
|
||||
from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
import functools
|
||||
from math import gcd
|
||||
from itertools import product
|
||||
from tinygrad.helpers import partition
|
||||
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator
|
||||
|
||||
# NOTE: Python has different behavior for negative mod and floor div than c
|
||||
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
||||
|
||||
def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
|
||||
|
||||
class Node:
|
||||
b: Union[Node, int]
|
||||
min: int
|
||||
max: int
|
||||
def render(self, ops=None, ctx=None) -> Any:
|
||||
if ops is None: ops = render_python
|
||||
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
def vars(self): return []
|
||||
|
||||
def expand_idx(self) -> VariableOrNum: return next((v for v in self.vars() if v.expr is None), NumNode(0))
|
||||
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
|
||||
# expand increments earlier variables faster than later variables (as specified in the argument)
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def expand(self, idxs:Optional[Tuple[VariableOrNum, ...]]=None) -> List[Node]:
|
||||
if idxs is None: idxs = (self.expand_idx(),)
|
||||
return [self.substitute(dict(zip(idxs, (NumNode(x) for x in rep)))) for rep in Node.iter_idxs(idxs)]
|
||||
@staticmethod
|
||||
def iter_idxs(idxs:Tuple[VariableOrNum, ...]) -> Iterator[Tuple[int,...]]:
|
||||
yield from (x[::-1] for x in product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
|
||||
# substitute Variables with the values in var_vals
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__)
|
||||
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
|
||||
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
@functools.cached_property
|
||||
def hash(self) -> int: return hash(self.key)
|
||||
def __repr__(self): return self.render(ctx="REPR")
|
||||
def __str__(self): return "<"+self.key+">"
|
||||
def __hash__(self): return self.hash
|
||||
def __bool__(self): return not (self.max == self.min == 0)
|
||||
def __eq__(self, other:object) -> bool:
|
||||
if not isinstance(other, Node): return NotImplemented
|
||||
return self.key == other.key
|
||||
def __neg__(self): return self*-1
|
||||
def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
|
||||
def __radd__(self, b:int): return self+b
|
||||
def __sub__(self, b:Union[Node,int]): return self+-b
|
||||
def __rsub__(self, b:int): return -self+b
|
||||
def __le__(self, b:Union[Node,int]): return self < (b+1)
|
||||
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
|
||||
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
|
||||
def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b))
|
||||
def __mul__(self, b:Union[Node, int]):
|
||||
if b == 0: return NumNode(0)
|
||||
if b == 1: return self
|
||||
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
|
||||
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
|
||||
def __rmul__(self, b:int): return self*b
|
||||
|
||||
# *** complex ops ***
|
||||
|
||||
def __rfloordiv__(self, b:int):
|
||||
if self.min > b >= 0: return NumNode(0)
|
||||
if isinstance(self, NumNode): return NumNode(b // self.b)
|
||||
raise RuntimeError(f"not supported: {b} // {self}")
|
||||
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
|
||||
if isinstance(b, Node):
|
||||
if b.__class__ is NumNode: return self // b.b
|
||||
if self == b: return NumNode(1)
|
||||
if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
|
||||
raise RuntimeError(f"not supported: {self} // {b}")
|
||||
assert b != 0
|
||||
if b < 0: return (self//-b)*-1
|
||||
if b == 1: return self
|
||||
|
||||
# the numerator of div is not allowed to be negative
|
||||
if self.min < 0:
|
||||
offset = self.min//b
|
||||
# factor out an "offset" to make the numerator positive. don't allowing factoring again
|
||||
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
|
||||
return create_node(DivNode(self, b))
|
||||
|
||||
def __rmod__(self, b:int):
|
||||
if self.min > b >= 0: return NumNode(b)
|
||||
if isinstance(self, NumNode): return NumNode(b % self.b)
|
||||
raise RuntimeError(f"not supported: {b} % {self}")
|
||||
def __mod__(self, b:Union[Node,int]):
|
||||
if isinstance(b, Node):
|
||||
if b.__class__ is NumNode: return self % b.b
|
||||
if self == b: return NumNode(0)
|
||||
if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node
|
||||
raise RuntimeError(f"not supported: {self} % {b}")
|
||||
assert b > 0
|
||||
if b == 1: return NumNode(0)
|
||||
if self.min >= 0 and self.max < b: return self
|
||||
if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
|
||||
if self.min < 0: return (self - ((self.min//b)*b)) % b
|
||||
return create_node(ModNode(self, b))
|
||||
|
||||
@staticmethod
|
||||
def num(num:int) -> NumNode: return NumNode(num)
|
||||
|
||||
@staticmethod
|
||||
def factorize(nodes:List[Node]) -> List[Node]:
|
||||
mul_groups: Dict[Node, int] = {}
|
||||
for x in nodes:
|
||||
a,b = (x.a,x.b) if isinstance(x, MulNode) else (x,1)
|
||||
mul_groups[a] = mul_groups.get(a, 0) + b
|
||||
return [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
|
||||
|
||||
@staticmethod
|
||||
def sum(nodes:List[Node]) -> Node:
|
||||
nodes = [x for x in nodes if x.max or x.min]
|
||||
if not nodes: return NumNode(0)
|
||||
if len(nodes) == 1: return nodes[0]
|
||||
|
||||
new_nodes: List[Node] = []
|
||||
num_node_sum = 0
|
||||
for node in SumNode(nodes).flat_components:
|
||||
if node.__class__ is NumNode: num_node_sum += node.b
|
||||
else: new_nodes.append(node)
|
||||
|
||||
if len(new_nodes) > 1 and len(set([x.a if isinstance(x, MulNode) else x for x in new_nodes])) < len(new_nodes):
|
||||
new_nodes = Node.factorize(new_nodes)
|
||||
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
|
||||
return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
|
||||
|
||||
@staticmethod
|
||||
def ands(nodes:List[Node]) -> Node:
|
||||
if not nodes: return NumNode(1)
|
||||
if len(nodes) == 1: return nodes[0]
|
||||
if any(not x for x in nodes): return NumNode(0)
|
||||
|
||||
# filter 1s
|
||||
nodes = [x for x in nodes if x.min != x.max]
|
||||
return create_rednode(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
|
||||
|
||||
# 4 basic node types
|
||||
|
||||
class Variable(Node):
|
||||
def __new__(cls, expr:Optional[str], nmin:int, nmax:int):
|
||||
assert nmin >= 0 and nmin <= nmax
|
||||
if nmin == nmax: return NumNode(nmin)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
self.val:Optional[int] = None
|
||||
def bind(self, val):
|
||||
assert self.val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
|
||||
self.val = val
|
||||
return self
|
||||
def unbind(self) -> Tuple[Variable, int]:
|
||||
assert self.val is not None, f"cannot unbind {self}"
|
||||
return Variable(self.expr, self.min, self.max), self.val
|
||||
def vars(self): return [self]
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return var_vals[self] if self in var_vals else self
|
||||
|
||||
class NumNode(Node):
|
||||
def __init__(self, num:int):
|
||||
assert isinstance(num, int), f"{num} is not an int"
|
||||
self.b:int = num
|
||||
self.min, self.max = num, num
|
||||
def bind(self, val):
|
||||
assert self.b == val, f"cannot bind {val} to {self}"
|
||||
return self
|
||||
def __eq__(self, other): return self.b == other
|
||||
def __hash__(self): return self.hash # needed with __eq__ override
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self
|
||||
|
||||
def create_node(ret:Node):
|
||||
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
|
||||
if ret.min == ret.max: return NumNode(ret.min)
|
||||
return ret
|
||||
|
||||
class OpNode(Node):
|
||||
def __init__(self, a:Node, b:Union[Node, int]):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = self.get_bounds()
|
||||
def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else [])
|
||||
@abstractmethod
|
||||
def get_bounds(self) -> Tuple[int, int]: pass
|
||||
|
||||
class LtNode(OpNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
if isinstance(self.b, int):
|
||||
return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
|
||||
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
|
||||
class MulNode(OpNode):
|
||||
def __lt__(self, b: Union[Node, int]):
|
||||
if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1: return Node.__lt__(self, b)
|
||||
sgn = 1 if self.b > 0 else -1
|
||||
return Node.__lt__(self.a*sgn, (b + abs(self.b) - 1)//abs(self.b))
|
||||
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
|
||||
def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
|
||||
if self.b % b == 0: return self.a*(self.b//b)
|
||||
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
a = (self.a * (self.b%b))
|
||||
return Node.__mod__(a, b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
|
||||
class DivNode(OpNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
assert self.a.min >= 0 and isinstance(self.b, int)
|
||||
return self.a.min//self.b, self.a.max//self.b
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) // self.b
|
||||
|
||||
class ModNode(OpNode):
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
if isinstance(b, Node) or isinstance(self.b, Node): return Node.__mod__(self, b)
|
||||
return self.a % b if gcd(self.b, b) == b else Node.__mod__(self, b)
|
||||
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
|
||||
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
assert self.a.min >= 0 and isinstance(self.b, int)
|
||||
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b)
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) % self.b
|
||||
|
||||
class RedNode(Node):
|
||||
def __init__(self, nodes:List[Node]): self.nodes = nodes
|
||||
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
|
||||
|
||||
class SumNode(RedNode):
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
|
||||
fully_divided: List[Node] = []
|
||||
rest: List[Node] = []
|
||||
if isinstance(b, SumNode):
|
||||
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
||||
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
||||
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
|
||||
if isinstance(b, Node):
|
||||
for x in self.flat_components:
|
||||
if x % b == 0: fully_divided.append(x // b)
|
||||
else: rest.append(x)
|
||||
if (sum_fully_divided:=create_rednode(SumNode, fully_divided)) != 0: return sum_fully_divided + create_rednode(SumNode, rest) // b
|
||||
return Node.__floordiv__(self, b, False)
|
||||
if b == 1: return self
|
||||
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
||||
fully_divided, rest = [], []
|
||||
_gcd = b
|
||||
divisor = 1
|
||||
for x in self.flat_components:
|
||||
if x.__class__ in (NumNode, MulNode):
|
||||
if x.b%b == 0: fully_divided.append(x//b)
|
||||
else:
|
||||
rest.append(x)
|
||||
_gcd = gcd(_gcd, x.b)
|
||||
if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
|
||||
else:
|
||||
rest.append(x)
|
||||
_gcd = 1
|
||||
if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd)
|
||||
if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
|
||||
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
if isinstance(b, SumNode):
|
||||
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
||||
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
||||
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
|
||||
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
|
||||
new_nodes: List[Node] = []
|
||||
for x in self.nodes:
|
||||
if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b))
|
||||
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
|
||||
else: new_nodes.append(x)
|
||||
return Node.__mod__(Node.sum(new_nodes), b)
|
||||
|
||||
def __lt__(self, b:Union[Node,int]):
|
||||
lhs: Node = self
|
||||
if isinstance(b, int):
|
||||
new_sum = []
|
||||
for x in self.nodes:
|
||||
# TODO: should we just force the last one to always be the number
|
||||
if isinstance(x, NumNode): b -= x.b
|
||||
else: new_sum.append(x)
|
||||
lhs = Node.sum(new_sum)
|
||||
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
|
||||
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
||||
if muls:
|
||||
# NOTE: gcd in python 3.8 takes exactly 2 args
|
||||
mul_gcd = b
|
||||
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell x.b is int here
|
||||
all_others = Variable.sum(others)
|
||||
if all_others.min >= 0 and all_others.max < mul_gcd:
|
||||
lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
||||
return Node.__lt__(lhs, b)
|
||||
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Variable.sum([node.substitute(var_vals) for node in self.nodes])
|
||||
|
||||
@property
|
||||
def flat_components(self): # recursively expand sumnode components
|
||||
new_nodes = []
|
||||
for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x])
|
||||
return new_nodes
|
||||
|
||||
class AndNode(RedNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=True): return Variable.ands([x//b for x in self.nodes])
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
|
||||
subed = []
|
||||
for node in self.nodes:
|
||||
if not (sub:=node.substitute(var_vals)): return NumNode(0)
|
||||
subed.append(sub)
|
||||
return Variable.ands(subed)
|
||||
|
||||
def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
||||
ret = typ(nodes)
|
||||
if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes]))
|
||||
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
|
||||
return create_node(ret)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
|
||||
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
||||
def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
|
||||
if isinstance(a, (int, float)): return a
|
||||
ret = a.substitute({k:Variable.num(v) for k, v in var_vals.items()})
|
||||
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
|
||||
return ret.b
|
||||
|
||||
# symbolic int
|
||||
sint = Union[Node, int]
|
||||
VariableOrNum = Union[Variable, NumNode]
|
||||
|
||||
render_python: Dict[Type, Callable] = {
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self.val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"),
|
||||
NumNode: lambda self,ops,ctx: f"{self.b}",
|
||||
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
|
||||
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
||||
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
|
||||
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
|
||||
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
||||
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
|
||||
}
|
||||
132
tinygrad_repo/tinygrad/shape/view.py
Normal file
132
tinygrad_repo/tinygrad/shape/view.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
import functools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, cast
|
||||
from tinygrad.helpers import prod, all_int, dedup
|
||||
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, is_sym_int, sint
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
strides = [1] if shape else []
|
||||
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
|
||||
return filter_strides(shape, tuple(strides))
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
shape:Tuple[sint, ...]
|
||||
strides:Tuple[sint, ...]
|
||||
offset:sint
|
||||
mask:Optional[Tuple[Tuple[sint, sint], ...]]
|
||||
contiguous:bool
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
|
||||
strides = filter_strides(shape, strides) if strides else strides_for_shape(shape)
|
||||
contiguous = offset == 0 and mask is None and all(s1 == s2 for s1,s2 in zip(strides, strides_for_shape(shape)))
|
||||
return View(shape, strides, offset, mask, contiguous)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def size(self): return prod([s.max if isinstance(s, Node) else s for s,st in zip(self.shape, self.strides) if st != 0])
|
||||
|
||||
def vars(self) -> List[Variable]:
|
||||
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
|
||||
return dedup(functools.reduce(operator.add, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], []))
|
||||
|
||||
def unbind(self) -> View:
|
||||
unbound_vars:Dict[VariableOrNum,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None}
|
||||
new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
|
||||
new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
|
||||
new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
|
||||
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
|
||||
return View.create(new_shape, new_strides, new_offset, new_mask)
|
||||
|
||||
# MovementOps live here now
|
||||
|
||||
def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
|
||||
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
|
||||
if self.mask:
|
||||
# move the old mask
|
||||
nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.mask, arg)])
|
||||
# merge the masks if we have two
|
||||
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
||||
shape = [y-x for x,y in arg]
|
||||
return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View:
|
||||
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
||||
if any(b or e for b, e in arg):
|
||||
zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
|
||||
mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
|
||||
return self.__unsafe_resize(zvarg, mask=mask)
|
||||
return self
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
|
||||
assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape)
|
||||
return self.__unsafe_resize(arg)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def expand(self, new_shape: Tuple[sint, ...]) -> View:
|
||||
assert len(new_shape) == len(self.shape)
|
||||
assert all(is_sym_int(x) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
|
||||
# NOTE: can the mask ever be (0,0)?
|
||||
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
|
||||
return View.create(new_shape, self.strides, self.offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def permute(self, axis: Tuple[int, ...]) -> View:
|
||||
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
|
||||
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
|
||||
return View.create(tuple([self.shape[a] for a in axis]), tuple([self.strides[a] for a in axis]), self.offset, tuple([self.mask[a] for a in axis]) if self.mask is not None else None)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def stride(self, mul: Tuple[int, ...]) -> View:
|
||||
# except for the negative case, you can build this from the others. invertible in the negative case
|
||||
assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
|
||||
strides = tuple([z*m for z,m in zip(self.strides, mul)])
|
||||
new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
|
||||
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
|
||||
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
|
||||
return View.create(new_shape, strides, self.offset + offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
|
||||
if self.shape == new_shape: return self
|
||||
|
||||
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
|
||||
# check for the same size
|
||||
if all_int(self.shape):
|
||||
if all_int(new_shape):
|
||||
assert prod(self.shape) == prod(new_shape), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
|
||||
else:
|
||||
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||
assert prod(self.shape) == prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
|
||||
|
||||
# after the asserts, it's okay to check contiguous
|
||||
if self.contiguous: return View.create(new_shape)
|
||||
|
||||
# check if this is adding or removing 1s (only)
|
||||
# NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
|
||||
if [x for x in self.shape if x != 1] == [x for x in new_shape if x != 1]:
|
||||
new_strides: List[sint] = [y for x,y in zip(self.shape, self.strides) if x != 1]
|
||||
new_strides_tuple: Tuple[sint, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape])
|
||||
new_mask_tuple: Optional[Tuple[Tuple[sint, sint], ...]] = None
|
||||
if self.mask:
|
||||
for x,y in zip(self.shape, self.mask):
|
||||
if x == 1 and y != (0, 1):
|
||||
new_mask_tuple = ((0,0),) * len(new_shape)
|
||||
break
|
||||
else:
|
||||
new_mask: List[Tuple[sint, sint]] = [y for x,y in zip(self.shape, self.mask) if x != 1]
|
||||
new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape])
|
||||
return View.create(new_shape, new_strides_tuple, self.offset, new_mask_tuple)
|
||||
|
||||
# TODO: bring the merge_views logic here for more caching
|
||||
|
||||
return None
|
||||
790
tinygrad_repo/tinygrad/tensor.py
Normal file
790
tinygrad_repo/tinygrad/tensor.py
Normal file
@@ -0,0 +1,790 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math
|
||||
from collections import defaultdict
|
||||
from functools import partialmethod, reduce
|
||||
from itertools import accumulate
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set
|
||||
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import Device, LoadOps
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.realize import run_schedule
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
def __init__(self, device:str, *tensors:Tensor):
|
||||
self.device = device
|
||||
self.needs_input_grad = [t.requires_grad for t in tensors]
|
||||
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
|
||||
if self.requires_grad: self.parents = tensors
|
||||
|
||||
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
|
||||
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
|
||||
|
||||
@classmethod
|
||||
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
|
||||
ctx = fxn(x[0].device, *x)
|
||||
ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
|
||||
if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
|
||||
return ret
|
||||
|
||||
import tinygrad.mlops as mlops
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Tensor:
|
||||
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
||||
__deletable__ = ('_ctx',)
|
||||
training: ClassVar[bool] = False
|
||||
class train:
|
||||
def __init__(self, val=True): self.val = val
|
||||
def __enter__(self):
|
||||
self.prev = Tensor.training
|
||||
Tensor.training = self.val
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev
|
||||
|
||||
no_grad: ClassVar[bool] = False
|
||||
default_type: ClassVar[DType] = dtypes.float32
|
||||
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
device = Device.canonicalize(device)
|
||||
# tensors have gradients, buffers do not
|
||||
self.grad: Optional[Tensor] = None
|
||||
|
||||
# NOTE: this can be in three states. False and None: no gradient, True: gradient
|
||||
# None (the default) will be updated to True if it's put in an optimizer
|
||||
self.requires_grad: Optional[bool] = requires_grad
|
||||
|
||||
# internal variables used for autograd graph construction
|
||||
self._ctx: Optional[Function] = None
|
||||
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
elif isinstance(data, (int, float)):
|
||||
data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
|
||||
elif data.__class__ is list:
|
||||
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
|
||||
data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np))
|
||||
elif isinstance(data, np.ndarray):
|
||||
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
|
||||
if data.shape == ():
|
||||
data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
|
||||
else:
|
||||
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
||||
else: raise RuntimeError(f"can't create Tensor from {data}")
|
||||
|
||||
# data is a LazyBuffer, but it might be on the wrong device
|
||||
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
|
||||
|
||||
# Python has a non moving GC, so this should be okay
|
||||
def __hash__(self): return id(self)
|
||||
|
||||
@property
|
||||
def device(self) -> str: return self.lazydata.device
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType: return self.lazydata.dtype
|
||||
|
||||
# ***** data handlers ****
|
||||
|
||||
@staticmethod
|
||||
def corealize(lst:Iterable[Tensor]):
|
||||
seen:Set[LazyBuffer] = set()
|
||||
sched = []
|
||||
for t in lst: sched += t.lazydata.schedule(seen)
|
||||
run_schedule(sched)
|
||||
|
||||
def realize(self) -> Tensor:
|
||||
run_schedule(self.lazydata.schedule())
|
||||
return self
|
||||
|
||||
def assign(self, x) -> Tensor:
|
||||
# TODO: this is a hack for writing to DISK
|
||||
if self.device.startswith("DISK"):
|
||||
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
|
||||
self.contiguous().realize().lazydata.realized._copyin(x.numpy()) # type: ignore
|
||||
return self
|
||||
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
||||
assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
||||
if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
|
||||
self.lazydata = x.lazydata
|
||||
return self
|
||||
|
||||
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||
def numpy(self) -> np.ndarray:
|
||||
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
|
||||
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
|
||||
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape)
|
||||
|
||||
# TODO: if things are realized this won't work
|
||||
def to_(self, device:str):
|
||||
assert self.lazydata.realized is None
|
||||
self.lazydata.device = device
|
||||
if self.grad: self.grad.to_(device)
|
||||
|
||||
def to(self, device:str) -> Tensor:
|
||||
ret = Tensor(self.lazydata, device)
|
||||
if self.grad: ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
|
||||
# ***** creation llop entrypoint *****
|
||||
|
||||
@staticmethod
|
||||
def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
|
||||
return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def empty(*shape, **kwargs):
|
||||
assert all_int(shape), f"cannot create with symbolic shape {shape}"
|
||||
return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape)
|
||||
|
||||
_seed: int = int(time.time())
|
||||
@staticmethod
|
||||
def manual_seed(seed=0): Tensor._seed = seed
|
||||
|
||||
@staticmethod
|
||||
def rand(*shape, **kwargs):
|
||||
assert all_int(shape), f"cannot create with symbolic shape {shape}"
|
||||
Tensor._seed += 1
|
||||
return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
@staticmethod
|
||||
def full(shape:Tuple[sint, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
|
||||
@staticmethod
|
||||
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def arange(start, stop=None, step=1, **kwargs):
|
||||
if stop is None: stop, start = start, 0
|
||||
return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
|
||||
|
||||
@staticmethod
|
||||
def eye(dim:int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
|
||||
|
||||
def full_like(self, fill_value, **kwargs):
|
||||
return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
||||
def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
|
||||
def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
|
||||
|
||||
# ***** rng hlops *****
|
||||
|
||||
@staticmethod
|
||||
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
|
||||
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
||||
src = Tensor.rand(2, *shape, **kwargs)
|
||||
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
|
||||
|
||||
@staticmethod
|
||||
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
|
||||
|
||||
@staticmethod
|
||||
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
|
||||
dtype = kwargs.pop("dtype", Tensor.default_type)
|
||||
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
|
||||
|
||||
@staticmethod
|
||||
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(shape)**-0.5)
|
||||
|
||||
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
|
||||
@staticmethod
|
||||
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)
|
||||
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
|
||||
@staticmethod
|
||||
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
|
||||
@staticmethod
|
||||
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
|
||||
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
def deepwalk(self):
|
||||
def _deepwalk(node, visited, nodes):
|
||||
visited.add(node)
|
||||
if getattr(node, "_ctx", None):
|
||||
for i in node._ctx.parents:
|
||||
if i not in visited: _deepwalk(i, visited, nodes)
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
return _deepwalk(self, set(), [])
|
||||
|
||||
def backward(self):
|
||||
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
|
||||
|
||||
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
||||
# this is "implicit gradient creation"
|
||||
self.grad = Tensor(1, device=self.device, requires_grad=False)
|
||||
|
||||
for t0 in reversed(self.deepwalk()):
|
||||
assert (t0.grad is not None)
|
||||
grads = t0._ctx.backward(t0.grad.lazydata)
|
||||
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
||||
for t, g in zip(t0._ctx.parents, grads):
|
||||
if g is not None and t.requires_grad:
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
del t0._ctx
|
||||
|
||||
# ***** movement mlops *****
|
||||
def reshape(self, shape, *args) -> Tensor:
|
||||
new_shape = argfix(shape, *args)
|
||||
assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}"
|
||||
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]))
|
||||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
|
||||
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
|
||||
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)
|
||||
|
||||
# ***** movement hlops *****
|
||||
|
||||
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
||||
# - A slice i:j returns the elements with indices in [i, j)
|
||||
# - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
|
||||
# - Negative values for i and j are taken relative to the end of the sequence
|
||||
# - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence
|
||||
# - Indexing with None on a given axis will add a new dimension of size one before that axis
|
||||
# - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends).
|
||||
# - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len).
|
||||
# - Strides > 1 and < 0 are now allowed!:
|
||||
# - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional)
|
||||
# - Idea of stride < 0 support:
|
||||
# - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
|
||||
# - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink):
|
||||
# - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s].
|
||||
# - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
|
||||
# is possible.
|
||||
# - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s].
|
||||
# - Fancy indexing and combined indexing is supported
|
||||
# - Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing
|
||||
# - Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively
|
||||
# - The first iteration will expand the dim of self while consecutive iterations will reduce the dim
|
||||
# - There's a special case where a permute is needed at the end:
|
||||
# - if first Tensor passed in (expand dims) is not at dim 0
|
||||
# - and following Tensors does not follow consecutively to the end of fancy indexing's dims
|
||||
def __getitem__(self, val): # val: Union[int, slice, Tensor, None, Ellipsis, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]]
|
||||
def normalize_int(e, i, dim_sz):
|
||||
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
|
||||
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
|
||||
|
||||
orig_slices = list(val) if isinstance(val, tuple) else [val]
|
||||
count = defaultdict(list)
|
||||
for i,v in enumerate(orig_slices): count[type(v)].append(i)
|
||||
|
||||
if (num_slices := len(count[int]) + len(count[slice]) + len(count[Tensor])) > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
|
||||
if len(ellipsis_found := count[type(Ellipsis)]) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
|
||||
|
||||
ellipsis_idx = ellipsis_found[0] if ellipsis_found else len(orig_slices)
|
||||
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
|
||||
|
||||
valid_slices = [v for v in orig_slices if v is not None]
|
||||
valid_slices = [v if isinstance(v, slice) else slice(y_ := normalize_int(v, i, dim_sz), y_+1) if isinstance(v, int) else slice(None) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
|
||||
|
||||
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
|
||||
new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides))
|
||||
sliced_tensor = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0])
|
||||
new_shape = sliced_tensor.shape
|
||||
if any(abs(s) != 1 for s in strides):
|
||||
strides = tuple(abs(s) for s in strides)
|
||||
# Pad: add pad at the end: [dim_sz] -> [dim_sz_padded]
|
||||
padded_tensor = sliced_tensor.pad(tuple((0, s-(dim_sz % s) if dim_sz % s != 0 else 0) for s, dim_sz in zip(strides, sliced_tensor.shape)))
|
||||
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
|
||||
reshaped_tensor = padded_tensor.reshape(flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides)))
|
||||
new_shape = reshaped_tensor.shape[::2]
|
||||
# Shrink: do [:, 0]
|
||||
sliced_tensor = reshaped_tensor.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in new_shape)))
|
||||
|
||||
final_shape, it_shape, dim, tensors, dim_collapsed = [], iter(new_shape), [], [], 0
|
||||
for i,s in enumerate(orig_slices):
|
||||
if s is None: final_shape.append(1)
|
||||
else: # s is int or slice or Tensor
|
||||
dim_shape = next(it_shape)
|
||||
if isinstance(s, int):
|
||||
dim_collapsed += 1
|
||||
else:
|
||||
assert isinstance(dim_shape, int), f"does not support symbolic shape {dim_shape}"
|
||||
final_shape.append(dim_shape)
|
||||
if isinstance(s, Tensor):
|
||||
tensors.append(s)
|
||||
dim.append(i-dim_collapsed)
|
||||
ret = sliced_tensor.reshape(tuple(final_shape))
|
||||
|
||||
if tensors: # Fancy/tensor indexing
|
||||
# normalize idx
|
||||
# TODO: first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm
|
||||
idx = [t.sign().contiguous().__neg__().contiguous().relu() * ret.shape[d] + t for d,t in zip(dim, tensors)]
|
||||
max_dim = max(i.ndim for i in idx)
|
||||
# compute sum_dim, arange, and idx
|
||||
sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(dim)]
|
||||
arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, dim))]
|
||||
first_idx = [idx[0].reshape(*[1]*dim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - dim[0] - 1))]
|
||||
rest_idx = [i.reshape(*[1]*dim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - dim[0] - n)) for n,i in enumerate(idx[1:], 1)]
|
||||
idx = first_idx + rest_idx
|
||||
ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:])
|
||||
# iteratively fancy index
|
||||
for a,i,sd in zip(arange, idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
|
||||
# special permute case
|
||||
if dim[0] != 0 and len(dim) != 1 and dim != list(range(dim[0], dim[-1]+1)):
|
||||
ret_dims = list(range(ret.ndim))
|
||||
ret = ret.permute(ret_dims[dim[0]:dim[0]+max_dim] + ret_dims[:dim[0]] + ret_dims[dim[0]+max_dim:])
|
||||
return ret
|
||||
|
||||
def __setitem__(self,s,v): return self.__getitem__(s).assign(v)
|
||||
|
||||
# NOTE: using slice is discouraged and things should migrate to pad and shrink
|
||||
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
|
||||
arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
|
||||
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
|
||||
return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
|
||||
|
||||
def gather(self: Tensor, idx: Tensor, dim: int):
|
||||
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
|
||||
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
|
||||
if dim < 0: dim += self.ndim
|
||||
idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
|
||||
permarg = list(range(self.ndim))
|
||||
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
|
||||
return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
|
||||
|
||||
def cat(self, *args, dim=0):
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
|
||||
catargs = [self, *args]
|
||||
assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated"
|
||||
shapes = [s.shape[dim] for s in catargs]
|
||||
shape_cumsum = [0, *accumulate(shapes)]
|
||||
slc = [[(0, 0) for _ in self.shape] for _ in catargs]
|
||||
for shp,k,s in zip(shapes, shape_cumsum[:-1], slc):
|
||||
s[dim] = (k, shape_cumsum[-1] - k - shp)
|
||||
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
||||
|
||||
@staticmethod
|
||||
def stack(tensors, dim=0):
|
||||
first = tensors[0].unsqueeze(dim)
|
||||
unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors[1:]]
|
||||
# checks for shapes and number of dimensions delegated to cat
|
||||
return first.cat(*unsqueezed_tensors, dim=dim)
|
||||
|
||||
def repeat(self, repeats):
|
||||
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
|
||||
new_shape = [x for b in base_shape for x in [1, b]]
|
||||
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
|
||||
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
||||
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
||||
|
||||
def chunk(self, num:int, dim:int) -> List[Tensor]:
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
|
||||
slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
|
||||
return [self[tuple(sl)] for sl in slice_params]
|
||||
|
||||
def squeeze(self, dim=None):
|
||||
if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1])
|
||||
if dim <= 0 and self.ndim == 0: return self # This is to match PyTorch behavior
|
||||
if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})")
|
||||
if dim < 0: dim += self.ndim
|
||||
return self if self.shape[dim] != 1 else self.reshape(*[size for idx, size in enumerate(self.shape) if idx != dim])
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
if dim < 0: dim = len(self.shape) + dim + 1
|
||||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
def pad2d(self, padding:Union[List[int], Tuple[int, ...]], value:float=0):
|
||||
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
|
||||
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
|
||||
|
||||
@property
|
||||
def T(self) -> Tensor: return self.transpose()
|
||||
def transpose(self, ax1=1, ax2=0) -> Tensor:
|
||||
order = list(range(len(self.shape)))
|
||||
order[ax1], order[ax2] = order[ax2], order[ax1]
|
||||
return self.permute(order)
|
||||
def flatten(self, start_dim=0): return self.reshape(shape=self.shape[:start_dim] + (-1,))
|
||||
|
||||
# ***** reduce ops *****
|
||||
|
||||
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
|
||||
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if axis.__class__ is int else list(axis)) # type: ignore
|
||||
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
|
||||
shape = [s for i,s in enumerate(self.shape) if i not in axis_]
|
||||
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
|
||||
return ret if keepdim else ret.reshape(shape=shape)
|
||||
|
||||
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
|
||||
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
|
||||
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
|
||||
|
||||
def mean(self, axis=None, keepdim=False):
|
||||
assert all_int(self.shape), "does not support symbolic shape"
|
||||
out = self.sum(axis=axis, keepdim=keepdim)
|
||||
return out.mul(prod(out.shape)/prod(self.shape))
|
||||
def std(self, axis=None, keepdim=False, correction=1):
|
||||
assert all_int(self.shape), "does not support symbolic shape"
|
||||
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
||||
return square_sum.div(prod(self.shape)/prod(square_sum.shape)-correction).sqrt()
|
||||
def _softmax(self, axis):
|
||||
m = self - self.max(axis=axis, keepdim=True)
|
||||
e = m.exp()
|
||||
return m, e, e.sum(axis=axis, keepdim=True)
|
||||
|
||||
def softmax(self, axis=-1):
|
||||
_, e, ss = self._softmax(axis)
|
||||
return e.div(ss)
|
||||
|
||||
def log_softmax(self, axis=-1):
|
||||
m, _, ss = self._softmax(axis)
|
||||
return m - ss.log()
|
||||
|
||||
def argmax(self, axis=None, keepdim=False):
|
||||
if axis is None:
|
||||
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
|
||||
return prod(self.shape) - idx.max() - 1
|
||||
axis = axis + len(self.shape) if axis < 0 else axis
|
||||
m = self == self.max(axis=axis, keepdim=True)
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
|
||||
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
|
||||
|
||||
# ***** processing ops *****
|
||||
|
||||
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
|
||||
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
|
||||
assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
||||
slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
|
||||
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
|
||||
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
|
||||
e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
|
||||
xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)])
|
||||
# slide by dilation
|
||||
xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
|
||||
xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
|
||||
xup = xup.slice(slc_prefix + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_)))
|
||||
# handle stride, and permute to move reduce to the end
|
||||
xup = xup.reshape(*prefix, *flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
|
||||
xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))
|
||||
xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_)))
|
||||
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))])
|
||||
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
|
||||
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
|
||||
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
|
||||
xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
||||
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
|
||||
|
||||
# NOTE: these work for more than 2D
|
||||
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
||||
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
||||
|
||||
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
|
||||
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
|
||||
x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing)
|
||||
stride = make_pair(stride, len(HW))
|
||||
if any(s>1 for s in stride):
|
||||
x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:]))
|
||||
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
|
||||
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])
|
||||
x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
||||
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
|
||||
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
wino = int(getenv("WINO", "0"))
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
|
||||
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
||||
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
|
||||
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}"
|
||||
padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1])
|
||||
|
||||
# conv2d is a pooling op (with padding)
|
||||
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
||||
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
||||
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not Tensor.wino:
|
||||
# normal conv
|
||||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
|
||||
|
||||
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
||||
|
||||
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
||||
def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat])
|
||||
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
||||
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
|
||||
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
|
||||
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order almost doubles compilation time
|
||||
|
||||
# todo: stride == dilation
|
||||
# use padding to round up to 4x4 output tiles
|
||||
d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # (bs, cin_, tyx, HWI)
|
||||
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))).contiguous_backward() # move HW to the front: # (HWI, bs, cin_, tyx)
|
||||
tyx = d.shape[-len(HWI):] # dim of tiling
|
||||
|
||||
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
|
||||
|
||||
# compute 6x6 winograd tiles: GgGt, BtdB
|
||||
gfactors = apply_matrix(winograd_G, g).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
|
||||
dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx) # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
|
||||
|
||||
ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW))) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
||||
|
||||
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
|
||||
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx])) # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
|
||||
|
||||
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
||||
|
||||
def dot(self, w:Tensor) -> Tensor:
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
||||
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
|
||||
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
||||
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
||||
return (x*w).sum(-1)
|
||||
|
||||
def cumsum(self, axis:int=0) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-1,0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
||||
|
||||
# ***** mlops (unary) *****
|
||||
|
||||
def __neg__(self): return mlops.Neg.apply(self)
|
||||
def contiguous(self): return mlops.Contiguous.apply(self)
|
||||
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
|
||||
def log(self): return mlops.Log.apply(self)
|
||||
def log2(self): return mlops.Log.apply(self)/math.log(2)
|
||||
def exp(self): return mlops.Exp.apply(self)
|
||||
def exp2(self): return mlops.Exp.apply(self*math.log(2))
|
||||
def relu(self): return mlops.Relu.apply(self)
|
||||
def sigmoid(self): return mlops.Sigmoid.apply(self)
|
||||
def sin(self): return mlops.Sin.apply(self)
|
||||
def sqrt(self): return mlops.Sqrt.apply(self)
|
||||
def rsqrt(self): return (1/self).sqrt()
|
||||
def cos(self): return ((math.pi/2)-self).sin()
|
||||
def tan(self): return self.sin() / self.cos()
|
||||
|
||||
@staticmethod
|
||||
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
def triu(self, k:int=0) -> Tensor:
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
|
||||
def tril(self, k:int=0) -> Tensor:
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)
|
||||
|
||||
# ***** math functions (unary) *****
|
||||
def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype)
|
||||
def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b)
|
||||
def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b)
|
||||
|
||||
def square(self): return self*self
|
||||
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
|
||||
def abs(self): return self.relu() + (-self).relu()
|
||||
def sign(self): return self / (self.abs() + 1e-10)
|
||||
def reciprocal(self): return 1.0/self
|
||||
|
||||
# ***** activation functions (unary) *****
|
||||
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
|
||||
def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
||||
def swish(self): return self * self.sigmoid()
|
||||
def silu(self): return self.swish() # The SiLU function is also known as the swish function.
|
||||
def relu6(self): return self.relu() - (self-6).relu()
|
||||
def hardswish(self): return self * (self+3).relu6() * (1/6)
|
||||
def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
||||
def hardtanh(self, min_val=-1, max_val=1): return self.clip(min_val, max_val)
|
||||
def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
|
||||
def quick_gelu(self): return self * (self * 1.702).sigmoid()
|
||||
def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
|
||||
def mish(self): return self * self.softplus().tanh()
|
||||
def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
|
||||
def softsign(self): return self / (1 + self.abs())
|
||||
|
||||
# ***** broadcasted binary mlops *****
|
||||
|
||||
def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tensor, Tensor]:
|
||||
x: Tensor = self
|
||||
if not isinstance(y, Tensor):
|
||||
y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32)
|
||||
if reverse: x, y = y, x
|
||||
if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
|
||||
|
||||
shape_delta = len(xshape) - len(yshape)
|
||||
if shape_delta > 0: y = y.reshape((1,) * shape_delta + yshape)
|
||||
elif shape_delta < 0: x = x.reshape((1,) * -shape_delta + xshape)
|
||||
if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
|
||||
|
||||
shape_ret = tuple([max(x, y) for x, y in zip(xshape, yshape)])
|
||||
if xshape != shape_ret: x = x.expand(shape_ret)
|
||||
if yshape != shape_ret: y = y.expand(shape_ret)
|
||||
return (x, y)
|
||||
|
||||
def _to_float(self, x:Union[Tensor, float]):
|
||||
return x.lazydata.op.arg if isinstance(x, Tensor) and not x.lazydata.realized and x.lazydata.op.op == LoadOps.CONST and not x.requires_grad \
|
||||
and x.lazydata.st.contiguous and self._broadcasted(x)[0].shape == self.shape else x
|
||||
|
||||
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self
|
||||
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else (-self if reverse else self)
|
||||
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self)
|
||||
if x.__class__ is not Tensor and x == -1.0: return -self
|
||||
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
|
||||
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x)
|
||||
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
if x.__class__ is not Tensor and not reverse:
|
||||
# simple pow identities
|
||||
if x < 0: return self.reciprocal().pow(-x)
|
||||
if x == 3.0: return self*self*self
|
||||
if x == 2.0: return self*self
|
||||
if x == 1.0: return self
|
||||
if x == 0.5: return self.sqrt()
|
||||
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
||||
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
|
||||
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
|
||||
sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos()
|
||||
# we only need to correct the sign if the base is negative
|
||||
base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2
|
||||
# we need 0 to be positive so we need to correct base_sign when the base is 0
|
||||
base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x))))))
|
||||
# inject nan if the base is negative and the power is not an integer
|
||||
to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign
|
||||
inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
|
||||
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
|
||||
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
|
||||
|
||||
def maximum(self, x:Union[Tensor, float]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
|
||||
def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
|
||||
|
||||
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
|
||||
x_,y = self._broadcasted(input_)
|
||||
x,z = x_._broadcasted(other)
|
||||
return mlops.Where.apply(x, *y._broadcasted(z))
|
||||
|
||||
# ***** binary op wrappers (18 wasted lines to make the typechecker happy) *****
|
||||
|
||||
# NOTE: __pow__ and friends are broken in mypyc with the ** operator
|
||||
def __add__(self, x) -> Tensor: return self.add(x)
|
||||
def __sub__(self, x) -> Tensor: return self.sub(x)
|
||||
def __mul__(self, x) -> Tensor: return self.mul(x)
|
||||
def __pow__(self, x) -> Tensor: return self.pow(x)
|
||||
def __truediv__(self, x) -> Tensor: return self.div(x)
|
||||
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
||||
|
||||
def __radd__(self, x) -> Tensor: return self.add(x, True)
|
||||
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
|
||||
def __rmul__(self, x) -> Tensor: return self.mul(x, True)
|
||||
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
|
||||
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
|
||||
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
|
||||
|
||||
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
|
||||
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
|
||||
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
|
||||
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
|
||||
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
||||
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
||||
|
||||
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
|
||||
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
|
||||
def __ge__(self, x) -> Tensor: return 1.0-(self<x)
|
||||
def __le__(self, x) -> Tensor: return 1.0-(self>x)
|
||||
def __ne__(self, x) -> Tensor: return (self<x) + (self>x) # type: ignore
|
||||
def __eq__(self, x) -> Tensor: return 1.0-(self != x) # type: ignore
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
|
||||
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
||||
return x.add(bias) if bias is not None else x
|
||||
|
||||
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return reduce(lambda x,f: f(x), ll, self)
|
||||
|
||||
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
|
||||
y = (self - self.mean(axis, keepdim=True))
|
||||
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
||||
|
||||
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
|
||||
x = (self - mean.reshape(shape=[1, -1, 1, 1]))
|
||||
if weight: x = x * weight.reshape(shape=[1, -1, 1, 1])
|
||||
ret = x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd)
|
||||
return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret
|
||||
|
||||
def dropout(self, p=0.5) -> Tensor:
|
||||
if not Tensor.training or p == 0: return self
|
||||
mask = (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p).cast(dtypes.bool)
|
||||
return self * mask * (1/(1.0 - p))
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
||||
# NOTE: it works if key, value have symbolic shape
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
||||
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
|
||||
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
|
||||
|
||||
def binary_crossentropy(self, y:Tensor) -> Tensor:
|
||||
return (-y*self.log() - (1-y)*(1-self).log()).mean()
|
||||
|
||||
def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
|
||||
return (self.maximum(0) - y * self + (1 + self.abs().__neg__().exp()).log()).mean()
|
||||
|
||||
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
|
||||
loss_mask = Y != ignore_index
|
||||
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
return self.log_softmax().mul(y).sum() / loss_mask.sum()
|
||||
|
||||
# ***** cast ops *****
|
||||
|
||||
def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
|
||||
def bitcast(self, dtype:DType) -> Tensor:
|
||||
assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
|
||||
return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
||||
def float(self) -> Tensor: return self.cast(dtypes.float32)
|
||||
def half(self) -> Tensor: return self.cast(dtypes.float16)
|
||||
|
||||
# ***** convenience stuff *****
|
||||
|
||||
@property
|
||||
def ndim(self) -> int: return len(self.shape)
|
||||
def numel(self) -> sint: return prod(self.shape)
|
||||
def element_size(self) -> int: return self.dtype.itemsize
|
||||
def nbytes(self) -> int: return self.numel() * self.element_size()
|
||||
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
||||
|
||||
# register functions to move between devices
|
||||
for device in Device._buffers:
|
||||
setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
|
||||
setattr(Tensor, f"{device.lower()}_", partialmethod(Tensor.to_, device))
|
||||
|
||||
if IMAGE:
|
||||
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
|
||||
from tinygrad.features.image import image_conv2d, image_dot
|
||||
setattr(Tensor, "conv2d", image_conv2d)
|
||||
setattr(Tensor, "dot", image_dot)
|
||||
Reference in New Issue
Block a user