# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""For not very complicated loss functions"""
from ..callbacks import Callback, Callbacks, Cbs
from typing import Callable, Tuple
import k1lib, math
try: import torch; import torch.nn.functional as F; hasTorch = True
except: torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["LossF", "LossNLLCross"]
LossFSig = Callable[[Tuple[torch.Tensor, torch.Tensor]], float]
[docs]@k1lib.patch(Cbs)
@k1lib.patch(Callback.lossCls)
class LossF(Callback):                                                           # LossF
    " "                                                                          # LossF
[docs]    def __init__(self, lossF:LossFSig):                                          # LossF
        """Generic loss function.
Expected variables in :class:`~k1lib.Learner`:
- y: result of model. Auto-included in :class:`~k1lib.callbacks.core.CoreNormal`
  and :class:`~k1lib.callbacks.core.CoreRNN`.
Deposits variables into :class:`~k1lib.Learner` at checkpoint ``inLoss``:
- lossG: single float tensor value, attached to graph
- loss: lossG, but single float value
:param lossF: takes in ``(y, yb)`` and returns ``lossG``"""                      # LossF
        super().__init__()                                                       # LossF
        self.lossF = lossF                                                       # LossF 
[docs]    def inLoss(self):                                                            # LossF
        self.l.lossG = self.lossF(self.l.y, self.l.yb)                           # LossF
        self.l.loss = self.l.lossG.detach().item()                               # LossF  
[docs]class LossNLLCross(Callback):                                                    # LossNLLCross
    " "                                                                          # LossNLLCross
[docs]    def __init__(self, nll:bool, integrations:bool):                             # LossNLLCross
        """Adds a cross-entropy/negative-likelihood loss function.
:param nll: if True, then use :class:`torch.nn.NLLLoss`, else use :class:`torch.nn.CrossEntropyLoss`
:param integrations: whether to integrate with :class:`~k1lib.callbacks.accuracy.AccF` callback""" # LossNLLCross
        super().__init__(); self.integrations = integrations; self.ownsAccCb = False # LossNLLCross
        self.order = 11 # to make sure it's after AccF                           # LossNLLCross
        self.lossF = torch.nn.NLLLoss() if nll else torch.nn.CrossEntropyLoss()  # LossNLLCross
        self.accuracyCbTop5 = None                                               # LossNLLCross 
[docs]    def attached(self): # delayed initialization, so that learner and cbs has already been attached # LossNLLCross
        if self.integrations:                                                    # LossNLLCross
            if "AccF" not in self.cbs:                                           # LossNLLCross
                self.ownsAccCb = True                                            # LossNLLCross
                self.accuracyCb = Cbs.AccF(); self.cbs.add(self.accuracyCb)      # LossNLLCross
                self.accuracyCbTop5 = Cbs.AccF(lambda y: y, lambda y, yb: (yb[:,None] == y.topk(5, dim=1).indices+0).sum(1), variable="accTop5", hookToLearner=False); self.cbs.add(self.accuracyCbTop5) # LossNLLCross
                self.cbs.add(Cbs.Accuracy("accTop5"), name="AccuracyTop5")       # LossNLLCross
            else: self.accuracyCb = self.cbs.AccF                                # LossNLLCross 
[docs]    def inLoss(self):                                                            # LossNLLCross
        self.l.lossG = self.lossF(self.l.y, self.l.yb)                           # LossNLLCross
        self.l.loss = self.l.lossG.detach().item()                               # LossNLLCross 
[docs]    def detach(self):                                                            # LossNLLCross
        super().detach()                                                         # LossNLLCross
        if self.accuracyCb != None:                                              # LossNLLCross
            if self.ownsAccCb: self.accuracyCb.detach()                          # LossNLLCross
            self.accuracyCb = None                                               # LossNLLCross
        if self.accuracyCbTop5 != None:                                          # LossNLLCross
            if self.ownsAccCb: self.accuracyCbTop5.detach()                      # LossNLLCross
            self.accuracyCbTop5 = None                                           # LossNLLCross  
@k1lib.patch(Cbs)                                                                # LossNLLCross
@k1lib.patch(Callback.lossCls)                                                   # LossNLLCross
class LossCrossEntropy(LossNLLCross):                                            # LossCrossEntropy
    def __init__(self, integrations:bool=True):                                  # LossCrossEntropy
        """Cross entropy loss function. Deposits into :class:`~k1lib.Learner`
the same variables as in :class:`LossF`."""                                      # LossCrossEntropy
        super().__init__(False, integrations)                                    # LossCrossEntropy
@k1lib.patch(Cbs)                                                                # LossCrossEntropy
@k1lib.patch(Callback.lossCls)                                                   # LossCrossEntropy
class LossNLL(LossNLLCross):                                                     # LossNLL
    def __init__(self, integrations:bool=True):                                  # LossNLL
        """Negative log loss function. Deposits into :class:`~k1lib.Learner`
the same variables as in :class:`LossF`."""                                      # LossNLL
        super().__init__(True, integrations)                                     # LossNLL