Source code for k1lib.callbacks.lossFunctions.accuracy

# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""For not very complicated accuracies functions"""
from ..callbacks import Callback, Callbacks, Cbs
from typing import Callable, Tuple
import torch, k1lib
__all__ = ["AccF"]
AccFSig = Callable[[Tuple[torch.Tensor, torch.Tensor]], float]
PredFSig = Callable[[torch.Tensor], torch.Tensor]
[docs]@k1lib.patch(Cbs) class AccF(Callback): " "
[docs] def __init__(self, predF:PredFSig=None, accF:AccFSig=None, integrations:bool=True): """Generic accuracy function. Built in default accuracies functions are fine, if you don't do something too dramatic/different. Expected variables in :class:`~k1lib.Learner`: - y: :class:`~torch.Tensor` of shape (\*N, C) - yb: :class:`~torch.Tensor` of shape (\*N,) Deposits variables into :class:`~k1lib.Learner`: - preds: detached, batched tensor output of ``predF`` - accuracies: detached, batched tensor output of ``accF`` - accuracy: detached, single float, mean of ``accuracies`` Where: - N is the batch size. Can be multidimensional, but has to agree between ``y`` and ``yb`` - C is the number of categories :param predF: takes in ``y``, returns predictions (tensor with int elements indicating the categories) :param accF: takes in ``(predictions, yb)``, returns accuracies (tensor with 0 or 1 elements) :param integrations: whether to integrate :class:`~k1lib.callbacks.confusionMatrix.ConfusionMatrix` or not.""" super().__init__(); self.order = 10; self.integrations = integrations; self.ownsConMat = False self.predF = predF or (lambda y: y.argmax(-1)) self.accF = accF or (lambda p, yb: (p == yb)+0)
[docs] def attached(self): if self.integrations: if "ConfusionMatrix" not in self.cbs: self.conMatCb = Cbs.ConfusionMatrix() self.cbs.add(self.conMatCb); self.ownsConMat = True else: self.conMatCb = self.cbs.ConfusionMatrix
def endLoss(self): preds = self.predF(self.l.y); self.l.preds = preds.detach() accs = self.accF(preds, self.l.yb); self.l.accuracies = accs.detach() self.l.accuracy = accs.float().mean().item()
[docs] def detach(self): super().detach() if self.conMatCb != None: if self.ownsConMat: self.conMatCb.detach() self.conMatCb = None