# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, time
try: import torch; hasTorch = True
except: hasTorch = False
__all__ = ["BatchLimit", "EpochLimit", "TimeLimit", "CancelOnExplosion",
           "CancelOnLowLoss", "CancelOnHighAccuracy", "CancelOnOverfit", "DontTrain",
           "GradientClipping", "GradientClippingNorm", "TrainOnly", "ValidOnly"]
[docs]@k1lib.patch(Cbs)
class BatchLimit(Callback):                                                      # BatchLimit
    """Cancels the epoch after executed certain number of batches"""             # BatchLimit
    def __init__(self, limit:int):                                               # BatchLimit
        super().__init__(); self.order = 25                                      # BatchLimit
        self.limit = limit if limit != None else float("inf")                    # BatchLimit
    def startEpoch(self): self.currentBatch = 0                                  # BatchLimit
    def startBatch(self):                                                        # BatchLimit
        if self.currentBatch >= self.limit:                                      # BatchLimit
            raise k1lib.CancelEpochException(f"Batch {self.limit} reached")      # BatchLimit
    def endBatch(self): self.currentBatch += 1                                   # BatchLimit 
[docs]@k1lib.patch(Cbs)                                                                # BatchLimit
class EpochLimit(Callback):                                                      # EpochLimit
    """Cancels the run after executed certain number of epochs"""                # EpochLimit
    def __init__(self, limit:int):                                               # EpochLimit
        super().__init__(); self.order = 25                                      # EpochLimit
        self.limit = limit if limit != None else float("inf")                    # EpochLimit
    def startRun(self): self.currentEpoch = 0                                    # EpochLimit
    def startEpoch(self):                                                        # EpochLimit
        if self.currentEpoch >= self.limit:                                      # EpochLimit
            raise k1lib.CancelRunException(f"Epoch {self.limit} reached!")       # EpochLimit
    def endEpoch(self): self.currentEpoch += 1                                   # EpochLimit 
[docs]@k1lib.patch(Cbs)                                                                # EpochLimit
class TimeLimit(Callback):                                                       # TimeLimit
    """Cancels the run after a certain number of seconds have passed"""          # TimeLimit
    def __init__(self, seconds=30):                                              # TimeLimit
        super().__init__(); self.seconds = seconds if seconds != None else float("inf"); self.order = 25 # TimeLimit
    def startRun(self): self.startTime = time.time()                             # TimeLimit
    def startBatch(self):                                                        # TimeLimit
        if time.time() - self.startTime > self.seconds:                          # TimeLimit
            raise k1lib.CancelRunException(f"Takes more than {self.seconds} seconds!") # TimeLimit 
[docs]@k1lib.patch(Cbs)                                                                # TimeLimit
class CancelOnExplosion(Callback):                                               # CancelOnExplosion
    """Cancels the run if any of the parameters are larger than a certain limit""" # CancelOnExplosion
    def __init__(self, limit:float=1e6):                                         # CancelOnExplosion
        super().__init__(); self.order = 25                                      # CancelOnExplosion
        self.limit = limit; self.triggered = False                               # CancelOnExplosion
    def startRun(self): self.triggered = False                                   # CancelOnExplosion
    def startBatch(self):                                                        # CancelOnExplosion
        for p in self.l.model.parameters():                                      # CancelOnExplosion
            o = p.detach()                                                       # CancelOnExplosion
            if o.max().float() > self.limit or o.min().float() < -self.limit:    # CancelOnExplosion
                self.triggered = True                                            # CancelOnExplosion
                raise k1lib.CancelRunException("Explosion detected!")            # CancelOnExplosion
    def __repr__(self):                                                          # CancelOnExplosion
        return f"""{self._reprHead}, use...
- cb.triggered: to see if there was an explosion on the last run
- cb.progress: to see current progress at explosion time
{self._reprCan}"""                                                               # CancelOnExplosion 
@k1lib.patch(Cbs)                                                                # CancelOnExplosion
class CancelOnLowLoss(Callback):                                                 # CancelOnLowLoss
    " "                                                                          # CancelOnLowLoss
    def __init__(self, loss:float, epochMode:bool=False):                        # CancelOnLowLoss
        """Cancels the run if loss is lower than amount specified.
Original class: :class:`~k1lib.callbacks.limits.CancelOnLowLoss`
:param epochMode: False if use batch loss, True if use valid epoch loss"""       # CancelOnLowLoss
        super().__init__(); self.order = 25; self.dependsOn = ["Loss"]           # CancelOnLowLoss
        self.loss = loss; self.epochMode = epochMode                             # CancelOnLowLoss
    def startRun(self):                                                          # CancelOnLowLoss
        if not hasattr(self.l.cbs, "Loss"):                                      # CancelOnLowLoss
            raise AttributeError("Learner does not have required `Loss` callback") # CancelOnLowLoss
        self.v = self.cbs.Loss.valid; self.ve = self.cbs.Loss.epoch.valid # List[int] # CancelOnLowLoss
    def endBatch(self):                                                          # CancelOnLowLoss
        if self.epochMode:                                                       # CancelOnLowLoss
            if len(self.ve) > 0 and self.ve[-1] < self.loss:                     # CancelOnLowLoss
                raise k1lib.CancelRunException(f"Low loss {self.loss} ({self.ve[-3:]} actual) achieved!") # CancelOnLowLoss
        elif len(self.v) and self.v[-1] < self.loss:                             # CancelOnLowLoss
            raise k1lib.CancelRunException(f"Low loss {self.loss} ({self.v[-3:]} actual) achieved!") # CancelOnLowLoss
