Source code for k1lib.callbacks.hookParam

# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib
import matplotlib.pyplot as plt
from functools import partial
from typing import List, Tuple, Callable, Union
try: import torch; import torch.nn as nn; hasTorch = True
except:
    torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
    nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["HookParam"]
class ParamData(k1lib.Object):                                                   # ParamData
    def __init__(self):                                                          # ParamData
        super().__init__()                                                       # ParamData
        self.means = []; self.stds = []                                          # ParamData
        self.mins = []; self.maxs = []                                           # ParamData
    def update(self, torchParam:nn.Parameter):                                   # ParamData
        self.means.append(torchParam.mean().item())                              # ParamData
        self.stds.append(torchParam.std().item())                                # ParamData
        self.mins.append(torchParam.min().item())                                # ParamData
        self.maxs.append(torchParam.max().item())                                # ParamData
    def __len__(self): return len(self.means)                                    # ParamData
    def __repr__(self):                                                          # ParamData
        return f"""Param's saved data. Use...
- d.means: to get list of means
- d.stds: to get list of means
- d.mins: to get list of mins
- d.maxs: to get list of maxs"""                                                 # ParamData
class Param:                                                                     # Param
    def __init__(self, name:str, torchParam:nn.Parameter):                       # Param
        self.name = name                                                         # Param
        self.torchParam = torchParam                                             # Param
        self.data = ParamData()                                                  # Param
        self.every = k1lib.Every(3)                                              # Param
    def update(self):                                                            # Param
        if self.every(): self.data.update(self.torchParam.detach())              # Param
    def __repr__(self):                                                          # Param
        return f"""Param `{self.name}`. Use...
- p.torchParam: to get actual underlying parameter
- p.data: to get data stored
- cb.plot(): to quickly look at everything"""                                    # Param
[docs]@k1lib.patch(Cbs) # Param class HookParam(Callback): # HookParam """Records means and stds of all parameters""" # HookParam def __init__(self): # HookParam "" # HookParam super().__init__(); self.params:List[Param] = [] # HookParam def __getitem__(self, idx:Union[int, slice]): # HookParam if type(idx) == int: return self.params[idx] # HookParam answer = HookParam(); answer.params = self.params[idx]; return answer # HookParam def __len__(self): return len(self.params) # HookParam def _selected(self, paramName:str): # HookParam splits = paramName.split(".") # HookParam try: # HookParam mS = self.l.selector # HookParam for split in splits[:-1]: mS = mS[split] # HookParam return "HookParam" in mS and hasattr(mS, splits[-1]) # HookParam except KeyError: return False # HookParam def startRun(self): # HookParam if len(self) == 0: # set things up first time only # HookParam self.params = [Param(k, v) for k, v in self.l.model.named_parameters() if self._selected(k)] # HookParam def startBatch(self): [param.update() for param in self.params] # HookParam
[docs] def css(self, css:str): # HookParam """Creates a new HookParam object with selected modules. May be useful for displaying a subset of the recorded data""" # HookParam oldSelector = self.l.selector; answer = HookParam() # HookParam self.l.selector = k1lib.selector.select(self.l.model, css) # HookParam answer.params = [param for param in self.params if self._selected(param.name)] # HookParam self.l.selector = oldSelector; return answer # HookParam
def __repr__(self): # HookParam s = f", {len(self[0].data)} means and stds each" if len(self) > 0 else "" # HookParam names = "\n".join([f" {i}. {p.name}" for i, p in enumerate(self)]) # HookParam return f"""{super()._reprHead}: {len(self)} params{s}:\n{names}\n Use... - p.plot(): to quickly look at everything - p[i]: to view a single param - p[a:b]: to get a new HookParam with selected params - p.css("..."): to select a specific subset of modules only {super()._reprCan}""" # HookParam
def plotF(params:Union[HookParam, Param, List[Param]], rangeSlice:slice): # plotF if type(params) == Param: params = [params] # plotF fields = params[0].data.state.keys(); step = rangeSlice.step or 1 # plotF fig, axes = plt.subplots(2, 2, figsize=(10, 6), dpi=100) # plotF axes = axes.flatten() # plotF for field, ax in zip(fields, axes): # plotF for param in params: # plotF fieldData = param.data[field] # plotF r = k1lib.Range(len(fieldData))[rangeSlice] # plotF ax.plot(r.range_[::step], fieldData[r.slice_][::step]) # plotF ax.set_title(field.capitalize()) # plotF plt.figlegend([p.name for p in params], loc='right') # plotF @k1lib.patch(HookParam) # plotF @k1lib.patch(Param) # plotF def plot(self): return k1lib.viz.SliceablePlot(partial(plotF, self)) # plot