# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""
For operations that feel like the termination of operations
"""
from collections import defaultdict
from typing import Iterator, Any
from k1lib.cli.init import BaseCli; import k1lib.cli.init as init
import numbers, numpy as np, k1lib, tempfile, os, sys, time, math, json, re
from k1lib import cli; from k1lib.cli.typehint import *
plt = k1lib.dep("matplotlib.pyplot")
try: import torch; hasTorch = True
except: hasTorch = False
__all__ = ["stdout", "tee", "file", "pretty", "unpretty", "display", "headOut",
           "intercept", "plotImgs"]
settings = k1lib.settings.cli
[docs]class stdout(BaseCli):                                                           # stdout
[docs]    def __init__(self):                                                          # stdout
        """Prints out all lines. If not iterable, then print out the input
raw. Example::
    # prints out "0\\n1\\n2"
    range(3) | stdout()
    # same as above, but (maybe?) more familiar
    range(3) > stdout()
This is rarely used alone. It's more common to use :meth:`headOut`
for list of items, and :meth:`display` for tables."""                            # stdout
        super().__init__()                                                       # stdout 
    def _typehint(self, inp): return None                                        # stdout
[docs]    def __ror__(self, it:Iterator[str]):                                         # stdout
        try:                                                                     # stdout
            it = iter(it)                                                        # stdout
            for line in it: print(line)                                          # stdout
        except TypeError: print(it)                                              # stdout  
_defaultTeeF = lambda s: f"{s}\n"                                                # stdout
[docs]class tee(BaseCli):                                                              # tee
[docs]    def __init__(self, f=_defaultTeeF, s=None, every:int=1, delay:float=0):      # tee
        """Like the Linux ``tee`` command, this prints the elements to another
specified stream, while yielding the elements. Example::
    # prints "0) 0\\n1) 1\\n2) 2\\n3) 3\\n4) 4\\n" and returns [0, 1, 4, 9, 16]
    range(5) | tee() | apply(op() ** 2) | deref()
See also: :class:`~k1lib.cli.modifier.consume`
This cli is not exactly well-thoughout and is a little janky
:param f: element transform function. Defaults to just adding a new
    line at the end
:param s: stream to write to. Defaults to :attr:`sys.stdout`
:param every: only prints out 1 line in ``every`` lines, to limit print rate
:param delay: if subsequent prints are less than this number of seconds apart then don't print them""" # tee
        self.s = s or sys.stdout; self.f = f; self.every = every; self.delay = delay # tee 
[docs]    def __ror__(self, it):                                                       # tee
        s = self.s; f = self.f; every = self.every; delay = self.delay           # tee
        lastTime = 0                                                             # tee
        for i, e in enumerate(it):                                               # tee
            if i % every == 0 and time.time()-lastTime > delay:                  # tee
                print(f"     \r{i}) {f(e)}", end="", file=s); lastTime = time.time() # tee
            yield e                                                              # tee 
[docs]    def cr(self):                                                                # tee
        """Tee, but replaces the previous line. "cr" stands for carriage return.
Example::
    # prints "4" and returns [0, 1, 4, 9, 16]. Does print all the numbers in the middle, but is overriden
    range(5) | tee().cr() | apply(op() ** 2) | deref()"""                        # tee
        f = (lambda x: x) if self.f == _defaultTeeF else self.f                  # tee
        self.f = lambda s: f"{f(s)}"; return self                                # tee 
[docs]    def crt(self):                                                               # tee
        """Like :meth:`tee.cr`, but includes an elapsed time text at the end.
Example::
    range(5) | tee().cr() | apply(op() ** 2) | deref()"""                        # tee
        beginTime = time.time()                                                  # tee
        f = (lambda x: x) if self.f == _defaultTeeF else self.f                  # tee
        self.f = lambda s: f"{f(s)}, {int(time.time() - beginTime)}s elapsed"; return self # tee 
