openpilot v0.9.6 release
date: 2024-01-12T10:13:37 master commit: ba792d576a49a0899b88a753fa1c52956bedf9e6
This commit is contained in:
204
tinygrad_repo/tinygrad/features/image.py
Normal file
204
tinygrad_repo/tinygrad/features/image.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from typing import List, Tuple, Dict, Any
|
||||
from tinygrad.helpers import ImageDType, prod, IMAGE, getenv, dtypes, DEBUG, flatten
|
||||
|
||||
# *** image Tensor function replacements ***
|
||||
|
||||
from tinygrad.lazy import get_single_root
|
||||
|
||||
def image_dot(self, w):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
||||
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
|
||||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = self.shape[0:-2] + (cout,-1)
|
||||
if len(self.shape) > 1:
|
||||
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
|
||||
else:
|
||||
order, out_shape_t = (0,), (cout, )
|
||||
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
|
||||
|
||||
# NOTE: with NHWC we can remove the transposes
|
||||
# bs x groups*cin x H x W
|
||||
cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
|
||||
# groups*cout x cin x H, W
|
||||
cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
|
||||
return image_conv2d(cx, cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
|
||||
|
||||
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
|
||||
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
||||
|
||||
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
||||
rcout = cout//groups
|
||||
x, w = self, weight.reshape(groups, rcout, cin, H, W)
|
||||
|
||||
# hack for non multiples of 4 on cin
|
||||
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
|
||||
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
|
||||
added_input_channels = 4 - (cin % 4)
|
||||
w = w.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(w.shape))))
|
||||
x = x.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(x.shape))))
|
||||
cin = cin + added_input_channels
|
||||
x = x.reshape(bs, groups*cin, iy, ix)
|
||||
|
||||
# hack for non multiples of 4 on rcout
|
||||
added_output_channels = 0
|
||||
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
|
||||
added_output_channels = 4 - (rcout % 4)
|
||||
rcout += added_output_channels
|
||||
cout = groups * rcout
|
||||
w = w.slice(tuple((0, rcout) if i == 1 else (0, s) for i,s in enumerate(w.shape)))
|
||||
|
||||
# packed (note: flipping bs and iy would make the auto-padding work)
|
||||
x = x.permute(0,2,3,1)
|
||||
cin_last = iy == 1 and ix == 1
|
||||
if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
|
||||
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
|
||||
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
|
||||
|
||||
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
||||
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
|
||||
x, w = x.contiguous(), w.contiguous()
|
||||
if getenv("PREREALIZE", 1) and get_single_root(w.lazydata).realized: w.realize()
|
||||
|
||||
# expand out
|
||||
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
|
||||
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
|
||||
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
||||
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
||||
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
||||
|
||||
# padding
|
||||
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
|
||||
x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
|
||||
|
||||
# prepare input
|
||||
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
||||
oy, ox = x.shape[4:6]
|
||||
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
|
||||
x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W)
|
||||
|
||||
# prepare weights
|
||||
w = w.permute(0,4,2,5,1,3)
|
||||
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)).expand(x.shape)
|
||||
|
||||
# the conv! (+ the bias)
|
||||
ret = x*w
|
||||
if IMAGE >= 2: ret = ret.cast(base_image_type((bs*oy, ox*cout//4, 4)))
|
||||
ret = ret.sum((-4, -3, -2, -1))
|
||||
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
if added_output_channels != 0:
|
||||
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
|
||||
rcout -= added_output_channels
|
||||
cout = groups * rcout
|
||||
|
||||
# NCHW output
|
||||
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
||||
|
||||
# *** schedules with images need to be fixed to be valid ***
|
||||
|
||||
import dataclasses
|
||||
from tinygrad.ops import ScheduleItem, BufferOps, LazyOp, UnaryOps, LoadOps, MemBuffer, get_lazyop_info
|
||||
|
||||
def fix_schedule_for_images(schedule:List[ScheduleItem]):
|
||||
# this is the fundamental fix, find unwritable or unreadable images and convert them to normal float32 (TODO: should it be float16?)
|
||||
replace_inputs = {}
|
||||
for i, si in enumerate(schedule):
|
||||
if isinstance(si.out.dtype, ImageDType) and (prod(si.out.shape) != prod(si.out.dtype.shape) or not any(si.out.shape[x]%4 == 0 for x in si.out.st.unit_stride_axes())):
|
||||
if DEBUG >= 1: print(f"{i:3d}: rewrite output, output shape {prod(si.out.shape)}, image dtype {si.out.dtype} prod {prod(si.out.dtype.shape)}")
|
||||
si.out.dtype = dtypes.float32
|
||||
for b in si.ast.get_lazyops():
|
||||
if b.op != BufferOps.MEM: continue
|
||||
# TODO: unit_stride axes will fail if there's a mask, even if the mask is divisble by four. this is too aggressive
|
||||
if isinstance(si.inputs[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())):
|
||||
if DEBUG >= 1: print(f"{i:3d}: rewrite input, image dtype {si.inputs[b.arg.idx-1].dtype}, {b.arg.st.views}")
|
||||
if si.inputs[b.arg.idx-1].realized:
|
||||
# have to copy it
|
||||
replace_inputs[si.inputs[b.arg.idx-1]] = si.inputs[b.arg.idx-1].cast(dtypes.float32)
|
||||
else:
|
||||
# change it before it's created
|
||||
si.inputs[b.arg.idx-1].dtype = dtypes.float32
|
||||
|
||||
# now fix up the schedule to reflect the new dtypes
|
||||
fixed_schedule:List[ScheduleItem] = []
|
||||
for i,si in enumerate(schedule):
|
||||
ast = si.ast
|
||||
inputs = si.inputs
|
||||
|
||||
# replace inputs with casted versions
|
||||
if any(x in replace_inputs for x in inputs):
|
||||
fixed_schedule += flatten([replace_inputs[x].schedule() for x in inputs if x in replace_inputs])
|
||||
inputs = tuple(replace_inputs.get(x, x) for x in inputs)
|
||||
|
||||
# fix input dtypes to match what they actually are
|
||||
replacements = {}
|
||||
for b in si.ast.get_lazyops():
|
||||
if b.op != BufferOps.MEM: continue
|
||||
if b.arg.dtype != inputs[b.arg.idx-1].dtype:
|
||||
replacements[b] = LazyOp(BufferOps.MEM, (), MemBuffer(b.arg.idx, inputs[b.arg.idx-1].dtype, b.arg.st))
|
||||
if replacements: ast = ast.map_buffers(replacements)
|
||||
|
||||
# fix the ops to create the output dtype
|
||||
if ast.op not in LoadOps:
|
||||
info = get_lazyop_info(ast)
|
||||
if info.dtype != si.out.dtype:
|
||||
if DEBUG >= 3: print(f"{i:3d}: info.dtype {info.dtype} != {si.out.dtype} -> {si.out.dtype}")
|
||||
ast = LazyOp(UnaryOps.CAST, (ast,), (si.out.dtype, False))
|
||||
|
||||
# put this in the fixed schedule
|
||||
fixed_schedule.append(dataclasses.replace(si, ast=ast, inputs=inputs))
|
||||
return fixed_schedule
|
||||
|
||||
# *** images have weird indexing requirements ***
|
||||
|
||||
from tinygrad.shape.symbolic import Node, AndNode, Variable, NumNode, SumNode, LtNode
|
||||
|
||||
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
|
||||
idx = (idxy // 4) % base_shape[1]
|
||||
idy = (idxy // (4 * base_shape[1]))
|
||||
|
||||
if valid.min == 0 and isinstance(idxy, SumNode):
|
||||
nodes = valid.nodes if isinstance(valid, AndNode) else [valid]
|
||||
val_dict: Dict[Node, Any] = {}
|
||||
idxy_flat_var = [(i, i.vars()[0]) for i in idxy.flat_components if not isinstance(i, NumNode)]
|
||||
|
||||
for node in nodes:
|
||||
assert isinstance(node, LtNode)
|
||||
node_flat, node_vars = node.a.flat_components if isinstance(node.a, SumNode) else [node.a], node.vars()
|
||||
same_sym = [i for (i, var) in idxy_flat_var if var in node_vars]
|
||||
if len(same_sym) == 0: continue
|
||||
first, second = sorted(same_sym)[0], sorted(node_flat)[0]
|
||||
f_b = 1 if isinstance(first, Variable) else first.b
|
||||
s_b = 1 if isinstance(second, Variable) else second.b
|
||||
sig = -1 if s_b < 0 else 1
|
||||
key_node = sig*node.a
|
||||
if key_node not in val_dict: val_dict[key_node] = [key_node.min, key_node.max, abs(f_b//s_b)]
|
||||
val_dict[key_node][(sig + 1)//2] = sig*(node.b - 1)
|
||||
|
||||
fakes = {}
|
||||
for cnt, (key_node, (mnn, mxn, multip)) in enumerate(val_dict.items()):
|
||||
fake_var = Variable("fake_" + str(cnt), mnn, mxn)
|
||||
fakes[fake_var] = key_node
|
||||
idxy += multip*(fake_var - key_node)
|
||||
|
||||
idx = (idxy // 4) % base_shape[1]
|
||||
idy = (idxy // (4 * base_shape[1]))
|
||||
|
||||
fake_rep = {fake: node for fake, node in fakes.items()}
|
||||
|
||||
idx = idx.substitute(fake_rep)
|
||||
idy = idy.substitute(fake_rep)
|
||||
|
||||
idy_vars, idx_vars, ones = set(idy.vars()), set(idx.vars()), []
|
||||
for node in nodes:
|
||||
node_vars = set(node.vars())
|
||||
if not node_vars & (idx_vars | idy_vars): continue #There is simplified NumNode which can not go outside the bounds
|
||||
# NOTE: Why does only idy is problematic? and not the idx
|
||||
if idy_vars == node_vars or idy_vars & node_vars == set(): ones.append(node)
|
||||
valid = Variable.ands([i for i in nodes if i not in ones])
|
||||
|
||||
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
|
||||
return (idx, idy), valid
|
||||
151
tinygrad_repo/tinygrad/features/search.py
Normal file
151
tinygrad_repo/tinygrad/features/search.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||
import itertools, random
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.ops import Device, Compiled, MemBuffer
|
||||
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from collections import defaultdict
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7]] for axis in range(6)])
|
||||
actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)])
|
||||
actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)])
|
||||
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
|
||||
actions += [
|
||||
Opt(op=OptOps.LOCAL, axis=0, amt=32),
|
||||
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
|
||||
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
|
||||
Opt(op=OptOps.NOLOCALS),
|
||||
]
|
||||
|
||||
# returns time in seconds
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
|
||||
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size}
|
||||
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
|
||||
try:
|
||||
lin.linearize()
|
||||
prg = cast(Compiled, Device[Device.DEFAULT]).to_program(lin)
|
||||
real_global_size = prg.global_size
|
||||
if allow_test_size and prg.global_size:
|
||||
test_global_size = prg.global_size[:]
|
||||
while prod(test_global_size) > max_global_size:
|
||||
for j in range(2,-1,-1):
|
||||
if test_global_size[j] > 16:
|
||||
test_global_size[j] //= 2
|
||||
break
|
||||
factor = prod(prg.global_size) / prod(test_global_size)
|
||||
prg.global_size = test_global_size
|
||||
#print(real_global_size, test_global_size, factor)
|
||||
else:
|
||||
factor = 1
|
||||
# TODO: this is super broken for var_vals
|
||||
# TODO: this is copied from prg.__call__
|
||||
global_size, local_size = prg.launch_dims(var_vals)
|
||||
if global_size is not None and local_size is None:
|
||||
local_size = prg.optimize_local_size(global_size, rawbufs)
|
||||
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||
tms = []
|
||||
for _ in range(cnt):
|
||||
if clear_l2:
|
||||
# TODO: this is too small for many L2 caches
|
||||
with Context(DEBUG=0): Tensor.rand(1024,1024).realize()
|
||||
lra = prg.runtime_args.copy()
|
||||
if global_size: lra['global_size'] = global_size
|
||||
if local_size: lra['local_size'] = local_size
|
||||
tms.append(prg.clprg(*rawbufs, *var_vals.values(), **lra, wait=True)*factor)
|
||||
prg.global_size = real_global_size
|
||||
except Exception:
|
||||
if DEBUG >= 4:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print("FAILED")
|
||||
print(lin.ast)
|
||||
print(lin.applied_opts)
|
||||
tms = [float('inf')]
|
||||
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
|
||||
return min(tms)
|
||||
|
||||
# get (scrap) buffers for timing the linearizer
|
||||
def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
|
||||
bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
|
||||
for x in lin.membufs: bufsts[x.idx].append(x)
|
||||
rawbufs:List[Optional[RawBuffer]] = [None]*len(bufsts)
|
||||
for k,lx in bufsts.items():
|
||||
rawbufs[k] = cast(Compiled, Device[Device.DEFAULT]).buffer(prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype)
|
||||
assert all(r is not None for r in rawbufs)
|
||||
return cast(List[RawBuffer], rawbufs)
|
||||
|
||||
# get dictionary of all possible actions
|
||||
def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
|
||||
acted_lins = {0:lin} if include_0 else {}
|
||||
for i,a in enumerate(actions):
|
||||
if a.axis is not None and a.axis >= lin.shape_len: continue
|
||||
if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
|
||||
lin2 = lin.copy()
|
||||
try:
|
||||
lin2.apply_opt(a)
|
||||
up, lcl = 1, 1
|
||||
for s,c in zip(lin2.full_shape, lin2.colors()):
|
||||
if c in {"magenta", "yellow"}: up *= s
|
||||
if c in {"cyan", "green", "white"}: lcl *= s
|
||||
if up > 256 or lcl > 256: continue
|
||||
acted_lins[i+1] = lin2
|
||||
except Exception:
|
||||
pass
|
||||
return acted_lins
|
||||
|
||||
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
|
||||
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size}
|
||||
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
|
||||
ret = lin.copy()
|
||||
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
||||
return ret
|
||||
|
||||
# init the BEAM with the base linearizer
|
||||
beam: List[Tuple[Linearizer, float]] = [(lin, time_linearizer(lin, rawbufs, allow_test_size=allow_test_size))]
|
||||
|
||||
# NOTE: real uops use a weird compare method that's only valid inside a linearizer
|
||||
def tuplize_uops(uops): return tuple([(x.uop, x.dtype, tuple(x.num for x in x.vin), x.arg) for x in uops])
|
||||
seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)}
|
||||
|
||||
while 1:
|
||||
acted_lins = lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
|
||||
|
||||
# dedup with uops (TODO: double linearize not needed)
|
||||
acted_lins_dedup = []
|
||||
for lin in acted_lins:
|
||||
tuops = tuplize_uops(lin.linearize().uops)
|
||||
if tuops in seen_uops:
|
||||
#print(seen_uops[tuops], lin.applied_opts)
|
||||
continue
|
||||
seen_uops[tuops] = tuple(lin.applied_opts)
|
||||
acted_lins_dedup.append(lin)
|
||||
acted_lins = acted_lins_dedup
|
||||
|
||||
# time linearizers
|
||||
timed_lins: List[Tuple[Linearizer, float]] = [(v,time_linearizer(v,rawbufs,allow_test_size=allow_test_size)) for v in acted_lins]
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
if len(opts) == 0 or beam[0][1] <= opts[0][1]: break # we didn't get faster
|
||||
|
||||
# keep the BEAM best
|
||||
beam = opts[:amt]
|
||||
if DEBUG >= 2: print(f"{opts[0][1]*1e6:12.2f} us from {len(lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape())
|
||||
|
||||
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
||||
if DEBUG >= 3: print(beam[0][0].applied_opts)
|
||||
return beam[0][0]
|
||||
|
||||
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
|
||||
test_rawbuffers = [type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
|
||||
MAX_WORKGROUP = clprg.max_work_group_size() if hasattr(clprg, 'max_work_group_size') else 1024
|
||||
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
||||
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
||||
def try_exec(local_size):
|
||||
try:
|
||||
return clprg(*test_rawbuffers, global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True)
|
||||
except Exception:
|
||||
return float('inf')
|
||||
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
|
||||
Reference in New Issue
Block a user