openpilot v0.9.6 release
date: 2024-02-21T23:02:42 master commit: 0b4d08fab8e35a264bc7383e878538f8083c33e5
This commit is contained in:
209
tinygrad_repo/extra/onnx.py
Normal file
209
tinygrad_repo/extra/onnx.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from __future__ import annotations
|
||||
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
|
||||
import importlib
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
|
||||
from typing import List,Dict
|
||||
from onnx.onnx_pb import AttributeProto, ModelProto, TensorProto, TypeProto
|
||||
try:
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
except ImportError:
|
||||
# for onnx < 1.13
|
||||
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
||||
tensor_dtype_to_np_dtype = lambda x: TENSOR_TYPE_TO_NP_TYPE[x]
|
||||
|
||||
# global numpy cache for parameters
|
||||
numpy_cache = {}
|
||||
def safe_numpy(t) -> np.ndarray:
|
||||
if not isinstance(t, Tensor): return t
|
||||
global numpy_cache
|
||||
if t not in numpy_cache:
|
||||
if DEBUG >= 3: print("numpy cache miss", t)
|
||||
tmp = t.numpy()
|
||||
numpy_cache[t] = tmp if len(tmp.shape) else tmp.reshape(1)
|
||||
assert len(numpy_cache[t].shape) > 0
|
||||
return numpy_cache[t]
|
||||
|
||||
onnx_ops = importlib.import_module('extra.onnx_ops')
|
||||
|
||||
ONNXLIMIT = getenv("ONNXLIMIT", -1)
|
||||
|
||||
def get_run_onnx(onnx_model: ModelProto):
|
||||
def type_parse(type_proto: TypeProto):
|
||||
ret = []
|
||||
while True:
|
||||
attr = type_proto.WhichOneof('value')
|
||||
if attr == 'tensor_type':
|
||||
if "dim_value" not in getattr(type_proto, attr).shape.dim.__dir__(): return () # variable type, unable to determine shape
|
||||
elif not ret:
|
||||
return tuple([x.dim_value for x in getattr(type_proto, attr).shape.dim])
|
||||
else:
|
||||
ret.extend([(x.dim_value,) for x in getattr(type_proto, attr).shape.dim])
|
||||
return tuple(ret)
|
||||
elif attr == 'sequence_type':
|
||||
type_proto = getattr(type_proto, attr).elem_type
|
||||
ret.append(1)
|
||||
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
|
||||
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
|
||||
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
|
||||
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
|
||||
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
|
||||
|
||||
def buffer_parse(inp: TensorProto) -> Tensor:
|
||||
if inp.data_type in (1,10,6,7):
|
||||
# TODO: this is shared with below
|
||||
if len(inp.float_data) > 0:
|
||||
ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int64_data) > 0:
|
||||
ret = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int32_data) > 0:
|
||||
ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False)
|
||||
else:
|
||||
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
|
||||
else:
|
||||
raise Exception(f"bad data type {inp.name} {inp.dims} {inp.data_type}")
|
||||
return ret
|
||||
|
||||
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
|
||||
# TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list
|
||||
if a.type == AttributeProto.FLOAT: return float(a.f)
|
||||
elif a.type == AttributeProto.INT: return int(a.i)
|
||||
elif a.type == AttributeProto.STRING: return a.s.decode("utf-8")
|
||||
elif a.type == AttributeProto.TENSOR: return buffer_parse(a.t) # TENSOR
|
||||
elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats)
|
||||
elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints)
|
||||
elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings)
|
||||
elif a.type == AttributeProto.GRAPH: raise Exception(f"graph not implemented: {a.g}")
|
||||
else: raise Exception(f"can't parse {a.type} {a}")
|
||||
def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): return {x.name:attribute_parse(x) for x in a}
|
||||
|
||||
tensors: Dict[str, Tensor] = {}
|
||||
|
||||
# get weights and biases
|
||||
for inp in onnx_model.graph.initializer:
|
||||
if len(inp.raw_data) > 0:
|
||||
tensors[inp.name] = buffer_parse(inp)
|
||||
elif len(inp.float_data) > 0:
|
||||
tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int64_data) > 0:
|
||||
tensors[inp.name] = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.raw_data) == 0:
|
||||
tensors[inp.name] = Tensor(np.array([], dtype=np.float32), requires_grad=False)
|
||||
else:
|
||||
print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
|
||||
print(inp)
|
||||
raise Exception("no data")
|
||||
|
||||
# preparse the attributes
|
||||
attribute_dict = {}
|
||||
domain = ""
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
attribute_dict[num] = attribute_to_dict(n.attribute)
|
||||
if n.domain: domain = n.domain
|
||||
|
||||
onnx_model_version = onnx_model.opset_import[0].version
|
||||
|
||||
def run_onnx(inputs={}, debug=0):
|
||||
debug = getenv("DEBUGONNX") or debug
|
||||
input_tensors: Dict[str,Tensor] = {}
|
||||
intermediate_tensors: Dict[str,Tensor] = {}
|
||||
output_tensor_names = [x.name for x in onnx_model.graph.output]
|
||||
|
||||
# get inputs
|
||||
for inp in onnx_model.graph.input:
|
||||
if inp.name in tensors: continue
|
||||
shape = type_parse(inp.type)
|
||||
if inp.name in inputs:
|
||||
if isinstance(inputs[inp.name], Tensor):
|
||||
input_tensors[inp.name] = inputs[inp.name]
|
||||
elif isinstance(inputs[inp.name], list):
|
||||
input_tensors[inp.name] = [Tensor(i, requires_grad=False) for i in inputs[inp.name]]
|
||||
elif domain == "ai.onnx.preview.training": # not sure if in real use the domain is "ai.onnx.preview.training"
|
||||
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=True) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops
|
||||
else:
|
||||
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=False)
|
||||
if shape: # if only input_tensor is not variable type
|
||||
input_shape = input_tensors[inp.name].shape if isinstance(input_tensors[inp.name], Tensor) else (1, *[i.shape for i in input_tensors[inp.name]])
|
||||
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
|
||||
else:
|
||||
raise Exception(f"no data for {inp.name} with shape {shape}")
|
||||
|
||||
def fetch_tensor(x: str):
|
||||
if x in tensors: return tensors[x]
|
||||
if x in intermediate_tensors: return intermediate_tensors[x]
|
||||
if x != str(): return input_tensors[x]
|
||||
return None
|
||||
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
inp: List[Tensor] = []
|
||||
if debug >= 3: print("inputs:")
|
||||
for x in n.input:
|
||||
t = fetch_tensor(x)
|
||||
if debug >= 3: print(f"\t{x} - {t}")
|
||||
inp.append(t)
|
||||
opt: Dict = attribute_dict[num]
|
||||
if debug >= 1: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
|
||||
# some ops live here because they require some local variables
|
||||
if n.op_type == "Split": # have to use n.output for cases when num_outputs is absent
|
||||
axis = opt.get("axis", 0)
|
||||
split = None if len(inp) == 1 else [int(x) for x in safe_numpy(inp[1])]
|
||||
if split is None:
|
||||
split = [inp[0].shape[axis] // len(n.output)] * len(n.output)
|
||||
for i in range(inp[0].shape[axis] % len(n.output)):
|
||||
split[i] += 1
|
||||
i, ret = 0, []
|
||||
arg = [(0,x) for x in inp[0].shape]
|
||||
for s in split:
|
||||
arg[axis] = (i,i+s)
|
||||
ret.append(inp[0].shrink(arg=tuple(arg)))
|
||||
i = i+s
|
||||
ret = tuple(ret)
|
||||
elif n.op_type == "Slice": # need to check onnx_model_version
|
||||
if onnx_model_version < 10:
|
||||
axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim
|
||||
else:
|
||||
starts, ends = inp[1:3]
|
||||
axes = safe_numpy(Tensor.arange(inp[0].ndim, dtype=dtypes.int32) if len(inp) <= 3 else inp[3]).tolist()
|
||||
steps = safe_numpy(inp[4]) if len(inp) > 4 else [1]*inp[0].ndim
|
||||
starts, ends = safe_numpy(starts.ceil().cast(dtypes.int32)).tolist(), safe_numpy(ends.ceil().cast(dtypes.int32)).tolist()
|
||||
arg = [(0,x,1) for x in inp[0].shape]
|
||||
for i, axis in enumerate(axes):
|
||||
axis = int(axis) + inp[0].ndim if axis < 0 else int(axis)
|
||||
starts[i], ends[i] = starts[i] + inp[0].shape[axis] if starts[i] < 0 else starts[i], ends[i] + inp[0].shape[axis] if ends[i] < 0 else ends[i]
|
||||
starts[i], ends[i] = max(0, min(starts[i], inp[0].shape[axis])), max(0, min(ends[i], inp[0].shape[axis]))
|
||||
if starts[i] > ends[i] and steps[i] >= 0: steps[i] = -steps[i]
|
||||
arg[axis] = (starts[i], ends[i], steps[i])
|
||||
new_shape = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in arg)
|
||||
if any(s==e for s,e in new_shape): ret = inp[0].shrink(new_shape)
|
||||
else: ret = inp[0].__getitem__(tuple([slice(s,e,st) for s,e,st in arg]))
|
||||
elif n.op_type == "Gradient": # need to call backward on intermediate_tensors
|
||||
assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"
|
||||
y = opt["y"]
|
||||
intermediate_tensors[y].backward()
|
||||
ret = tuple([t.grad for t in inp])
|
||||
elif hasattr(onnx_ops, n.op_type):
|
||||
fxn = getattr(onnx_ops, n.op_type)
|
||||
if isinstance(fxn, dict):
|
||||
for k in sorted(fxn.keys()):
|
||||
if k <= onnx_model_version:
|
||||
real_fxn = fxn[k]
|
||||
else:
|
||||
real_fxn = fxn
|
||||
ret = real_fxn(*inp, **opt)
|
||||
else:
|
||||
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
||||
raise Exception(f"op_type {n.op_type} not supported")
|
||||
if not isinstance(ret, tuple): ret = (ret, )
|
||||
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
|
||||
if debug >= 2: print([x.shape if isinstance(x, Tensor) else None for x in ret])
|
||||
if debug >= 2: print("outputs:")
|
||||
for i in range(len(n.output)):
|
||||
if debug >= 2: print(f"\t{n.output[i]} - {ret[i]}")
|
||||
intermediate_tensors[n.output[i]] = ret[i]
|
||||
if num == ONNXLIMIT:
|
||||
output_tensor_names = n.output
|
||||
break
|
||||
|
||||
return {outp:intermediate_tensors[outp] for outp in output_tensor_names}
|
||||
return run_onnx
|
||||
720
tinygrad_repo/extra/onnx_ops.py
Normal file
720
tinygrad_repo/extra/onnx_ops.py
Normal file
@@ -0,0 +1,720 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod, dtypes, ImageDType
|
||||
from extra.onnx import safe_numpy
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from onnx.onnx_pb import TensorProto
|
||||
import os
|
||||
import numpy as np
|
||||
import functools
|
||||
from typing import Union, Tuple, Optional, List, Any
|
||||
import math
|
||||
|
||||
# **************** Free Ops ****************
|
||||
|
||||
def Identity(input: Tensor): return input
|
||||
def Neg(input: Tensor): return -input
|
||||
def Add(input: Tensor, other: Tensor, broadcast=None): return input + other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input + other).cast(input.dtype)
|
||||
def Sub(input: Union[Tensor, Any], other: Tensor): return input - other # some test has input as int
|
||||
def Mul(input: Tensor, other: Tensor): return (input * other) if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input * other).cast(input.dtype)
|
||||
# in openpilot, due to SHUFFLE_PAD_OPS issues, we are spending an extra kernel
|
||||
def Div(input: Tensor, other: Tensor): return input / other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else input.div(other).floor()
|
||||
def Pow(input: Tensor, other: Tensor): return (input.float() ** other.float()).cast(input.dtype)
|
||||
def Reciprocal(input: Tensor): return input.reciprocal()
|
||||
def Sqrt(input: Tensor): return input.sqrt()
|
||||
def Sign(input: Tensor): return input.sign()
|
||||
def Abs(input: Tensor): return input.abs()
|
||||
def Exp(input: Tensor): return input.exp()
|
||||
def Log(input: Tensor): return input.log()
|
||||
def Mish(input: Tensor): return input.mish()
|
||||
def Sin(x: Tensor): return x.sin()
|
||||
def Cos(x: Tensor): return x.cos()
|
||||
def Tan(x: Tensor): return x.tan()
|
||||
def Relu(input: Tensor): return input.relu()
|
||||
def Sigmoid(input: Tensor): return input.sigmoid()
|
||||
def Tanh(input: Tensor): return input.tanh()
|
||||
def MatMul(input: Tensor, other: Tensor): return input.matmul(other)
|
||||
def Floor(x:Tensor): return x.floor()
|
||||
def Ceil(x:Tensor): return x.ceil()
|
||||
def Less(x:Tensor,y:Tensor): return (x<y).cast(dtypes.bool)
|
||||
def LessOrEqual(x:Tensor,y:Tensor): return (x<=y).cast(dtypes.bool)
|
||||
def Greater(x:Tensor,y:Tensor): return (x>y).cast(dtypes.bool)
|
||||
def GreaterOrEqual(x:Tensor,y:Tensor): return (x>=y).cast(dtypes.bool)
|
||||
def Equal(x:Tensor,y:Tensor): return (x==y).cast(dtypes.bool)
|
||||
def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
|
||||
def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
|
||||
def Sum(*data_0): return functools.reduce(Tensor.__add__, data_0)
|
||||
def Mean(*data_0): return functools.reduce(Tensor.__add__, data_0) / len(data_0)
|
||||
def Where(condition:Tensor,X:Tensor,Y:Tensor): return condition.where(X, Y).cast(X.dtype)
|
||||
def Cast(input: Tensor, to): return input.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to)))
|
||||
|
||||
# **************** Simple Ops ****************
|
||||
|
||||
def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None):
|
||||
if value: return value
|
||||
elif value_float: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
|
||||
elif value_floats: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
|
||||
elif value_int: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
|
||||
elif value_ints: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
|
||||
elif value_string or value_strings: raise NotImplementedError(f'value_string or value_strings not implemented for Constant op')
|
||||
|
||||
def Softsign(input: Tensor): return input / (1+input.abs())
|
||||
def Cosh(x): return (math.e ** x + math.e ** -x) / 2
|
||||
def Sinh(x): return (math.e ** x - math.e ** -x) / 2
|
||||
def Tanh(x): return x.tanh()
|
||||
|
||||
def HardSigmoid(input: Tensor, alpha=0.2, beta=0.5): return (alpha*input + beta).clip(0, 1)
|
||||
def HardSwish(input: Tensor): return input * HardSigmoid(input, 1/6, 0.5)
|
||||
def Celu(X: Tensor, alpha=1.0): return X.relu() - (-alpha*(X/alpha).exp()+1).relu()
|
||||
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
|
||||
def Softplus(X: Tensor): return X.softplus()
|
||||
def PRelu(X:Tensor, slope:Tensor):
|
||||
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
|
||||
return X.clip(0, float("inf")) + X.clip(float("-inf"), 0) * slope
|
||||
def LeakyRelu(X: Tensor, alpha=0.01): return X.leakyrelu(alpha)
|
||||
def ThresholdedRelu(X: Tensor, alpha=1.0): return (X-alpha).relu() + (X-alpha).relu().sign() * alpha
|
||||
def Softmax_1(input: Tensor, axis=1): return input.softmax(axis)
|
||||
def Softmax_13(input: Tensor, axis=-1): return input.softmax(axis)
|
||||
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
|
||||
def LogSoftmax(input: Tensor, axis=-1): return input.log_softmax(axis)
|
||||
def Clip(input: Tensor, min=None, max=None): return input.clip(float('-inf') if min is None else min, float('inf') if max is None else max)
|
||||
|
||||
# NOTE ReduceProd would require a new llop
|
||||
def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,)) else ([] if noop_with_empty_axes else None)
|
||||
def ReduceMax(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceMin(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceMean(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceSumSquare(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceL1(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.abs().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).sqrt()
|
||||
def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
|
||||
def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.exp().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
|
||||
|
||||
def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
||||
def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
||||
def OptionalHasElement(x: Tensor=None): return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool)
|
||||
def OptionalGetElement(x: Tensor=None): return x if x is not None else Tensor([], dtype=dtypes.float32)
|
||||
|
||||
def Tile(input: Tensor, repeats): return input.repeat([int(x) for x in safe_numpy(repeats)])
|
||||
def Range(start: Tensor, limit, delta): return Tensor.arange(start=int(safe_numpy(start)), stop=int(safe_numpy(limit)), step=int(safe_numpy(delta))).cast(dtype=start.dtype)
|
||||
def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64) # TODO: really?
|
||||
def Size(data: Tensor): return prod(data if isinstance(data, list) else data.shape)
|
||||
def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)
|
||||
def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))])
|
||||
def Shrink(input: Tensor, bias=0.0, lambd=0.5): return (input < -lambd)*(input+bias) + (input > lambd)*(input-bias)
|
||||
def And(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.zeros(*x.shape)).cast(dtypes.bool)
|
||||
def Or(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
def Xor(x:Tensor, y:Tensor): return Where((x==y), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
def Not(x:Tensor): return Where((x==1), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
|
||||
def Asin(x): return Atan(x / Tensor.sqrt(1 - x * x))
|
||||
def Asinh(x): return Tensor.log(x + Tensor.sqrt(x * x + 1))
|
||||
def Acosh(x): return Tensor.log(x + Tensor.sqrt(x * x - 1))
|
||||
def Atanh(x): return 0.5 * Tensor.log((1 + x)/(1 - x))
|
||||
def Acos(x: Tensor):
|
||||
negate = (x < 0)
|
||||
x = x.abs()
|
||||
ret = ((((-0.0187293 * x) + 0.0742610)*x - 0.2121144) * x + 1.5707288) * Tensor.sqrt(1.0 - x)
|
||||
ret = ret - 2 * negate * ret
|
||||
return negate * 3.14159265358979 + ret
|
||||
def Atan(y: Tensor):
|
||||
x = Tensor.ones(y.shape)
|
||||
t3 = x
|
||||
t1 = y.abs()
|
||||
t0 = (t3 > t1).where(t3, t1)
|
||||
t1 = (t3 < t1).where(t3, t1)
|
||||
t3 = t1 / t0
|
||||
t4 = t3 * t3
|
||||
t0 = ((((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4 - 0.332994597) * t4 + 0.999995630
|
||||
t3 = t0 * t3
|
||||
t3 = (y.abs() > x.abs()).where(1.570796327 - t3, t3)
|
||||
return (y < 0).where(-t3, t3)
|
||||
|
||||
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
|
||||
k = int(k.numpy().item()) if k != 0 else 0 # onnx passes k as a tensor int64 with one element, default is 0
|
||||
return x.triu(k) if upper else x.tril(k)
|
||||
|
||||
def Squeeze(input: Tensor, axes):
|
||||
if isinstance(axes, Tensor): axes = safe_numpy(axes)
|
||||
axes = [int(x) if x >= 0 else int(x+input.ndim) for x in axes]
|
||||
return input.reshape([s for i,s in enumerate(input.shape) if i not in axes])
|
||||
def Unsqueeze(data: Tensor, axes):
|
||||
axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)]
|
||||
new_shape = [1] * (len(data.shape) + len(axes))
|
||||
ptr = iter(data.shape)
|
||||
for i in range(len(new_shape)):
|
||||
if i not in axes:
|
||||
new_shape[i] = next(ptr)
|
||||
return data.reshape(new_shape)
|
||||
|
||||
def Binarizer(input, threshold=0.0): return input > threshold
|
||||
|
||||
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
|
||||
axis = axis + x.ndim if axis < 0 else axis
|
||||
m = x == (x.max(axis=axis, keepdim=keepdims) if keepdims else x.max(axis=axis, keepdim=keepdims).unsqueeze(axis))
|
||||
c = Tensor.arange(x.shape[axis]).reshape(*[1]*(axis), x.shape[axis], *[1]*(x.ndim - axis-1)) * m
|
||||
return c.max(axis=axis,keepdim=keepdims).cast(dtypes.int64)
|
||||
def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
||||
|
||||
def Elu(input: Tensor, alpha=1.0): return input.elu(alpha=alpha)
|
||||
def Concat(*inputs: List[Tensor], axis): return inputs[0].cat(*inputs[1:], dim=axis)
|
||||
def Transpose(input: Tensor, perm=None): return input.permute(order=list(range(len(input.shape))[::-1]) if perm is None else perm)
|
||||
|
||||
# NOTE: since we only have one type, this is valid!
|
||||
def CastLike(input, target_type):
|
||||
assert isinstance(target_type, Tensor), "can only CastLike Tensor"
|
||||
return input
|
||||
|
||||
def ConstantOfShape(input, value:Tensor=None):
|
||||
if value is None: value=Tensor([0.0])
|
||||
shape = [int(x) for x in safe_numpy(input)]
|
||||
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
|
||||
|
||||
# TODO: abstract out the broadcast logic in tensor
|
||||
def Expand(input: Tensor, shape):
|
||||
x_shape, y_shape = input.shape, [int(x) for x in safe_numpy(shape)]
|
||||
# copied from _broadcasted
|
||||
x_shape, y_shape = [([1]*(max(len(x_shape), len(y_shape))-len(t_shape)) + list(t_shape)) for t_shape in [x_shape, y_shape]]
|
||||
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x_shape, y_shape))
|
||||
return input.reshape(x_shape).expand(shape_ret)
|
||||
|
||||
# **************** Complex Ops ****************
|
||||
|
||||
def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
|
||||
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
|
||||
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1]))
|
||||
return ret
|
||||
|
||||
# works with Tensors.ndim != 4
|
||||
def _batchnorm(self:Tensor, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor):
|
||||
shape = [1, -1] + [1] * (self.ndim-2)
|
||||
x = (self - mean.reshape(shape=shape))
|
||||
if weight: x = x * weight.reshape(shape=shape)
|
||||
ret = x.mul(invstd.reshape(shape=shape) if len(invstd.shape) == 1 else invstd)
|
||||
return (ret + bias.reshape(shape=shape)) if bias else ret
|
||||
|
||||
# TODO: this is copied from tinygrad/nn/__init__.py
|
||||
# spatial is from opset 7 and has since been removed
|
||||
def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0, spatial=1, is_test=0):
|
||||
if training_mode:
|
||||
x_detached = X.detach()
|
||||
current_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
current_var = (y*y).mean(axis=(0,2,3))
|
||||
current_invstd = current_var.add(epsilon).pow(-0.5)
|
||||
|
||||
running_mean = input_mean * momentum + current_mean * (1 - momentum)
|
||||
running_var = input_var * momentum + current_var * (1 - momentum)
|
||||
|
||||
return _batchnorm(X, scale, B, current_mean, current_invstd), running_mean, running_var
|
||||
else:
|
||||
invstd = (input_var + epsilon)**-0.5
|
||||
return _batchnorm(X, scale, B, input_mean, invstd)
|
||||
|
||||
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
|
||||
axis = tuple(range(2, len(x.shape)))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).pow(-0.5)
|
||||
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
|
||||
|
||||
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
axis = tuple(i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape)))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).sqrt().reciprocal()
|
||||
|
||||
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
|
||||
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
|
||||
|
||||
# onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
|
||||
# numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...)
|
||||
def _format_padding(onnx_pads, ndims=None, axes=None):
|
||||
if ndims and len(onnx_pads)//2 != ndims: onnx_pads = onnx_pads * ndims # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2
|
||||
if ndims is None: ndims = len(onnx_pads) // 2
|
||||
if axes is None: axes = list(range(ndims))
|
||||
num_axes = len(axes)
|
||||
np_pads = [(0,0)] * ndims
|
||||
for i in range(num_axes):
|
||||
np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes])
|
||||
return np_pads
|
||||
|
||||
def _padding(X: Tensor, pads=None, auto_pad="NOTSET", axes=None, constant_value=0., strides=None, kernel_shape=None, dilations=None):
|
||||
if auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
|
||||
if pads is None: return X
|
||||
pads = _format_padding(pads, ndims=len(X.shape), axes=axes)
|
||||
return X.pad(tuple(pads), value=constant_value)
|
||||
|
||||
def _auto_pad(X, auto_pad, strides, kernel_shape, dilations):
|
||||
strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape)
|
||||
dilations = [1]*len(kernel_shape) if dilations == 1 else dilations
|
||||
pad_shape = [(math.ceil(sh/st)-1)*st+((ks-1)*di+1)-sh for sh, st, ks, di in zip(X.shape[-len(strides):], strides, kernel_shape, dilations)]
|
||||
if auto_pad == "SAME_UPPER": return [pad_shape[0]//2, pad_shape[1]//2, pad_shape[0]-pad_shape[0]//2, pad_shape[1]-pad_shape[1]//2]
|
||||
elif auto_pad == "SAME_LOWER": return [pad_shape[0]-pad_shape[0]//2, pad_shape[1]-pad_shape[1]//2, pad_shape[0]//2, pad_shape[1]//2]
|
||||
else: raise NotImplementedError(f"auto_pad={auto_pad} not implemented, yet")
|
||||
|
||||
def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=None, axes: Tensor=None, mode="constant", value: float=0.):
|
||||
constant_value = value if constant_value is None else float(safe_numpy(constant_value)[0])
|
||||
seq_pads = list(pads) if isinstance(pads, tuple) else safe_numpy(pads)
|
||||
seq_pads = [math.ceil(i) for i in seq_pads]
|
||||
seq_axes = safe_numpy(axes).astype(np.int32).tolist() if axes is not None else None
|
||||
base_shape = x.shape
|
||||
pads = _format_padding(seq_pads, ndims=len(x.shape), axes=seq_axes)
|
||||
if mode == "wrap":
|
||||
repeat_args = [math.ceil(dim[0]/sh) + math.ceil(dim[1]/sh) + 1 for dim, sh in zip(pads, base_shape)]
|
||||
new_shape = [s*r for s,r in zip(base_shape, repeat_args)]
|
||||
shrink_args = [(sh-dim[0]%sh if dim[0]%sh != 0 else 0, nsh-(sh-dim[1]%sh) if dim[1]%sh != 0 else nsh) for dim, sh, nsh in zip(pads, base_shape, new_shape)]
|
||||
return x.repeat(tuple(repeat_args)).shrink(tuple(shrink_args))
|
||||
elif mode == "reflect":
|
||||
for i,s in enumerate(x.shape):
|
||||
if pads[i] == (0,0): continue
|
||||
elif pads[i][0] and not pads[i][1]:
|
||||
x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (s-pads[i][0]-1, s_-1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (0,s) for i_ in range(x.ndim)])) + \
|
||||
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
|
||||
elif not pads[i][0] and pads[i][1]:
|
||||
x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (1, pads[i][1]+1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (s,0) for i_ in range(x.ndim)])) + \
|
||||
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
|
||||
else:
|
||||
x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (s-pads[i][0]-1, s_-1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (0,s+pads[i][1]) for i_ in range(x.ndim)])) + \
|
||||
x.flip(i).shrink(tuple([(0,s_) if i_ != i else (1, pads[i][1]+1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \
|
||||
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
|
||||
return x
|
||||
elif mode == "edge":
|
||||
for i,s in enumerate(x.shape):
|
||||
if pads[i] == (0,0): continue
|
||||
elif pads[i][0] and not pads[i][1]:
|
||||
x = x.shrink(tuple([(0,s_) if i_ != i else (0,1) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (0,s) for i_ in range(x.ndim)])) + \
|
||||
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
|
||||
elif not pads[i][0] and pads[i][1]:
|
||||
x = x.shrink(tuple([(0,s_) if i_ != i else (s_-1, s_) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \
|
||||
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
|
||||
else:
|
||||
x = x.shrink(tuple([(0,s_) if i_ != i else (0,1) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (0,s+pads[i][1]) for i_ in range(x.ndim)])) + \
|
||||
x.shrink(tuple([(0,s_) if i_ != i else (s_-1, s_) for i_,s_ in enumerate(x.shape)])).expand([pads[i][1] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \
|
||||
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
|
||||
return x
|
||||
elif mode == "constant":
|
||||
return _padding(x, seq_pads, axes=seq_axes, constant_value=constant_value)
|
||||
|
||||
def AveragePool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_pad=0, dilations=1, pads=None, strides=1):
|
||||
if dilations != 1: raise NotImplementedError(f"dilations != 1 not supported, dilations:{dilations}")
|
||||
pixel_axes = tuple(range(len(X.shape)))[-2:]
|
||||
if ceil_mode: auto_pad = "SAME_UPPER"
|
||||
padding_included = _padding(X, pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations).avg_pool2d(kernel_shape, stride=strides)
|
||||
if count_include_pad:
|
||||
return padding_included
|
||||
else:
|
||||
div = _padding(Tensor.ones(*X.shape), pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations).avg_pool2d(kernel_shape, stride=strides)
|
||||
return padding_included / div
|
||||
|
||||
def MaxPool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1):
|
||||
if ceil_mode: auto_pad = "SAME_UPPER"
|
||||
ret = _padding(X, pads, auto_pad, constant_value=-np.inf, axes=tuple(range(len(X.shape)))[-len(kernel_shape):], strides=strides, kernel_shape=kernel_shape, dilations=dilations)
|
||||
ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations)
|
||||
ret_len, X_len = ret.numel(), X.numel()
|
||||
indices = ((ret.flatten().unsqueeze(1).expand(ret_len, X_len) == X.flatten().reshape(1, X_len).expand(ret_len, X_len)) * Tensor.arange(X_len).reshape(1, X_len).expand(ret_len, X_len)).sum(1).reshape(ret.shape).cast(dtypes.int64)
|
||||
if storage_order: indices = indices.transpose(indices.ndim-2, indices.ndim-1)
|
||||
return ret, indices
|
||||
|
||||
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: Tensor=None, kernel_shape=None, pads=None, strides=None):
|
||||
out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)]
|
||||
outlength = prod(out_sh)
|
||||
xI = xI.flatten().unsqueeze(1).expand(prod(xT.shape), outlength)
|
||||
arange = Tensor.arange(outlength, requires_grad=False).reshape(1, outlength).expand(xI.shape)
|
||||
xT = xT.flatten().unsqueeze(1).expand(prod(xT.shape), outlength)
|
||||
ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh)
|
||||
if outshape is not None:
|
||||
outshape = safe_numpy(outshape).tolist()
|
||||
if outshape != ret.shape:
|
||||
diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]]
|
||||
pad_args = [diff[0]//2, diff[1]//2, diff[0]-diff[0]//2, diff[1]-diff[1]//2]
|
||||
ret = ret.pad2d((pad_args[1], pad_args[3], pad_args[0], pad_args[2]))
|
||||
return ret
|
||||
|
||||
def Conv(X: Tensor, W: Tensor, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1):
|
||||
if auto_pad != "NOTSET": padding = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
|
||||
else: padding = [p for ps in zip(pads[:len(pads)//2][::-1], pads[len(pads)//2:][::-1]) for p in ps] if pads is not None else 0 # reorder padding
|
||||
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=padding)
|
||||
|
||||
def ConvTranspose(X: Tensor, W: Tensor, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, output_shape=None, output_padding=0, strides=1):
|
||||
if not kernel_shape: kernel_shape = W.shape
|
||||
if pads is None and auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
|
||||
elif pads is None and auto_pad == "NOTSET": pads = [0,0] * (X.ndim - 2)
|
||||
strides_ = [1]*(W.ndim-1) + [strides] if isinstance(strides, int) else [1]*(W.ndim-len(strides)) + list(strides)
|
||||
dilations_ = [1]*(W.ndim-1) + [dilations] if isinstance(dilations, int) else [1]*(W.ndim-len(dilations)) + list(dilations)
|
||||
if output_shape and not output_padding:
|
||||
out_sh = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides_, X.shape, kernel_shape, dilations_))]
|
||||
output_padding = [os - rs for os, rs in zip(output_shape, out_sh[-len(output_shape):])]
|
||||
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads if pads is not None else 0, output_padding=output_padding)
|
||||
|
||||
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
||||
def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None):
|
||||
if isinstance(ratio, Tensor) and not ratio.shape: ratio = safe_numpy(ratio) # ratio and tensor is passed in as Tensor with shape: ()
|
||||
if isinstance(training_mode, Tensor) and not training_mode.shape: training_mode = safe_numpy(training_mode)
|
||||
if not training_mode: return data, Tensor.ones(*data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
|
||||
rng = np.random.RandomState(seed)
|
||||
ratio = ratio.lazydata.realize().toCPU()[0] if isinstance(ratio, Tensor) else ratio
|
||||
mask = Tensor((rng.random(data.shape) >= ratio), requires_grad=False, device=data.device)
|
||||
return data * mask * (1/(1.0 - ratio)), mask
|
||||
|
||||
def LRN(input: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0):
|
||||
bs, c, iy, ix = input.shape
|
||||
return input / input.mul(input).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta)
|
||||
|
||||
def MeanVarianceNormalization(input: Tensor, axis=(0, 2, 3)):
|
||||
data_mean = input.mean(axis=axis, keepdim=True)
|
||||
std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt()
|
||||
return (input - data_mean) / (std + 1e-9)
|
||||
|
||||
def NegativeLogLikelihoodLoss(input: Tensor, target: Tensor, weight=None, ignore_index=None, reduction="mean"):
|
||||
target = target.cast(dtypes.float32)
|
||||
N, C, i_shape = input.shape[0], input.shape[1], input.shape
|
||||
t_shape = target.shape
|
||||
if len(input.shape) != 3:
|
||||
input = input.reshape((N, C, -1))
|
||||
target = target.reshape((N, -1))
|
||||
if weight is not None:
|
||||
mask = target.unsqueeze(-1) == Tensor.arange(C).repeat((N, 1, 1))
|
||||
weight = (mask * weight).sum(axis=-1)
|
||||
if ignore_index is not None:
|
||||
cond = target == ignore_index
|
||||
weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1)
|
||||
mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(input.shape) -2))
|
||||
loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight)
|
||||
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
|
||||
elif reduction == "sum": return loss.sum()
|
||||
return loss.reshape(t_shape) if len(i_shape) != 3 else loss
|
||||
|
||||
def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean"):
|
||||
N, C, *s_dimensions = scores.shape
|
||||
if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels)
|
||||
mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(1, C, *[1]*len(s_dimensions))
|
||||
y = scores.log_softmax(axis=1)
|
||||
if weights is not None: weights = weights.__getitem__(tuple([labels, *[slice(None)]*(weights.ndim-1)]))
|
||||
loss = (mask * -y).sum(1) if weights is None else (mask * -y).sum(1) * weights
|
||||
if reduction == "mean": loss = loss.sum() / (loss == 0).where(0, 1).sum() if weights is None else loss.sum() / weights.sum()
|
||||
elif reduction == "sum": loss = loss.sum()
|
||||
return loss, y
|
||||
|
||||
def ArrayFeatureExtractor(input: Tensor, indices: Tensor): return input.__getitem__(tuple([slice(None) if i != (input.ndim-1) else indices for i in range(input.ndim)]))
|
||||
def Gather(input: Tensor, indices: Tensor, axis=0):
|
||||
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
|
||||
input_sh = list(input.shape)
|
||||
ret_shape = input_sh[:axis] + list(indices.shape) + input_sh[axis+1:]
|
||||
if indices.ndim > 1: indices = indices.flatten()
|
||||
indices = [int(safe_numpy(indices))] if indices.shape == () else [input_sh[axis]+int(x) if x<0 else int(x) for x in safe_numpy(indices)]
|
||||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(input_sh)] for i in indices]
|
||||
return input.shrink(arg=tuple(args[0])).cat(*[input.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
|
||||
else: # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
||||
return input.__getitem__(tuple([slice(None) if i != axis else indices for i in range(input.ndim)]))
|
||||
|
||||
def GatherElements(input: Tensor, indices: Tensor, axis):
|
||||
indices = indices.sign().contiguous().__neg__().contiguous().relu() * input.shape[axis] + indices
|
||||
return input.gather(indices, axis)
|
||||
|
||||
def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
|
||||
def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0)
|
||||
assert n <= 1, f"n:{n} shouldn't be larger than 1"
|
||||
b = x.cast(dtypes.int32).contiguous().cast(x.dtype)
|
||||
b = (b >= 0).where(b+n, b-n)
|
||||
if equidistant_case == "round_down":
|
||||
return (x > b).where(b+1-n, b-n)
|
||||
elif equidistant_case == "round_up":
|
||||
return (x >= b).where(b+1-n, b-n)
|
||||
elif equidistant_case == "round_to_even":
|
||||
x_ceil_fraction = x.ceil()/2
|
||||
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
|
||||
x = (_and(x == b, cond_ceil_even)).where(x+1-n, x)
|
||||
x = (x > b).where(b+1-n, b-n)
|
||||
return x
|
||||
|
||||
def Round(X:Tensor): return _round(X, 0.5, "round_to_even")
|
||||
|
||||
def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel', cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch', mode='nearest', nearest_mode='round_prefer_floor'):
|
||||
def _nearest_gather(X: Tensor, x_out, y_out): return X[:,:,y_out,:][:,:,:,x_out]
|
||||
def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len):
|
||||
if nearest_mode == "round_prefer_floor": ret = _round(x_resized, 0.5, "round_down")
|
||||
elif nearest_mode == "round_prefer_ceil": ret = _round(x_resized, 0.5, "round_up")
|
||||
elif nearest_mode == "floor": ret = x_resized.floor()
|
||||
elif nearest_mode == "ceil": ret = x_resized.ceil()
|
||||
return ret.clip(0, x_len-1)
|
||||
def _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi=None):
|
||||
if coordinate_transformation_mode == "half_pixel":
|
||||
x_out = (x_out + 0.5)/Tensor(scales_lol[-1]) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
|
||||
y_out = (y_out + 0.5)/Tensor(scales_lol[-2]) - 0.5
|
||||
elif coordinate_transformation_mode == "align_corners":
|
||||
x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1)
|
||||
y_out = y_out * (X.shape[-2] - 1) / (output_shape[-2] - 1)
|
||||
elif coordinate_transformation_mode == "asymmetric":
|
||||
x_out = x_out/scales_lol[-1]
|
||||
y_out = y_out/scales_lol[-2]
|
||||
elif coordinate_transformation_mode == "half_pixel_symmetric":
|
||||
x_out = X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1]) + (x_out + 0.5) / scales_lol[-1] - 0.5
|
||||
y_out = X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2]) + (y_out + 0.5) / scales_lol[-2] - 0.5
|
||||
elif coordinate_transformation_mode == "pytorch_half_pixel":
|
||||
x_out = (x_out + 0.5)/scales_lol[-1] - 0.5 if output_shape[-1] > 1 else Tensor([0])
|
||||
y_out = (y_out + 0.5)/scales_lol[-2] - 0.5 if output_shape[-2] > 1 else Tensor([0])
|
||||
elif coordinate_transformation_mode == "tf_crop_and_resize":
|
||||
x_out = roi[-1][0] * (X.shape[-1] - 1) + x_out * ((roi[-1][1] - roi[-1][0]) * (X.shape[-1] - 1) / (output_shape[-1] - 1)) if output_shape[-1] > 1 else Tensor([0.5 * (roi[-1][0] + roi[-1][1]) * (X.shape[-1] - 1)])
|
||||
y_out = roi[-2][0] * (X.shape[-2] - 1) + y_out * ((roi[-2][1] - roi[-2][0]) * (X.shape[-2] - 1) / (output_shape[-2] - 1)) if output_shape[-2] > 1 else Tensor([0.5 * (roi[-2][0] + roi[-2][1]) * (X.shape[-2] - 1)])
|
||||
return x_out.clip(0, X.shape[-1]-1), y_out.clip(0, X.shape[-2]-1)
|
||||
if roi is not None:
|
||||
roi = safe_numpy(roi)
|
||||
roi = [(st,ed) for st, ed in zip(roi[:len(roi)//2], roi[len(roi)//2:])]
|
||||
roi_ = [(1,1)] * 4
|
||||
if axes is not None:
|
||||
for a,r in zip(axes, roi):
|
||||
roi_[a] = r
|
||||
roi = roi_
|
||||
if scales is not None:
|
||||
scales = safe_numpy(scales).tolist()
|
||||
if axes is not None:
|
||||
scales_ = [1]*X.ndim
|
||||
for a,s in zip(axes, scales):
|
||||
scales_[a] = s
|
||||
scales = scales_
|
||||
elif sizes is not None:
|
||||
sizes = [int(i) for i in safe_numpy(sizes)]
|
||||
scales = []
|
||||
if axes is not None:
|
||||
sizes_ = [1]*X.ndim
|
||||
for a,s in zip(axes, sizes):
|
||||
sizes_[a] = s
|
||||
scales.append(s/X.shape[a])
|
||||
sizes = sizes_
|
||||
else: scales = [si/xs for xs, si in zip(X.shape, sizes)]
|
||||
if keep_aspect_ratio_policy == "not_larger":
|
||||
scale = min(scales)
|
||||
sizes = _round(Tensor(list(X.shape[-2:]))*scale, 0.5, "round_up")
|
||||
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
|
||||
elif keep_aspect_ratio_policy == "not_smaller":
|
||||
scale = max(scales)
|
||||
sizes = _round(Tensor(list(X.shape[-2:]))*scale, 0.5, "round_up")
|
||||
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
|
||||
output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
|
||||
output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
|
||||
scales_lol = [os/xs for xs, os in zip(X.shape, output_shape)]
|
||||
x_out = Tensor.arange(output_shape[-1])
|
||||
y_out = Tensor.arange(output_shape[-2])
|
||||
if mode == "nearest":
|
||||
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi)
|
||||
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])
|
||||
y_out = _nearest_mode(y_out, nearest_mode, X.shape[-1])
|
||||
return _nearest_gather(X, x_out, y_out)
|
||||
elif mode == "linear":
|
||||
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape_, scales, roi)
|
||||
ret = []
|
||||
for y in safe_numpy(y_out):
|
||||
for x in safe_numpy(x_out):
|
||||
x_floor, y_floor = int(x), int(y)
|
||||
y_shrink = (0, X.shape[2]) if X.shape[2] == 1 else (y_floor, y_floor+2) if y != y_floor else (y_floor, y_floor+1)
|
||||
x_shrink = (x_floor, x_floor+2) if x != x_floor else (x_floor, x_floor+1)
|
||||
shrink_args = ((0, X.shape[0]), (0, X.shape[1]), y_shrink, x_shrink)
|
||||
corners = safe_numpy(X.shrink(shrink_args))
|
||||
x1, x2, y1, y2 = x_floor, x_floor+1, y_floor, y_floor+1
|
||||
if x == x_floor and y == y_floor: # TODO https://en.wikipedia.org/wiki/Bilinear_interpolation#Weighted_mean maybe do weighted mean?
|
||||
ret.append(corners[0,0,0,0])
|
||||
elif x == x_floor:
|
||||
ret.append((corners[0,0,0,0] * (y2 - y) + corners[0,0,1,0] * (y - y1)) / (y2 - y1))
|
||||
elif y == y_floor:
|
||||
ret.append((corners[0,0,0,0] * (x2 - x) + corners[0,0,0,1] * (x - x1)) / (x2 - x1))
|
||||
else:
|
||||
ret.append((corners[0,0,0,0] * (x2 - x) * (y2 - y) + corners[0,0,0,1] * (x - x1) * (y2 - y) + corners[0,0,1,0] * (x2 - x) * (y - y1) + corners[0,0,1,1] * (x - x1) * (y - y1)) / ((x2 - x1) * (y2 - y1)))
|
||||
return Tensor(ret).reshape(output_shape)
|
||||
elif mode == "cubic":
|
||||
raise Exception("cubic interpolation is not implemented")
|
||||
|
||||
def CenterCropPad(input: Tensor, shape: Tensor, axes=None):
|
||||
if not axes: axes = list(range(input.ndim))
|
||||
shrink_arg = [(0,i) for i in input.shape]
|
||||
pad_arg = [(0,0) for _ in range(input.ndim)]
|
||||
shape = safe_numpy(shape).tolist()
|
||||
for s, x in zip(shape, axes):
|
||||
if s < input.shape[x]: shrink_arg[x] = (input.shape[x]//2 - s//2, input.shape[x]//2 + s//2) if s%2 == 0 else (input.shape[x]//2 - s//2 - 1, input.shape[x]//2 + s//2)
|
||||
elif s > input.shape[x]: pad_arg[x] = ((s - input.shape[x])//2, (s - input.shape[x])//2) if (s - input.shape[x])% 2 == 0 else ((s - input.shape[x])//2, (s - input.shape[x])//2 + 1)
|
||||
return input.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
|
||||
|
||||
def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
|
||||
depth = int(safe_numpy(depth).item())
|
||||
indices, rank = (indices < 0).where(indices+depth, indices), len(indices.shape)
|
||||
if axis < 0: axis += rank + 1
|
||||
ls, rs = indices.shape[0:axis], indices.shape[axis: rank]
|
||||
cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs))
|
||||
return cond.where(values[1], values[0]).cast(values.dtype)
|
||||
|
||||
def Erf(x: Tensor):
|
||||
sign = x.sign()
|
||||
x = x.abs()
|
||||
t = 1.0 / (1.0 + 0.3275911 * x)
|
||||
term1 = 0.254829592 * t
|
||||
term2 = -0.284496736 * t ** 2
|
||||
term3 = 1.421413741 * t ** 3
|
||||
term4 = -1.453152027 * t ** 4
|
||||
term5 = 1.061405429 * t ** 5
|
||||
y = (term1 + term2 + term3 + term4 + term5)
|
||||
return sign * (1.0 - y * Tensor.exp(-x * x))
|
||||
|
||||
def Compress(inp: Tensor, condition: Tensor, axis=None):
|
||||
if axis is None:
|
||||
inp = inp.flatten()
|
||||
axis = 0
|
||||
|
||||
axis = axis + inp.ndim if axis < 0 else axis
|
||||
|
||||
con_np = safe_numpy(condition)
|
||||
con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor
|
||||
return inp.__getitem__(tuple([slice(None) if i != axis else con for i in range(inp.ndim)]))
|
||||
|
||||
type_map = {TensorProto.DOUBLE: dtypes.double, TensorProto.FLOAT: dtypes.float32}
|
||||
def EyeLike(x: Tensor, dtype=None, k=0):
|
||||
if dtype is None: dtype = x.dtype
|
||||
else: dtype = type_map[dtype]
|
||||
shape = x.shape
|
||||
dim = min(x.shape)
|
||||
if shape[0] == shape[1]: return Tensor.eye(dim=dim, dtype=dtype)
|
||||
else:
|
||||
diff = (shape[0]-dim, shape[1]-dim)
|
||||
padarg = tuple([(d, d) if d == 0 else (k, d-k) for d in diff])
|
||||
return Tensor.eye(dim=dim, dtype=dtype).pad(padarg)
|
||||
|
||||
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode)
|
||||
|
||||
# Needs work
|
||||
def IsInf(x,detect_negative=1,detect_positive=1):
|
||||
ret = (x == float("inf"))*detect_positive + (x == float("-inf"))*detect_negative + Tensor.zeros(*x.shape)
|
||||
return ret.cast(dtypes.bool)
|
||||
|
||||
# Needs work
|
||||
def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point=0, axis=1):
|
||||
axis = axis + x.ndim if axis < 0 else axis
|
||||
x_sc = x_scale.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim))
|
||||
x_zer = x_zero_point.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim)) if isinstance(x_zero_point, Tensor) else x_zero_point
|
||||
return (x - x_zer) * x_sc
|
||||
|
||||
# Needs work
|
||||
def IsNaN(x):
|
||||
return (x < float("-inf")).cast(dtypes.bool)
|
||||
|
||||
# **************** com.microsoft Ops ****************
|
||||
|
||||
def SkipLayerNormalization(input:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon=None):
|
||||
if epsilon is None: epsilon=1e-12
|
||||
x = input + skip + bias
|
||||
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
|
||||
|
||||
def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
|
||||
x = x + bias
|
||||
return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x ** 3).tanh())
|
||||
|
||||
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):
|
||||
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
|
||||
assert (segment_ids is None) is (segment_embedding is None)
|
||||
assert (mask is None) is (mask_index_type is None)
|
||||
assert mask is None, "functionality not supported yet" # TODO
|
||||
input_shape = input_ids.shape
|
||||
bsz, seq_length = input_shape[0], input_shape[1]
|
||||
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
|
||||
vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None)
|
||||
|
||||
def embedding(x:Tensor, vocab_size, weight:Tensor)->Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor
|
||||
vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).reshape(1, 1, vocab_size).expand(*x.shape, vocab_size)
|
||||
return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight
|
||||
|
||||
# bert embedding layer
|
||||
if epsilon is None: epsilon = 1e-12
|
||||
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
|
||||
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
|
||||
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
|
||||
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
|
||||
|
||||
embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res
|
||||
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
|
||||
return out, None, embedding_sum
|
||||
|
||||
def Attention(input:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None, relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary=None, mask_filter_value=None, num_heads=None, past_present_share_buffer=None, qkv_hidden_sizes=None, scale=None, unidirectional=None):
|
||||
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
|
||||
assert num_heads is not None # required
|
||||
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
|
||||
assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, "functionality not supported yet" # TODO strange params
|
||||
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
|
||||
|
||||
if unidirectional: # gpt-style
|
||||
assert hidden_size == v_hidden_size
|
||||
xqkv = input.linear(weights, bias)
|
||||
xq, xk, xv = [xqkv.slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
|
||||
else: # bert-style
|
||||
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
|
||||
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
|
||||
xq, xk, xv = [input.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
|
||||
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
|
||||
|
||||
if past is not None:
|
||||
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
|
||||
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
|
||||
|
||||
def attn(query, key, value, attn_mask):
|
||||
query_length, key_length = query.shape[-2], key.shape[-2]
|
||||
cdim = max(query_length, key_length) + 1
|
||||
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
|
||||
# This is where Tensor.scaled_dot_product_attention differs:
|
||||
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False).cast(dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length].cast(dtypes.bool)
|
||||
return (Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask).softmax(-1) @ value
|
||||
|
||||
bsz, _, seq_len, _ = xq.shape
|
||||
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
return out, present
|
||||
|
||||
# **************** ai.onnx.preview.training Ops ****************
|
||||
|
||||
# TODO not entirely sure these optimizers are correct
|
||||
def Adagrad(R, T, *inputs, decay_factor=0.0, epsilon=0.0, norm_coefficient=0.0):
|
||||
groups = len(inputs) // 3
|
||||
grouped_inputs = [inputs[i::groups] for i in range(groups)]
|
||||
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
|
||||
r = R / (1 + T * decay_factor)
|
||||
ret = []
|
||||
for input in grouped_inputs:
|
||||
X, G, H = input
|
||||
X.grad = norm_coefficient * X + G
|
||||
X.grad.requires_grad, H.requires_grad = False, False # TODO manually turning off requires_grad, see TODO under (domain == "ai.onnx.preview.training") in onnx.py
|
||||
H.assign(H.detach() + X.grad * X.grad).realize()
|
||||
H_adaptive = H.sqrt() + epsilon
|
||||
X.assign(X.detach() - r * X.grad / H_adaptive)
|
||||
ret.extend([X, H])
|
||||
ret = ret[::2] + ret[1::2]
|
||||
return tuple(ret)
|
||||
|
||||
def Momentum(R, T, *inputs, alpha, beta, mode, norm_coefficient):
|
||||
groups = len(inputs) // 3
|
||||
grouped_inputs = [inputs[i::groups] for i in range(groups)]
|
||||
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
|
||||
beta_adjusted = beta if T > 0 else 1
|
||||
ret = []
|
||||
for input in grouped_inputs:
|
||||
X, G, V = input
|
||||
X.grad = (norm_coefficient * X + G).realize()
|
||||
X.grad.requires_grad, V.requires_grad = False, False
|
||||
V.assign(alpha * V + beta_adjusted * X.grad).realize()
|
||||
if mode == "standard": X.assign(X.detach() - R * V).realize()
|
||||
elif mode == "nesterov": X.assign(X.detach() - R * (X.grad + alpha + V)).realize()
|
||||
ret.extend([X, V])
|
||||
ret = ret[::2] + ret[1::2]
|
||||
return tuple(ret)
|
||||
|
||||
# copied from tinygrad/nn/optim.py: LAMB with some edits
|
||||
def Adam(R, T, *inputs, alpha=0.9, beta=0.999, epsilon=0.0, norm_coefficient=0.0, norm_coefficient_post=0.0):
|
||||
groups = len(inputs) // 4
|
||||
grouped_inputs = [inputs[i::groups] for i in range(groups)]
|
||||
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
|
||||
ret = []
|
||||
for input in grouped_inputs:
|
||||
X, G, V, H = input
|
||||
X.grad = (norm_coefficient * X + G).realize()
|
||||
V.requires_grad, H.requires_grad, X.grad.requires_grad = False, False, False
|
||||
V.assign(alpha * V + (1.0 - alpha) * X.grad).realize()
|
||||
H.assign(beta * H + (1.0 - beta) * (X.grad * X.grad)).realize()
|
||||
up = (V / (1.0 - alpha**T)) / ((H / (1.0 - beta**T)).sqrt() + epsilon) if T > 0 else V / (H.sqrt() + epsilon)
|
||||
X.assign(X.detach() - R * up).realize()
|
||||
X = (1 - norm_coefficient_post) * X
|
||||
ret.extend([X, V, H])
|
||||
ret = ret[::3] + ret[1::3] + ret[2::3]
|
||||
return tuple(ret)
|
||||
285
tinygrad_repo/extra/thneed.py
Normal file
285
tinygrad_repo/extra/thneed.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# this can be constructed from a cl_cache or loaded from a thneed file
|
||||
import time
|
||||
import struct
|
||||
import json
|
||||
import traceback
|
||||
import numpy as np
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
from collections import defaultdict
|
||||
import pyopencl as cl
|
||||
from tinygrad.runtime.ops_gpu import CL, OSX_TIMING_RATIO
|
||||
|
||||
DEBUGCL = getenv("DEBUGCL", 0)
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
|
||||
class Thneed:
|
||||
def __init__(self, cl_cache=[], inputs={}):
|
||||
self.cl_cache, self.inputs = cl_cache[:], inputs
|
||||
self.gobj = 0
|
||||
|
||||
# build graph
|
||||
# NOTE: if CLCACHE=1, this is wrong!
|
||||
nodes = defaultdict(lambda: {'in_edges': [], 'out_edges': []})
|
||||
for _, args in self.cl_cache:
|
||||
# output is always the first parameter
|
||||
for a in args[3:]:
|
||||
nodes[a]['out_edges'].append(args[2])
|
||||
nodes[args[2]]['in_edges'].append(a)
|
||||
|
||||
# get buffers to save
|
||||
self.buffers_to_save = set()
|
||||
self.outputs = []
|
||||
for n in nodes.keys():
|
||||
if len(nodes[n]['in_edges']) == 0:
|
||||
self.buffers_to_save.add(n)
|
||||
if len(nodes[n]['out_edges']) == 0:
|
||||
self.outputs.append(n)
|
||||
|
||||
fake_inputs = []
|
||||
for k,n in self.inputs.items():
|
||||
if n in self.buffers_to_save:
|
||||
self.buffers_to_save.remove(n)
|
||||
else:
|
||||
print(f"WARNING: {k} was not a used input, removing it")
|
||||
fake_inputs.append(k)
|
||||
for k in fake_inputs:
|
||||
del self.inputs[k]
|
||||
|
||||
def load(self, input_fn):
|
||||
float32 = not FLOAT16
|
||||
|
||||
mf = cl.mem_flags
|
||||
image_fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT if float32 else cl.channel_type.HALF_FLOAT)
|
||||
image_fmt_32 = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT)
|
||||
|
||||
with open(input_fn, "rb") as f:
|
||||
json_len = struct.unpack("I", f.read(4))[0]
|
||||
jdat = json.loads(f.read(json_len).decode('latin_1'))
|
||||
weights = f.read()
|
||||
|
||||
# load in the buffers
|
||||
bufs = {'\x00\x00\x00\x00\x00\x00\x00\x00': None}
|
||||
bufs_loaded = {}
|
||||
ptr = 0
|
||||
for o in jdat['objects']:
|
||||
#print(o)
|
||||
if o['needs_load']:
|
||||
nptr = ptr + o['size']
|
||||
o['data'] = weights[ptr:nptr]
|
||||
ptr = nptr
|
||||
|
||||
if o['arg_type'] == "image2d_t" or o['arg_type'] == "image1d_t":
|
||||
tfmt = image_fmt_32 if 'float32' in o and o['float32'] else image_fmt
|
||||
if o['arg_type'] == "image2d_t":
|
||||
if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]:
|
||||
# hack: use a image1d since we can back that with a buffer
|
||||
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
|
||||
else:
|
||||
# buffer isn't supported in image2d, copy buffer into image
|
||||
if 'buffer_id' in o and bufs_loaded[o['buffer_id']]:
|
||||
arr = np.zeros(bufs[o['buffer_id']].size // 2, dtype=np.float16)
|
||||
cl.enqueue_copy(CL.cl_queue[0], arr, bufs[o['buffer_id']])
|
||||
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
|
||||
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr)
|
||||
elif o['needs_load']:
|
||||
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
|
||||
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data'])
|
||||
else:
|
||||
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
|
||||
if o['arg_type'] == "image1d_t":
|
||||
assert not o['needs_load']
|
||||
assert not bufs_loaded[o['buffer_id']]
|
||||
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
|
||||
else:
|
||||
if 'data' in o:
|
||||
buf = cl.Buffer(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
|
||||
else:
|
||||
# zero out buffers
|
||||
buf = cl.Buffer(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
|
||||
|
||||
bufs[o['id']] = buf
|
||||
bufs_loaded[o['id']] = 'data' in o
|
||||
# if it's loaded, it's saved
|
||||
if 'data' in o:
|
||||
self.buffers_to_save.add(buf)
|
||||
|
||||
# load binaries
|
||||
prgs = {}
|
||||
for o in jdat['binaries']:
|
||||
nptr = ptr + o['length']
|
||||
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr])
|
||||
ptr = nptr
|
||||
|
||||
# populate the cl_cache
|
||||
for i,k in enumerate(jdat['kernels']):
|
||||
kernel = prgs[k['name']]
|
||||
aaa = []
|
||||
for j,(a,sz) in enumerate(zip(k['args'], k['args_size'])):
|
||||
if len(a) == 0:
|
||||
aa = cl.LocalMemory(sz)
|
||||
elif len(a) == 4:
|
||||
a = a.encode('latin_1')
|
||||
aa = np.uint32(struct.unpack("I", a)[0])
|
||||
elif len(a) == 2:
|
||||
a = a.encode('latin_1')
|
||||
aa = np.uint16(struct.unpack("H", a)[0])
|
||||
elif len(a) == 8:
|
||||
#print(i,j,struct.unpack("Q", a.encode('latin_1'))[0])
|
||||
aa = bufs[a]
|
||||
aaa.append(aa)
|
||||
self.cl_cache.append((kernel, [k['global_work_size'], k['local_work_size'], *aaa]))
|
||||
|
||||
if DEBUG >= 1: print(f"thneed: total bufs loaded: {len(bufs.keys())}")
|
||||
|
||||
# load inputs
|
||||
for k in jdat['inputs']:
|
||||
self.inputs[k['name']] = bufs[k['buffer_id']]
|
||||
|
||||
# load outputs
|
||||
for k in jdat['outputs']:
|
||||
self.outputs.append(bufs[k['buffer_id']])
|
||||
|
||||
|
||||
def save(self, output_fn):
|
||||
# this is the struct that will be saved
|
||||
jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []}
|
||||
|
||||
# build the pieces of this struct
|
||||
weights = []
|
||||
binaries = []
|
||||
saved_objs = set()
|
||||
saved_binaries = set()
|
||||
for prg, args in self.cl_cache:
|
||||
# get binaries for saving
|
||||
if prg.name not in saved_binaries:
|
||||
binary = prg.clprograms[0].get_info(cl.program_info.BINARIES)
|
||||
assert len(binary) == 1
|
||||
jdat['binaries'].append({"name":prg.name, "length":len(binary[0])})
|
||||
binaries.append(binary[0])
|
||||
saved_binaries.add(prg.name)
|
||||
|
||||
# get the args from the kernel, some need the data saved
|
||||
targs, args_size = [], []
|
||||
argdtypes = prg.argdtypes if prg.argdtypes is not None else [None]*(len(args)-2)
|
||||
for a,d in zip(args[2:], argdtypes):
|
||||
if d == np.int16:
|
||||
targs.append(struct.pack("H", a).decode("latin_1"))
|
||||
args_size.append(2)
|
||||
elif d == np.int32:
|
||||
targs.append(struct.pack("I", a).decode("latin_1"))
|
||||
args_size.append(4)
|
||||
elif isinstance(a, cl.LocalMemory):
|
||||
targs.append("")
|
||||
args_size.append(a.size)
|
||||
elif d is None:
|
||||
if getattr(a, "global_id", None) is None:
|
||||
setattr(a, "global_id", self.gobj)
|
||||
self.gobj += 1
|
||||
ptr = struct.pack("Q", a.global_id).decode("latin_1")
|
||||
if ptr not in saved_objs:
|
||||
if isinstance(a, cl.Buffer):
|
||||
needs_load = a in self.buffers_to_save
|
||||
jdat['objects'].append({
|
||||
"id": ptr, "arg_type": "float*", "needs_load": needs_load, "size": a.size,
|
||||
})
|
||||
if needs_load:
|
||||
data = np.empty(a.size//4, dtype=np.float32)
|
||||
cl.enqueue_copy(CL.cl_queue[0], data, a, is_blocking=True)
|
||||
weights.append(data.tobytes())
|
||||
elif isinstance(a, cl.Image):
|
||||
assert a.format == cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT), "wrong type"
|
||||
needs_load = a in self.buffers_to_save
|
||||
row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64
|
||||
size = row_pitch * a.shape[1]
|
||||
# this is *2 if float16 and *4 if float32
|
||||
buf = cl.Buffer(CL.cl_ctxs[0], cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
|
||||
|
||||
# zero out the buffer
|
||||
cl.enqueue_copy(CL.cl_queue[0], buf, b'\x00'*buf.size, is_blocking=True)
|
||||
|
||||
CLProgram("from_image_strided", compile_gpu("""
|
||||
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
int2 l;
|
||||
l.y = get_global_id(1);
|
||||
l.x = get_global_id(0);
|
||||
out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
|
||||
}
|
||||
"""), argdtypes=(None, None, np.int32))(a, buf, row_pitch//(4*(2 if FLOAT16 else 4)), global_size=a.shape)
|
||||
|
||||
# multiple of 32 isn't enough
|
||||
jdat['objects'].append({
|
||||
"id": ptr, "needs_load": needs_load, "size": size, "arg_type": "image2d_t",
|
||||
"width": a.shape[0], "height": a.shape[1], "row_pitch": row_pitch, "float32": not FLOAT16,
|
||||
})
|
||||
|
||||
if needs_load:
|
||||
data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32)
|
||||
cl.enqueue_copy(CL.cl_queue[0], data, buf, is_blocking=True)
|
||||
if FLOAT16: data = data.astype(np.float16)
|
||||
weights.append(data.tobytes())
|
||||
else:
|
||||
raise Exception("unknown object", a)
|
||||
#print(jdat['objects'][-1])
|
||||
saved_objs.add(ptr)
|
||||
targs.append(ptr)
|
||||
args_size.append(8)
|
||||
else:
|
||||
raise Exception("idk this type")
|
||||
|
||||
# save the kernel itself
|
||||
jdat['kernels'].append({
|
||||
"name": prg.name,
|
||||
"work_dim": len(args[0]),
|
||||
"global_work_size": args[0],
|
||||
# TODO: C++ thneed requires a local_work_size, so we fill it with ones
|
||||
"local_work_size": [1 for _ in args[0]] if args[1] is None else args[1],
|
||||
"num_args": len(args)-2,
|
||||
"args": targs,
|
||||
"args_size": args_size
|
||||
})
|
||||
|
||||
jdat['outputs'] = [{
|
||||
"buffer_id": struct.pack("Q", x.global_id).decode("latin_1"),
|
||||
"size": x.size,
|
||||
} for x in self.outputs]
|
||||
|
||||
jdat['inputs'] = [{
|
||||
"buffer_id": struct.pack("Q", v.global_id).decode("latin_1"),
|
||||
"size": v.size,
|
||||
"name": k
|
||||
} for k,v in self.inputs.items()][::-1]
|
||||
|
||||
print(f"saving thneed to {output_fn}")
|
||||
with open(output_fn, "wb") as f:
|
||||
j = json.dumps(jdat, ensure_ascii=False).encode('latin_1')
|
||||
f.write(struct.pack("I", len(j)))
|
||||
f.write(j)
|
||||
f.write(b''.join(weights))
|
||||
f.write(b''.join(binaries))
|
||||
|
||||
def run(self):
|
||||
events = []
|
||||
st = time.monotonic()
|
||||
for prg, args in self.cl_cache:
|
||||
events.append(prg.clprgs[0](CL.cl_queue[0], *args))
|
||||
mt = time.monotonic()
|
||||
CL.synchronize()
|
||||
et = time.monotonic() - st
|
||||
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")
|
||||
|
||||
if DEBUGCL >= 2:
|
||||
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
|
||||
print(f"{i:3d} {prg.name:25s} " + "queued @ %5.2f ms, submit @ %5.2fms, start @ %5.2f ms, end @ %5.2f ms" % tuple((x*OSX_TIMING_RATIO - st*1e9)/1e6 for x in [e.profile.queued, e.profile.submit, e.profile.start, e.profile.end]))
|
||||
if DEBUGCL >= 1:
|
||||
total_runtime = 0
|
||||
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
|
||||
runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
|
||||
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:25s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
|
||||
if hasattr(prg, 'prg') and ((DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3):
|
||||
print(prg.prg)
|
||||
total_runtime += runtime
|
||||
print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms")
|
||||
return total_runtime/1e9
|
||||
return et
|
||||
205
tinygrad_repo/extra/utils.py
Normal file
205
tinygrad_repo/extra/utils.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# type: ignore
|
||||
import pickle, hashlib, zipfile, io, requests, struct, tempfile, platform, concurrent.futures
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
from typing import Union
|
||||
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
OSX = platform.system() == "Darwin"
|
||||
WINDOWS = platform.system() == "Windows"
|
||||
|
||||
def temp(x:str) -> str: return (Path(tempfile.gettempdir()) / x).as_posix()
|
||||
|
||||
def fetch(url):
|
||||
if url.startswith("/") or url.startswith("."):
|
||||
with open(url, "rb") as f:
|
||||
return f.read()
|
||||
fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
download_file(url, fp, skip_if_exists=not getenv("NOCACHE"))
|
||||
with open(fp, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def fetch_as_file(url):
|
||||
if url.startswith("/") or url.startswith("."):
|
||||
with open(url, "rb") as f:
|
||||
return f.read()
|
||||
fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
download_file(url, fp, skip_if_exists=not getenv("NOCACHE"))
|
||||
return fp
|
||||
|
||||
def download_file(url, fp, skip_if_exists=True):
|
||||
if skip_if_exists and Path(fp).is_file() and Path(fp).stat().st_size > 0:
|
||||
return
|
||||
r = requests.get(url, stream=True)
|
||||
assert r.status_code == 200
|
||||
progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url)
|
||||
(path := Path(fp).parent).mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
||||
for chunk in r.iter_content(chunk_size=16384):
|
||||
progress_bar.update(f.write(chunk))
|
||||
f.close()
|
||||
Path(f.name).rename(fp)
|
||||
|
||||
def my_unpickle(fb0):
|
||||
key_prelookup = defaultdict(list)
|
||||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
|
||||
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
||||
ident, storage_type, obj_key, location, obj_size = storage[0:5]
|
||||
assert ident == 'storage'
|
||||
assert prod(size) <= (obj_size - storage_offset)
|
||||
|
||||
if storage_type not in [np.float16, np.float32]:
|
||||
if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {size}")
|
||||
ret = None
|
||||
else:
|
||||
ret = Tensor.empty(*size, dtype=dtypes.from_np(storage_type))
|
||||
key_prelookup[obj_key].append((storage_type, obj_size, ret, size, stride, storage_offset))
|
||||
return ret
|
||||
|
||||
def _rebuild_parameter(*args):
|
||||
#print(args)
|
||||
pass
|
||||
|
||||
class Dummy: pass
|
||||
|
||||
class MyPickle(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#print(module, name)
|
||||
if name == 'FloatStorage': return np.float32
|
||||
if name == 'LongStorage': return np.int64
|
||||
if name == 'IntStorage': return np.int32
|
||||
if name == 'HalfStorage': return np.float16
|
||||
if module == "torch._utils":
|
||||
if name == "_rebuild_tensor_v2": return _rebuild_tensor_v2
|
||||
if name == "_rebuild_parameter": return _rebuild_parameter
|
||||
else:
|
||||
if module.startswith('pytorch_lightning'): return Dummy
|
||||
try:
|
||||
return super().find_class(module, name)
|
||||
except Exception:
|
||||
return Dummy
|
||||
|
||||
def persistent_load(self, pid):
|
||||
return pid
|
||||
|
||||
return MyPickle(fb0).load(), key_prelookup
|
||||
|
||||
def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset, mmap_allowed=False):
|
||||
bytes_size = np.dtype(dtype).itemsize
|
||||
if t is None:
|
||||
myfile.seek(prod(shape) * bytes_size, 1)
|
||||
return
|
||||
|
||||
bytes_offset = 0
|
||||
if storage_offset is not None:
|
||||
bytes_offset = storage_offset * bytes_size
|
||||
myfile.seek(bytes_offset)
|
||||
|
||||
assert t.shape == shape or shape == tuple(), f"shape mismatch {t.shape} != {shape}"
|
||||
assert t.dtype.np == dtype and t.dtype.itemsize == bytes_size
|
||||
if any(s != 1 and st1 != st2 for s, st1, st2 in zip(shape, strides_for_shape(shape), strides)):
|
||||
# slow path
|
||||
buffer_size = sum(strides[i]*t.dtype.itemsize * (shape[i] - 1) for i in range(len(shape)))
|
||||
buffer_size += t.dtype.itemsize
|
||||
np_array = np.frombuffer(myfile.read(buffer_size), t.dtype.np)
|
||||
|
||||
np_array = np.lib.stride_tricks.as_strided(
|
||||
np_array, shape=shape, strides=[i*t.dtype.itemsize for i in strides])
|
||||
|
||||
lna = t.lazydata.op.arg
|
||||
lna.fxn = lambda _: np_array
|
||||
t.realize()
|
||||
return
|
||||
|
||||
# ["METAL", "CLANG", "LLVM"] support readinto for more speed
|
||||
# ["GPU", "CUDA"] use _mmap since they have to copy in to the GPU anyway
|
||||
# this needs real APIs
|
||||
if t.device in ["METAL", "CLANG", "LLVM"]:
|
||||
del t.lazydata.op
|
||||
t.lazydata.realized = Device[t.lazydata.device].buffer(prod(t.shape), dtype=t.dtype)
|
||||
myfile.readinto(t.lazydata.realized._buffer())
|
||||
else:
|
||||
def _mmap(lna):
|
||||
assert myfile._compress_type == 0, "compressed data can't be mmaped"
|
||||
return np.memmap(myfile._fileobj._file, dtype=lna.dtype, mode='r', offset=myfile._orig_compress_start + bytes_offset, shape=lna.shape)
|
||||
def _read(lna):
|
||||
ret = np.empty(lna.shape, dtype=lna.dtype)
|
||||
myfile.readinto(ret.data)
|
||||
return ret
|
||||
if mmap_allowed and not OSX and t.device in ["GPU", "CUDA"]: t.lazydata.op.arg.fxn = _mmap
|
||||
else: t.lazydata.op.arg.fxn = _read
|
||||
t.realize()
|
||||
|
||||
def fake_torch_load_zipped(fb0, load_weights=True, multithreaded=True):
|
||||
if Device.DEFAULT in ["TORCH", "GPU", "CUDA"]: multithreaded = False # multithreaded doesn't work with CUDA or TORCH. for GPU it's a wash with _mmap
|
||||
with zipfile.ZipFile(fb0, 'r') as myzip:
|
||||
base_name = myzip.namelist()[0].split('/', 1)[0]
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||||
ret = my_unpickle(myfile)
|
||||
if load_weights:
|
||||
def load_weight(k, vv):
|
||||
with myzip.open(f'{base_name}/data/{k}') as myfile:
|
||||
for v in vv:
|
||||
load_single_weight(v[2], myfile, v[3], v[4], v[0], v[5], mmap_allowed=True)
|
||||
if multithreaded:
|
||||
# 2 seems fastest
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
futures = {executor.submit(load_weight, k, v):k for k,v in ret[1].items()}
|
||||
for future in (t:=tqdm(concurrent.futures.as_completed(futures), total=len(futures))):
|
||||
if future.exception() is not None: raise future.exception()
|
||||
k = futures[future]
|
||||
t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
||||
else:
|
||||
for k,v in (t := tqdm(ret[1].items())):
|
||||
t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
||||
load_weight(k,v)
|
||||
return ret[0]
|
||||
|
||||
def fake_torch_load(b0):
|
||||
|
||||
# convert it to a file
|
||||
fb0 = io.BytesIO(b0)
|
||||
|
||||
if b0[0:2] == b"\x50\x4b":
|
||||
return fake_torch_load_zipped(fb0)
|
||||
|
||||
# skip three junk pickles
|
||||
pickle.load(fb0)
|
||||
pickle.load(fb0)
|
||||
pickle.load(fb0)
|
||||
|
||||
ret, key_prelookup = my_unpickle(fb0)
|
||||
|
||||
# create key_lookup
|
||||
key_lookup = pickle.load(fb0)
|
||||
key_real = [None] * len(key_lookup)
|
||||
for k,v in key_prelookup.items():
|
||||
assert len(v) == 1
|
||||
key_real[key_lookup.index(k)] = v[0]
|
||||
|
||||
# read in the actual data
|
||||
for storage_type, obj_size, tensor, np_shape, np_strides, storage_offset in key_real:
|
||||
ll = struct.unpack("Q", fb0.read(8))[0]
|
||||
assert ll == obj_size, f"size mismatch {ll} != {obj_size}"
|
||||
assert storage_offset == 0, "not implemented"
|
||||
load_single_weight(tensor, fb0, np_shape, np_strides, storage_type, None)
|
||||
|
||||
return ret
|
||||
|
||||
def get_child(parent, key):
|
||||
obj = parent
|
||||
for k in key.split('.'):
|
||||
if k.isnumeric():
|
||||
obj = obj[int(k)]
|
||||
elif isinstance(obj, dict):
|
||||
obj = obj[k]
|
||||
else:
|
||||
obj = getattr(obj, k)
|
||||
return obj
|
||||
Reference in New Issue
Block a user