# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""For not very complicated accuracies functions"""
from ..callbacks import Callback, Callbacks, Cbs
from typing import Callable, Tuple
import k1lib
try: import torch; hasTorch = True
except: torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["AccF"]
AccFSig = Callable[[Tuple[torch.Tensor, torch.Tensor]], float]
PredFSig = Callable[[torch.Tensor], torch.Tensor]
[docs]@k1lib.patch(Cbs)
class AccF(Callback):
" "
[docs] def __init__(self, predF:PredFSig=None, accF:AccFSig=None, integrations:bool=True, variable:str="accuracy", hookToLearner:bool=True):
"""Generic accuracy function.
Built in default accuracies functions are fine, if you don't do something too
dramatic/different. Expected variables in :class:`~k1lib.Learner`:
- y: :class:`~torch.Tensor` of shape (\*N, C)
- yb: :class:`~torch.Tensor` of shape (\*N,)
Deposits variables into :class:`~k1lib.Learner`:
- preds: detached, batched tensor output of ``predF``
- accuracies: detached, batched tensor output of ``accF``
- accuracy: detached, single float, mean of ``accuracies``
Where:
- N is the batch size. Can be multidimensional, but has to agree between ``y`` and ``yb``
- C is the number of categories
:param predF: takes in ``y``, returns predictions (tensor with int elements indicating the categories)
:param accF: takes in ``(predictions, yb)``, returns accuracies (tensor with 0 or 1 elements)
:param integrations: whether to integrate :class:`~k1lib.callbacks.confusionMatrix.ConfusionMatrix` or not.
:param variable: variable to deposit into Learner"""
super().__init__(); self.order = 10; self.integrations = integrations; self.ownsConMat = False; self.hookToLearner = hookToLearner
self.predF = predF or (lambda y: y.argmax(-1))
self.accF = accF or (lambda p, yb: (p == yb)+0); self.variable = variable
[docs] def attached(self):
if self.integrations:
if "ConfusionMatrix" not in self.cbs:
self.conMatCb = Cbs.ConfusionMatrix()
self.cbs.add(self.conMatCb); self.ownsConMat = True
else: self.conMatCb = self.cbs.ConfusionMatrix
[docs] def endLoss(self):
preds = self.predF(self.l.y); accs = self.accF(preds, self.l.yb);
if self.hookToLearner:
self.l.preds = preds.detach()
self.l.accuracies = accs.detach()
self.l.__dict__[self.variable] = accs.float().mean().item()
[docs] def detach(self):
super().detach()
if self.conMatCb != None:
if self.ownsConMat: self.conMatCb.detach()
self.conMatCb = None