# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
import k1lib, time; from k1lib import fmt
from .callbacks import Callback, Callbacks, Cbs
import k1lib.cli as cli; plt = k1lib.dep("matplotlib.pyplot")
__all__ = ["ProgressBar"]
[docs]@k1lib.patch(Cbs)
class ProgressBar(Callback):                                                     # ProgressBar
    """Displays the current progress, epoch and batch while running.
Deposits variables into :class:`~k1lib.Learner` at checkpoint ``startBatch``:
- progress: single float from 0 to 1, guaranteed to increase monotonically
- epochThroughput
- remaining: estimated remaining time. Does not take into account callbacks that can potentially cancel the run prematurely""" # ProgressBar
    def startRun(self):                                                          # ProgressBar
        self.startTime = time.time(); self.step = 0; self.l.progress = 0         # ProgressBar
        self.l.loss = float("inf") # to make sure this variable exist            # ProgressBar
        self.aL = self.bL = self.cL = self.dL = self.eL = 0; self.data = []      # ProgressBar
    def startBatch(self):                                                        # ProgressBar
        batch = self.l.batch; batches = self.l.batches; epoch = self.l.epoch; epochs = self.l.epochs # ProgressBar
        elapsedTime = self.elapsedTime = time.time() - self.startTime            # ProgressBar
        if batches is None: progress = self.l.progress = epoch / epochs; batchTh = None # ProgressBar
        else:                                                                    # ProgressBar
            progress = self.l.progress = (batch / batches + epoch) / epochs      # ProgressBar
            batchTh = batches * epochs * progress / elapsedTime                  # ProgressBar
        epochTh = self.l.epochThroughput = epochs * progress / elapsedTime; self.l.batchThroughput = batchTh # ProgressBar
        remaining = self.l.remaining = round(elapsedTime / (progress+1e-7) * (1-progress), 2) if progress > 0 else float('inf') # ProgressBar
                                                                                 # ProgressBar
        a = str(round(100 * progress)); self.aL = max(self.aL, len(a)); a = a.rjust(self.aL) # ProgressBar
        b = f"{epoch}/{epochs} ({fmt.throughput(epochTh, ' epochs')})"; self.bL = max(self.bL, len(b)); b = b.rjust(self.bL) # ProgressBar
        if batches is not None:                                                  # ProgressBar
            c = f"{batch}/{batches} ({fmt.throughput(batchTh, ' batches')})"; self.cL = max(self.cL, len(c)); c = c.rjust(self.cL) # ProgressBar
        else: c = f"{batch}/{batches}"; self.cL = max(self.cL, len(c)); c = c.rjust(self.cL) # ProgressBar
        d = f"{round(elapsedTime, 2)}".rjust(6); self.dL = max(self.dL, len(d)); d = d.rjust(self.dL) # ProgressBar
        e = f"{remaining}"; self.eL = max(self.eL, len(e)); e = e.rjust(self.eL) # ProgressBar
        self.data.append([epoch, batch, elapsedTime, progress, epochTh, batchTh, remaining]) # ProgressBar
        print(f"\rProgress: {a}%, epoch: {b}, batch: {c}, elapsed: {d}s, remaining: {e}s, loss: {self.l.loss}             ", end="") # ProgressBar
[docs]    def plot(self, f=cli.iden(), perEpoch=False, _window=2):                     # ProgressBar
        """Plots detailed partial execution time profile.
:param f: optional post processing step
:param perEpoch: if True, normalize time per epoch, else keep it at time per run
:param _window: number of batches to calculate the processing rate over. Put low
    values (min 2) to make it crisp (and inaccurate), put high values to make it
    smooth (and accurate)"""                                                     # ProgressBar
        if perEpoch: f = cli.apply(cli.op()/self.l.epochs) | f                   # ProgressBar
        self.data | cli.cut(2, 3) | cli.deref() | cli.window(_window)\
        
| cli.apply(cli.rows(0, -1) | cli.transpose() | ~cli.apply(lambda x, y: y-x) | ~cli.aS(lambda x, y: x/y))\
        
| f | cli.deref() | cli.aS(plt.plot)                                     # ProgressBar
        plt.ylabel("Time/epoch (seconds)" if perEpoch else "Time/run (seconds)"); # ProgressBar