Source code for k1lib.callbacks.shorts

# AUTOGENERATED FILE! PLEASE DON'T EDIT
from .callbacks import Callback, Callbacks, Cbs
import k1lib, os, torch
from typing import Callable
__all__ = ["Autosave", "DontTrainValid", "InspectLoss", "ModifyLoss", "Cpu", "Cuda",
           "DType", "InspectBatch", "ModifyBatch", "InspectOutput", "ModifyOutput", 
           "Beep", "OnProgress"]
[docs]@k1lib.patch(Cbs) class Autosave(Callback): """Autosaves 3 versions of the network to disk""" def __init__(self): super().__init__(); self.order = 23 def endRun(self): os.system("mv autosave-1.pth autosave-0.pth") os.system("mv autosave-2.pth autosave-1.pth") self.l.save("autosave-2.pth")
[docs]@k1lib.patch(Cbs) class DontTrainValid(Callback): """If is not training, then don't run m.backward() and opt.step(). The core training loop in k1lib.Learner don't specifically do this, cause there may be some weird cases where you want to also train valid.""" def _common(self): if not self.l.model.training: return True def startBackward(self): return self._common() def startStep(self): return self._common()
[docs]@k1lib.patch(Cbs) class InspectLoss(Callback): """Expected `f` to take in 1 float.""" def __init__(self, f): super().__init__(); self.f = f; self.order = 15 def endLoss(self): self.f(self.loss.detach())
[docs]@k1lib.patch(Cbs) class ModifyLoss(Callback): """Expected `f` to take in 1 float and return 1 float.""" def __init__(self, f): super().__init__(); self.f = f; self.order = 13 def endLoss(self): self.l.loss = self.f(self.loss)
[docs]@k1lib.patch(Cbs) class Cuda(Callback): """Moves batch and model to the default GPU""" def startRun(self): self.l.model.cuda() def startBatch(self): self.l.xb = self.l.xb.cuda() self.l.yb = self.l.yb.cuda()
[docs]@k1lib.patch(Cbs) class Cpu(Callback): """Moves batch and model to CPU""" def startRun(self): self.l.model.cpu() def startBatch(self): self.l.xb = self.l.xb.cpu() self.l.yb = self.l.yb.cpu()
[docs]@k1lib.patch(Cbs) class DType(Callback): """Moves batch and model to a specified data type""" def __init__(self, dtype): super().__init__(); self.dtype = dtype def startRun(self): self.l.model = self.l.model.to(self.dtype) def startBatch(self): self.l.xb = self.l.xb.to(self.dtype) self.l.yb = self.l.yb.to(self.dtype)
[docs]@k1lib.patch(Cbs) class InspectBatch(Callback): """Expected `f` to take in 2 tensors.""" def __init__(self, f:callable): super().__init__(); self.f = f; self.order = 15 def startBatch(self): self.f(self.l.xb, self.l.yb)
[docs]@k1lib.patch(Cbs) class ModifyBatch(Callback): """Modifies xb and yb on the fly. Expected `f` to take in 2 tensors and return 2 tensors.""" def __init__(self, f): super().__init__(); self.f = f; self.order = 13 def startBatch(self): self.l.xb, self.l.yb = self.f(self.l.xb, self.l.yb)
[docs]@k1lib.patch(Cbs) class InspectOutput(Callback): """Expected `f` to take in 1 tensor.""" def __init__(self, f): super().__init__(); self.f = f; self.order = 15 def endPass(self): self.f(self.y)
[docs]@k1lib.patch(Cbs) class ModifyOutput(Callback): """Modifies output on the fly. Expected `f` to take in 1 tensor and return 1 tensor""" def __init__(self, f): super().__init__(); self.f = f; self.order = 13 def endPass(self): self.l.y = self.f(self.y)
[docs]@k1lib.patch(Cbs) class Beep(Callback): """Plays a beep sound when the run is over""" def endRun(self): k1lib.beep()
[docs]@k1lib.patch(Cbs) class OnProgress(Callback): """Triggers a specific function once when reached a certain progress. :param f: function that takes in a :class:`~k1lib.Learner` object :param progress: progress to trigger the function, from 0 to 1""" def __init__(self, f:Callable[["k1lib.Learner"], None], progress:float): super().__init__(); self.f = f; self.progress = progress; self.ran = False def startRun(self): self.ran = False def startBatch(self): if (not self.ran) and self.l.progress > self.progress: self.ran = True; self.f(self.l)