# AUTOGENERATED FILE! PLEASE DON'T EDIT
from .callbacks import Callback, Callbacks, Cbs
import k1lib, os, warnings
from typing import Callable
try: import torch; hasTorch = True
except: hasTorch = False
__all__ = ["Autosave", "DontTrainValid", "InspectLoss", "ModifyLoss", "Cpu", "Cuda",
"DType", "InspectBatch", "ModifyBatch", "InspectOutput", "ModifyOutput",
"Beep", "OnProgress"]
[docs]@k1lib.patch(Cbs)
class Autosave(Callback):
"""Autosaves 3 versions of the network to disk"""
def __init__(self): super().__init__(); self.order = 23
def endRun(self):
os.system("mv autosave-1.pth autosave-0.pth")
os.system("mv autosave-2.pth autosave-1.pth")
self.l.save("autosave-2.pth")
[docs]@k1lib.patch(Cbs)
class DontTrainValid(Callback):
"""If is not training, then don't run m.backward() and opt.step().
The core training loop in k1lib.Learner don't specifically do this,
cause there may be some weird cases where you want to also train valid."""
def _common(self):
if not self.l.model.training: return True
def startBackward(self): return self._common()
def startStep(self): return self._common()
[docs]@k1lib.patch(Cbs)
class InspectLoss(Callback):
"""Expected `f` to take in 1 float."""
def __init__(self, f): super().__init__(); self.f = f; self.order = 15
def endLoss(self): self.f(self.loss.detach())
[docs]@k1lib.patch(Cbs)
class ModifyLoss(Callback):
"""Expected `f` to take in 1 float and return 1 float."""
def __init__(self, f): super().__init__(); self.f = f; self.order = 13
def endLoss(self): self.l.loss = self.f(self.loss)
[docs]@k1lib.patch(Cbs)
class Cuda(Callback):
"""Moves batch and model to the default GPU"""
def startRun(self): self.l.model.cuda()
def startBatch(self):
try: self.l.xb = self.l.xb.cuda()
except Exception as e: warnings.warn(f"xb can't be moved to the GPU: {e}")
try: self.l.yb = self.l.yb.cuda()
except Exception as e: warnings.warn(f"yb can't be moved to the GPU: {e}")
[docs]@k1lib.patch(Cbs)
class Cpu(Callback):
"""Moves batch and model to CPU"""
def startRun(self): self.l.model.cpu()
def startBatch(self):
self.l.xb = self.l.xb.cpu()
self.l.yb = self.l.yb.cpu()
[docs]@k1lib.patch(Cbs)
class DType(Callback):
"""Moves batch and model to a specified data type"""
def __init__(self, dtype): super().__init__(); self.dtype = dtype
def startRun(self): self.l.model = self.l.model.to(self.dtype)
def startBatch(self):
self.l.xb = self.l.xb.to(self.dtype)
self.l.yb = self.l.yb.to(self.dtype)
[docs]@k1lib.patch(Cbs)
class InspectBatch(Callback):
"""Expected `f` to take in 2 tensors."""
def __init__(self, f:callable): super().__init__(); self.f = f; self.order = 15
def startBatch(self): self.f(self.l.xb, self.l.yb)
[docs]@k1lib.patch(Cbs)
class ModifyBatch(Callback):
"""Modifies xb and yb on the fly. Expected `f`
to take in 2 tensors and return 2 tensors."""
def __init__(self, f): super().__init__(); self.f = f; self.order = 13
def startBatch(self): self.l.xb, self.l.yb = self.f(self.l.xb, self.l.yb)
[docs]@k1lib.patch(Cbs)
class InspectOutput(Callback):
"""Expected `f` to take in 1 tensor."""
def __init__(self, f): super().__init__(); self.f = f; self.order = 15
def endPass(self): self.f(self.y)
[docs]@k1lib.patch(Cbs)
class ModifyOutput(Callback):
"""Modifies output on the fly. Expected `f` to take
in 1 tensor and return 1 tensor"""
def __init__(self, f): super().__init__(); self.f = f; self.order = 13
def endPass(self): self.l.y = self.f(self.y)
[docs]@k1lib.patch(Cbs)
class Beep(Callback):
"""Plays a beep sound when the run is over"""
def endRun(self): k1lib.beep()
[docs]@k1lib.patch(Cbs)
class OnProgress(Callback):
"""Triggers a specific function once when reached a certain
progress.
:param f: function that takes in a :class:`~k1lib.Learner` object
:param progress: progress to trigger the function, from 0 to 1"""
def __init__(self, f:Callable[["k1lib.Learner"], None], progress:float):
super().__init__(); self.f = f; self.progress = progress; self.ran = False
def startRun(self): self.ran = False
def startBatch(self):
if (not self.ran) and self.l.progress > self.progress:
self.ran = True; self.f(self.l)