# AUTOGENERATED FILE! PLEASE DON'T EDIT
from .callbacks import Callback, Callbacks, Cbs
import k1lib, numpy as np, math
from functools import partial
import matplotlib.pyplot as plt
from typing import Callable
__all__ = ["Loss", "Accuracy"]
def plotF(losses, _slice): # actual function stored by the sliceable plot
plt.figure(figsize=(10, 3), dpi=100); step = _slice.step or 1
tR, vR = k1lib.Range.proportionalSlice(len(losses.train), len(losses.valid), _slice)
try:
plt.subplot(1, 2, 1); plt.plot(tR.range_[::step], losses.train[tR.slice_][::step]); plt.title(f"Train loss")
plt.subplot(1, 2, 2); plt.plot(vR.range_[::step], losses.valid[vR.slice_][::step]); plt.title(f"Valid loss")
except: pass
def commonPlot(obj):
return k1lib.viz.SliceablePlot(partial(plotF, obj), docs="""\n\nReminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame""")
def nonEmptyList(_list):
return [0] if _list == [] else _list
[docs]@k1lib.patch(Cbs)
class Loss(Callback):
"Records losses after each batch."
def __init__(self):
super().__init__(); self.order = 20
self.train = []; self.valid = [] # all stats all times
# average stats for each epoch
self.epoch = k1lib.Object.fromDict({"train": [], "valid": []})\
.withRepr("Use...\n" +\
"- `.train` for epoch-averaged training losses\n" +\
"- `.valid` for epoch-averaged validation losses\n" +\
"- `.plot()` to plot the 2 above")
self.plot = partial(commonPlot, self)
self.epoch.plot = partial(commonPlot, self.epoch)
self._trainLosses = []; self._validLosses = []
self._landscape = k1lib.callbacks.Landscape(lambda l: l.loss, "_LossLandscape")
def endLoss(self):
if self.l.model.training: self._trainLosses.append(self.l.loss)
else: self._validLosses.append(self.l.loss)
def endEpoch(self):
self.train.extend(self._trainLosses); self.epoch.train.append(np.mean(nonEmptyList(self._trainLosses)))
self.valid.extend(self._validLosses); self.epoch.valid.append(np.mean(nonEmptyList(self._validLosses)))
self._trainLosses = []; self._validLosses = []
@property
def Landscape(self):
"""Gets loss-landscape-plotting Callback"""
self.cbs.add(self._landscape); return self._landscape
[docs] def detach(self): self._landscape.detach(); return super().detach()
def __repr__(self):
return f"""{super()._reprHead}, use...
- cb.train: for all training losses over all epochs and batches (#epochs * #batches)
- cb.valid: for all validation losses over all epochs and batches (#epochs * #batches)
- cb.plot(): to plot the 2 above
- cb.epoch: for average losses of each epochs
- cb.Landscape: for loss-landscape-plotting Callback
{super()._reprCan}"""
accFMsg = "You have to specify how to compute the accuracy with the AccF callback first"
[docs]@k1lib.patch(Cbs)
class Accuracy(Callback):
""" """
[docs] def __init__(self):
"""Records accuracies after each batch."""
super().__init__(); self.order = 20
self.train = [0]; self.valid = [0]
self._landscape = k1lib.callbacks.Landscape(lambda l: l.accuracy, "_AccuracyLandscape")
@property
def hasAccF(self):
return any(isinstance(cb, Cbs.AccF) for cb in self.l.cbs.cbs)
def startRun(self):
self.paused = not self.hasAccF
if not self.paused:
self.train = list(self.train); self.valid = list(self.valid)
def endRun(self):
if not self.paused:
self.train = np.array(self.train); self.valid = np.array(self.valid)
def endLoss(self):
if not self.paused:
(self.train if self.l.model.training else self.valid).append(self.l.accuracy)
def plot(self):
if not self.hasAccF: raise RuntimeError(accFMsg)
def plotF(_slice):
plt.figure(figsize=(10, 3), dpi=100); step = _slice.step or 1
tR, vR = k1lib.Range.proportionalSlice(len(self.train), len(self.valid), _slice)
try:
plt.subplot(1, 2, 1); plt.plot(tR.range_[::step], 100*self.train[tR.slice_][::step]); plt.title(f"Train accuracy")
plt.subplot(1, 2, 2); plt.plot(vR.range_[::step], 100*self.valid[vR.slice_][::step]); plt.title(f"Valid accuracy")
except: pass
return k1lib.viz.SliceablePlot(plotF)
@property
def Landscape(self):
"""Gets accuracy-landscape-plotting Callback"""
if self.hasAccF:
self._landscape.parent = self
self.cbs.add(self._landscape); return self._landscape
else: raise RuntimeError(f"{accFMsg}, before you can view the landscape")
def __repr__(self):
return f"""{super()._reprHead}{f" (.accuracyF not defined yet)" if not self.hasAccF else ""}, use...
- a.train: for train accuracies over all batches
- a.valid: for train accuracies over all batches
- a.plot(): to plot the 2 above
- a.Landscape: for loss-landscape-plotting Callback
{super()._reprCan}"""