Source code for k1lib.callbacks.progress

# AUTOGENERATED FILE! PLEASE DON'T EDIT
import k1lib, time; from k1lib import fmt
from .callbacks import Callback, Callbacks, Cbs
import k1lib.cli as cli; import matplotlib.pyplot as plt
__all__ = ["ProgressBar"]
[docs]@k1lib.patch(Cbs) class ProgressBar(Callback): """Displays the current progress, epoch and batch while running. Deposits variables into :class:`~k1lib.Learner` at checkpoint ``startBatch``: - progress: single float from 0 to 1, guaranteed to increase monotonically - epochThroughput - remaining: estimated remaining time. Does not take into account callbacks that can potentially cancel the run prematurely""" def startRun(self): self.startTime = time.time(); self.step = 0; self.l.progress = 0 self.l.loss = float("inf") # to make sure this variable exist self.aL = self.bL = self.cL = self.dL = self.eL = 0; self.data = [] def startBatch(self): batch = self.l.batch; batches = self.l.batches; epoch = self.l.epoch; epochs = self.l.epochs elapsedTime = self.elapsedTime = time.time() - self.startTime if batches is None: progress = self.l.progress = epoch / epochs; batchTh = None else: progress = self.l.progress = (batch / batches + epoch) / epochs batchTh = batches * epochs * progress / elapsedTime epochTh = self.l.epochThroughput = epochs * progress / elapsedTime; self.l.batchThroughput = batchTh remaining = self.l.remaining = round(elapsedTime / (progress+1e-7) * (1-progress), 2) if progress > 0 else float('inf') a = str(round(100 * progress)); self.aL = max(self.aL, len(a)); a = a.rjust(self.aL) b = f"{epoch}/{epochs} ({fmt.throughput(epochTh, ' epochs')})"; self.bL = max(self.bL, len(b)); b = b.rjust(self.bL) if batches is not None: c = f"{batch}/{batches} ({fmt.throughput(batchTh, ' batches')})"; self.cL = max(self.cL, len(c)); c = c.rjust(self.cL) else: c = f"{batch}/{batches}"; self.cL = max(self.cL, len(c)); c = c.rjust(self.cL) d = f"{round(elapsedTime, 2)}".rjust(6); self.dL = max(self.dL, len(d)); d = d.rjust(self.dL) e = f"{remaining}"; self.eL = max(self.eL, len(e)); e = e.rjust(self.eL) self.data.append([epoch, batch, elapsedTime, progress, epochTh, batchTh, remaining]) print(f"\rProgress: {a}%, epoch: {b}, batch: {c}, elapsed: {d}s, remaining: {e}s, loss: {self.l.loss} ", end="")
[docs] def plot(self, f=cli.iden(), perEpoch=False, _window=2): """Plots detailed partial execution time profile. :param f: optional post processing step :param perEpoch: if True, normalize time per epoch, else keep it at time per run :param _window: number of batches to calculate the processing rate over. Put low values (min 2) to make it crisp (and inaccurate), put high values to make it smooth (and accurate)""" if perEpoch: f = cli.apply(cli.op()/self.l.epochs) | f self.data | cli.cut(2, 3) | cli.deref() | cli.window(_window)\ | cli.apply(cli.rows(0, -1) | cli.transpose() | ~cli.apply(lambda x, y: y-x) | ~cli.aS(lambda x, y: x/y))\ | f | cli.deref() | cli.aS(plt.plot) plt.ylabel("Time/epoch (seconds)" if perEpoch else "Time/run (seconds)");