# AUTOGENERATED FILE! PLEASE DON'T EDIT
from k1lib.callbacks import Callback
import k1lib, numpy as np; from torch import nn
_spacing = lambda s: f"{s} "; # inserted at end of everything, if that element existed
_lcomp = 14; _lp1 = 8; _lp2 = 15; _lp3 = 14
class ComputationData:
def __init__(self, cProfiler, mS:k1lib.selector.ModuleSelector):
self.cProfiler = cProfiler; self.mS = mS; self.flop = 0
self.handle = None; self.hook()
self.flops = 0; self.tS = None # corresponding time selector
def hook(self):
def hk(m, i, o):
i = k1lib.squeeze(i)
if isinstance(m, nn.Linear): self.flop += i.numel() * m.out_features
elif isinstance(m, nn.Conv2d):
self.flop += m.out_channels * i.shape.numel() * np.prod(m.kernel_size)
elif isinstance(m, (nn.LeakyReLU, nn.ReLU, nn.Sigmoid)):
self.flop += i.numel()
self.handle = self.mS.nnModule.register_forward_hook(hk)
def unhook(self):
self.cProfiler.totalFlop += self.flop; self.handle.remove()
def __getstate__(self):
answer = dict(self.__dict__)
del answer["mS"]; del answer["cProfiler"]; return answer
def __setstate__(self, state): self.__dict__.update(dict(state))
def __str__(self):
if self.flop <= 0: return ""
a = _spacing(f"{k1lib.format.computation(self.flop)}".ljust(_lcomp))
b = _spacing(f"{round(100 * self.flop / self.cProfiler.totalFlop)}%".rjust(_lp1))
c = ""
if self.cProfiler.tpAvailable:
self.flops = self.flop / self.tS.data.time
c = _spacing(f"{k1lib.format.computationRate(self.flops)}".ljust(_lp2))
d = ""
if self.cProfiler.selected:
if self.mS.selected("_compProf_"):
d = f"{round(100 * self.flop / self.cProfiler.selectedTotalFlop)}%"
d = _spacing(d.rjust(_lp3))
return f"{a}{b}{c}{d}"
[docs]class ComputationProfiler(Callback):
"""Profiles computation. Only provide reports on well known
layers only, and thus can't really be universal"""
def __init__(self, profiler:"Profiler"):
super().__init__(); self.profiler = profiler
def startRun(self):
if not hasattr(self, "selector"): # if no selectors found
self.selector = self.l.selector.copy().clearProps()
for m in self.selector.modules(): m.data = ComputationData(self, m)
self.selector.displayF = lambda m: (k1lib.format.red if m.selected("_compProf_") else k1lib.format.identity)(m.data)
self.totalFlop = 0; self.selectedTotalFlop = None
@property
def selected(self): return self.selectedTotalFlop != None
@property
def tpAvailable(self) -> bool:
"""Whether TimeProfiler's results are available"""
try: self.profiler._time(); return True
except Exception as e: return False
def startStep(self): return True
[docs] def run(self):
"""Runs everything"""
with self.cbs.context(), self.cbs.suspendEval():
self.cbs.withCpu(); self.l.run(1, 1)
for m in self.selector.modules(): m.data.unhook()
def detached(self): # time profiler integration, so that flops can be displayed
if self.tpAvailable:
for cS, tS in zip(self.selector.modules(), self.profiler.time.selector.modules()):
cS.data.tS = tS # injecting dependency
[docs] def css(self, css:str):
"""Selects a small part of the network to highlight"""
self.selector.parse(k1lib.selector.filter(css, "_compProf_"))
self.selectedTotalFlop = 0
for m in self.selector.modules():
if m.selected("_compProf_"):
self.selectedTotalFlop += m.data.flop
print(self.__repr__())
self.selector.clearProps(); self.selectedTotalFlop = None
def __repr__(self):
header = _spacing("computation".ljust(_lcomp))
header += _spacing("% total".rjust(_lp1))
header += _spacing("rate".ljust(_lp2)) if self.tpAvailable else ""
header += _spacing("% selected".rjust(_lp3)) if self.selected else ""
footer = _spacing(f"{k1lib.format.computation(self.totalFlop)}".ljust(_lcomp))
footer += _spacing("".rjust(_lp1))
footer += _spacing("".ljust(_lp2)) if self.tpAvailable else ""
footer += _spacing(f"{k1lib.format.computation(self.selectedTotalFlop)}".rjust(_lp3)) if self.selected else ''
footer = ("Total", footer)
return f"""ComputationProfiler:
{k1lib.tab(self.selector.__repr__(intro=False, header=header, footer=footer))}
The "rate" column will appear if integration with Profiler.time is
possible, showing actual ops/s
Can...
- cp.css("..."): highlights a particular part of the network
- cp.selector: to get internal k1lib.ModuleSelector object"""