# AUTOGENERATED FILE! PLEASE DON'T EDIT
from k1lib.callbacks import Callback, Callbacks
import k1lib, torch.nn as nn
import matplotlib.pyplot as plt
from functools import partial
from typing import List, Tuple, Callable, Union
__all__ = ["HookParam"]
class ParamData(k1lib.Object):
def __init__(self):
super().__init__()
self.means = []; self.stds = []
self.mins = []; self.maxs = []
def update(self, torchParam:nn.Parameter):
self.means.append(torchParam.mean().item())
self.stds.append(torchParam.std().item())
self.mins.append(torchParam.min().item())
self.maxs.append(torchParam.max().item())
def __len__(self): return len(self.means)
def __repr__(self):
return f"""Param's saved data. Use...
- d.means: to get list of means
- d.stds: to get list of means
- d.mins: to get list of mins
- d.maxs: to get list of maxs"""
class Param:
def __init__(self, name:str, torchParam:nn.Parameter):
self.name = name
self.torchParam = torchParam
self.data = ParamData()
def update(self): self.data.update(self.torchParam.detach())
def __repr__(self):
return f"""Param `{self.name}`. Use...
- p.torchParam: to get actual underlying parameter
- p.data: to get data stored
- cb.plot(): to quickly look at everything"""
[docs]@k1lib.patch(Callback.cls)
class HookParam(Callback):
"""Records means and stds of all parameters"""
def __init__(self):
""
super().__init__(); self.params:List[Param] = []
def __getitem__(self, idx:Union[int, slice]):
if type(idx) == int: return self.params[idx]
answer = HookParam(); answer.params = self.params[idx]; return answer
def __len__(self): return len(self.params)
def _selected(self, paramName:str):
splits = paramName.split(".")
try:
mS = self.l.selector
for split in splits[:-1]: mS = mS[split]
return mS.selected("HookParam") and hasattr(mS, splits[-1])
except KeyError: return False
def startRun(self):
if len(self) == 0: # set things up first time only
self.params = [Param(k, v) for k, v in self.l.model.named_parameters() if self._selected(k)]
def startBatch(self): [param.update() for param in self.params]
[docs] def css(self, css:str):
"""Creates a new HookParam object with selected modules. May be useful
for displaying a subset of the recorded data"""
oldSelector = self.l.selector; answer = HookParam()
self.l.selector = k1lib.selector.select(self.l.model, css)
answer.params = [param for param in self.params if self._selected(param.name)]
self.l.selector = oldSelector; return answer
def __repr__(self):
s = f", {len(self[0].data)} means and stds each" if len(self) > 0 else ""
names = "\n".join([f" {i}. {p.name}" for i, p in enumerate(self)])
return f"""{super()._reprHead}: {len(self)} params{s}:\n{names}\n
Use...
- p.plot(): to quickly look at everything
- p[i]: to view a single param
- p[a:b]: to get a new HookParam with selected params
- p.css("..."): to select a specific subset of modules only
{super()._reprCan}"""
def plotF(params:Union[HookParam, Param, List[Param]], rangeSlice:slice):
if type(params) == Param: params = [params]
fields = params[0].data.state.keys(); step = rangeSlice.step or 1
fig, axes = plt.subplots(2, 2, figsize=(10, 6), dpi=100)
axes = axes.flatten()
for field, ax in zip(fields, axes):
for param in params:
fieldData = param.data[field]
r = k1lib.Range(len(fieldData))[rangeSlice]
ax.plot(r.range_[::step], fieldData[r.slice_][::step])
ax.set_title(field.capitalize())
plt.figlegend([p.name for p in params], loc='right')
@k1lib.patch(HookParam)
@k1lib.patch(Param)
def plot(self): return k1lib.viz.SliceablePlot(partial(plotF, self))
@k1lib.patch(Callbacks, docs=HookParam)
def withHookParam(self): return self.append(HookParam())