# AUTOGENERATED FILE! PLEASE DON'T EDIT
import k1lib, torch; from .callbacks import Callback, Callbacks, Cbs
from typing import Tuple, List
__all__ = ["Recorder"]
[docs]@k1lib.patch(Cbs)
class Recorder(Callback):
"""Records xb, yb and y from a short run. No training involved.
Example::
l = k1lib.Learner.sample()
l.cbs.withRecorder()
xb, yb, y = l.Recorder.record(1, 2)
xb # list of x batches passed in
yb # list of y batches passed in, "the correct label"
y # list of network's output
"""
def __init__(self):
super().__init__(); self.order = 20; self.suspended = True
self.xbs = []; self.ybs = []; self.ys = []
def startBatch(self):
self.xbs += [self.l.xb.detach()]
self.ybs += [self.l.yb.detach()]
def endPass(self):
self.ys += [self.l.y.detach()]
@property
def values(self): return self.xbs, self.ybs, self.ys
[docs] def record(self, epochs:int=1, batches:int=None) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
"""Returns recorded xBatch, yBatch and answer y"""
self.suspended = False
with self.cbs.context(), self.cbs.suspendEval():
self.cbs.withDontTrain().withTimeLimit(5)
self.l.run(epochs, batches)
self.suspended = True; return self.values
def __repr__(self):
return f"""{self._reprHead}, can...
- r.record(epoch[, batches]): runs for a while, and records x and y batches, and the output
{self._reprCan}"""
@k1lib.patch(Callbacks, docs=Recorder)
def withRecorder(self): return self.append(Recorder())