# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""For not very complicated loss functions"""
from ..callbacks import Callback, Callbacks, Cbs
from typing import Callable, Tuple
import k1lib, math
try: import torch; import torch.nn.functional as F; hasTorch = True
except: torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["LossF", "LossNLLCross"]
LossFSig = Callable[[Tuple[torch.Tensor, torch.Tensor]], float]
[docs]@k1lib.patch(Cbs)
@k1lib.patch(Callback.lossCls)
class LossF(Callback):
" "
[docs] def __init__(self, lossF:LossFSig):
"""Generic loss function.
Expected variables in :class:`~k1lib.Learner`:
- y: result of model. Auto-included in :class:`~k1lib.callbacks.core.CoreNormal`
and :class:`~k1lib.callbacks.core.CoreRNN`.
Deposits variables into :class:`~k1lib.Learner` at checkpoint ``inLoss``:
- lossG: single float tensor value, attached to graph
- loss: lossG, but single float value
:param lossF: takes in ``(y, yb)`` and returns ``lossG``"""
super().__init__()
self.lossF = lossF
[docs] def inLoss(self):
self.l.lossG = self.lossF(self.l.y, self.l.yb)
self.l.loss = self.l.lossG.detach().item()
[docs]class LossNLLCross(Callback):
" "
[docs] def __init__(self, nll:bool, integrations:bool):
"""Adds a cross-entropy/negative-likelihood loss function.
:param nll: if True, then use :class:`torch.nn.NLLLoss`, else use :class:`torch.nn.CrossEntropyLoss`
:param integrations: whether to integrate with :class:`~k1lib.callbacks.accuracy.AccF` callback"""
super().__init__(); self.integrations = integrations; self.ownsAccCb = False
self.order = 11 # to make sure it's after AccF
self.lossF = torch.nn.NLLLoss() if nll else torch.nn.CrossEntropyLoss()
self.accuracyCbTop5 = None
[docs] def attached(self): # delayed initialization, so that learner and cbs has already been attached
if self.integrations:
if "AccF" not in self.cbs:
self.ownsAccCb = True
self.accuracyCb = Cbs.AccF(); self.cbs.add(self.accuracyCb)
self.accuracyCbTop5 = Cbs.AccF(lambda y: y, lambda y, yb: (yb[:,None] == y.topk(5, dim=1).indices+0).sum(1), variable="accTop5", hookToLearner=False); self.cbs.add(self.accuracyCbTop5)
self.cbs.add(Cbs.Accuracy("accTop5"), name="AccuracyTop5")
else: self.accuracyCb = self.cbs.AccF
[docs] def inLoss(self):
self.l.lossG = self.lossF(self.l.y, self.l.yb)
self.l.loss = self.l.lossG.detach().item()
[docs] def detach(self):
super().detach()
if self.accuracyCb != None:
if self.ownsAccCb: self.accuracyCb.detach()
self.accuracyCb = None
if self.accuracyCbTop5 != None:
if self.ownsAccCb: self.accuracyCbTop5.detach()
self.accuracyCbTop5 = None
@k1lib.patch(Cbs)
@k1lib.patch(Callback.lossCls)
class LossCrossEntropy(LossNLLCross):
def __init__(self, integrations:bool=True):
"""Cross entropy loss function. Deposits into :class:`~k1lib.Learner`
the same variables as in :class:`LossF`."""
super().__init__(False, integrations)
@k1lib.patch(Cbs)
@k1lib.patch(Callback.lossCls)
class LossNLL(LossNLLCross):
def __init__(self, integrations:bool=True):
"""Negative log loss function. Deposits into :class:`~k1lib.Learner`
the same variables as in :class:`LossF`."""
super().__init__(True, integrations)