# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, numpy as np, math
import k1lib.cli as cli
from functools import partial
plt = k1lib.dep("matplotlib.pyplot")
from typing import Callable
__all__ = ["Loss", "Accuracy"]
def plotF(losses, f): # actual function stored by the sliceable plot             # plotF
    plt.figure(figsize=(10, 3), dpi=100); f = f | cli.deref()                    # plotF
    try:                                                                         # plotF
        plt.subplot(1, 2, 1); plt.plot(range(len(losses.train)) | f, losses.train | f); plt.title(f"Train loss") # plotF
        plt.subplot(1, 2, 2); plt.plot(range(len(losses.valid)) | f, losses.valid | f); plt.title(f"Valid loss") # plotF
    except: pass                                                                 # plotF
def commonPlot(obj, f=cli.iden()):                                               # commonPlot
    plotF(obj, f); return                                                        # commonPlot
    return k1lib.viz.SliceablePlot(partial(plotF, obj, f), docs="""\n\nReminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame""") # commonPlot
def nonEmptyList(_list):                                                         # nonEmptyList
    return [0] if _list == [] else _list                                         # nonEmptyList
[docs]@k1lib.patch(Cbs)                                                                # nonEmptyList
class Loss(Callback):                                                            # Loss
    " "                                                                          # Loss
[docs]    def __init__(self, f=lambda l: l.loss):                                      # Loss
        """Records losses after each batch.
Expected variables in :class:`~k1lib.Learner`:
- loss: single float value
:param f: optional function to get the loss from :class:`~k1lib.Learner` object""" # Loss
        super().__init__(); self.order = 20; self.f = f                          # Loss
        self.train = []; self.valid = [] # all stats all times                   # Loss
        # average stats for each epoch                                           # Loss
        self.epoch = k1lib.Object.fromDict({"train": [], "valid": []})\
                        
.withRepr("Use...\n" +\
                                 
"- `.train` for epoch-averaged training losses\n" +\
                                 
"- `.valid` for epoch-averaged validation losses\n" +\
                                 
"- `.plot()` to plot the 2 above")              # Loss
        self.plot = partial(commonPlot, self)                                    # Loss
        self.epoch.plot = partial(commonPlot, self.epoch)                        # Loss
        self._trainLosses = []; self._validLosses = []                           # Loss
        self._landscape = k1lib.callbacks.Landscape(lambda l: l.loss, "_LossLandscape") # Loss 
[docs]    def endLoss(self):                                                           # Loss
        loss = self.f(self.l)                                                    # Loss
        if self.l.model.training: self._trainLosses.append(loss)                 # Loss
        else: self._validLosses.append(loss)                                     # Loss 
[docs]    def endEpoch(self):                                                          # Loss
        self.train.extend(self._trainLosses); self.epoch.train.append(np.mean(nonEmptyList(self._trainLosses))) # Loss
        self.valid.extend(self._validLosses); self.epoch.valid.append(np.mean(nonEmptyList(self._validLosses))) # Loss
        self._trainLosses = []; self._validLosses = []                           # Loss 
    @property                                                                    # Loss
    def Landscape(self):                                                         # Loss
        """Gets loss-landscape-plotting Callback.
Example::
    l = k1lib.Learner.sample()
    l.cbs.add(Cbs.Loss())
    l.Loss.Landscape.plot()"""                                                   # Loss
        self.cbs.add(self._landscape); return self._landscape                    # Loss
[docs]    def detach(self): self._landscape.detach(); return super().detach()          # Loss 
[docs]    def clear(self):                                                             # Loss
        """Clears saved data"""                                                  # Loss
        self.train = []; self.epoch.train = []                                   # Loss
        self.valid = []; self.epoch.valid = []                                   # Loss 
    def __repr__(self):                                                          # Loss
        return f"""{super()._reprHead}, use...
- cb.train: for all training losses over all epochs and batches (#epochs * #batches)
- cb.valid: for all validation losses over all epochs and batches (#epochs * #batches)
- cb.plot(): to plot the 2 above
- cb.clear(): to clear saved data
- cb.epoch: for average losses of each epochs
- cb.Landscape: for loss-landscape-plotting Callback
{super()._reprCan}"""                                                            # Loss 
accFMsg = "You have to specify how to compute the accuracy with the AccF callback first" # Loss
[docs]@k1lib.patch(Cbs)                                                                # Loss
class Accuracy(Callback):                                                        # Accuracy
    " "                                                                          # Accuracy
