# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""
This is for all short utilities that has the boilerplate feeling
"""
from k1lib.cli.init import patchDefaultDelim, BaseCli, cliSettings, Table, T
import k1lib.cli as cli, numbers, torch, numpy as np
from typing import overload, Iterator, Any, List, Set, Union
import k1lib
__all__ = ["size", "shape", "item", "identity", "iden",
           "toStr", "join", "toNumpy", "toTensor",
           "toList", "wrapList", "toSet", "toIter", "toRange", "toType",
           "equals", "reverse", "ignore",
           "toSum", "toAvg", "toMean", "toMax", "toMin",
           "lengths", "headerIdx", "deref"]
def exploreSize(it):
    """Returns first element and length of array"""
    if isinstance(it, str): raise TypeError("Just here to terminate shape()")
    sentinel = object(); it = iter(it)
    o = next(it, sentinel); count = 1
    if o is sentinel: return None, 0
    try:
        while True:
            next(it)
            count += 1
    except StopIteration: pass
    return o, count
[docs]class size(BaseCli):
[docs]    def __init__(self, idx=None):
        """Returns number of rows and columns in the input.
Example::
    # returns (3, 2)
    [[2, 3], [4, 5, 6], [3]] | size()
    # returns 3
    [[2, 3], [4, 5, 6], [3]] | size(0)
    # returns 2
    [[2, 3], [4, 5, 6], [3]] | size(1)
    # returns (2, 0)
    [[], [2, 3]] | size()
    # returns (3,)
    [2, 3, 5] | size()
    # returns 3
    [2, 3, 5] | size(0)
    # returns (3, 2, 2)
    [[[2, 1], [0, 6, 7]], 3, 5] | size()
    # returns (1,) and not (1, 3)
    ["abc"] | size()
    # returns (1, 2, 3)
    [torch.randn(2, 3)] | size()
    # returns (2, 3, 5)
    size()(np.random.randn(2, 3, 5))
There's also :class:`lengths`, which is sort of a simplified/faster version of
this, but only use it if you are sure that ``len(it)`` can be called.
If encounter PyTorch tensors or Numpy arrays, then this will just get the shape
instead of actually looping over them.
:param idx: if idx is None return (rows, columns). If 0 or 1, then rows or
    columns"""
        super().__init__(); self.idx = idx 
[docs]    def __ror__(self, it:Iterator[str]):
        super().__ror__(it)
        if self.idx is None:
            answer = []
            try:
                while True:
                    if isinstance(it, (torch.Tensor, np.ndarray)):
                        tuple(answer + list(it.shape))
                    it, s = exploreSize(it)
                    answer.append(s)
            except TypeError: pass
            return tuple(answer)
        else:
            it |= cli.item(self.idx)
            return exploreSize(it)[1]  
shape = size
[docs]class item(BaseCli):
[docs]    def __init__(self, amt:int=1):
        """Returns the first row.
Example::
    # returns 0
    iter(range(5)) | item()
    # returns torch.Size([5])
    torch.randn(3,4,5) | item(2) | shape()
:param amt: how many times do you want to call item() back to back?"""
        self.amt = amt 
[docs]    def __ror__(self, it:Iterator[str]):
        if self.amt != 1:
            return it | cli.serial(*(item() for _ in range(self.amt)))
        return next(iter(it))  
[docs]class identity(BaseCli):
    """Yields whatever the input is. Useful for multiple streams.
Example::
    # returns range(5)
    range(5) | identity()"""
[docs]    def __ror__(self, it:Iterator[Any]):
        return it  
iden = identity
[docs]class toStr(BaseCli):
[docs]    def __init__(self, column:int=None):
        """Converts every line to a string.
Example::
    # returns ['2', 'a']
    [2, "a"] | toStr() | deref()
    # returns [[2, 'a'], [3, '5']]
    assert [[2, "a"], [3, 5]] | toStr(1) | deref()"""
        super().__init__(); self.column = column 
[docs]    def __ror__(self, it:Iterator[str]):
        c = self.column
        if c is None:
            for line in it: yield str(line)
        else:
            for row in it:
                yield [e if i != c else str(e) for i, e in enumerate(row)]  
[docs]class join(BaseCli):
[docs]    def __init__(self, delim:str=None):
        r"""Merges all strings into 1, with `delim` in the middle. Basically