[docs]    def autoInc(self):                                                           # tee
        """Like :meth:`tee.crt`, but instead of printing the object, just print
the current index and time"""                                                    # tee
        beginTime = time.time(); autoInc = k1lib.AutoIncrement()                 # tee
        self.f = lambda s: f"{autoInc()}, {int(time.time()-beginTime)}s elapsed"; return self # tee  
try:                                                                             # tee
    import PIL; hasPIL = True                                                    # tee
except: hasPIL = False                                                           # tee
[docs]class file(BaseCli):                                                             # file
[docs]    def __init__(self, fileName:str=None, flush:bool=False, mkdir:bool=False):   # file
        """Opens a new file for writing. This will iterate through
the iterator fed to it and put each element on a separate line. Example::
    # writes "0\\n1\\n2\\n" to file
    range(3) | file("test/f.txt")
    # same as above, but (maybe?) more familiar
    range(3) > file("text/f.txt")
    # returns ['0', '1', '2']
    cat("folder/f.txt") | deref()
If the input is a string, then it will just put the string into the
file and does not iterate through the string::
    # writes "some text\\n123" to file, default iterator mode like above
    ["some text", "123"] | file("test/f.txt")
    # same as above, but this is a special case when it detects you're piping in a string
    "some text\\n123" | file("test/f.txt")
If the input is a :class:`bytes` object or an iterator of :class:`bytes`, then it
will open the file in binary mode and dumps the bytes in::
    # writes bytes to file
    b'5643' | file("test/a.bin")
    [b'56', b'43'] >> file("test/a.bin")
    # returns ['56435643']
    cat("test/a.bin") | deref()
If the input is a :class:`PIL.Image.Image` object, then it will just save the image in
the file::
    # creates an random image and saves it to a file
    torch.randn(100, 200) | toImg() | file("a.png")
Reminder that the image pixel range is expected to be from 0 to 255. You
can create temporary files on the fly by not specifying a file name::
    # creates temporary file
    url = range(3) > file()
    # returns ['0', '1', '2']
    cat(url) | deref()
This can be especially useful when integrating with shell scripts that wants to
read in a file::
    seq1 = "CCAAACCCCCCCTCCCCCGCTTC"
    seq2 = "CCAAACCCCCCCCTCCCCCCGCTTC"
    # use "needle" program to locally align 2 sequences
    None | cmd(f"needle {seq1 > file()} {seq2 > file()} -filter")
You can also append to file with the ">>" operator::
    url = range(3) > file()
    # appended to file
    range(10, 13) >> file(url)
    # returns ['0', '1', '2', '10', '11', '12']
    cat(url) | deref()
:param fileName: if not specified, create new temporary file and returns the url
    when pipes into it
:param flush: whether to flush to file immediately after every iteration
:param mkdir: whether to recursively make directories going to the file location or not""" # file
        super().__init__(); self.fileName = fileName; self.flush = flush; self.mkdir = mkdir # file
        self.append = False # whether to append to file rather than erasing it   # file 
[docs]    def __ror__(self, it:Iterator[str]) -> None:                                 # file
        super().__ror__(it); fileName = self.fileName; flushF = (lambda f: f.flush()) if self.flush else (lambda _: 0) # file
        if fileName is None:                                                     # file
            f = tempfile.NamedTemporaryFile()                                    # file
            fileName = f.name; f.close()                                         # file
        fileName = os.path.expanduser(fileName); firstLine = None                # file
        if self.mkdir: os.makedirs(os.path.dirname(fileName), exist_ok=True)     # file
        if hasPIL and isinstance(it, PIL.Image.Image): it.save(fileName); return fileName # file
        if isinstance(it, str): it = [it]; text = True                           # file
        elif isinstance(it, bytes): text = False                                 # file
        else:                                                                    # file
            it = iter(it); sentinel = object(); firstLine = next(it, sentinel)   # file
            if firstLine is sentinel: # no elements at all                       # file
                with open(fileName, "w") as f: f.write("")                       # file
                return fileName                                                  # file
            text = not isinstance(firstLine, bytes)                              # file
        if text:                                                                 # file
            with open(fileName, "a" if self.append else "w") as f:               # file
                if firstLine is not None: f.write(f"{firstLine}\n")              # file
                for line in it: f.write(f"{line}\n"); flushF(f)                  # file
        else:                                                                    # file
            with open(fileName, "ab" if self.append else "wb") as f:             # file
                if firstLine is not None:                                        # file
                    f.write(firstLine)                                           # file
                    for e in it: f.write(e); flushF(f)                           # file
                else: f.write(it)                                                # file
        return fileName                                                          # file 
    def __rrshift__(self, it):                                                   # file
        self.append = True # why do this? because `a | b >> c` will be interpreted as `a | (b >> c)` # file
        if isinstance(it, BaseCli): return cli.serial(it, self)                  # file
        else: return self.__ror__(it)                                            # file
    @property                                                                    # file
    def name(self):                                                              # file
        """File name of this :class:`file`"""                                    # file
        return self.fileName                                                     # file 