[docs]    def __init__(self, variable:str="accuracy"):                                 # Accuracy
        """Records accuracies after each batch.
Expected variables in :class:`~k1lib.Learner`:
- accuracy: single float value from 0 to 1
:param variable: name of variable expected to be available in Learner"""         # Accuracy
        super().__init__(); self.order = 20                                      # Accuracy
        self.train = [0]; self.valid = [0]; self.paused = True; self.variable = variable # Accuracy
        self._landscape = k1lib.callbacks.Landscape(lambda l: l.__dict__[variable], "_AccuracyLandscape") # Accuracy 
    @property                                                                    # Accuracy
    def hasAccF(self):                                                           # Accuracy
        return any(isinstance(cb, Cbs.AccF) for cb in self.l.cbs.cbs)            # Accuracy
[docs]    def startRun(self):                                                          # Accuracy
        self.paused = not self.hasAccF                                           # Accuracy
        if not self.paused:                                                      # Accuracy
            self.train = list(self.train); self.valid = list(self.valid)         # Accuracy 
[docs]    def endRun(self):                                                            # Accuracy
        if not self.paused:                                                      # Accuracy
            self.train = np.array(self.train); self.valid = np.array(self.valid) # Accuracy 
[docs]    def endLoss(self):                                                           # Accuracy
        if not self.paused:                                                      # Accuracy
            (self.train if self.l.model.training else self.valid).append(self.l.__dict__[self.variable]) # Accuracy 
[docs]    def plot(self, f=cli.iden()):                                                # Accuracy
        """
:param f:Optional post-processing cli"""                                         # Accuracy
        if not self.hasAccF: raise RuntimeError(accFMsg)                         # Accuracy
        plt.figure(figsize=(10, 3), dpi=100); f = f | cli.deref()                # Accuracy
        try:                                                                     # Accuracy
            plt.subplot(1, 2, 1); plt.plot(range(len(self.train)) | f, 100*self.train | f); plt.title(f"Train accuracy") # Accuracy
            plt.subplot(1, 2, 2); plt.plot(range(len(self.valid)) | f, 100*self.valid | f); plt.title(f"Valid accuracy") # Accuracy
        except: pass                                                             # Accuracy 
    @property                                                                    # Accuracy
    def Landscape(self):                                                         # Accuracy
        """Gets accuracy-landscape-plotting Callback.
Example::
    l = k1lib.Learner.sample()
    l.add(Cbs.Accuracy())
    l.Accuracy.Landscape.plot()
This exact example won't work, as the sample :class:`~k1lib.Learner` task is not
categorical, but the general idea still stands"""                                # Accuracy
        if self.hasAccF:                                                         # Accuracy
            self._landscape.parent = self                                        # Accuracy
            self.cbs.add(self._landscape); return self._landscape                # Accuracy
        else: raise RuntimeError(f"{accFMsg}, before you can view the landscape") # Accuracy
[docs]    def clear(self):                                                             # Accuracy
        """Clears saved data."""                                                 # Accuracy
        self.train = [0]; self.valid = [0]                                       # Accuracy 
    def __repr__(self):                                                          # Accuracy
        return f"""{super()._reprHead}{f" (.accuracyF not defined yet)" if not self.hasAccF else ""}, use...
- a.train: for train accuracies over all batches
- a.valid: for train accuracies over all batches
- a.plot(): to plot the 2 above
- a.clear(): to clear saved data
- a.Landscape: for loss-landscape-plotting Callback
{super()._reprCan}"""                                                            # Accuracy