# AUTOGENERATED FILE! PLEASE DON'T EDIT
from k1lib.callbacks import Callback, Callbacks
import k1lib, numpy as np
import matplotlib.pyplot as plt
from functools import partial
__all__ = ["ParamFinder"]
[docs]@k1lib.patch(Callback.cls)
class ParamFinder(Callback):
"""Automatically finds out the right value for
a specific parameter"""
def __init__(self):
super().__init__(); self.order = 23
self.suspended = True; self.losses = []
@property
def samples(self): return self._samples
@samples.setter
def samples(self, samples):
self._samples = samples
self.potentialValues = 10**np.linspace(-6, 2, samples)
@property
def value(self):
if self.idx >= len(self.potentialValues): raise k1lib.CancelRunException("Checked all possible param values")
return self.potentialValues[self.idx]
@property
def lossAvgs(self): return sum(self.losses[-2:])/2
def startBatch(self):
self.idx += 1
for paramGroup in self.l.opt.param_groups:
paramGroup[self.param] = self.value
@property
def suggestedValue(self):
"""The suggested param value. Has to :meth:`run` first, before
this value exists"""
return self.best/2
def endLoss(self):
self.losses.append(self.l.loss)
lossAvgs = self.lossAvgs
if lossAvgs < self.bestLoss:
self.best = self.value
self.bestLoss = lossAvgs
if lossAvgs > self.bestLoss * 10: raise k1lib.CancelRunException("Loss increases significantly")
def __repr__(self):
return f"""{self._reprHead}, use...
- pf.run(): to start scanning for good params and automatically plots
- pf.plot(): to plot
- pf.samples = ...: to set how many param values to iterate through
{self._reprCan}"""
@k1lib.patch(ParamFinder)
def run(self, param:str="lr", samples:int=300) -> float:
"""Finds the optimin param value.
:param samples: how many samples to test between :math:`10^{-6}` to :math:`10^2`
:return: the suggested param value"""
self.param = param; self.samples = samples
self.idx = 0; self.losses = []; self.best = None; self.bestLoss = float("inf")
with self.cbs.suspendEval(less=["ProgressBar"]):
self.suspended = False; ogParams = self.l.model.exportParams()
self.l.run(int(1e3))
self.suspended = True; self.l.model.importParams(ogParams)
return self.suggestedValue
def plotF(self, _slice):
r = k1lib.Range(len(self.losses)).fromUnit(_slice)
plt.plot(self.potentialValues[r.slice_], self.losses[r.slice_])
plt.xscale("log"); plt.xlabel(self.param); plt.ylabel("Loss")
@k1lib.patch(ParamFinder)
def plot(self, *args, **kwargs):
"""Plots loss at different param scales. Automatically :meth:`run`
if hasn't, returns a :class:`k1lib.viz.SliceablePlot`.
:param args: Arguments to pass through to :meth:`run` if a run is
required. Just for convenience sake"""
if len(self.losses) == 0: return self.run(*args, **kwargs)
else:
print(f"Suggested param: {self.suggestedValue}"); plt.figure(dpi=120)
return k1lib.viz.SliceablePlot(partial(plotF, self), docs="\n\nReminder: slice range here is actually [0, 1], because it's kinda hard to slice the normal way")
@k1lib.patch(Callbacks, docs=ParamFinder)
def withParamFinder(self): return self.append(ParamFinder())