# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""For not very complicated accuracies functions"""
from ..callbacks import Callback, Callbacks, Cbs
from typing import Callable, Tuple
import torch, k1lib
__all__ = ["AccF"]
AccFSig = Callable[[Tuple[torch.Tensor, torch.Tensor]], float]
PredFSig = Callable[[torch.Tensor], torch.Tensor]
[docs]@k1lib.patch(Cbs)
class AccF(Callback):
" "
[docs] def __init__(self, predF:PredFSig=None, accF:AccFSig=None, integrations:bool=True):
"""Creates a generic Callback accuracy function.
Built in default accuracies functions are fine, if you don't do
something too dramatic/different. It expects:
- y: to have shape (\*N, C)
- yb: to have shape (\*N,)
Where:
- N is the batch size. Can be multidimensional, but has to agree between ``y`` and ``yb``
- C is the number of categories
If these are not your system requirements
Deposits these variables into :class:`~k1lib.Learner`:
- preds: detached, batched tensor output of ``predF``
- accuracies: detached, batched tensor output of ``accF``
- accuracy: detached, single float, mean of ``accuracies``
:param predF: takes in ``y``, returns predictions (tensor with int elements indicating the categories)
:param accF: takes in ``(predictions, yb)``, returns accuracies (tensor with 0 or 1 elements)
:param integrations: whether to integrate :class:`~k1lib.callbacks.confusionMatrix.ConfusionMatrix` or not."""
super().__init__(); self.order = 10; self.integrations = integrations; self.ownsConMat = False
self.predF = predF or (lambda y: y.argmax(-1))
self.accF = accF or (lambda p, yb: (p == yb)+0)
def appended(self):
if self.integrations:
if "ConfusionMatrix" not in self.cbs:
self.conMatCb = Cbs.ConfusionMatrix()
self.cbs.append(self.conMatCb); self.ownsConMat = True
else: self.conMatCb = self.cbs.ConfusionMatrix
def endLoss(self):
preds = self.predF(self.l.y); self.l.preds = preds.detach()
accs = self.accF(preds, self.l.yb); self.l.accuracies = accs.detach()
self.l.accuracy = accs.float().mean().item()
[docs] def detach(self):
super().detach()
if self.conMatCb != None:
if self.ownsConMat: self.conMatCb.detach()
self.conMatCb = None
@k1lib.patch(Callbacks, docs=AccF.__init__)
def withAccF(self, accF:AccFSig=None, name:str=None):
return self.append(AccF(accF), name=name)