# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""For not very complicated loss functions"""
from ..callbacks import Callback, Callbacks, Cbs
from typing import Callable, Tuple
import torch, k1lib
__all__ = ["AccF"]
AccFSig = Callable[[Tuple[torch.Tensor, torch.Tensor]], float]
[docs]@k1lib.patch(Cbs)
class AccF(Callback):
" "
[docs] def __init__(self, accF:AccFSig=None):
"""Creates a generic accuracy function that takes in ``y`` and
correct y ``yb`` and return tensor size (n,) containing 1 for accurate, 0 for not."""
super().__init__()
self.accF = accF or (lambda y, yb: (y.argmax(1) == yb)+0.0)
def endLoss(self):
self.l.accuracies = self.accF(self.l.y, self.l.yb)
self.l.accuracy = self.l.accuracies.mean().item()
@k1lib.patch(Callbacks, docs=AccF.__init__)
def withAccF(self, accF:AccFSig=None, name:str=None):
return self.append(AccF(accF), name=name)