[docs]@k1lib.patch(Cbs)                                                                # CancelOnLowLoss
class CancelOnHighAccuracy(Callback):                                            # CancelOnHighAccuracy
    """Cancels the run if accuracy is higher than the amount specified"""        # CancelOnHighAccuracy
    def __init__(self, accuracy:float):                                          # CancelOnHighAccuracy
        super().__init__(); self.order = 25                                      # CancelOnHighAccuracy
        self.accuracy = accuracy; self.dependsOn = ["Accuracy"]                  # CancelOnHighAccuracy
    def endBatch(self):                                                          # CancelOnHighAccuracy
        if not hasattr(self.l, "Accuracy"): raise AttributeError("Learner does not have `Accuracy` callback") # CancelOnHighAccuracy
        a = self.l.Accuracy.valid[-1]                                            # CancelOnHighAccuracy
        if a > self.accuracy:                                                    # CancelOnHighAccuracy
            raise k1lib.CancelRunException(f"High accuracy {self.accuracy} ({a} actual) achieved!") # CancelOnHighAccuracy 
[docs]@k1lib.patch(Cbs)                                                                # CancelOnHighAccuracy
class CancelOnOverfit(Callback):                                                 # CancelOnOverfit
[docs]    def __init__(self, ratio:float=1.2, alpha:float=0.99, after:int=10):         # CancelOnOverfit
        """Cancels the run if overfit is detected.
:param ratio: Max ratio between the lowest loss and the current loss before cancelling the run
:param alpha: Moving average's alpha, used for both minLoss and loss estimates
:param after: After how many epochs should the overfit detection be activated?""" # CancelOnOverfit
        super().__init__(); self.ratio = ratio                                   # CancelOnOverfit
        self.minLoss = k1lib.MovingAvg(alpha=alpha, debias=True)                 # CancelOnOverfit
        self.loss = k1lib.MovingAvg(alpha=alpha, debias=True)                    # CancelOnOverfit
        self.count = 0; self.after = after                                       # CancelOnOverfit 
    def startRun(self): self.count = 0                                           # CancelOnOverfit
    def endEpoch(self): self.count += 1                                          # CancelOnOverfit
    def endBatch(self):                                                          # CancelOnOverfit
        if not self.l.model.training:                                            # CancelOnOverfit
            loss = self.l.loss; self.loss(loss)                                  # CancelOnOverfit
            if self.loss.value < self.minLoss.value or self.minLoss.value == 0: self.minLoss(self.loss.value) # CancelOnOverfit
            if self.count > self.after and self.loss.value > self.minLoss.value * self.ratio: # CancelOnOverfit
                raise k1lib.CancelRunException(f"Overfit detected! Smoothed min loss: {self.minLoss.value}, loss: {loss}") # CancelOnOverfit 
[docs]@k1lib.patch(Cbs)                                                                # CancelOnOverfit
class DontTrain(Callback):                                                       # DontTrain
    """Don't allow the network to train at all"""                                # DontTrain
    def startBackward(self): return True                                         # DontTrain
    def startStep(self): return True                                             # DontTrain 
if hasTorch:                                                                     # DontTrain
    from torch.nn.utils import clip_grad_value_                                  # DontTrain
    @k1lib.patch(Cbs)                                                            # DontTrain
    class GradientClipping(Callback):                                            # DontTrain
        """Clips gradient to a specific max value"""                             # DontTrain
        def __init__(self, value:float): super().__init__(); self.value = value  # DontTrain
        def startStep(self):                                                     # DontTrain
            clip_grad_value_(self.l.model.parameters(), self.value)              # DontTrain
else:                                                                            # DontTrain
[docs]    class GradientClipping(Callback): pass                                       # DontTrain 
if hasTorch:                                                                     # DontTrain
    from torch.nn.utils import clip_grad_norm_                                   # DontTrain
    @k1lib.patch(Cbs)                                                            # DontTrain
    class GradientClippingNorm(Callback):                                        # DontTrain
        """Clips gradient to a specific max_norm value. Can choose to lump
all params together or do each separately.
See also: :class:`~k1lib.callbacks.limits.GradientClipping` callback."""         # DontTrain
        def __init__(self, max_norm:float, each:bool=True):                      # DontTrain
            super().__init__(); self.max_norm = max_norm; self.each = each       # DontTrain
        def startStep(self):                                                     # DontTrain
            if self.each:                                                        # DontTrain
                for m in self.l.model.parameters():                              # DontTrain
                    clip_grad_norm_(m, self.max_norm)                            # DontTrain
            else: clip_grad_norm_(self.l.model.parameters(), self.max_norm)      # DontTrain
else:                                                                            # DontTrain
[docs]    class GradientClippingNorm(Callback): pass                                   # DontTrain 
@k1lib.patch(Cbs)                                                                # DontTrain
class TrainOnly(Callback):                                                       # TrainOnly
    " "                                                                          # TrainOnly
    def __init__(self, cb):                                                      # TrainOnly
        """Only executes specified callback when training. This modifies the callback's
``suspended`` variable, so it may interfere with :meth:`k1lib.callbacks.callbacks.Callbacks.suspend`
by setting it to different values while in the context."""                       # TrainOnly
        super().__init__(); self.cb = cb                                         # TrainOnly
    def startBatch(self):                                                        # TrainOnly
        self.cb.suspended = not self.l.model.training                            # TrainOnly
@k1lib.patch(Cbs)                                                                # TrainOnly
class ValidOnly(Callback):                                                       # ValidOnly
    " "                                                                          # ValidOnly
    def __init__(self, cb):                                                      # ValidOnly
        """Same as :class:`TrainOnly`, but only executes specified callback when doing
validation."""                                                                   # ValidOnly
        super().__init__(); self.cb = cb                                         # ValidOnly
    def startBatch(self):                                                        # ValidOnly
        self.cb.suspended = self.l.model.training                                # ValidOnly