# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""For not very complicated accuracies functions"""
from ..callbacks import Callback, Callbacks, Cbs
from typing import Callable, Tuple
import k1lib
try: import torch; hasTorch = True
except: torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["AccF"]
AccFSig = Callable[[Tuple[torch.Tensor, torch.Tensor]], float]
PredFSig = Callable[[torch.Tensor], torch.Tensor]
[docs]@k1lib.patch(Cbs)
class AccF(Callback):                                                            # AccF
    " "                                                                          # AccF
[docs]    def __init__(self, predF:PredFSig=None, accF:AccFSig=None, integrations:bool=True, variable:str="accuracy", hookToLearner:bool=True): # AccF
        """Generic accuracy function.
Built in default accuracies functions are fine, if you don't do something too
dramatic/different. Expected variables in :class:`~k1lib.Learner`:
- y:  :class:`~torch.Tensor` of shape (\*N, C)
- yb: :class:`~torch.Tensor` of shape (\*N,)
Deposits variables into :class:`~k1lib.Learner`:
- preds:      detached, batched tensor output of ``predF``
- accuracies: detached, batched tensor output of ``accF``
- accuracy:   detached, single float, mean of ``accuracies``
Where:
- N is the batch size. Can be multidimensional, but has to agree between ``y`` and ``yb``
- C is the number of categories
:param predF: takes in ``y``, returns predictions (tensor with int elements indicating the categories)
:param accF: takes in ``(predictions, yb)``, returns accuracies (tensor with 0 or 1 elements)
:param integrations: whether to integrate :class:`~k1lib.callbacks.confusionMatrix.ConfusionMatrix` or not.
:param variable: variable to deposit into Learner"""                             # AccF
        super().__init__(); self.order = 10; self.integrations = integrations; self.ownsConMat = False; self.hookToLearner = hookToLearner # AccF
        self.predF = predF or (lambda y: y.argmax(-1))                           # AccF
        self.accF = accF or (lambda p, yb: (p == yb)+0); self.variable = variable # AccF 
[docs]    def attached(self):                                                          # AccF
        if self.integrations:                                                    # AccF
            if "ConfusionMatrix" not in self.cbs:                                # AccF
                self.conMatCb = Cbs.ConfusionMatrix()                            # AccF
                self.cbs.add(self.conMatCb); self.ownsConMat = True              # AccF
            else: self.conMatCb = self.cbs.ConfusionMatrix                       # AccF 
[docs]    def endLoss(self):                                                           # AccF
        preds = self.predF(self.l.y); accs = self.accF(preds, self.l.yb);        # AccF
        if self.hookToLearner:                                                   # AccF
            self.l.preds = preds.detach()                                        # AccF
            self.l.accuracies = accs.detach()                                    # AccF
        self.l.__dict__[self.variable] = accs.float().mean().item()              # AccF 
[docs]    def detach(self):                                                            # AccF
        super().detach()                                                         # AccF
        if self.conMatCb != None:                                                # AccF
            if self.ownsConMat: self.conMatCb.detach()                           # AccF
            self.conMatCb = None                                                 # AccF