Source code for k1lib.callbacks.paramFinder

# AUTOGENERATED FILE! PLEASE DON'T EDIT
from .callbacks import Callback, Callbacks, Cbs
import k1lib, numpy as np
import matplotlib.pyplot as plt
from functools import partial
__all__ = ["ParamFinder"]
[docs]@k1lib.patch(Cbs) class ParamFinder(Callback): " "
[docs] def __init__(self, tolerance:float=10): """Automatically finds out the right value for a specific parameter. :param tolerance: how much higher should the loss be to be considered a failure?""" super().__init__(); self.order = 23 self.suspended = True; self.losses = []; self.tolerance = tolerance
@property def samples(self): return self._samples @samples.setter def samples(self, samples): self._samples = samples self.potentialValues = 10**np.linspace(-6, 2, samples) @property def value(self): if self.idx >= len(self.potentialValues): raise k1lib.CancelRunException("Checked all possible param values") return self.potentialValues[self.idx] @property def lossAvgs(self): return sum(self.losses[-2:])/2 def startBatch(self): self.idx += 1 for paramGroup in self.l.opt.param_groups: paramGroup[self.param] = self.value @property def suggestedValue(self): """The suggested param value. Has to :meth:`run` first, before this value exists""" return self.best/2 def endLoss(self): self.losses.append(self.l.loss) lossAvgs = self.lossAvgs if lossAvgs < self.bestLoss: self.best = self.value self.bestLoss = lossAvgs if lossAvgs > self.bestLoss * self.tolerance: raise k1lib.CancelRunException("Loss increases significantly") def __repr__(self): return f"""{self._reprHead}, use... - pf.run(): to start scanning for good params and automatically plots - pf.plot(): to plot - pf.samples = ...: to set how many param values to iterate through {self._reprCan}"""
@k1lib.patch(ParamFinder) def run(self, param:str="lr", samples:int=300) -> float: """Finds the optimin param value. :param samples: how many samples to test between :math:`10^{-6}` to :math:`10^2` :return: the suggested param value""" self.param = param; self.samples = samples self.idx = 0; self.losses = []; self.best = None; self.bestLoss = float("inf") with self.cbs.suspendEval(less=["ProgressBar"]), self.l.model.paramsContext(): self.suspended = False; self.l.run(int(1e3)); self.suspended = True return self.suggestedValue def plotF(self, _slice): r = k1lib.Range(len(self.losses)).fromUnit(_slice) plt.plot(self.potentialValues[r.slice_], self.losses[r.slice_]) plt.xscale("log"); plt.xlabel(self.param); plt.ylabel("Loss") @k1lib.patch(ParamFinder) def plot(self, *args, **kwargs): """Plots loss at different param scales. Automatically :meth:`run` if hasn't, returns a :class:`k1lib.viz.SliceablePlot`. :param args: Arguments to pass through to :meth:`run` if a run is required. Just for convenience sake""" if len(self.losses) == 0: self.run(*args, **kwargs) print(f"Suggested param: {self.suggestedValue}"); plt.figure(dpi=120) return k1lib.viz.SliceablePlot(partial(plotF, self), docs="\n\nReminder: slice range here is actually [0, 1], because it's kinda hard to slice the normal way") @k1lib.patch(Callbacks, docs=ParamFinder) def withParamFinder(self, tolerance:float=10, name:str=None): return self.append(ParamFinder(tolerance), name)