# AUTOGENERATED FILE! PLEASE DON'T EDIT
from .callbacks import Callback, Callbacks, Cbs
import k1lib; from k1lib import squeeze
import torch; import torch.nn as nn
from functools import partial
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict, Iterator, Union, Any, Callable
__all__ = ["HookModule"]
class Handles:
def __init__(self):
self.forward = None; self.backward = None
def remove(self):
if self.active:
self.forward.remove(); self.forward = None
self.backward.remove(); self.backward = None
@property
def active(self):
if self.forward != None and self.backward != None: return True
elif self.forward == None and self.backward == None: return False
raise Exception("Supposed to be unreachable")
class Data(k1lib.Object):
def __init__(self):
super().__init__(); self.withAutoDeclare(lambda: [])
class ModuleData:
def __init__(self): self.forward = Data(); self.backward = Data()
def _plot(self, axes, field:str, rangeSlice:slice):
forwardData = self.forward[field]; step = rangeSlice.step or 1
backwardData = self.backward[field]
if len(forwardData) == 0 or len(backwardData) == 0: return
fR, bR = k1lib.Range.proportionalSlice(len(forwardData), len(backwardData), rangeSlice)
axes[0].plot(fR.range_[::step], forwardData[fR.slice_][::step], alpha=0.5)
axes[1].plot(bR.range_[::step], backwardData[bR.slice_][::step], alpha=0.5)
def __repr__(self):
return """Module's saved data. can...
- d.forward: to get data stored during forward pass
- d.backward: to get data stored during backward pass"""
_Fn = Callable[[Data, nn.Module, Tuple[torch.Tensor], Tuple[torch.Tensor]], None]
class Function:
def __init__(self, f:_Fn, name=None):
self.f = f; self.name = name or "f(<no name>)"
def __call__(self, *args, **kwargs):
self.f(*args, **kwargs)
def hook(fns:List[Function], *args): [fn(*args) for fn in fns]
class Module:
def __init__(self, module:nn.Module):
self.nnModule = module
self.handles = Handles()
self.data = ModuleData()
self.name = module.__class__.__name__
def registerHooks(self, forwardFns:List[Function], backwardFns:List[Function]):
self.handles.forward = self.nnModule.register_forward_hook(partial(hook, forwardFns, self.data.forward))
self.handles.backward = self.nnModule.register_full_backward_hook(partial(hook, backwardFns, self.data.backward))
return self
def unregisterHooks(self): self.handles.remove()
def __repr__(self):
return f"""Module `{self.name}`. Use...
- m.data: to get data stored
- m.nnModule: to get actual nn.Module object
- m.plot("means", "stds"): to plot simple statistics"""
[docs]@k1lib.patch(Cbs)
class HookModule(Callback):
"""Hooks into selected modules in the network, and
execute functions like .mean(), .std(). This is fairly
complicated, and I highly recommend displaying this
callback in a cell for more info"""
[docs] def __init__(self, persistent:bool=False):
"""
:param persistent: whether to save results across
runs. If false, then can execute `.reset()` to
reset everything"""
super(HookModule, self).__init__()
self.modules:List[Module] = []
self.forwardFns:List[Function] = []
self.backwardFns:List[Function] = []
self.cleanFns = []; self.persistent = persistent
[docs] def reset(self):
"""Intended to be called by end user only, to reset
everything if choose to persist results across runs."""
self._end(); self._start()
def startRun(self):
if (not self.persistent) or (len(self.modules) == 0): self._start()
def _registerHooks(self):
for module in self.modules:
module.registerHooks(self.forwardFns, self.backwardFns)
def _unregisterHooks(self):
for module in self.modules: module.unregisterHooks()
def endRun(self):
if not self.persistent: self._end()
[docs] def suspend(self):
self.actuallyRestore = len(self) == 0 or self[0].handles.active
if self.actuallyRestore: self._unregisterHooks()
[docs] def restore(self):
if self.actuallyRestore:
self._registerHooks()
self.actuallyRestore = False
def __getitem__(self, idx):
if type(idx) == int: return self.modules[idx]
answer = HookModule(self.persistent)
answer.modules = self.modules[idx]
return answer
def __len__(self): return len(self.modules)
def __repr__(self):
f = '\n'.join([f' - {fn.name or str(fn)}' for fn in self.forwardFns])
f = "" if f == "" else f"Forward hooks:\n{f}\n"
b = '\n'.join([f' - {fn.name or str(fn)}' for fn in self.backwardFns])
b = "" if b == "" else f"Backward hooks:\n{b}\n"
n = '\n'.join([f' {i}. {data.name}' for i, data in enumerate(self)])
excludes = {"withForwardHook", "withBackwardHook", "withHook", "withCheckpoint"}
withs = '\n'.join([f"- m.{key}()" for key in dir(self) if key.startswith("with") and key not in excludes])
return f"""{super()._reprHead} with {len(self)} modules:\n{n}\n{f}{b}
Use...
- m.plot("means", "stds"): to plot simple statistics
- m[i]: to get a specific module
- m[a:b]: to get a new HookModule with selected modules
- m.css("..."): to select a specific subset of modules only
- m.withHook(hookCb): to hook a specific callback function
- m.clearHooks(): to clear all hooks
{super()._reprCan}
Built-in `with-` functions:\n{withs}"""
@k1lib.patch(HookModule)
def _start(self):
self.modules = []
for nnModule, sel in zip(self.l.model.modules(), self.l.selector.modules()):
if sel.selected("HookModule"): self.modules.append(Module(nnModule))
self._registerHooks()
@k1lib.patch(HookModule)
def _end(self):
for module in self.modules:
for cleanFn in self.cleanFns:
cleanFn(module.data)
self._unregisterHooks()
@k1lib.patch(HookModule)
def withForwardHook(self, hook:_Fn, name:str=None):
"""Adds a hook to the forward pass. See :func:`~k1lib.callbacks.hookModule.HookModule.withHook`"""
self.forwardFns += [Function(hook, name)]; return self
@k1lib.patch(HookModule)
def withBackwardHook(self, hook:_Fn, name:str=None):
"""Adds a hook to the backward pass. See :func:`~k1lib.callbacks.hookModule.HookModule.withHook`"""
self.backwardFns += [Function(hook, name)]; return self
@k1lib.patch(HookModule)
def withHook(self, hook:_Fn, name:str=None):
"""Adds a hook to both the forward and backward pass.
:param hook: this function is expected to take in these parameters: **(data, module, inp, out)**
:data: the injected dependency for you to store stuff.
Initially, `data.max` is an empty list, and you can
append to it directly, like this::
data.max.append() # okay
Later on, you can do things like::
HookModule[i].forward.max
and get the data you saved from the hook.
:module: the module this function hooks into. Please
refer to :func:`torch.nn.Module.register_forward_hook()` to
know more.
:inp: input (or grad of input) to the module
:out: output (or grad of output) to the module
:param name: custom name for the function for nice displaying
See also: m.withForwardHook(), m.withBackwardHook()"""
return self.withForwardHook(hook, name).withBackwardHook(hook, name)
@k1lib.patch(HookModule)
def clearHooks(self):
self._unregisterHooks()
self.forwardFns = []; self.backwardFns = []
self.cleanFns = []; return self
def meanCb(data, m, inp, out):
data.means.append(squeeze(out, hard=True).data.mean().item())
@k1lib.patch(HookModule)
def withMeanRecorder(self):
"""Records mean"""
return self.withHook(meanCb, "mean")
def stdCb(data, m, inp, out):
data.stds.append(squeeze(out, hard=True).data.std().item())
@k1lib.patch(HookModule)
def withStdRecorder(self):
"""Records standard deviation"""
return self.withHook(stdCb, "std")
def minCb(data, m, inp, out):
data.mins.append(squeeze(out, hard=True).data.min().item())
@k1lib.patch(HookModule)
def withMinRecorder(self):
"""Records min"""
return self.withHook(minCb, "min")
def maxCb(data, m, inp, out):
data.maxs.append(squeeze(out, hard=True).data.max().item())
@k1lib.patch(HookModule)
def withMaxRecorder(self):
"""Records max"""
return self.withHook(maxCb, "max")
@k1lib.patch(HookModule)
def css(self, css:str):
answer = HookModule()
selector = k1lib.selector.select(self.l.model, css)
d = {m.nnModule: m for m in self.modules}
for nnModule, sel in zip(self.l.model.modules(), selector.modules()):
if sel.selected("HookModule") and sel.nnModule in d:
answer.modules.append(d[sel.nnModule])
return answer
def plotF(modules:HookModule, fields:List[str], rangeSlice:slice):
fig, axes = plt.subplots(len(fields), 2, figsize=(10, 3*len(fields)), dpi=100)
axes = axes.reshape((-1, 2))
for axs, field in zip(axes, fields):
for module in modules:
module.data._plot(axs, field, rangeSlice)
axs[0].set_title(f"Forward {field}")
axs[1].set_title(f"Backward {field}")
plt.figlegend([f"{i}. {module.name}" for i, module in enumerate(modules)], loc='center right')
@k1lib.patch(HookModule)
@k1lib.patch(Module)
def plot(self, *fields:List[str]):
"""Plots every simple (1 number saved/pass/module) fields.
:param fields: list of fields to plot. If none, then
will automatically find all simple fields"""
modules = [self] if isinstance(self, Module) else self
if len(modules) == 0: raise Exception("No modules to plot!")
if len(fields) == 0:
fields = []; forwardData = modules[0].data.forward
for field in forwardData.state.keys():
if field.startswith("_"): continue
fieldData = forwardData[field]
if type(fieldData) == list and k1lib.isNumeric(fieldData[0]):
fields.append(field)
return k1lib.viz.SliceablePlot(partial(plotF, modules, fields))
@k1lib.patch(Callbacks, docs=HookModule)
def withHookModule(self, persistent=True):
return self.append(HookModule(persistent).withMeanRecorder().withStdRecorder())