# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""For not very complicated loss functions"""
from k1lib.callbacks import Callback, Callbacks
from typing import Callable, Tuple
import torch, k1lib
__all__ = ["LossLambda", "LossNLLCross"]
LossFSig = Callable[[Tuple[torch.Tensor, torch.Tensor]], float]
[docs]@k1lib.patch(Callback.cls)
class LossLambda(Callback):
    " "
[docs]    def __init__(self, lossF:LossFSig):
        """Creates a generic loss function that takes in ``y`` and
correct y ``yb`` and return a single loss float (still attached to graph)."""
        super().__init__()
        self.lossF = lossF 
    def inLoss(self):
        self.l.lossG = self.lossF(self.l.y, self.l.yb)
        self.l.loss = self.l.lossG.detach().item() 
@k1lib.patch(Callbacks, docs=LossLambda.__init__)
def withLossLambda(self, lossF:LossFSig, name:str=None):
    return self.append(LossLambda(lossF), name=name)
def accF(l):
    a = (l.y.argmax(dim=1) == l.yb)
    return a.sum() / a.numel()
def accCb():
    return k1lib.callbacks.Accuracy(accF)
[docs]@k1lib.patch(Callback.cls)
class LossNLLCross(Callback):
    " "
[docs]    def __init__(self, nll:bool, integrations:bool):
        """
:param nll: if True, then use :class:`torch.nn.NLLLoss`, else use :class:`torch.nn.CrossEntropyLoss`
:param integrations: whether to integrate with
    :class:`~k1lib.callbacks.loss_accuracy.Accuracy` callback"""
        super().__init__(); self.integrations = integrations; self.accuracyCb = None
        self.lossF = torch.nn.NLLLoss() if nll else torch.nn.CrossEntropyLoss() 
    def appended(self):
        if self.integrations:
            self.accuracyCb = accCb()
            self.cbs.append(self.accuracyCb)
    def inLoss(self):
        self.l.lossG = self.lossF(self.l.y, self.l.yb)
        self.l.loss = self.l.lossG.detach().item()
[docs]    def detach(self):
        if self.accuracyCb != None:
            self.accuracyCb.detach(); self.accuracyCb = None  
@k1lib.patch(Callbacks, docs=LossNLLCross.__init__)
def withLossNLL(self, integrations:bool=True, name:str=None):
    return self.append(LossNLLCross(True, integrations), name=name)
@k1lib.patch(Callbacks, docs=LossNLLCross.__init__)
def withLossCrossEntropy(self, integrations:bool=True, name:str=None):
    return self.append(LossNLLCross(False, integrations), name=name)