:meth:`str.join`. Example::
    # returns '2\na'
    [2, "a"] | join("\n")"""
        super().__init__(); self.delim = patchDefaultDelim(delim) 
[docs]    def __ror__(self, it:Iterator[str]):
        super().__ror__(it); return self.delim.join(it | toStr())  
[docs]class toNumpy(BaseCli):
    """Converts generator to numpy array. Essentially ``np.array(list(it))``"""
[docs]    def __ror__(self, it:Iterator[float]) -> np.array:
        return np.array(list(it))  
[docs]class toTensor(BaseCli):
[docs]    def __init__(self, dtype=torch.float32):
        """Converts generator to :class:`torch.Tensor`. Essentially
``torch.tensor(list(it))``.
Also checks if input is a PIL Image. If yes, turn it into a :class:`torch.Tensor`
and return."""
        self.dtype = dtype 
[docs]    def __ror__(self, it:Iterator[float]) -> torch.Tensor:
        try:
            import PIL; pic=it
            if isinstance(pic, PIL.Image.Image): # stolen from torchvision ToTensor transform
                mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
                img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
                if pic.mode == '1': img = 255 * img
                img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
                return img.permute((2, 0, 1)).contiguous().to(self.dtype) # put it from HWC to CHW format
        except: pass
        return torch.tensor(list(it)).to(self.dtype)  
[docs]class toList(BaseCli):
    """Converts generator to list. :class:`list` would do the
same, but this is just to maintain the style"""
[docs]    def __ror__(self, it:Iterator[Any]) -> List[Any]:
        return list(it)  
[docs]class wrapList(BaseCli):
    """Wraps inputs inside a list. There's a more advanced cli tool
built from this, which is :meth:`~k1lib.cli.structural.unsqueeze`."""
[docs]    def __ror__(self, it:T) -> List[T]:
        return [it]  
[docs]class toSet(BaseCli):
    """Converts generator to set. :class:`set` would do the
same, but this is just to maintain the style"""
[docs]    def __ror__(self, it:Iterator[T]) -> Set[T]:
        return set(it)  
[docs]class toIter(BaseCli):
    """Converts object to iterator. `iter()` would do the
same, but this is just to maintain the style"""
[docs]    def __ror__(self, it:List[T]) -> Iterator[T]:
        return iter(it)  
[docs]class toRange(BaseCli):
    """Returns iter(range(len(it))), effectively"""
[docs]    def __ror__(self, it:Iterator[Any]) -> Iterator[int]:
        for i, _ in enumerate(it): yield i  
[docs]class toType(BaseCli):
    """Converts object to its type.
Example::
    # returns [int, float, str, torch.Tensor]
    [2, 3.5, "ah", torch.randn(2, 3)] | toType() | deref()"""
[docs]    def __ror__(self, it:Iterator[T]) -> Iterator[type]:
        for e in it: yield type(e)  
class _EarlyExp(Exception): pass
[docs]class equals:
    """Checks if all incoming columns/streams are identical"""
[docs]    def __ror__(self, streams:Iterator[Iterator[str]]):
        streams = list(streams)
        for row in zip(*streams):
            sampleElem = row[0]
            try:
                for elem in row:
                    if sampleElem != elem: yield False; raise _EarlyExp()
                yield True
            except _EarlyExp: pass  
[docs]class reverse(BaseCli):
    """Reverses incoming list.
Example::
    # returns [3, 5, 2]
    [2, 5, 3] | reverse() | deref()"""
[docs]    def __ror__(self, it:Iterator[str]) -> List[str]:
        return reversed(list(it))  
[docs]class ignore(BaseCli):
    r"""Just loops through everything, ignoring the output.
Example::
    # will just return an iterator, and not print anything
    [2, 3] | apply(lambda x: print(x))
    # will prints "2\n3"
    [2, 3] | apply(lambda x: print(x)) | ignore()"""
[docs]    def __ror__(self, it:Iterator[Any]):
        for _ in it: pass  
[docs]class toSum(BaseCli):
    """Calculates the sum of list of numbers. Can pipe in :class:`torch.Tensor`.
Example::
    # returns 45
    range(10) | toSum()"""
[docs]    def __ror__(self, it:Iterator[float]):
        if isinstance(it, torch.Tensor): return it.sum()
        return sum(it)  
[docs]class toAvg(BaseCli):
    """Calculates average of list of numbers. Can pipe in :class:`torch.Tensor`.
