Source code for k1lib.callbacks.confusionMatrix

# AUTOGENERATED FILE! PLEASE DON'T EDIT
from .callbacks import Callback, Callbacks, Cbs
import k1lib, torch, warnings
from typing import List
__all__ = ["ConfusionMatrix"]
[docs]@k1lib.patch(Cbs) class ConfusionMatrix(Callback): " " categories:List[str] """String categories for displaying the matrix. You can set this so that it displays what you want, in case this Callback is included automatically.""" matrix:torch.Tensor """The recorded confusion matrix."""
[docs] def __init__(self, categories:List[str]=None): """Records what categories the network is confused the most. Expected variables in :class:`~k1lib.Learner`: - preds: long tensor with categories id of batch before checkpoint ``endLoss``. Auto-included in :class:`~k1lib.callbacks.lossFunctions.accuracy.AccF` and :class:`~k1lib.callbacks.lossFunctions.shorts.LossNLLCross`. :param categories: optional list of category names""" super().__init__(); self.categories = categories self.n = len(categories or []) or 2 self.matrix = torch.zeros(self.n, self.n)
def _adapt(self, idxs): """Adapts the internal matrix so that it supports new categories""" m = idxs.max().item() + 1 if m > self.n: # +1 because max index = len() - 1 matrix = torch.zeros(m, m) matrix[:self.n, :self.n] = self.matrix self.matrix = matrix; self.n = len(self.matrix) return idxs def startEpoch(self): self.matrix = torch.zeros(self.n, self.n) def endLoss(self): yb = self._adapt(self.l.yb); preds = self._adapt(self.l.preds) self.matrix[yb, preds] += 1 @property def goodMatrix(self) -> torch.Tensor: """Clears all inf, nans and whatnot from the matrix, then returns it.""" n = self.n; m = self.matrix while m.hasNan() or m.hasInfs(): n -= 1; m = m[:n, :n] if n != self.n: warnings.warn(f"Originally, the confusion matrix has {self.n} categories, now it has {n} only, after filtering, because there are some nans and infinite values.") if self.categories is not None: n = len(self.categories); m = m[:n, :n] return m/m.max(dim=1).values[:,None]
[docs] def plot(self): """Plots everything""" k1lib.viz.confusionMatrix(self.goodMatrix, self.categories or list(range(self.n)))
def __repr__(self): return f"""{super()._reprHead}, use... - l.plot(): to plot everything {super()._reprCan}"""