# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib
import matplotlib.pyplot as plt
from functools import partial
from typing import List, Tuple, Callable, Union
try: import torch; import torch.nn as nn; hasTorch = True
except:
torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["HookParam"]
class ParamData(k1lib.Object): # ParamData
def __init__(self): # ParamData
super().__init__() # ParamData
self.means = []; self.stds = [] # ParamData
self.mins = []; self.maxs = [] # ParamData
def update(self, torchParam:nn.Parameter): # ParamData
self.means.append(torchParam.mean().item()) # ParamData
self.stds.append(torchParam.std().item()) # ParamData
self.mins.append(torchParam.min().item()) # ParamData
self.maxs.append(torchParam.max().item()) # ParamData
def __len__(self): return len(self.means) # ParamData
def __repr__(self): # ParamData
return f"""Param's saved data. Use...
- d.means: to get list of means
- d.stds: to get list of means
- d.mins: to get list of mins
- d.maxs: to get list of maxs""" # ParamData
class Param: # Param
def __init__(self, name:str, torchParam:nn.Parameter): # Param
self.name = name # Param
self.torchParam = torchParam # Param
self.data = ParamData() # Param
self.every = k1lib.Every(3) # Param
def update(self): # Param
if self.every(): self.data.update(self.torchParam.detach()) # Param
def __repr__(self): # Param
return f"""Param `{self.name}`. Use...
- p.torchParam: to get actual underlying parameter
- p.data: to get data stored
- cb.plot(): to quickly look at everything""" # Param
[docs]@k1lib.patch(Cbs) # Param
class HookParam(Callback): # HookParam
"""Records means and stds of all parameters""" # HookParam
def __init__(self): # HookParam
"" # HookParam
super().__init__(); self.params:List[Param] = [] # HookParam
def __getitem__(self, idx:Union[int, slice]): # HookParam
if type(idx) == int: return self.params[idx] # HookParam
answer = HookParam(); answer.params = self.params[idx]; return answer # HookParam
def __len__(self): return len(self.params) # HookParam
def _selected(self, paramName:str): # HookParam
splits = paramName.split(".") # HookParam
try: # HookParam
mS = self.l.selector # HookParam
for split in splits[:-1]: mS = mS[split] # HookParam
return "HookParam" in mS and hasattr(mS, splits[-1]) # HookParam
except KeyError: return False # HookParam
def startRun(self): # HookParam
if len(self) == 0: # set things up first time only # HookParam
self.params = [Param(k, v) for k, v in self.l.model.named_parameters() if self._selected(k)] # HookParam
def startBatch(self): [param.update() for param in self.params] # HookParam
[docs] def css(self, css:str): # HookParam
"""Creates a new HookParam object with selected modules. May be useful
for displaying a subset of the recorded data""" # HookParam
oldSelector = self.l.selector; answer = HookParam() # HookParam
self.l.selector = k1lib.selector.select(self.l.model, css) # HookParam
answer.params = [param for param in self.params if self._selected(param.name)] # HookParam
self.l.selector = oldSelector; return answer # HookParam
def __repr__(self): # HookParam
s = f", {len(self[0].data)} means and stds each" if len(self) > 0 else "" # HookParam
names = "\n".join([f" {i}. {p.name}" for i, p in enumerate(self)]) # HookParam
return f"""{super()._reprHead}: {len(self)} params{s}:\n{names}\n
Use...
- p.plot(): to quickly look at everything
- p[i]: to view a single param
- p[a:b]: to get a new HookParam with selected params
- p.css("..."): to select a specific subset of modules only
{super()._reprCan}""" # HookParam
def plotF(params:Union[HookParam, Param, List[Param]], rangeSlice:slice): # plotF
if type(params) == Param: params = [params] # plotF
fields = params[0].data.state.keys(); step = rangeSlice.step or 1 # plotF
fig, axes = plt.subplots(2, 2, figsize=(10, 6), dpi=100) # plotF
axes = axes.flatten() # plotF
for field, ax in zip(fields, axes): # plotF
for param in params: # plotF
fieldData = param.data[field] # plotF
r = k1lib.Range(len(fieldData))[rangeSlice] # plotF
ax.plot(r.range_[::step], fieldData[r.slice_][::step]) # plotF
ax.set_title(field.capitalize()) # plotF
plt.figlegend([p.name for p in params], loc='right') # plotF
@k1lib.patch(HookParam) # plotF
@k1lib.patch(Param) # plotF
def plot(self): return k1lib.viz.SliceablePlot(partial(plotF, self)) # plot