Example::
    # returns 4.5
    range(10) | toAvg()
    # returns nan
    [] | toAvg()"""
[docs]    def __ror__(self, it:Iterator[float]):
        if isinstance(it, torch.Tensor): return it.mean()
        s = 0; i = -1
        for i, v in enumerate(it): s += v
        i += 1
        if not cliSettings["strict"] and i == 0: return float("nan")
        return s / i  
toMean = toAvg
[docs]class toMax(BaseCli):
    """Calculates the max of a bunch of numbers. Can pipe in :class:`torch.Tensor`.
Example::
    # returns 6
    [2, 5, 6, 1, 2] | toMax()"""
[docs]    def __ror__(self, it:Iterator[float]) -> float:
        if isinstance(it, torch.Tensor): return it.max()
        return max(it)  
[docs]class toMin(BaseCli):
    """Calculates the min of a bunch of numbers. Can pipe in :class:`torch.Tensor`.
Example::
    # returns 1
    [2, 5, 6, 1, 2] | toMin()"""
[docs]    def __ror__(self, it:Iterator[float]) -> float:
        if isinstance(it, torch.Tensor): return it.min()
        return min(it)  
[docs]class lengths(BaseCli):
    """Returns the lengths of each element.
Example::
    [range(5), range(10)] | lengths() == [5, 10]
This is a simpler (and faster!) version of :class:`shape`. It assumes each element
can be called with ``len(x)``, while :class:`shape` iterates through every elements
to get the length, and thus is much slower."""
[docs]    def __ror__(self, it:Iterator[List[Any]]) -> Iterator[int]:
        for e in it: yield len(e)  
Tensor = torch.Tensor
atomicTypes = (numbers.Number, np.number, str, dict, torch.nn.Module)
class inv_dereference(BaseCli):
    def __init__(self, ignoreTensors=False):
        """Kinda the inverse to :class:`dereference`"""
        super().__init__(); self.ignoreTensors = ignoreTensors
    def __ror__(self, it:Iterator[Any]) -> List[Any]:
        super().__ror__(it); ignoreTensors = self.ignoreTensors; 
        for e in it:
            if e is None or isinstance(e, atomicTypes): yield e
            elif isinstance(e, Tensor):
                if not ignoreTensors and len(e.shape) == 0: yield e.item()
                else: yield e
            else:
                try: yield e | self
                except: yield e
[docs]class deref(BaseCli):
[docs]    def __init__(self, ignoreTensors=True, maxDepth=float("inf")):
        """Recursively converts any iterator into a list. Only :class:`str`,
:class:`numbers.Number` and :class:`~torch.nn.Module` are not converted. Example::
    # returns something like "<range_iterator at 0x7fa8c52ca870>"
    iter(range(5))
    # returns [0, 1, 2, 3, 4]
    iter(range(5)) | deref()
    # returns [2, 3], yieldSentinel stops things early
    [2, 3, yieldSentinel, 6] | deref()
You can also specify a ``maxDepth``::
    # returns something like "<list_iterator at 0x7f810cf0fdc0>"
    iter([range(3)]) | deref(maxDepth=0)
    # returns [range(3)]
    iter([range(3)]) | deref(maxDepth=1)
    # returns [[0, 1, 2]]
    iter([range(3)]) | deref(maxDepth=2)
:param ignoreTensors: if True, then don't loop over :class:`torch.Tensor`
    internals
:param maxDepth: maximum depth to dereference. Starts at 0 for not doing anything
    at all
.. warning::
    Can work well with PyTorch Tensors, but not Numpy's array as they screw things up
    with the __ror__ operator, so do torch.from_numpy(...) first. Don't worry about
    unnecessary copying, as numpy and torch both utilizes the buffer protocol."""
        super().__init__(); self.ignoreTensors = ignoreTensors
        self.maxDepth = maxDepth; self.depth = 0 
[docs]    def __ror__(self, it:Iterator[T]) -> List[T]:
        super().__ror__(it); ignoreTensors = self.ignoreTensors
        if self.depth >= self.maxDepth: return it
        elif isinstance(it, atomicTypes): return it
        elif isinstance(it, Tensor):
            if ignoreTensors: return it
            if len(it.shape) == 0: return it.item()
        try: iter(it)
        except: return it
        self.depth += 1; answer = []
        for e in it:
            if e is cli.yieldSentinel: return answer
            answer.append(self.__ror__(e))
        self.depth -= 1
        return answer 
[docs]    def __invert__(self) -> BaseCli:
        """Returns a :class:`~k1lib.cli.init.BaseCli` that makes
everything an iterator. Not entirely sure when this comes in handy, but it's
there."""
        return inv_dereference(self.ignoreTensors)