[docs]class pretty(BaseCli):                                                           # pretty
[docs]    def __init__(self, delim="", left=True):                                     # pretty
        """Pretty-formats a table, or a list of tables.
Example::
    # These 2 statements are pretty much the same
    [range(10), range(10, 20)] | head(5) | pretty() > stdout()
    [range(10), range(10, 20)] | display()
They both print::
    0    1    2    3    4    5    6    7    8    9
    10   11   12   13   14   15   16   17   18   19
This can also pretty-formats multiple tables::
    [[range(10), range(10, 20)], [["abc", "defff"], ["1", "1234567"]]] | ~pretty() | joinStreams() | stdout()
This will print::
    0     1         2    3    4    5    6    7    8    9
    10    11        12   13   14   15   16   17   18   19
    abc   defff
    1     1234567
:param delim: delimiter between elements within a row. You might want
    to set it to "|" to create an artificial border or sth
:param left: whether to left or right-align each element"""                      # pretty
        self.delim = delim; self.inverted = False; self.left = left              # pretty 
    def _typehint(self, inp): return tIter(str)                                  # pretty
[docs]    def __ror__(self, it) -> Iterator[str]:                                      # pretty
        inv = self.inverted; delim = self.delim; left = self.left                # pretty
        if inv: tables = [[list(i1) for i1 in i2] for i2 in it]                  # pretty
        else: tables = [[list(i1) for i1 in it]]                                 # pretty
        widths = defaultdict(lambda: 0)                                          # pretty
        for table in tables:                                                     # pretty
            for row in table:                                                    # pretty
                for i, e in enumerate(row):                                      # pretty
                    e = f"{e}"; row[i] = e                                       # pretty
                    widths[i] = max(len(e), widths[i])                           # pretty
        def gen(table):                                                          # pretty
            if left:                                                             # pretty
                for row in table: yield delim.join(e.rstrip(" ").ljust(w+3) for w, e in zip(widths.values(), row)) # pretty
            else:                                                                # pretty
                for row in table: yield delim.join(e.rstrip(" ").rjust(w+3) for w, e in zip(widths.values(), row)) # pretty
        if inv: return tables | cli.apply(gen)                                   # pretty
        else: return gen(tables[0])                                              # pretty 
[docs]    def __invert__(self): self.inverted = not self.inverted; return self         # pretty 
    def _jsF(self, meta):                                                        # pretty
        fIdx = init._jsFAuto(); dataIdx = init._jsDAuto()                        # pretty
        return f"const {fIdx} = ({dataIdx}) => {dataIdx}.pretty({json.dumps(self.delim)}, {cli.kjs.v(self.inverted)})", fIdx # pretty 
[docs]def display(lines:int=10):                                                       # display
    """Convenience method for displaying a table.
Pretty much equivalent to ``head() | pretty() | stdout()``.
See also: :class:`pretty`"""                                                     # display
    f = pretty() | stdout()                                                      # display
    if lines is None: return f                                                   # display
    else: return cli.head(lines) | f                                             # display 
