# AUTOGENERATED FILE! PLEASE DON'T EDIT
from k1lib.callbacks import Callback, Callbacks
import k1lib, time
__all__ = ["BatchLimit", "EpochLimit", "TimeLimit", "CancelOnExplosion",
           "CancelOnLowLoss", "CancelOnHighAccuracy", "DontTrain",
           "GradientClipping", "GradientClippingNorm"]
[docs]@k1lib.patch(Callback.cls)
class BatchLimit(Callback):
    """Cancels the epoch after executed certain number of batches"""
    def __init__(self, limit:int):
        super().__init__(); self.order = 25
        self.limit = limit if limit != None else float("inf")
    def startEpoch(self): self.currentBatch = 0
    def startBatch(self):
        if self.currentBatch >= self.limit:
            raise k1lib.CancelEpochException(f"Batch {self.limit} reached")
    def endBatch(self): self.currentBatch += 1 
@k1lib.patch(Callbacks, docs=BatchLimit)
def withBatchLimit(self, limit:int, name:str=None): return self.append(BatchLimit(limit), name)
[docs]@k1lib.patch(Callback.cls)
class EpochLimit(Callback):
    """Cancels the run after executed certain number of epochs"""
    def __init__(self, limit:int):
        super().__init__(); self.order = 25
        self.limit = limit if limit != None else float("inf")
    def startRun(self): self.currentEpoch = 0
    def startEpoch(self):
        if self.currentEpoch >= self.limit:
            raise k1lib.CancelRunException(f"Epoch {self.limit} reached!")
    def endEpoch(self): self.currentEpoch += 1 
@k1lib.patch(Callbacks, docs=EpochLimit)
def withEpochLimit(self, limit:int, name:str=None): return self.append(EpochLimit(limit), name)
[docs]@k1lib.patch(Callback.cls)
class TimeLimit(Callback):
    """Cancels the run after a certain number of seconds have passed"""
    def __init__(self, seconds=30):
        super().__init__(); self.seconds = seconds; self.order = 25
    def startRun(self): self.startTime = time.time()
    def startBatch(self):
        if time.time() - self.startTime > self.seconds:
            raise k1lib.CancelRunException(f"Takes more than {self.seconds} seconds!") 
@k1lib.patch(Callbacks, docs=TimeLimit)
def withTimeLimit(self, seconds=30, name:str=None):
    return self.append(TimeLimit(seconds), name)
[docs]@k1lib.patch(Callback.cls)
class CancelOnExplosion(Callback):
    """Cancels the run if any of the parameters are larger than a certain limit"""
    def __init__(self, limit:float=1e6):
        super().__init__(); self.order = 25
        self.limit = limit; self.triggered = False
    def startRun(self): self.triggered = False
    def startBatch(self):
        for p in self.l.model.parameters():
            o = p.detach()
            if o.max().float() > self.limit or o.min().float() < -self.limit:
                self.triggered = True
                raise k1lib.CancelRunException("Explosion detected!")
    def __repr__(self):
        return f"""{self._reprHead}, use...
- cb.triggered: to see if there was an explosion on the last run
- cb.progress: to see current progress at explosion time
{self._reprCan}""" 
@k1lib.patch(Callbacks, docs=CancelOnExplosion)
def withCancelOnExplosion(self, limit:float=1e6, name:str=None):
    return self.append(CancelOnExplosion(limit), name)
[docs]@k1lib.patch(Callback.cls)
class CancelOnLowLoss(Callback):
    " "
[docs]    def __init__(self, loss:float, epochMode:bool=False):
        """Cancels the run if loss is lower than amount specified.
Original class: :class:`~k1lib.callbacks.limits.CancelOnLowLoss`
:param epochMode: False if use batch loss, True if use valid epoch loss"""
        super().__init__(); self.order = 25; self.dependsOn = ["Loss"]
        self.loss = loss; self.epochMode = epochMode 
    def startRun(self):
        if not hasattr(self.l.cbs, "Loss"):
            raise AttributeError("Learner does not have required `Loss` callback")
        self.v = self.cbs.Loss.valid; self.ve = self.cbs.Loss.epoch.valid # List[int]
    def endBatch(self):
        if self.epochMode:
            if len(self.ve) > 0 and self.ve[-1] < self.loss:
                raise k1lib.CancelRunException(f"Low loss {self.loss} ({self.ve[-3:]} actual) achieved!")
        elif len(self.v) and self.v[-1] < self.loss:
            raise k1lib.CancelRunException(f"Low loss {self.loss} ({self.v[-3:]} actual) achieved!") 
@k1lib.patch(Callbacks, docs=CancelOnLowLoss.__init__)
def withCancelOnLowLoss(self, loss:float, epochMode:bool=False, name:str=None):
    return self.append(CancelOnLowLoss(loss, epochMode), name)
[docs]@k1lib.patch(Callback.cls)
class CancelOnHighAccuracy(Callback):
    """Cancels the run if accuracy is higher than the amount specified"""
    def __init__(self, accuracy:float):
        super().__init__(); self.order = 25
        self.accuracy = accuracy; self.dependsOn = ["Accuracy"]
    def endBatch(self):
        if not hasattr(self.l, "Accuracy"): raise AttributeError("Learner does not have `Accuracy` callback")
        a = self.Accuracy.valid[-1]
        if a > self.accuracy:
            raise k1lib.CancelRunException(f"High accuracy {self.accuracy} ({a} actual) achieved!") 
@k1lib.patch(Callbacks, docs=CancelOnHighAccuracy)
def withCancelOnHighAccuracy(self, accuracy:float, name:str=None):
    return self.append(CancelOnHighAccuracy(accuracy), name)
[docs]@k1lib.patch(Callback.cls)
class DontTrain(Callback):
    """Don't allow the network to train at all"""
    def startBackward(self): return True
    def startStep(self): return True 
@k1lib.patch(Callbacks, docs=DontTrain)
def withDontTrain(self, name:str=None): return self.append(DontTrain(), name)
from torch.nn.utils import clip_grad_value_
[docs]@k1lib.patch(Callback.cls)
class GradientClipping(Callback):
    """Clips gradient to a specific max value"""
    def __init__(self, value:float): super().__init__(); self.value = value
    def startStep(self):
        clip_grad_value_(self.l.model.parameters(), self.value) 
@k1lib.patch(Callbacks, docs=GradientClipping)
def withGradientClipping(self, value:float, name:str=None):
    return self.append(GradientClipping(value), name)
from torch.nn.utils import clip_grad_norm_
[docs]@k1lib.patch(Callback.cls)
class GradientClippingNorm(Callback):
    """Clips gradient to a specific max_norm value. Can choose to lump
all params together or do each separately.
See also: :class:`~k1lib.callbacks.limits.GradientClipping` callback."""
    def __init__(self, max_norm:float, each:bool=True):
        super().__init__(); self.max_norm = max_norm; self.each = each
    def startStep(self):
        if self.each:
            for m in self.l.model.parameters():
                clip_grad_norm_(m, self.max_norm)
        else: clip_grad_norm_(self.l.model.parameters(), self.max_norm) 
@k1lib.patch(Callbacks, docs=GradientClippingNorm)
def withGradientClippingNorm(self, max_norm:float, each:bool=True, name:str=None):
    return self.append(GradientClippingNorm(max_norm, each), name)