# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib; plt = k1lib.dep("matplotlib.pyplot")
from functools import partial
from typing import List, Tuple, Callable, Union
try: import torch; import torch.nn as nn; hasTorch = True
except:
    torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
    nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["HookParam"]
class ParamData(k1lib.Object):                                                   # ParamData
    def __init__(self):                                                          # ParamData
        super().__init__()                                                       # ParamData
        self.means = []; self.stds = []                                          # ParamData
        self.mins = []; self.maxs = []                                           # ParamData
    def update(self, torchParam:nn.Parameter):                                   # ParamData
        self.means.append(torchParam.mean().item())                              # ParamData
        self.stds.append(torchParam.std().item())                                # ParamData
        self.mins.append(torchParam.min().item())                                # ParamData
        self.maxs.append(torchParam.max().item())                                # ParamData
    def __len__(self): return len(self.means)                                    # ParamData
    def __repr__(self):                                                          # ParamData
        return f"""Param's saved data. Use...
- d.means: to get list of means
- d.stds: to get list of means
- d.mins: to get list of mins
- d.maxs: to get list of maxs"""                                                 # ParamData
class Param:                                                                     # Param
    def __init__(self, name:str, torchParam:nn.Parameter):                       # Param
        self.name = name                                                         # Param
        self.torchParam = torchParam                                             # Param
        self.data = ParamData()                                                  # Param
        self.every = k1lib.Every(3)                                              # Param
    def update(self):                                                            # Param
        if self.every(): self.data.update(self.torchParam.detach())              # Param
    def __repr__(self):                                                          # Param
        return f"""Param `{self.name}`. Use...
- p.torchParam: to get actual underlying parameter
- p.data: to get data stored
- cb.plot(): to quickly look at everything"""                                    # Param
[docs]@k1lib.patch(Cbs)                                                                # Param
class HookParam(Callback):                                                       # HookParam
    """Records means and stds of all parameters"""                               # HookParam
    def __init__(self):                                                          # HookParam
        ""                                                                       # HookParam
        super().__init__(); self.params:List[Param] = []                         # HookParam
    def __getitem__(self, idx:Union[int, slice]):                                # HookParam
        if type(idx) == int: return self.params[idx]                             # HookParam
        answer = HookParam(); answer.params = self.params[idx]; return answer    # HookParam
    def __len__(self): return len(self.params)                                   # HookParam
    def _selected(self, paramName:str):                                          # HookParam
        splits = paramName.split(".")                                            # HookParam
        try:                                                                     # HookParam
            mS = self.l.selector                                                 # HookParam
            for split in splits[:-1]: mS = mS[split]                             # HookParam
            return "HookParam" in mS and hasattr(mS, splits[-1])                 # HookParam
        except KeyError: return False                                            # HookParam
    def startRun(self):                                                          # HookParam
        if len(self) == 0: # set things up first time only                       # HookParam
            self.params = [Param(k, v) for k, v in self.l.model.named_parameters() if self._selected(k)] # HookParam
    def startBatch(self): [param.update() for param in self.params]              # HookParam
[docs]    def css(self, css:str):                                                      # HookParam
        """Creates a new HookParam object with selected modules. May be useful
for displaying a subset of the recorded data"""                                  # HookParam
        oldSelector = self.l.selector; answer = HookParam()                      # HookParam
        self.l.selector = k1lib.selector.select(self.l.model, css)               # HookParam
        answer.params = [param for param in self.params if self._selected(param.name)] # HookParam
        self.l.selector = oldSelector; return answer                             # HookParam 
    def __repr__(self):                                                          # HookParam
        s = f", {len(self[0].data)} means and stds each" if len(self) > 0 else "" # HookParam
        names = "\n".join([f"  {i}. {p.name}" for i, p in enumerate(self)])      # HookParam
        return f"""{super()._reprHead}: {len(self)} params{s}:\n{names}\n
Use...
- p.plot(): to quickly look at everything
- p[i]: to view a single param
- p[a:b]: to get a new HookParam with selected params
- p.css("..."): to select a specific subset of modules only
{super()._reprCan}"""                                                            # HookParam 
def plotF(params:Union[HookParam, Param, List[Param]], rangeSlice:slice):        # plotF
    if type(params) == Param: params = [params]                                  # plotF
    fields = params[0].data.state.keys(); step = rangeSlice.step or 1            # plotF
    fig, axes = plt.subplots(2, 2, figsize=(10, 6), dpi=100)                     # plotF
    axes = axes.flatten()                                                        # plotF
    for field, ax in zip(fields, axes):                                          # plotF
        for param in params:                                                     # plotF
            fieldData = param.data[field]                                        # plotF
            r = k1lib.Range(len(fieldData))[rangeSlice]                          # plotF
            ax.plot(r.range_[::step], fieldData[r.slice_][::step])               # plotF
        ax.set_title(field.capitalize())                                         # plotF
    plt.figlegend([p.name for p in params], loc='right')                         # plotF
@k1lib.patch(HookParam)                                                          # plotF
@k1lib.patch(Param)                                                              # plotF
def plot(self): return k1lib.viz.SliceablePlot(partial(plotF, self))             # plot