[docs]def headOut(lines:int=10):                                                       # headOut
    """Convenience method for head() | stdout()"""                               # headOut
    if lines is None: return stdout()                                            # headOut
    else: return cli.head(lines) | stdout()                                      # headOut 
[docs]class unpretty(BaseCli):                                                         # unpretty
[docs]    def __init__(self, ncols:int=None, left=True, headers=None):                 # unpretty
        """Takes in a stream of strings, assumes it's a table, and tries to
split every line into multiple columns. Example::
    # returns ['0   1   2   ', '3   4   5   ', '6   7   8   ']
    a = range(10) | batched(3) | pretty() | deref()
    # returns [['0   ', '1   ', '2   '], ['3   ', '4   ', '5   '], ['6   ', '7   ', '8   ']]
    a | unpretty(3) | deref()
This cli will take the number of columns requested and try to split into a table by analyzing
at what character column does it transition from a space to a non-space (left align), or from
a non-space to a space (right align). Then the first ``ncols`` most popular transitions are
selected.
Sometimes this is not robust enough, may be some of your columns have lots of empty elements,
then the transition counts will be skewed, making it split up at strange places. In those cases,
you can specify the headers directly, like this::
    # returns ['a   b   c    ', '3   5   11   ', '4   6   7    ']
    a = [["a", 3, 4], ["b", 5, 6], ["c", 11, 7]] | transpose() | pretty() | deref()
    # returns [['a   ', 'b   ', 'c    '], ['3   ', '5   ', '11   '], ['4   ', '6   ', '7    ']]
    a | unpretty(headers=["a", "b", "c"]) | deref()
:param ncols: number of columns
:param left: whether the data is left or right aligned
:param header:"""                                                                # unpretty
        self.ncols = ncols; self.left = left; self.headers = headers             # unpretty
        self.pat = re.compile(" [^ ]+") if left else re.compile("[^ ]+ ")        # unpretty 
[docs]    def __ror__(self, it):                                                       # unpretty
        ncols = self.ncols; left = self.left; pat = self.pat; headers = self.headers # unpretty
        if headers is not None: ncols = len(headers)                             # unpretty
        if ncols < 1: raise Exception(f"Does not make sense to unpretty() into {ncols} columns") # unpretty
        if ncols == 1: return it                                                 # unpretty
        if headers is None:                                                      # unpretty
            try: len(it)                                                         # unpretty
            except: it = list(it)                                                # unpretty
            splits = it | cli.head(10000) | (cli.apply(lambda x: (m.start()+1 for m in re.finditer(pat, x))) if left else cli.apply(lambda x: (m.end()-1 for m in re.finditer(pat, x))))\
                
| cli.joinSt() | cli.count() | ~cli.sort() | cli.cut(1) | cli.head(ncols-1) | cli.sort(None) | cli.aS(list) # unpretty
        else:                                                                    # unpretty
            firstRow, it = it | cli.peek()                                       # unpretty
            if it == []: return []                                               # unpretty
            splits = sorted([firstRow.find(h) for h in headers])[1:]             # unpretty
        if ncols == 2: c = splits[0]; return ([row[:c],row[c:]] for row in it)   # unpretty
        a,*r,b = splits; s = splits | cli.window(2) | ~cli.apply(lambda x,y: f"x[{x}:{y}], ") | cli.join("") # unpretty
        f = eval(f"lambda x: [x[:{a}], {s}x[{b}:]]"); return (f(row) for row in it) # unpretty  
def tab(text, pad="    "):                                                       # tab
    return "\n".join([pad + line for line in text.split("\n")])                  # tab
[docs]class intercept(BaseCli):                                                        # intercept
[docs]    def __init__(self, f=None, raiseError:bool=True):                            # intercept
        """Intercept flow at a particular point, analyze the object piped in, and
raises error to stop flow. Example::
    3 | intercept()
:param f: prints out the object transformed by this function
:param raiseError: whether to raise error when executed or not."""               # intercept
        self.f = f or cli.shape(); self.raiseError = raiseError                  # intercept 
