Source code for k1lib.callbacks.hookModule

# AUTOGENERATED FILE! PLEASE DON'T EDIT
from .callbacks import Callback, Callbacks, Cbs
import k1lib; from k1lib import squeeze
import torch; import torch.nn as nn
from functools import partial
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict, Iterator, Union, Any, Callable
__all__ = ["HookModule"]
class Handles:
    def __init__(self):
        self.forward = None; self.backward = None
    def remove(self):
        if self.active:
            self.forward.remove(); self.forward = None
            self.backward.remove(); self.backward = None
    @property
    def active(self):
        if self.forward != None and self.backward != None: return True
        elif self.forward == None and self.backward == None: return False
        raise Exception("Supposed to be unreachable")
class Data(k1lib.Object):
    def __init__(self):
        super().__init__(); self.withAutoDeclare(lambda: [])
class ModuleData:
    def __init__(self): self.forward = Data(); self.backward = Data()
    def _plot(self, axes, field:str, rangeSlice:slice):
        forwardData = self.forward[field]; step = rangeSlice.step or 1
        backwardData = self.backward[field]
        if len(forwardData) == 0 or len(backwardData) == 0: return
        fR, bR = k1lib.Range.proportionalSlice(len(forwardData), len(backwardData), rangeSlice)
        axes[0].plot(fR.range_[::step], forwardData[fR.slice_][::step], alpha=0.5)
        axes[1].plot(bR.range_[::step], backwardData[bR.slice_][::step], alpha=0.5)
    def __repr__(self):
        return """Module's saved data. can...
- d.forward: to get data stored during forward pass
- d.backward: to get data stored during backward pass"""
_Fn = Callable[[Data, nn.Module, Tuple[torch.Tensor], Tuple[torch.Tensor]], None]
class Function:
    def __init__(self, f:_Fn, name=None):
        self.f = f; self.name = name or "f(<no name>)"
    def __call__(self, *args, **kwargs):
        self.f(*args, **kwargs)
