# AUTOGENERATED FILE! PLEASE DON'T EDIT
import k1lib, time, torch
from k1lib.callbacks import Callback
_ltime = 15; _lt1 = 8; _lt2 = 18
class TimeData:
    def __init__(self, tProfiler, mS:k1lib.selector.ModuleSelector):
        self.tProfiler = tProfiler; self.mS = mS
        self.startTime = None; self.time = 0; self.hook()
    def hook(self):
        def fpHk(m, i):
            if self.tProfiler.is_cuda: torch.cuda.synchronize()
            self.startTime = time.time()
        def fHk(m, i, o):
            if self.tProfiler.is_cuda: torch.cuda.synchronize()
            self.time += time.time() - self.startTime
        self.fpH = self.mS.nnModule.register_forward_pre_hook(fpHk)
        self.fH = self.mS.nnModule.register_forward_hook(fHk)
    def unhook(self):
        self.tProfiler.totalTime = max(self.tProfiler.totalTime, self.time)
        self.fpH.remove(); self.fH.remove()
    def __getstate__(self):
        answer = dict(self.__dict__)
        del answer["mS"]; del answer["tProfiler"]; return answer
    def __setstate__(self, state): self.__dict__.update(dict(state))
    def __str__(self):
        if self.time <= 1e-20: return ""
        a = f"{k1lib.format.time(self.time)}".ljust(_ltime)
        b = f"{round(100 * self.time / self.tProfiler.totalTime)}%".rjust(_lt1)
        c = f"{round(100 * self.time / self.tProfiler.selectedMaxTime)}%".rjust(_lt2) if self.tProfiler.selectedMaxTime != None and self.mS.selected("_timeProf_") else ""
        return f"{a}{b}{c}"
[docs]class TimeProfiler(Callback):
    """Profiles execution time. Only measures forward times, as
backward times can't really be measured"""
    def startRun(self):
        if not hasattr(self, "selector"): # if no selectors found
            self.selector = self.l.selector.copy().clearProps()
        for m in self.selector.modules(): m.data = TimeData(self, m)
        self.selector.displayF = lambda m: (k1lib.format.red if m.selected("_timeProf_") else k1lib.format.identity)(m.data)
        self.totalTime = 0; self.selectedMaxTime = None
    def startStep(self): return True
[docs]    def run(self):
        """Runs everything"""
        with self.cbs.context(), self.cbs.suspendEval():
            self.is_cuda = next(self.l.model.parameters()).is_cuda
            if self.is_cuda: self.cbs.withCuda()
            else: self.cbs.withCpu()
            self.l.run(1, 1)
        for m in self.selector.modules(): m.data.unhook() 
[docs]    def css(self, css:str):
        """Selects a small part of the network to highlight"""
        self.selector.parse(k1lib.selector.filter(css, "_timeProf_"))
        self.selectedMaxTime = 0
        for m in self.selector.modules():
            if m.selected("_timeProf_"):
                self.selectedMaxTime = max(self.selectedMaxTime, m.data.time)
        print(self.__repr__())
        self.selector.clearProps(); self.selectedMaxTime = None 
    def __repr__(self):
        header = "time".ljust(_ltime) + "% total".rjust(_lt1) + ("% selected max" if self.selectedMaxTime != None else "").rjust(_lt2)
        footer = ""
        if self.selectedMaxTime != None:
            b = f"{round(100 * self.selectedMaxTime / self.totalTime)}%".rjust(_lt1, " ")
            st = f"{k1lib.format.time(self.selectedMaxTime)}".rjust(_lt2)
            footer = ("Selected max", " " * _ltime + b + st)
        return f"""TimeProfiler ({"GPU" if self.is_cuda else "CPU"}):
{k1lib.tab(self.selector.__repr__(intro=False, header=header, footer=footer))}
Caveats: This one's a bit stranger than memory and computation profilers
1. There is no "total" time (adding all times in all modules). There
    is total network execution time tho, which is just the time taken
    for the top level module to execute
2. "% selected max" column displays percentage of selected max, not
    percentage of total selected time, which may matter in your analysis
Can...
- tp.css("..."): highlights a particular part of the network
- tp.selector: to get internal k1lib.ModuleSelector object"""