openpilot v0.9.6 release
date: 2024-01-12T10:13:37 master commit: ba792d576a49a0899b88a753fa1c52956bedf9e6
This commit is contained in:
585
tinygrad_repo/tinygrad/codegen/kernel.py
Normal file
585
tinygrad_repo/tinygrad/codegen/kernel.py
Normal file
@@ -0,0 +1,585 @@
|
||||
from __future__ import annotations
|
||||
import os, math, itertools
|
||||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, Device, Compiled
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int, ansilen, getenv, prod, DEBUG
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum):
|
||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class Opt:
|
||||
op: OptOps
|
||||
axis: Optional[int] = None
|
||||
amt: Optional[int] = None
|
||||
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorCore:
|
||||
device: str
|
||||
dims: List[int]
|
||||
dtype_in: DType
|
||||
dtype_out: DType
|
||||
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
|
||||
upcast_dim: int # which TC dim to upcast
|
||||
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim
|
||||
thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim
|
||||
arch: Optional[str] = None
|
||||
def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>"
|
||||
|
||||
tensor_cores: Dict[str, List[TensorCore]] = {
|
||||
"METAL": [
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"),
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"),
|
||||
],
|
||||
"HIP": [
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]),
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]),
|
||||
]
|
||||
}
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
name: str
|
||||
size: int
|
||||
dtype: DType = dtypes.float32
|
||||
realized: None = None
|
||||
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
|
||||
|
||||
class LinearizerOptions(NamedTuple):
|
||||
device: str = ""
|
||||
# TODO: make this generic with a list of supported types
|
||||
supports_float4: bool = True
|
||||
supports_float4_alu: bool = True
|
||||
has_local: bool = True
|
||||
has_shared: bool = True
|
||||
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
|
||||
global_max: Optional[List[int]] = None
|
||||
local_max: Optional[List[int]] = None
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None):
|
||||
self.opts = opts if opts else (cast(Compiled, Device[Device.DEFAULT]).linearizer_opts if isinstance(Device[Device.DEFAULT], Compiled) else LinearizerOptions())
|
||||
self.ast = ast
|
||||
|
||||
# fetch lazyop info
|
||||
self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
|
||||
|
||||
# there's only allowed to be one reduceop
|
||||
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]
|
||||
assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
|
||||
self.reduceop = reduceops[0] if reduceops else None
|
||||
|
||||
# create new shapetrackers inside this kernel, we will permute them
|
||||
self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in BufferOps])
|
||||
|
||||
# get earlybufs, before the one reduce op
|
||||
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in BufferOps] if self.reduceop else []
|
||||
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
|
||||
|
||||
# create the (permuted) shapetrackers
|
||||
self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)]
|
||||
|
||||
# move all reduce axes to the end
|
||||
reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
|
||||
permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
|
||||
self.reshape_and_permute(None, permute)
|
||||
|
||||
# parameters for optimization
|
||||
self.applied_opts: List[Opt] = []
|
||||
self.group_for_reduce: List[int] = []
|
||||
self.upcasted: int = 0
|
||||
self.local_dims: int = 0
|
||||
self.local_alias: Dict[int, LocalBuffer] = {}
|
||||
self.tensor_core: Optional[TensorCore] = None
|
||||
self.dont_use_locals: bool = False
|
||||
|
||||
# group simplifies
|
||||
self.simplify_ones()
|
||||
self.simplify_merge_adjacent()
|
||||
|
||||
# cache
|
||||
self.applied_opts_cache: Optional[List[Opt]] = None
|
||||
|
||||
def copy(self):
|
||||
ret = type(self).__new__(type(self))
|
||||
|
||||
# base linearizer params
|
||||
ret.opts, ret.ast = self.opts, self.ast
|
||||
|
||||
# things downstream of the AST
|
||||
# NOTE: we copy bufs for local buffers and sts for optimizations
|
||||
ret.info, ret.reduceop, ret.bufs, ret.earlybufs, ret.full_buf_index, ret.sts = \
|
||||
self.info, self.reduceop, self.bufs[:], self.earlybufs, self.full_buf_index, self.sts[:]
|
||||
|
||||
# parameters for optimizations
|
||||
ret.applied_opts, ret.group_for_reduce, ret.upcasted, ret.local_dims, ret.local_alias, ret.tensor_core, ret.dont_use_locals = \
|
||||
self.applied_opts[:], self.group_for_reduce[:], self.upcasted, self.local_dims, self.local_alias.copy(), self.tensor_core, self.dont_use_locals
|
||||
|
||||
# uncached since linearize didn't run
|
||||
ret.applied_opts_cache = None
|
||||
|
||||
return ret
|
||||
|
||||
@property
|
||||
def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
|
||||
|
||||
def has_variable_shape(self) -> bool:
|
||||
for b in self.bufs:
|
||||
if not isinstance(b, LocalBuffer) and not all_int(b.st.views[-1].shape): return True
|
||||
return False
|
||||
|
||||
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
|
||||
def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]
|
||||
|
||||
def upcasted_axis(self, i):
|
||||
return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:],
|
||||
self.sts[i].real_strides()[self.shape_len-self.upcasted:],
|
||||
[x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
|
||||
|
||||
# TODO: is there a better way to write this?
|
||||
def acc_offsets(self, i):
|
||||
if self.upcasted == 0: return [0]
|
||||
upcasted_i = self.upcasted_axis(i)
|
||||
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
|
||||
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
|
||||
|
||||
def get_upcast_dim(self, i) -> List[int]:
|
||||
should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
|
||||
return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1]
|
||||
|
||||
@property
|
||||
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)
|
||||
|
||||
@property
|
||||
def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
|
||||
|
||||
@property
|
||||
def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
|
||||
|
||||
@property
|
||||
def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.shape_len-self.upcasted]
|
||||
|
||||
@property
|
||||
def shape_len(self) -> int: return len(self.sts[0].shape)
|
||||
|
||||
@property
|
||||
def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]]
|
||||
|
||||
@property
|
||||
def global_dims(self) -> int: return self.first_reduce-self.local_dims
|
||||
|
||||
# there's eight chunks of the shape
|
||||
# blue -- global dims
|
||||
# cyan -- local dims (warp ones first)
|
||||
# *** self.first_reduce
|
||||
# green -- reduce-local dims
|
||||
# white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes)
|
||||
# red -- reduce loops
|
||||
# *** self.upcasted
|
||||
# purple -- reduce upcasted
|
||||
# yellow -- normal upcasted dimensions
|
||||
def colors(self) -> List[str]:
|
||||
# first non local non reduce dims are global (blue)
|
||||
colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
|
||||
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
|
||||
colors += ["cyan"] * self.local_dims
|
||||
# between first_reduce and first_reduce + group_for_reduce, they are either upcast mid reduce (white), or late upcasted (green)
|
||||
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))]
|
||||
# between first_reduce + group_for_reduce and upcasted, they are reduce (red)
|
||||
colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce)))
|
||||
# upcasted dimensions are reduce (magenta) or normal (yellow)
|
||||
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
|
||||
assert len(colors) == self.shape_len, "colors size mismatch"
|
||||
return colors
|
||||
|
||||
def colored_shape(self, pad=None, dense=False) -> str:
|
||||
ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors()))
|
||||
if pad: ret += ' '*(pad-ansilen(ret))
|
||||
return ret
|
||||
|
||||
# ******************** base simplifiers ********************
|
||||
|
||||
# apply reshape and permute to all shapetrackers
|
||||
def reshape_and_permute(self, new_shape_fxn, axis):
|
||||
new_sts = []
|
||||
for st in self.sts:
|
||||
if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape)))
|
||||
if axis is not None: st = st.permute(tuple(axis))
|
||||
new_sts.append(st)
|
||||
self.sts = new_sts
|
||||
|
||||
# drops the final dimension
|
||||
def upcast(self):
|
||||
assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1"
|
||||
self.upcasted += 1
|
||||
|
||||
# axis : the axis to pull from
|
||||
# amount : the amount to take
|
||||
# top : if you want to pull that amount from the top
|
||||
# insert_before : place to insert the new stuff
|
||||
def shift_to(self, axis, amount, top=False, insert_before=None):
|
||||
if insert_before is None: insert_before = self.shape_len
|
||||
move_axis = axis if top else axis+1
|
||||
if move_axis < insert_before: insert_before += 1
|
||||
self.reshape_and_permute(
|
||||
lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]),
|
||||
[i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
|
||||
|
||||
# ******************** complex simplifiers ********************
|
||||
|
||||
def simplify_ones(self) -> bool:
|
||||
# remove places where the shape is all ones
|
||||
# TODO: this should be factored in to multi shape stride
|
||||
if self.shape_len == 0: return False
|
||||
all_ones = [s==1 for s in self.full_shape]
|
||||
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
|
||||
self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
|
||||
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
||||
return any(all_ones)
|
||||
|
||||
def simplify_merge_adjacent(self):
|
||||
if self.shape_len == 0: return
|
||||
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
||||
|
||||
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
|
||||
if isinstance(self.bufs[0].dtype, ImageDType):
|
||||
base_shape = self.bufs[0].dtype.shape
|
||||
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
||||
special_strides: Tuple[int, ...] = tuple()
|
||||
for i,g in enumerate(shape_idx_groups):
|
||||
shape_piece = tuple(self.output_shape[x] for x in g)
|
||||
assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
|
||||
special_strides += strides_for_shape(shape_piece)
|
||||
# adding the fake image shape
|
||||
shapes.append(self.output_shape)
|
||||
strides.append(special_strides)
|
||||
|
||||
# merge dimensions if we can, multi get_shape_strides
|
||||
# TODO: does this always preserve the reduce dimension, NO
|
||||
# TODO: move this into shapetracker, with tests!
|
||||
rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
|
||||
for i in range(1, len(shapes[0])):
|
||||
can_merge = []
|
||||
for j in range(len(shapes)):
|
||||
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
||||
can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0)))
|
||||
# more can merge than this
|
||||
mergeable = all(can_merge) and i != self.first_reduce
|
||||
for j in range(len(shapes)):
|
||||
if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
|
||||
else: rets[j].append((shapes[j][i], strides[j][i]))
|
||||
|
||||
# do the reshapes
|
||||
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
|
||||
# ******************** GPU simplifiers ********************
|
||||
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
|
||||
new_shape,dims = list(x), len(x)
|
||||
for i in range(dims):
|
||||
next_idx = (i + 1) % dims
|
||||
while new_shape[i] > max_size[i]:
|
||||
new_shape[i] = new_shape[i] // 2
|
||||
if (new_shape[next_idx] <= max_size[next_idx]):
|
||||
new_shape[next_idx] = new_shape[next_idx] * 2
|
||||
else:
|
||||
next_idx = (next_idx + 1) % dims
|
||||
new_shape[next_idx] = new_shape[next_idx] * 2
|
||||
return tuple(new_shape)
|
||||
|
||||
def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
|
||||
# Check the global allocation limit, current the global_size will be flipped during codegen
|
||||
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
|
||||
global_dims = self.first_reduce-self.local_dims
|
||||
if global_dims > 0:
|
||||
if global_max:
|
||||
tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
|
||||
if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
|
||||
assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}"
|
||||
for i in range(global_dims-1):
|
||||
if i < len(global_max) and self.full_shape[i] > global_max[i]:
|
||||
order = list(range(len(self.full_shape)))
|
||||
order[i], order[global_dims-1] = order[global_dims-1], order[i]
|
||||
self.reshape_and_permute(None, order)
|
||||
if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
|
||||
|
||||
def alias_buffer(self, i, pattern):
|
||||
assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
|
||||
|
||||
bst = 1
|
||||
real_strides = self.sts[i].real_strides()
|
||||
shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
|
||||
for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
|
||||
for j,p in enumerate(pattern):
|
||||
if priority == p and real_strides[j] != 0:
|
||||
stride[j] = bst
|
||||
bst *= shp[j]
|
||||
|
||||
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
|
||||
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size()))
|
||||
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
|
||||
self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])
|
||||
|
||||
# ******************** high level optimizers ********************
|
||||
|
||||
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None):
|
||||
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores:
|
||||
for tc in tensor_cores[self.opts.device]:
|
||||
if not((tc.arch is None or tc.arch == os.uname().machine) and isinstance(self.reduceop.src[0], LazyOp)): continue
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
|
||||
if has_cast and not(isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
|
||||
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
|
||||
|
||||
if not(isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL): continue
|
||||
if not(isinstance(mul_op.src[0], LazyOp) and mul_op.src[0].op == BufferOps.MEM and mul_op.src[0].arg.dtype == tc.dtype_in): continue
|
||||
if not(isinstance(mul_op.src[1], LazyOp) and mul_op.src[1].op == BufferOps.MEM and mul_op.src[1].arg.dtype == tc.dtype_in): continue
|
||||
buf0, buf1 = self.bufs.index(cast(MemBuffer, mul_op.src[0].arg)), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg))
|
||||
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0]
|
||||
|
||||
if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue
|
||||
|
||||
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
||||
|
||||
s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0] # TODO: select axis in smart way
|
||||
s0_exists, s1_exists = True, True
|
||||
assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0
|
||||
def fix(needed, ax):
|
||||
nonlocal s0, s1, s0_exists, s1_exists
|
||||
if not needed: return
|
||||
if s0_exists and ax == s0:
|
||||
if s1_exists and s0 < s1: s1 -= 1
|
||||
s0_exists = False
|
||||
elif s1_exists and ax == s1:
|
||||
if s0_exists and s1 < s0: s0 -= 1
|
||||
s1_exists = False
|
||||
|
||||
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
|
||||
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
|
||||
self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
|
||||
for (tc_dim, tc_amt) in tc.threads:
|
||||
fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
|
||||
|
||||
# assert tensor core and prevent extra_opts from altering the key shape structure
|
||||
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
||||
|
||||
if extra_opts is not None:
|
||||
for opt in extra_opts:
|
||||
self.apply_opt(opt)
|
||||
else:
|
||||
# hand-coded TC opts
|
||||
if s1_exists:
|
||||
s1_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s1]%upc == 0][0]
|
||||
if s1_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1)
|
||||
if s0_exists:
|
||||
s0_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s0]%upc == 0][0]
|
||||
if s0_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0)
|
||||
if self.tensor_core and s0_exists:
|
||||
for upc in [4,2]:
|
||||
if self.full_shape[s0] % upc == 0:
|
||||
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
|
||||
break
|
||||
|
||||
# alias buffer
|
||||
alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2)
|
||||
self.alias_buffer(buf0, alias_pattern)
|
||||
self.alias_buffer(buf1, alias_pattern)
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_opt(self, opt:Opt):
|
||||
assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals"
|
||||
self.applied_opts.append(opt)
|
||||
if opt.axis is not None:
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0))
|
||||
else:
|
||||
axis = -1
|
||||
if opt.amt is not None:
|
||||
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
||||
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
||||
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
|
||||
else:
|
||||
amt = -1
|
||||
if opt.op == OptOps.LOCAL: # cyan
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
assert not(self.tensor_core), "can't local with tensor cores"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.LASTLOCAL: # cyan
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.GROUP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert not(self.tensor_core), "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.GROUPTOP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert not(self.tensor_core), "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.UNROLL: # purple
|
||||
assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted"
|
||||
assert amt <= 32, "don't unroll more than 32"
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCAST: # yellow
|
||||
assert axis < self.first_reduce, "upcast is for non-reduce"
|
||||
assert amt <= 8, "don't upcast more than 8"
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCASTMID: # white
|
||||
assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce"
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
assert axes[0] == axis, "wrong axis"
|
||||
assert amt == 4, "don't upcast mid anything but 4"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.NOLOCALS:
|
||||
assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals"
|
||||
assert not self.dont_use_locals, "already not using locals"
|
||||
self.dont_use_locals = True
|
||||
return self.simplify_ones()
|
||||
|
||||
def required_optimizations(self, early_only=False):
|
||||
for buf_index,buf in enumerate(self.bufs):
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
||||
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
|
||||
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
||||
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
|
||||
if unit_stride_axes_mul_4[0] < self.first_reduce:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
||||
else:
|
||||
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
|
||||
|
||||
def hand_coded_optimizations(self):
|
||||
# if there's images in the earlybufs, we have to make an axis the 4 loading one
|
||||
self.required_optimizations(early_only=True)
|
||||
|
||||
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
||||
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
||||
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||||
self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
||||
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \
|
||||
self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM:
|
||||
buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg)
|
||||
buf1 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[1]).arg)
|
||||
buf0_strides = self.sts[buf0].real_strides()
|
||||
buf1_strides = self.sts[buf1].real_strides()
|
||||
def has_expanded_axis(s, st): return any(x > 1 and y == 0 for x,y in zip(s,st))
|
||||
if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)):
|
||||
for global_idx in range(self.global_dims):
|
||||
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
|
||||
if MV_THREADS_PER_ROW > 1:
|
||||
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||||
if MV_BLOCKSIZE > 1:
|
||||
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||||
if MV_ROWS_PER_THREAD > 1:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
||||
return
|
||||
|
||||
if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]):
|
||||
# are we grouping? (requires local shape support)
|
||||
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
||||
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
||||
self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
||||
break
|
||||
|
||||
# are we upcasting in mid reduce? (only for images)
|
||||
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
if self.sts[0].shape[axes[0]]%4 == 0:
|
||||
self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
|
||||
|
||||
# now do everything required
|
||||
self.required_optimizations()
|
||||
|
||||
# no more opt if we are grouping
|
||||
if self.group_for_reduce: return
|
||||
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
|
||||
# to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below
|
||||
# expression and run test/test_ops.py with IMAGE=2
|
||||
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
||||
# this can be made much smarter
|
||||
to_upcast: List[int] = []
|
||||
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
||||
for axis in range(self.first_reduce):
|
||||
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
||||
# for now skip upcasting here if there is a symbolic axis
|
||||
if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
|
||||
prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
|
||||
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||||
to_upcast.append(axis)
|
||||
for axis in to_upcast[::-1]:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
||||
|
||||
# potentially do more upcasts of non reduce axes based on a heuristic
|
||||
upcasted_axis = set()
|
||||
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
|
||||
xb_choices = []
|
||||
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
||||
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)):
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
|
||||
if xb_choices:
|
||||
xb_choices = sorted(xb_choices)
|
||||
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
|
||||
self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
|
||||
upcasted_axis.add(xb_choices[0][2])
|
||||
else:
|
||||
break
|
||||
|
||||
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64):
|
||||
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
||||
else:
|
||||
for splits in [4]:
|
||||
if self.full_unupcasted_shape[-1]%splits == 0:
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits))
|
||||
break
|
||||
|
||||
# if nothing at all is upcasted and it's easy to, do an upcast
|
||||
# TODO: this is breaking the tests
|
||||
for splits in [4]:
|
||||
if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits))
|
||||
|
||||
# **** local groups ****
|
||||
|
||||
if self.opts.has_local:
|
||||
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce:
|
||||
self.apply_opt(Opt(OptOps.NOLOCALS))
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))]
|
||||
to_local: List[Tuple[int, int]] = []
|
||||
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||||
local_size = prod(sz for _, sz in to_local)
|
||||
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None)
|
||||
if local_sz is not None: to_local.append((axis, local_sz))
|
||||
deleted_shape = 0
|
||||
for axis, local_sz in sorted(to_local[:3]):
|
||||
axis = axis - deleted_shape
|
||||
will_delete_shape = local_sz == self.full_shape[axis]
|
||||
self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
|
||||
if will_delete_shape: deleted_shape += 1
|
||||
441
tinygrad_repo/tinygrad/codegen/linearizer.py
Normal file
441
tinygrad_repo/tinygrad/codegen/linearizer.py
Normal file
@@ -0,0 +1,441 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, Dict, Union, Sequence, Final, Set
|
||||
import itertools, math, functools
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
|
||||
from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename
|
||||
from tinygrad.codegen.kernel import LocalBuffer, Kernel
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.features.image import to_image_idx
|
||||
|
||||
# bottom ones are asm only
|
||||
class UOps(Enum):
|
||||
LOOP = auto(); IF = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
|
||||
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
|
||||
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
|
||||
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
|
||||
|
||||
class UOp(NamedTuple):
|
||||
uop: UOps
|
||||
dtype: Optional[DType]
|
||||
vin: Tuple[UOp, ...]
|
||||
arg: Any
|
||||
def __repr__(self): return f"{self.num:4d} {str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.num for x in self.vin]):32s} {self.arg}"
|
||||
#def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str(self.vin):32s} {self.arg}"
|
||||
|
||||
# UOps are unique
|
||||
num: int
|
||||
def __hash__(self): return self.num
|
||||
def __eq__(self, x): return self.num == x.num
|
||||
|
||||
def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
|
||||
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)]
|
||||
if maxdim != 0 and len(local_dims) > maxdim:
|
||||
dd = local_idxs[maxdim-1]
|
||||
nli = []
|
||||
for s in local_dims[maxdim-1:][::-1]:
|
||||
nli.append(dd % s)
|
||||
dd //= s
|
||||
local_idxs = local_idxs[0:maxdim-1] + nli[::-1]
|
||||
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
|
||||
|
||||
class Linearizer(Kernel):
|
||||
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
|
||||
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
||||
return self.uop(UOps.ALU, dtype, (a, render_b), op)
|
||||
|
||||
# NOTE: the consts have to be be cached for deduping of downstream uops to work
|
||||
def const(self, b:Union[int,float], dtype=dtypes.int32) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b)
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
||||
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
|
||||
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
|
||||
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT, dtype=dtypes.bool),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def global_load(self, i:int, idxs:Sequence[Node], acc=None) -> List[UOp]:
|
||||
buf = self.bufs[i]
|
||||
const = buf.val if isinstance(buf, ConstBuffer) else acc
|
||||
|
||||
def rename_var(v: VariableOrNum, expr: str): return v if isinstance(v, NumNode) else Variable(expr, v.min, v.max)
|
||||
|
||||
amt, dim = 1, None
|
||||
upcast_dim = self.get_upcast_dim(i)
|
||||
if len(upcast_dim) == 1 and len(float4_expand := idxs[upcast_dim[0]].expand()) in [4,2]:
|
||||
dim, amt = upcast_dim[0], len(float4_expand)
|
||||
|
||||
expand_vars = tuple([rename_var(idx.expand_idx(), f"_uidx{j}") for j, idx in enumerate(idxs)])
|
||||
fake_idxs = [idx.substitute({idx.expand_idx(): ev}) for idx, ev in zip(idxs, expand_vars)]
|
||||
if dim is not None:
|
||||
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs[:dim] + [float4_expand[0]] + fake_idxs[dim+1:])
|
||||
if (g_idx // amt * amt).render() != g_idx.render():
|
||||
(g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None
|
||||
else:
|
||||
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs)
|
||||
localtype = dtypes.float32 if amt == 1 else dtypes._float4 if amt == 4 else dtypes._float2
|
||||
|
||||
e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars)
|
||||
|
||||
ret = []
|
||||
invalid_value = 0 if dtypes.is_int(buf.dtype) else 0.0
|
||||
for idx, valid, rep_idx in zip(e_idxs, e_valids, Node.iter_idxs(expand_vars)):
|
||||
this_const, idx, valid = (invalid_value, Variable.num(0), Variable.num(1)) if valid.max == 0 else (const, idx, valid)
|
||||
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}"
|
||||
if key not in self.load_cache:
|
||||
if acc is not None:
|
||||
assert valid.min == 1
|
||||
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False)
|
||||
elif this_const is not None:
|
||||
self.load_cache[key] = self.const(this_const, localtype)
|
||||
if valid.min == 0 and valid.max == 1:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE)
|
||||
else:
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
|
||||
if valid.min == 0:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)))
|
||||
else:
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx))
|
||||
ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
||||
return ret
|
||||
|
||||
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> None:
|
||||
buf = self.bufs[i]
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
|
||||
expanded_nodes = [idx.expand() for idx in idxs]
|
||||
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
|
||||
store_offset = dict(zip(_idxs, store))
|
||||
|
||||
# float4 grouping
|
||||
upcast_dim = self.get_upcast_dim(i)
|
||||
if len(upcast_dim) == 1 and len(expanded_nodes[upcast_dim[0]]) in [2,4]:
|
||||
grouped_store_offset = defaultdict(list)
|
||||
for k in store_offset:
|
||||
_idx = k[:upcast_dim[0]] + (expanded_nodes[upcast_dim[0]][0],) + k[upcast_dim[0]+1:]
|
||||
grouped_store_offset[_idx].append(store_offset[k])
|
||||
store_offset_new = {}
|
||||
for k,out_tokens in grouped_store_offset.items():
|
||||
amt = len(out_tokens)
|
||||
idx, valid = self.sts[i].expr_idxs(k)
|
||||
assert idx.render() == ((idx//amt)*amt).render(), "float4 stores are always aligned"
|
||||
assert valid.min == 1, "stores are always valid"
|
||||
store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens))
|
||||
store_offset = store_offset_new
|
||||
|
||||
for idx, var in store_offset.items():
|
||||
idx, valid = self.sts[i].expr_idxs(idx)
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var))
|
||||
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
def linearize(self):
|
||||
# no new opts and we already ran? skip relinearizing
|
||||
if self.applied_opts == self.applied_opts_cache: return self
|
||||
|
||||
# save backups
|
||||
sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduce[:], self.upcasted
|
||||
|
||||
# global uop cache
|
||||
self.saved_exprs: Dict[Tuple, UOp] = dict()
|
||||
|
||||
# limit dims if we need to
|
||||
if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
|
||||
|
||||
# uops
|
||||
self.uops: List[UOp] = []
|
||||
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
|
||||
self.loop_uops: Dict[str, UOp] = {}
|
||||
|
||||
# add global buffers
|
||||
for i,buf in enumerate(self.bufs):
|
||||
if isinstance(buf, MemBuffer):
|
||||
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype))
|
||||
# add var vals
|
||||
for var in sorted(vars_from_ast(self.ast), key=lambda k: k.key):
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
|
||||
# define local buffers
|
||||
for lb in self.local_alias.values():
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))
|
||||
# add a local buffer for multistage reduce. # TODO: use local alias
|
||||
if self.group_for_reduce:
|
||||
# TODO: the strides of this can be controlled
|
||||
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
|
||||
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size())))
|
||||
|
||||
# kernel name (before late upcast)
|
||||
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) if isinstance(x, int) else sym_rename(x) for x in self.full_shape])
|
||||
self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
|
||||
# name the function something unique
|
||||
Linearizer.kernel_cnt[self.function_name] += 1
|
||||
suffix = f"{'n'+str(Linearizer.kernel_cnt[self.function_name]-1)}" if Linearizer.kernel_cnt[self.function_name] > 1 else ""
|
||||
self.function_name, self.display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
|
||||
|
||||
# define indexes
|
||||
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
|
||||
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+len(self.group_for_reduce)], 3 if self.opts.has_local else 0)
|
||||
full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]]
|
||||
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
||||
|
||||
# global and local loops
|
||||
def render_loop(xx:List[Variable]):
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.LOOP, dtypes.int32, (
|
||||
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
||||
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None})
|
||||
def end_loop(xx:List[Variable]):
|
||||
for x in xx[::-1]:
|
||||
if not isinstance(x, NumNode) and x.expr is not None:
|
||||
loop_uop = self.loop_uops[x.expr]
|
||||
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,))
|
||||
|
||||
# set global/local size
|
||||
self.global_size: Optional[List[int]] = None
|
||||
self.local_size: Optional[List[int]] = None
|
||||
if self.dont_use_locals:
|
||||
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||
elif self.opts.has_local:
|
||||
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
|
||||
self.global_size += [1]*(3-len(self.global_size))
|
||||
self.local_size += [1]*(3-len(self.local_size))
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
||||
else:
|
||||
render_loop(loop_global_idxs+loop_local_idxs)
|
||||
|
||||
# parse AST
|
||||
loaded_buffers = {}
|
||||
acc = []
|
||||
self.load_cache: Dict[str, UOp] = {}
|
||||
if_gate: Optional[UOp] = None
|
||||
|
||||
# reduce op
|
||||
fake_reduce_idxs: List[Variable] = []
|
||||
if self.reduceop is not None:
|
||||
# define indexes
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)]
|
||||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
|
||||
# define accumulator
|
||||
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
|
||||
if self.tensor_core:
|
||||
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
|
||||
replace_idxs = []
|
||||
for alias in aliases:
|
||||
full_var, full_var_sz = Variable.num(0), 1
|
||||
if alias[0] != 0:
|
||||
for i in alias:
|
||||
next_var = local_idxs[-i] if i > 0 else Variable(None, 0, local_size-1)
|
||||
full_var += next_var * full_var_sz
|
||||
full_var_sz *= next_var.max+1
|
||||
replace_idxs.append(full_var)
|
||||
return replace_idxs
|
||||
replace_acc_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[2], self.tensor_core.thread_local_aliases[2])
|
||||
for n in range(len(self.tensor_core.threads)):
|
||||
local_idxs[self.local_dims-len(self.tensor_core.threads)+n] = replace_acc_idxs[n] # replace locals
|
||||
for n in range(len(replace_acc_idxs)-len(self.tensor_core.threads)):
|
||||
upcast_idxs[n] = replace_acc_idxs[len(self.tensor_core.threads)+n] # replace upcasts
|
||||
|
||||
# reduce loop
|
||||
render_loop(reduce_idxs)
|
||||
|
||||
# barrier for fast GEMM
|
||||
if self.tensor_core: self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
|
||||
# compute local aliases
|
||||
locals_to_store = []
|
||||
for i in self.local_alias:
|
||||
localbuf_idx = self.bufs.index(self.local_alias[i])
|
||||
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
|
||||
if self.tensor_core:
|
||||
min_alias_idx = min(self.local_alias.keys())
|
||||
replace_input_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[i-min_alias_idx], self.tensor_core.thread_local_aliases[i-min_alias_idx])
|
||||
for n in range(len(self.tensor_core.threads)):
|
||||
buf_idxs[self.first_reduce-len(self.tensor_core.threads)+n] = replace_input_idxs[n] # replace locals
|
||||
for n in range(len(replace_input_idxs)-len(self.tensor_core.threads)):
|
||||
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(self.tensor_core.threads)+n] # replace upcasts
|
||||
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs)
|
||||
ll = self.global_load(i, buf_idxs)
|
||||
locals_to_store.append((localbuf_idx, buf_idxs, ll))
|
||||
|
||||
# copy in any global buffers
|
||||
if self.tensor_core:
|
||||
wmma_sz = self.tensor_core.thread_local_sizes
|
||||
# calculate the number of local accumulator reduces and render WMMAs: this is bad... this needs to come from someplace else
|
||||
nx, ny, nacc = (len(locals_to_store[0][2])//wmma_sz[0]), (len(locals_to_store[1][2])//wmma_sz[1]), (len(acc)//wmma_sz[2])
|
||||
acc_reds = math.isqrt((nx*ny)//nacc)
|
||||
i, bx, by = 0, nx//acc_reds, ny//acc_reds
|
||||
for y in range(by):
|
||||
for x in range(bx):
|
||||
for j in range(acc_reds):
|
||||
self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]]+locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]]+acc[i:i+wmma_sz[2]]), (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
|
||||
i += wmma_sz[2]
|
||||
else:
|
||||
if locals_to_store:
|
||||
self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
for i, idxs, ll in locals_to_store: self.global_store(i, idxs, ll)
|
||||
self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
|
||||
# load earlybufs
|
||||
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
|
||||
|
||||
# run early AST (with reduce)
|
||||
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True)
|
||||
|
||||
# end the reduce loop
|
||||
end_loop(reduce_idxs)
|
||||
self.load_cache.clear()
|
||||
|
||||
# end the local loop, do the local reduce
|
||||
if self.group_for_reduce:
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
||||
self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if?
|
||||
if self.opts.has_local:
|
||||
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
|
||||
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
||||
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
|
||||
if_gate = self.uop(UOps.IF, None, (if_cond,), cachable=False)
|
||||
|
||||
# create new late reduce local loops and replace local_idxs that have been used
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
|
||||
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
|
||||
|
||||
# if any group_for_reduce items aren't reduces, upcast them here
|
||||
for j in self.upcast_in_mid_reduce_axes:
|
||||
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
||||
self.upcast()
|
||||
self.group_for_reduce.pop()
|
||||
local_idxs = local_idxs[:-1]
|
||||
end_local_idxs = end_local_idxs[:-1]
|
||||
# regenerate upcast_idxs
|
||||
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
||||
|
||||
# NOTE: this structure is the same as the reduce op above
|
||||
|
||||
# define late accumulator
|
||||
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
|
||||
# late reduce loop
|
||||
render_loop(end_local_idxs)
|
||||
|
||||
# load localbufs
|
||||
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
|
||||
|
||||
# there's no AST here (and there's no shape for the reduce LazyOp)
|
||||
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True) # type: ignore
|
||||
|
||||
# end the late reduce loop
|
||||
end_loop(end_local_idxs)
|
||||
self.load_cache.clear()
|
||||
|
||||
# load latebufs
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
|
||||
|
||||
# run late AST
|
||||
val = self.ast_parse(self.ast, acc, None, loaded_buffers)
|
||||
|
||||
# store
|
||||
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
||||
|
||||
# end the global (and maybe local) loop
|
||||
if if_gate: self.uop(UOps.END, None, (if_gate,))
|
||||
end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs)
|
||||
|
||||
# (recursively) remove childless uops
|
||||
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.WMMA, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
||||
while 1:
|
||||
has_child: Set[UOp] = set()
|
||||
for ru in self.uops:
|
||||
for vu in ru.vin:
|
||||
has_child.add(vu)
|
||||
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS]
|
||||
if len(nu) == len(self.uops): break
|
||||
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
||||
self.uops = nu
|
||||
|
||||
# restore backups
|
||||
self.sts, self.group_for_reduce, self.upcasted = sts_backup, gfr_backup, upc_backup
|
||||
|
||||
# set cache and return
|
||||
self.applied_opts_cache = self.applied_opts[:]
|
||||
return self
|
||||
|
||||
def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True) -> UOp:
|
||||
key = (uop, dtype, vin, arg)
|
||||
if uop == UOps.PHI and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self phi is noop
|
||||
if uop == UOps.CAST and all(x.uop == UOps.GEP for x in vin) and all_same([x.vin[0] for x in vin]) and all(x.arg == i for i,x in enumerate(vin)): return vin[0].vin[0]
|
||||
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype)
|
||||
if uop == UOps.ALU:
|
||||
# rewrites. NOTE: the rewritten NEG op is still around...
|
||||
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable)
|
||||
# constant folding
|
||||
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype)
|
||||
# zero folding
|
||||
for x in [0,1]:
|
||||
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
||||
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
|
||||
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
|
||||
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
|
||||
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
||||
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
|
||||
self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops)))
|
||||
if DEBUG >= 5: print(self.uops[-1])
|
||||
if cachable: self.saved_exprs[key] = self.uops[-1]
|
||||
return self.uops[-1]
|
||||
|
||||
def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False) -> List[UOp]:
|
||||
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
|
||||
if x.op in BufferOps: return loaded_buffers[x.arg]
|
||||
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, offs, loaded_buffers) # cast isn't an ALU op
|
||||
if x.op in ReduceOps and not do_reduce:
|
||||
assert offs is None, "not available if we aren't doing reduce"
|
||||
return acc
|
||||
# MULACC fusion. TODO: this is copied from Interpreted
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL:
|
||||
x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
|
||||
x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
|
||||
values = [self.ast_parse(v, acc, offs, loaded_buffers) for v in x.src]
|
||||
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
||||
if x.op in ops:
|
||||
ret = []
|
||||
for idx, val, off in zip([[i] for i in range(len(values[0]))], zip(*values), offs):
|
||||
new_val = self.uop(UOps.ALU, dtypes.float32, val+(acc[off],), ops[x.op])
|
||||
# NOTE: we could apply the phi node to only the last change, but this breaks CLANG with nested max(x,y)
|
||||
acc[off] = self.uop(UOps.PHI, dtypes.float32, (acc[off], new_val))
|
||||
ret.append((idx, acc[off]))
|
||||
else:
|
||||
ret = [(idx, self.uop(UOps.ALU, dtypes.float32, val, x.op)) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values))]
|
||||
ordered_ret: List[Optional[UOp]] = [None]*len(values[0])
|
||||
# scatter
|
||||
for i,j in ret:
|
||||
for k in i:
|
||||
ordered_ret[k] = j
|
||||
assert all(isinstance(x, UOp) for x in ordered_ret), "some tokens didn't get scattered?"
|
||||
return cast(List[UOp], ordered_ret)
|
||||
Reference in New Issue
Block a user