# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, warnings
from typing import List, Callable
try: import torch; hasTorch = True
except: torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["ConfusionMatrix"]
[docs]@k1lib.patch(Cbs)
class ConfusionMatrix(Callback):                                                 # ConfusionMatrix
    " "                                                                          # ConfusionMatrix
    categories:List[str]                                                         # ConfusionMatrix
    """String categories for displaying the matrix. You can set this
so that it displays what you want, in case this Callback is included
automatically."""                                                                # ConfusionMatrix
    matrix:torch.Tensor                                                          # ConfusionMatrix
    """The recorded confusion matrix."""                                         # ConfusionMatrix
[docs]    def __init__(self, categories:List[str]=None, condF:Callable[["ConfusionMatrix"], bool]=lambda _: True): # ConfusionMatrix
        """Records what categories the network is confused the most. Expected
variables in :class:`~k1lib.Learner`:
- preds: long tensor with categories id of batch before checkpoint ``endLoss``.
  Auto-included in :class:`~k1lib.callbacks.lossFunctions.accuracy.AccF` and
  :class:`~k1lib.callbacks.lossFunctions.shorts.LossNLLCross`.
:param categories: optional list of category names
:param condF: takes in this cb's and returns whether to record at this
    particular `endLoss` checkpoint."""                                          # ConfusionMatrix
        super().__init__(); self.categories = categories                         # ConfusionMatrix
        self.n = len(categories or []) or 2; self.condF = condF                  # ConfusionMatrix
        self.matrix = torch.zeros(self.n, self.n);                               # ConfusionMatrix
        self.wipeOnAdd = False # flag to wipe matrix on adding new data points   # ConfusionMatrix 
    def _adapt(self, idxs):                                                      # ConfusionMatrix
        """Adapts the internal matrix so that it supports new categories"""      # ConfusionMatrix
        m = idxs.max().item() + 1                                                # ConfusionMatrix
        if m > self.n: # +1 because max index = len() - 1                        # ConfusionMatrix
            matrix = torch.zeros(m, m)                                           # ConfusionMatrix
            matrix[:self.n, :self.n] = self.matrix                               # ConfusionMatrix
            self.matrix = matrix; self.n = len(self.matrix)                      # ConfusionMatrix
        self.matrix = self.matrix.to(idxs.device); return idxs                   # ConfusionMatrix
[docs]    def startEpoch(self): self.wipeOnAdd = True                                  # ConfusionMatrix 
[docs]    def endLoss(self):                                                           # ConfusionMatrix
        if self.condF(self):                                                     # ConfusionMatrix
            if self.wipeOnAdd:                                                   # ConfusionMatrix
                self.matrix = torch.zeros(self.n, self.n);                       # ConfusionMatrix
                self.wipeOnAdd = False;                                          # ConfusionMatrix
            yb = self._adapt(self.l.yb); preds = self._adapt(self.l.preds)       # ConfusionMatrix
            self.matrix[yb, preds] += 1                                          # ConfusionMatrix 
    @property                                                                    # ConfusionMatrix
    def goodMatrix(self) -> torch.Tensor:                                        # ConfusionMatrix
        """Clears all inf, nans and whatnot from the matrix, then returns it.""" # ConfusionMatrix
        n = self.n; m = self.matrix                                              # ConfusionMatrix
        while m.hasNan() or m.hasInfs():                                         # ConfusionMatrix
            n -= 1; m = m[:n, :n]                                                # ConfusionMatrix
        if n != self.n: warnings.warn(f"Originally, the confusion matrix has {self.n} categories, now it has {n} only, after filtering, because there are some nans and infinite values.") # ConfusionMatrix
        if self.categories is not None:                                          # ConfusionMatrix
            n = len(self.categories); m = m[:n, :n]                              # ConfusionMatrix
        return m/m.max(dim=1).values[:,None]                                     # ConfusionMatrix
[docs]    def plot(self):                                                              # ConfusionMatrix
        """Plots everything"""                                                   # ConfusionMatrix
        k1lib.viz.confusionMatrix(self.goodMatrix, self.categories or list(range(self.n))) # ConfusionMatrix 
    def __repr__(self):                                                          # ConfusionMatrix
        return f"""{super()._reprHead}, use...
- l.plot(): to plot everything
{super()._reprCan}"""                                                            # ConfusionMatrix