# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, os, warnings
from typing import Callable
try: import torch; hasTorch = True
except: hasTorch = False
__all__ = ["Autosave", "DontTrainValid", "InspectLoss", "ModifyLoss", "Cpu", "Cuda",
           "DType", "InspectBatch", "ModifyBatch", "InspectOutput", "ModifyOutput",
           "Beep", "OnProgress"]
[docs]@k1lib.patch(Cbs)
class Autosave(Callback):                                                        # Autosave
    """Autosaves 3 versions of the network to disk"""                            # Autosave
    def __init__(self): super().__init__(); self.order = 23                      # Autosave
    def endRun(self):                                                            # Autosave
        os.system("mv autosave-1.pth autosave-0.pth")                            # Autosave
        os.system("mv autosave-2.pth autosave-1.pth")                            # Autosave
        self.l.save("autosave-2.pth")                                            # Autosave 
[docs]@k1lib.patch(Cbs)                                                                # Autosave
class DontTrainValid(Callback):                                                  # DontTrainValid
    """If is not training, then don't run m.backward() and opt.step().
The core training loop in k1lib.Learner don't specifically do this,
cause there may be some weird cases where you want to also train valid."""       # DontTrainValid
    def _common(self):                                                           # DontTrainValid
        if not self.l.model.training: return True                                # DontTrainValid
    def startBackward(self): return self._common()                               # DontTrainValid
    def startStep(self): return self._common()                                   # DontTrainValid 
[docs]@k1lib.patch(Cbs)                                                                # DontTrainValid
class InspectLoss(Callback):                                                     # InspectLoss
    """Expected `f` to take in 1 float."""                                       # InspectLoss
    def __init__(self, f): super().__init__(); self.f = f; self.order = 15       # InspectLoss
    def endLoss(self): self.f(self.loss.detach())                                # InspectLoss 
[docs]@k1lib.patch(Cbs)                                                                # InspectLoss
class ModifyLoss(Callback):                                                      # ModifyLoss
    """Expected `f` to take in 1 float and return 1 float."""                    # ModifyLoss
    def __init__(self, f): super().__init__(); self.f = f; self.order = 13       # ModifyLoss
    def endLoss(self): self.l.loss = self.f(self.loss)                           # ModifyLoss 
[docs]@k1lib.patch(Cbs)                                                                # ModifyLoss
class Cuda(Callback):                                                            # Cuda
    """Moves batch and model to the default GPU"""                               # Cuda
    def startRun(self): self.l.model.cuda()                                      # Cuda
    def startBatch(self):                                                        # Cuda
        try: self.l.xb = self.l.xb.cuda()                                        # Cuda
        except Exception as e: warnings.warn(f"xb can't be moved to the GPU: {e}") # Cuda
        try: self.l.yb = self.l.yb.cuda()                                        # Cuda
        except Exception as e: warnings.warn(f"yb can't be moved to the GPU: {e}") # Cuda 
[docs]@k1lib.patch(Cbs)                                                                # Cuda
class Cpu(Callback):                                                             # Cpu
    """Moves batch and model to CPU"""                                           # Cpu
    def startRun(self): self.l.model.cpu()                                       # Cpu
    def startBatch(self):                                                        # Cpu
        self.l.xb = self.l.xb.cpu()                                              # Cpu
        self.l.yb = self.l.yb.cpu()                                              # Cpu 
[docs]@k1lib.patch(Cbs)                                                                # Cpu
class DType(Callback):                                                           # DType
    """Moves batch and model to a specified data type"""                         # DType
    def __init__(self, dtype): super().__init__(); self.dtype = dtype            # DType
    def startRun(self): self.l.model = self.l.model.to(self.dtype)               # DType
    def startBatch(self):                                                        # DType
        self.l.xb = self.l.xb.to(self.dtype)                                     # DType
        self.l.yb = self.l.yb.to(self.dtype)                                     # DType 
[docs]@k1lib.patch(Cbs)                                                                # DType
class InspectBatch(Callback):                                                    # InspectBatch
    """Expected `f` to take in 2 tensors."""                                     # InspectBatch
    def __init__(self, f:callable): super().__init__(); self.f = f; self.order = 15 # InspectBatch
    def startBatch(self): self.f(self.l.xb, self.l.yb)                           # InspectBatch 
[docs]@k1lib.patch(Cbs)                                                                # InspectBatch
class ModifyBatch(Callback):                                                     # ModifyBatch
    """Modifies xb and yb on the fly. Expected `f`
    to take in 2 tensors and return 2 tensors."""                                # ModifyBatch
    def __init__(self, f): super().__init__(); self.f = f; self.order = 13       # ModifyBatch
    def startBatch(self): self.l.xb, self.l.yb = self.f(self.l.xb, self.l.yb)    # ModifyBatch 
[docs]@k1lib.patch(Cbs)                                                                # ModifyBatch
class InspectOutput(Callback):                                                   # InspectOutput
    """Expected `f` to take in 1 tensor."""                                      # InspectOutput
    def __init__(self, f): super().__init__(); self.f = f; self.order = 15       # InspectOutput
    def endPass(self): self.f(self.y)                                            # InspectOutput 
[docs]@k1lib.patch(Cbs)                                                                # InspectOutput
class ModifyOutput(Callback):                                                    # ModifyOutput
    """Modifies output on the fly. Expected `f` to take
in 1 tensor and return 1 tensor"""                                               # ModifyOutput
    def __init__(self, f): super().__init__(); self.f = f; self.order = 13       # ModifyOutput
    def endPass(self): self.l.y = self.f(self.y)                                 # ModifyOutput 
[docs]@k1lib.patch(Cbs)                                                                # ModifyOutput
class Beep(Callback):                                                            # Beep
    """Plays a beep sound when the run is over"""                                # Beep
    def endRun(self): k1lib.beep()                                               # Beep 
[docs]@k1lib.patch(Cbs)                                                                # Beep
class OnProgress(Callback):                                                      # OnProgress
    """Triggers a specific function once when reached a certain
progress.
:param f: function that takes in a :class:`~k1lib.Learner` object
:param progress: progress to trigger the function, from 0 to 1"""                # OnProgress
    def __init__(self, f:Callable[["k1lib.Learner"], None], progress:float):     # OnProgress
        super().__init__(); self.f = f; self.progress = progress; self.ran = False # OnProgress
    def startRun(self): self.ran = False                                         # OnProgress
    def startBatch(self):                                                        # OnProgress
        if (not self.ran) and self.l.progress > self.progress:                   # OnProgress
            self.ran = True; self.f(self.l)                                      # OnProgress