# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, numpy as np
plt = k1lib.dep("matplotlib.pyplot")
from functools import partial
__all__ = ["ParamFinder"]
[docs]@k1lib.patch(Cbs)
class ParamFinder(Callback):                                                     # ParamFinder
    " "                                                                          # ParamFinder
[docs]    def __init__(self, tolerance:float=10):                                      # ParamFinder
        """Automatically finds out the right value for a specific parameter.
:param tolerance: how much higher should the loss be to be considered a failure?""" # ParamFinder
        super().__init__(); self.order = 23                                      # ParamFinder
        self.suspended = True; self.losses = []; self.tolerance = tolerance      # ParamFinder 
    @property                                                                    # ParamFinder
    def samples(self): return self._samples                                      # ParamFinder
    @samples.setter                                                              # ParamFinder
    def samples(self, samples):                                                  # ParamFinder
        self._samples = samples                                                  # ParamFinder
        self.potentialValues = 10**np.linspace(-6, 2, samples)                   # ParamFinder
    @property                                                                    # ParamFinder
    def value(self):                                                             # ParamFinder
        if self.idx >= len(self.potentialValues): raise k1lib.CancelRunException("Checked all possible param values") # ParamFinder
        return self.potentialValues[self.idx]                                    # ParamFinder
    @property                                                                    # ParamFinder
    def lossAvgs(self): return sum(self.losses[-2:])/2                           # ParamFinder
[docs]    def startBatch(self):                                                        # ParamFinder
        self.idx += 1                                                            # ParamFinder
        for paramGroup in self.l.opt.param_groups:                               # ParamFinder
            paramGroup[self.param] = self.value                                  # ParamFinder 
    @property                                                                    # ParamFinder
    def suggestedValue(self):                                                    # ParamFinder
        """The suggested param value. Has to :meth:`run` first, before
this value exists"""                                                             # ParamFinder
        return self.best/2                                                       # ParamFinder
[docs]    def endLoss(self):                                                           # ParamFinder
        self.losses.append(self.l.loss)                                          # ParamFinder
        lossAvgs = self.lossAvgs                                                 # ParamFinder
        if lossAvgs < self.bestLoss:                                             # ParamFinder
            self.best = self.value                                               # ParamFinder
            self.bestLoss = lossAvgs                                             # ParamFinder
        if lossAvgs > self.bestLoss * self.tolerance: raise k1lib.CancelRunException("Loss increases significantly") # ParamFinder 
    def __repr__(self):                                                          # ParamFinder
        return f"""{self._reprHead}, use...
- pf.run(): to start scanning for good params and automatically plots
- pf.plot(): to plot
- pf.samples = ...: to set how many param values to iterate through
{self._reprCan}"""                                                               # ParamFinder 
@k1lib.patch(ParamFinder)                                                        # ParamFinder
def run(self, param:str="lr", samples:int=300) -> float:                         # run
    """Finds the optimin param value.
:param samples: how many samples to test between :math:`10^{-6}` to :math:`10^2`
:return: the suggested param value"""                                            # run
    self.param = param; self.samples = samples                                   # run
    self.idx = 0; self.losses = []; self.best = None; self.bestLoss = float("inf") # run
    with self.cbs.suspendEval(less=["ProgressBar"]), self.l.model.paramsContext(): # run
        self.suspended = False; self.l.run(int(1e3)); self.suspended = True      # run
    return self.suggestedValue                                                   # run
def plotF(self, _slice):                                                         # plotF
    r = k1lib.Range(len(self.losses)).fromUnit(_slice)                           # plotF
    plt.plot(self.potentialValues[r.slice_], self.losses[r.slice_])              # plotF
    plt.xscale("log"); plt.xlabel(self.param); plt.ylabel("Loss")                # plotF
@k1lib.patch(ParamFinder)                                                        # plotF
def plot(self, *args, **kwargs):                                                 # plot
    """Plots loss at different param scales. Automatically :meth:`run`
if hasn't, returns a :class:`k1lib.viz.SliceablePlot`.
:param args: Arguments to pass through to :meth:`run` if a run is
    required. Just for convenience sake"""                                       # plot
    if len(self.losses) == 0: self.run(*args, **kwargs)                          # plot
    print(f"Suggested param: {self.suggestedValue}"); plt.figure(dpi=120)        # plot
    return k1lib.viz.SliceablePlot(partial(plotF, self), docs="\n\nReminder: slice range here is actually [0, 1], because it's kinda hard to slice the normal way") # plot