Source code for k1lib.callbacks.recorder

# AUTOGENERATED FILE! PLEASE DON'T EDIT
import k1lib, torch; from k1lib.cli import empty, shape
from .callbacks import Callback, Callbacks, Cbs
from typing import Tuple, List
__all__ = ["Recorder"]
[docs]@k1lib.patch(Cbs) class Recorder(Callback): """Records xb, yb and y from a short run. No training involved. Example:: l = k1lib.Learner.sample() l.cbs.add(Cbs.Recorder()) xbs, ybs, ys = l.Recorder.record(1, 2) xbs # list of x batches passed in ybs # list of y batches passed in, "the correct label" ys # list of network's output If you have extra metadata in your dataloader, then the recorder will return (xb, yb, metab, ys) instead:: # creating a new dataloader that yields (xb, yb, metadata) x = torch.linspace(-5, 5, 1000); meta = torch.tensor(range(1000)) dl = [x, x+2, meta] | transpose() | randomize(None) | repeatFrom() | batched()\ | (transpose() | (toTensor() + toTensor() + toTensor())).all() | stagger(50) l = k1lib.Learner.sample(); l.data = [dl, []] l.cbs.add(Cbs.Recorder()) xbs, ybs, metabs, ys = l.Recorder.record(1, 2) """ def __init__(self): super().__init__(); self.order = 20; self.suspended = True def startRun(self): self.xbs = []; self.ybs = []; self.metabs = []; self.ys = [] def startBatch(self): self.xbs.append(self.l.xb.detach()) self.ybs.append(self.l.yb.detach()) self.metabs.append(self.l.metab) def endRun(self): n = min(len(self.xbs), len(self.ybs), len(self.metabs), len(self.ys)) self.xbs = self.xbs[:n]; self.ybs = self.ybs[:n] self.metabs = self.metabs[:n]; self.ys = self.ys[:n] def endPass(self): self.ys.append(self.l.y.detach()) @property def values(self): hasMeta = self.metabs | ~empty() | shape(0) > 0 if hasMeta: return self.xbs, self.ybs, self.metabs, self.ys else: return self.xbs, self.ybs, self.ys
[docs] def record(self, epochs:int=1, batches:int=None) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: """Returns recorded xBatch, yBatch and answer y""" self.suspended = False try: with self.cbs.context(), self.cbs.suspendEval(): self.cbs.add(Cbs.DontTrain()).add(Cbs.TimeLimit(5)) self.l.run(epochs, batches) finally: self.suspended = True return self.values
def __repr__(self): return f"""{self._reprHead}, can... - r.record(epoch[, batches]): runs for a while, and records x and y batches, and the output {self._reprCan}"""