[docs]    def __ror__(self, s):                                                        # intercept
        print(type(s)); print(self.f(s))                                         # intercept
        if self.raiseError: raise RuntimeError("intercepted")                    # intercept
        return s                                                                 # intercept  
[docs]class plotImgs(BaseCli):                                                         # plotImgs
[docs]    def __init__(self, col=5, aspect=1, fac=2, axis=False, table=False, im=False): # plotImgs
        """Plots a bunch of images at the same time in a table.
Example::
    # plots all images
    [torch.randn(10, 20), torch.randn(20, 10)] | plotImgs()
    # plots all images with titles
    [[torch.randn(10, 20), "img 1"], [torch.randn(20, 10), "img 2"]] | plotImgs()
If you have multiple rows with different number of images, you can
plot that with this too, just set ``table=True`` like this::
    [[torch.randn(10, 20), torch.randn(20, 10)], [torch.randn(10, 20)]] | plotImgs(table=True)
There's another cli that kinda does what this does: :class:`~k1lib.cli.utils.sketch`. You have
more control over there, and it does roughly what this cli does, but the typical usage is
different. This is more for plotting static, throwaway list of 2d arrays, like training set
images, where as :class:`~k1lib.cli.utils.sketch` is more about plotting results of detailed
analyses.
:param col: number of columns in the table. If explicitly None, it will turn
    into the number of images fed. Not available if ``table=True``
:param aspect: aspect ratio of each images, or ratio between width and height
:param fac: figsize factor. The higher, the more resolution
:param axis: whether to display the axis or not
:param table: whether to plot using table mode
:param im: if True, returns an image"""                                          # plotImgs
        self.col = col; self.fac = fac; self.axis = axis; self.aspect = aspect; self.table = table; self.im = im # plotImgs 
[docs]    def __ror__(self, imgs):                                                     # plotImgs
        imgs = imgs | cli.deref(); col = self.col; fac = self.fac; aspect = self.aspect**0.5 # plotImgs
        if not self.table: # main code                                           # plotImgs
            if len(imgs) == 0: return                                            # plotImgs
            if col is None or col > len(imgs): col = len(imgs)                   # plotImgs
            n = math.ceil(len(imgs)/col)                                         # plotImgs
            fig, axes = plt.subplots(n, col, figsize=(col*fac*aspect, n*fac/aspect)); # plotImgs
            axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]    # plotImgs
            for ax, im in zip(axes, imgs):                                       # plotImgs
                plt.sca(ax)                                                      # plotImgs
                if isinstance(im, (list, tuple)): plt.imshow(im[0]); plt.title(im[1]) # plotImgs
                else: plt.imshow(im)                                             # plotImgs
                if not self.axis: ax.axis("off")                                 # plotImgs
            for i in range(len(imgs), len(axes)): axes[i].remove() # removing leftover axes # plotImgs
        else:                                                                    # plotImgs
            if col != 5: raise Exception("Currently in table mode, can't set `col` parameter") # change this value to match col's default value # plotImgs
            h = imgs | cli.shape(0); w = imgs | cli.shape(0).all() | cli.toMax() # plotImgs
            fig, axes = plt.subplots(h, w, figsize=(w*fac*aspect, h*fac/aspect)); # plotImgs
            for rAx, rIm in zip(axes, imgs):                                     # plotImgs
                for cAx, cIm in zip(rAx, rIm):                                   # plotImgs
                    plt.sca(cAx)                                                 # plotImgs
                    if isinstance(cIm, (list, tuple)): plt.imshow(cIm[0]); plt.title(cIm[1]) # plotImgs
                    else: plt.imshow(cIm)                                        # plotImgs
                    if not self.axis: cAx.axis("off")                            # plotImgs
                for i in range(len(rIm), len(rAx)): rAx[i].remove() # removing leftover axes # plotImgs
        plt.tight_layout()                                                       # plotImgs
        if self.im: return plt.gcf() | cli.toImg()                               # plotImgs