# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
import k1lib; from k1lib.cli import empty, shape
from .callbacks import Callback, Callbacks, Cbs
from typing import Tuple, List
try: import torch; hasTorch = True
except: torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["Recorder"]
[docs]@k1lib.patch(Cbs)
class Recorder(Callback):                                                        # Recorder
    """Records xb, yb and y from a short run. No training involved.
Example::
    l = k1lib.Learner.sample()
    l.cbs.add(Cbs.Recorder())
    xbs, ybs, ys = l.Recorder.record(1, 2)
    xbs # list of x batches passed in
    ybs # list of y batches passed in, "the correct label"
    ys #  list of network's output
If you have extra metadata in your dataloader, then the recorder will return
(xb, yb, metab, ys) instead::
    # creating a new dataloader that yields (xb, yb, metadata)
    x = torch.linspace(-5, 5, 1000); meta = torch.tensor(range(1000))
    dl = [x, x+2, meta] | transpose() | randomize(None) | repeatFrom() | batched()\
    | (transpose() | (toTensor() + toTensor() + toTensor())).all() | stagger(50)
    l = k1lib.Learner.sample(); l.data = [dl, []]
    l.cbs.add(Cbs.Recorder())
    xbs, ybs, metabs, ys = l.Recorder.record(1, 2)
"""                                                                              # Recorder
    def __init__(self):                                                          # Recorder
        super().__init__(); self.order = 20; self.suspended = True               # Recorder
    def startRun(self):                                                          # Recorder
        self.xbs = []; self.ybs = []; self.metabs = []; self.ys = []             # Recorder
    def startBatch(self):                                                        # Recorder
        self.xbs.append(self.l.xb.detach())                                      # Recorder
        self.ybs.append(self.l.yb.detach())                                      # Recorder
        self.metabs.append(self.l.metab)                                         # Recorder
    def endRun(self):                                                            # Recorder
        n = min(len(self.xbs), len(self.ybs), len(self.metabs), len(self.ys))    # Recorder
        self.xbs = self.xbs[:n]; self.ybs = self.ybs[:n]                         # Recorder
        self.metabs = self.metabs[:n]; self.ys = self.ys[:n]                     # Recorder
    def endPass(self):                                                           # Recorder
        self.ys.append(self.l.y.detach())                                        # Recorder
    @property                                                                    # Recorder
    def values(self):                                                            # Recorder
        hasMeta = self.metabs | ~empty() | shape(0) > 0                          # Recorder
        if hasMeta: return self.xbs, self.ybs, self.metabs, self.ys              # Recorder
        else: return self.xbs, self.ybs, self.ys                                 # Recorder
[docs]    def record(self, epochs:int=1, batches:int=None) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: # Recorder
        """Returns recorded xBatch, yBatch and answer y"""                       # Recorder
        self.suspended = False                                                   # Recorder
        try:                                                                     # Recorder
            with self.cbs.context(), self.cbs.suspendEval():                     # Recorder
                self.cbs.add(Cbs.DontTrain()).add(Cbs.TimeLimit(5))              # Recorder
                self.l.run(epochs, batches)                                      # Recorder
        finally: self.suspended = True                                           # Recorder
        return self.values                                                       # Recorder 
    def __repr__(self):                                                          # Recorder
        return f"""{self._reprHead}, can...
- r.record(epoch[, batches]): runs for a while, and records x and y batches, and the output
{self._reprCan}"""                                                               # Recorder