# AUTOGENERATED FILE! PLEASE DON'T EDIT
from k1lib.callbacks import Callback
import k1lib
_li = 30
class IOData:
    def __init__(self, ioProfiler, mS:k1lib.selector.ModuleSelector):
        self.ioProfiler = ioProfiler; self.mS = mS
        self.iS = None; self.oS = None
        self.handle = None; self.hook()
    def hook(self):
        def hk(m, i, o):
            self.iS = list(k1lib.squeeze(i, True).shape)
            self.oS = list(k1lib.squeeze(o, True).shape)
        self.handle = self.mS.nnModule.register_forward_hook(hk)
    def unhook(self): self.handle.remove()
    def __getstate__(self):
        answer = dict(self.__dict__)
        del answer["mS"]; del answer["ioProfiler"]; return answer
    def __setstate__(self, state): self.__dict__.update(dict(state))
    def __str__(self):
        a = f"{self.iS}".ljust(_li); b = f"{self.oS}".ljust(_li)
        return f"{a}{b}"
[docs]class IOProfiler(Callback):
    """Gets input and output shapes of each layer"""
    def startRun(self):
        if not hasattr(self, "selector"): # if no selectors found
            self.selector = self.l.selector.copy().clearProps()
        for m in self.selector.modules(): m.data = IOData(self, m)
        self.selector.displayF = lambda m: (k1lib.format.red if m.selected("_ioProf_") else k1lib.format.identity)(m.data)
    def startStep(self): return True
[docs]    def run(self):
        """Runs everything"""
        with self.cbs.suspendEval(): self.l.run(1, 1)
        for m in self.selector.modules(): m.data.unhook() 
[docs]    def css(self, css:str):
        """Selects a small part of the network to highlight"""
        self.selector.parse(k1lib.selector.filter(css, "_ioProf_"))
        print(self.__repr__()); self.selector.clearProps() 
    def __repr__(self):
        header = "input shape".ljust(_li) + "output shape".ljust(_li)
        return f"""IOProfiler:
{k1lib.tab(self.selector.__repr__(intro=False, header=header))}
Can...
- iop.css("..."): highlights a particular part of the network
- iop.selector: to get internal k1lib.ModuleSelector object"""