def hook(fns:List[Function], *args): [fn(*args) for fn in fns]
class Module:
    def __init__(self, module:nn.Module):
        self.nnModule = module
        self.handles = Handles()
        self.data = ModuleData()
        self.name = module.__class__.__name__
    def registerHooks(self, forwardFns:List[Function], backwardFns:List[Function]):
        self.handles.forward = self.nnModule.register_forward_hook(partial(hook, forwardFns, self.data.forward))
        self.handles.backward = self.nnModule.register_full_backward_hook(partial(hook, backwardFns, self.data.backward))
        return self
    def unregisterHooks(self): self.handles.remove()
    def __repr__(self):
        return f"""Module `{self.name}`. Use...
- m.data: to get data stored
- m.nnModule: to get actual nn.Module object
- m.plot("means", "stds"): to plot simple statistics"""
[docs]@k1lib.patch(Cbs) class HookModule(Callback): """Hooks into selected modules in the network, and execute functions like .mean(), .std(). This is fairly complicated, and I highly recommend displaying this callback in a cell for more info"""
[docs] def __init__(self, persistent:bool=False): """ :param persistent: whether to save results across runs. If false, then can execute `.reset()` to reset everything""" super(HookModule, self).__init__() self.modules:List[Module] = [] self.forwardFns:List[Function] = [] self.backwardFns:List[Function] = [] self.cleanFns = []; self.persistent = persistent
[docs] def reset(self): """Intended to be called by end user only, to reset everything if choose to persist results across runs.""" self._end(); self._start()
def startRun(self): if (not self.persistent) or (len(self.modules) == 0): self._start() def _registerHooks(self): for module in self.modules: module.registerHooks(self.forwardFns, self.backwardFns) def _unregisterHooks(self): for module in self.modules: module.unregisterHooks() def endRun(self): if not self.persistent: self._end()
[docs] def suspend(self): self.actuallyRestore = len(self) == 0 or self[0].handles.active if self.actuallyRestore: self._unregisterHooks()
[docs] def restore(self): if self.actuallyRestore: self._registerHooks() self.actuallyRestore = False
def __getitem__(self, idx): if type(idx) == int: return self.modules[idx] answer = HookModule(self.persistent) answer.modules = self.modules[idx] return answer def __len__(self): return len(self.modules) def __repr__(self): f = '\n'.join([f' - {fn.name or str(fn)}' for fn in self.forwardFns]) f = "" if f == "" else f"Forward hooks:\n{f}\n" b = '\n'.join([f' - {fn.name or str(fn)}' for fn in self.backwardFns]) b = "" if b == "" else f"Backward hooks:\n{b}\n" n = '\n'.join([f' {i}. {data.name}' for i, data in enumerate(self)]) excludes = {"withForwardHook", "withBackwardHook", "withHook", "withCheckpoint"} withs = '\n'.join([f"- m.{key}()" for key in dir(self) if key.startswith("with") and key not in excludes]) return f"""{super()._reprHead} with {len(self)} modules:\n{n}\n{f}{b} Use... - m.plot("means", "stds"): to plot simple statistics - m[i]: to get a specific module - m[a:b]: to get a new HookModule with selected modules - m.css("..."): to select a specific subset of modules only - m.withHook(hookCb): to hook a specific callback function - m.clearHooks(): to clear all hooks {super()._reprCan} Built-in `with-` functions:\n{withs}"""
@k1lib.patch(HookModule) def _start(self): self.modules = [] for nnModule, sel in zip(self.l.model.modules(), self.l.selector.modules()): if sel.selected("HookModule"): self.modules.append(Module(nnModule)) self._registerHooks() @k1lib.patch(HookModule) def _end(self): for module in self.modules: for cleanFn in self.cleanFns: cleanFn(module.data) self._unregisterHooks() @k1lib.patch(HookModule) def withForwardHook(self, hook:_Fn, name:str=None): """Adds a hook to the forward pass. See :func:`~k1lib.callbacks.hookModule.HookModule.withHook`""" self.forwardFns += [Function(hook, name)]; return self @k1lib.patch(HookModule) def withBackwardHook(self, hook:_Fn, name:str=None): """Adds a hook to the backward pass. See :func:`~k1lib.callbacks.hookModule.HookModule.withHook`""" self.backwardFns += [Function(hook, name)]; return self @k1lib.patch(HookModule) def withHook(self, hook:_Fn, name:str=None): """Adds a hook to both the forward and backward pass. :param hook: this function is expected to take in these parameters: **(data, module, inp, out)** :data: the injected dependency for you to store stuff. Initially, `data.max` is an empty list, and you can append to it directly, like this:: data.max.append() # okay Later on, you can do things like:: HookModule[i].forward.max and get the data you saved from the hook. :module: the module this function hooks into. Please refer to :func:`torch.nn.Module.register_forward_hook()` to know more. :inp: input (or grad of input) to the module :out: output (or grad of output) to the module :param name: custom name for the function for nice displaying See also: m.withForwardHook(), m.withBackwardHook()""" return self.withForwardHook(hook, name).withBackwardHook(hook, name) @k1lib.patch(HookModule) def clearHooks(self): self._unregisterHooks() self.forwardFns = []; self.backwardFns = [] self.cleanFns = []; return self def meanCb(data, m, inp, out): data.means.append(squeeze(out, hard=True).data.mean().item()) @k1lib.patch(HookModule) def withMeanRecorder(self): """Records mean""" return self.withHook(meanCb, "mean") def stdCb(data, m, inp, out): data.stds.append(squeeze(out, hard=True).data.std().item()) @k1lib.patch(HookModule) def withStdRecorder(self): """Records standard deviation""" return self.withHook(stdCb, "std") def minCb(data, m, inp, out): data.mins.append(squeeze(out, hard=True).data.min().item()) @k1lib.patch(HookModule) def withMinRecorder(self): """Records min""" return self.withHook(minCb, "min") def maxCb(data, m, inp, out): data.maxs.append(squeeze(out, hard=True).data.max().item()) @k1lib.patch(HookModule) def withMaxRecorder(self): """Records max""" return self.withHook(maxCb, "max") @k1lib.patch(HookModule) def css(self, css:str): answer = HookModule() selector = k1lib.selector.select(self.l.model, css) d = {m.nnModule: m for m in self.modules} for nnModule, sel in zip(self.l.model.modules(), selector.modules()): if sel.selected("HookModule") and sel.nnModule in d: answer.modules.append(d[sel.nnModule]) return answer def plotF(modules:HookModule, fields:List[str], rangeSlice:slice): fig, axes = plt.subplots(len(fields), 2, figsize=(10, 3*len(fields)), dpi=100) axes = axes.reshape((-1, 2)) for axs, field in zip(axes, fields): for module in modules: module.data._plot(axs, field, rangeSlice) axs[0].set_title(f"Forward {field}") axs[1].set_title(f"Backward {field}") plt.figlegend([f"{i}. {module.name}" for i, module in enumerate(modules)], loc='center right') @k1lib.patch(HookModule) @k1lib.patch(Module) def plot(self, *fields:List[str]): """Plots every simple (1 number saved/pass/module) fields. :param fields: list of fields to plot. If none, then will automatically find all simple fields""" modules = [self] if isinstance(self, Module) else self if len(modules) == 0: raise Exception("No modules to plot!") if len(fields) == 0: fields = []; forwardData = modules[0].data.forward for field in forwardData.state.keys(): if field.startswith("_"): continue fieldData = forwardData[field] if type(fieldData) == list and k1lib.isNumeric(fieldData[0]): fields.append(field) return k1lib.viz.SliceablePlot(partial(plotF, modules, fields)) @k1lib.patch(Callbacks, docs=HookModule) def withHookModule(self, persistent=True): return self.append(HookModule(persistent).withMeanRecorder().withStdRecorder())