# AUTOGENERATED FILE! PLEASE DON'T EDIT
import k1lib, time; from k1lib import fmt
from .callbacks import Callback, Callbacks, Cbs
import k1lib.cli as cli; import matplotlib.pyplot as plt
__all__ = ["ProgressBar"]
[docs]@k1lib.patch(Cbs)
class ProgressBar(Callback):
"""Displays the current progress, epoch and batch while running.
Deposits variables into :class:`~k1lib.Learner` at checkpoint ``startBatch``:
- progress: single float from 0 to 1, guaranteed to increase monotonically
- epochThroughput
- remaining: estimated remaining time. Does not take into account callbacks that can potentially cancel the run prematurely"""
def startRun(self):
self.startTime = time.time(); self.step = 0; self.l.progress = 0
self.l.loss = float("inf") # to make sure this variable exist
self.aL = self.bL = self.cL = self.dL = self.eL = 0; self.data = []
def startBatch(self):
batch = self.l.batch; batches = self.l.batches; epoch = self.l.epoch; epochs = self.l.epochs
elapsedTime = self.elapsedTime = time.time() - self.startTime
if batches is None: progress = self.l.progress = epoch / epochs; batchTh = None
else:
progress = self.l.progress = (batch / batches + epoch) / epochs
batchTh = batches * epochs * progress / elapsedTime
epochTh = self.l.epochThroughput = epochs * progress / elapsedTime; self.l.batchThroughput = batchTh
remaining = self.l.remaining = round(elapsedTime / (progress+1e-7) * (1-progress), 2) if progress > 0 else float('inf')
a = str(round(100 * progress)); self.aL = max(self.aL, len(a)); a = a.rjust(self.aL)
b = f"{epoch}/{epochs} ({fmt.throughput(epochTh, ' epochs')})"; self.bL = max(self.bL, len(b)); b = b.rjust(self.bL)
if batches is not None:
c = f"{batch}/{batches} ({fmt.throughput(batchTh, ' batches')})"; self.cL = max(self.cL, len(c)); c = c.rjust(self.cL)
else: c = f"{batch}/{batches}"; self.cL = max(self.cL, len(c)); c = c.rjust(self.cL)
d = f"{round(elapsedTime, 2)}".rjust(6); self.dL = max(self.dL, len(d)); d = d.rjust(self.dL)
e = f"{remaining}"; self.eL = max(self.eL, len(e)); e = e.rjust(self.eL)
self.data.append([epoch, batch, elapsedTime, progress, epochTh, batchTh, remaining])
print(f"\rProgress: {a}%, epoch: {b}, batch: {c}, elapsed: {d}s, remaining: {e}s, loss: {self.l.loss} ", end="")
[docs] def plot(self, f=cli.iden(), perEpoch=False, _window=2):
"""Plots detailed partial execution time profile.
:param f: optional post processing step
:param perEpoch: if True, normalize time per epoch, else keep it at time per run
:param _window: number of batches to calculate the processing rate over. Put low
values (min 2) to make it crisp (and inaccurate), put high values to make it
smooth (and accurate)"""
if perEpoch: f = cli.apply(cli.op()/self.l.epochs) | f
self.data | cli.cut(2, 3) | cli.deref() | cli.window(_window)\
| cli.apply(cli.rows(0, -1) | cli.transpose() | ~cli.apply(lambda x, y: y-x) | ~cli.aS(lambda x, y: x/y))\
| f | cli.deref() | cli.aS(plt.plot)
plt.ylabel("Time/epoch (seconds)" if perEpoch else "Time/run (seconds)");