openpilot v0.9.6 release
date: 2024-02-21T23:02:42 master commit: 0b4d08fab8e35a264bc7383e878538f8083c33e5
This commit is contained in:
117
tinygrad_repo/tinygrad/graph.py
Normal file
117
tinygrad_repo/tinygrad/graph.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import os, atexit, functools
|
||||
try:
|
||||
import networkx as nx # type: ignore
|
||||
except ImportError:
|
||||
nx = None # graph won't work
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup
|
||||
from tinygrad.codegen.linearizer import UOps
|
||||
|
||||
# **** debugging and graphing ****
|
||||
|
||||
G = nx.DiGraph() if nx is not None else None
|
||||
cnts: Dict[OpType, int] = defaultdict(int)
|
||||
if DEBUG >= 2:
|
||||
def print_globalcounters():
|
||||
if GlobalCounters.time_sum_s == 0: return
|
||||
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s",
|
||||
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms")
|
||||
atexit.register(print_globalcounters)
|
||||
if GRAPH:
|
||||
def save_graph_exit():
|
||||
for k,v in cnts.items(): print(k, v)
|
||||
print("saving", G)
|
||||
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
|
||||
# -Gnslimit=100 can make it finish, but you won't like results
|
||||
os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
|
||||
atexit.register(save_graph_exit)
|
||||
|
||||
node_count = 0
|
||||
def nm(x):
|
||||
global node_count
|
||||
if not hasattr(x, 'node_id'):
|
||||
setattr(x, 'node_id', node_count)
|
||||
node_count += 1
|
||||
return x.node_id
|
||||
|
||||
def get_sop(op: List[Op]):
|
||||
op = [x for x in op if x not in BufferOps]
|
||||
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
|
||||
if len(op) <= 6: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
|
||||
return str(len(op))
|
||||
|
||||
def str_dtype(dtyp):
|
||||
ret = str(dtyp)[7:]
|
||||
return "" if ret == 'float' else f"\n{ret}"
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def add_st_node(nmx, nmo, label, st):
|
||||
global node_count
|
||||
inter_node = node_count
|
||||
node_count += 1
|
||||
G.add_node(inter_node, style='filled', fillcolor="#80ff8080", color="black", label=f"{st.shape}\n{st.real_strides()}" + (f"\n{st.real_offset()}" if st.real_offset() != 0 else ""))
|
||||
G.add_edge(nmx, inter_node, color='#00000060')
|
||||
G.add_edge(inter_node, nmo, label=label, color='#00000060')
|
||||
|
||||
logops = open(getenv("LOGOPS", ""),"a") if getenv("LOGOPS", "") else None
|
||||
def log_schedule_item(si: ScheduleItem):
|
||||
if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n")
|
||||
show_graph = bool(GRAPH)
|
||||
if not DEBUG and not show_graph: return
|
||||
if si.ast.op == LoadOps.CONTIGUOUS: setattr(si.out, 'node_id', nm(si.inputs[0].base))
|
||||
if si.ast.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return
|
||||
|
||||
op: List[Op] = [x.op for x in si.ast.get_lazyops()]
|
||||
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps, BufferOps]
|
||||
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
|
||||
cnts[optype] += 1
|
||||
if show_graph:
|
||||
assert si.out.base == si.out, "all outputs based"
|
||||
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'}
|
||||
|
||||
# get inputs for shapetrackers
|
||||
input_to_st = defaultdict(list)
|
||||
for lo in si.ast.get_lazyops():
|
||||
if lo.op != BufferOps.MEM: continue
|
||||
input_to_st[si.inputs[lo.arg.idx-1]].append(lo.arg.st)
|
||||
|
||||
# add them to the graph, potentially with a movement op seperating them
|
||||
for x in input_to_st:
|
||||
for st in dedup(input_to_st[x]):
|
||||
if st.contiguous:
|
||||
G.add_edge(nm(x), nm(si.out), label=get_sop(op), color='#00000060')
|
||||
else:
|
||||
add_st_node(nm(x), nm(si.out), get_sop(op), st)
|
||||
if 'label' not in G.nodes[nm(x)]:
|
||||
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(si.out.dtype)
|
||||
|
||||
if nm(si.out) not in G.nodes: G.add_node(nm(si.out))
|
||||
|
||||
G.nodes[nm(si.out)]['label'] = (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps else "")
|
||||
G.nodes[nm(si.out)]['fillcolor'] = top_colors[optype]
|
||||
G.nodes[nm(si.out)]['color'] = 'black'
|
||||
G.nodes[nm(si.out)]['style'] = 'filled'
|
||||
|
||||
def _tree(lazydata, prefix=""):
|
||||
if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")
|
||||
if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
|
||||
lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
|
||||
childs = [_tree(c) for c in lazydata.src[:]]
|
||||
for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
|
||||
return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
|
||||
|
||||
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))]))
|
||||
|
||||
def graph_uops(uops):
|
||||
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
|
||||
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
|
||||
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0"}
|
||||
G = nx.DiGraph()
|
||||
for u in uops:
|
||||
G.add_node(u.num, label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff"))
|
||||
for v in u.vin: G.add_edge(v.num, u.num)
|
||||
GRAPHPATH = "/tmp/uops"
|
||||
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
|
||||
os.system(f'dot -Grankdir=LR -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
|
||||
Reference in New Issue
Block a user