# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""
This module is for nice visualization tools. This is exposed automatically with::
   from k1lib.imports import *
   viz.mask # exposed
"""
import k1lib, base64, io, torch, os, matplotlib as mpl
import matplotlib.pyplot as plt, numpy as np
from typing import Callable, List, Union
from functools import partial, update_wrapper
__all__ = ["SliceablePlot", "plotSegments", "Carousel", "confusionMatrix", "FAnim",
           "mask"]
class _PlotDecorator:
    """The idea with decorators is that you can do something like this::
    sp = k1lib.viz.SliceablePlot()
    sp.yscale("log") # will format every plot as if ``plt.yscale("log")`` has been called
This class is not expected to be used by end users though."""
    def __init__(self, sliceablePlot:"SliceablePlot", name:str):
        """
:param sliceablePlot: the parent plot
:param name: the decorator's name, like "yscale" """
        self.sliceablePlot = sliceablePlot
        self.name = name; self.args, self.kwargs = None, None
    def __call__(self, *args, **kwargs):
        """Stores all args, then return the parent :class:`SliceablePlot`"""
        self.args = args; self.kwargs = kwargs; return self.sliceablePlot
    def run(self): getattr(plt, self.name)(*self.args, **self.kwargs)
[docs]class SliceablePlot:
    """This is a plot that is "sliceable", meaning you can focus into
a particular region of the plot quickly. A minimal example looks something
like this::
    import numpy as np, matplotlib.pyplot as plt, k1lib
    x = np.linspace(-2, 2, 100)
    
    def normalF():
        plt.plot(x, x**2)
    
    @k1lib.viz.SliceablePlot.decorate
    def plotF(_slice):
        plt.plot(x[_slice], (x**2)[_slice])
        
    plotF()[70:] # plots x^2 equation with x in [0.8, 2]
So, ``normalF`` plots the equation :math:`x^2` with x going from -2 to 2.
You can convert this into a :class:`SliceablePlot` by adding a term of
type :class:`slice` to the args, and decorate with :meth:`decorate`. Now,
every time you slice the :class:`SliceablePlot` with a specific range,
``plotF`` will receive it.
How intuitive everything is depends on how you slice your data. ``[70:]``
results in x in [0.8, 2] is rather unintuitive. You can change it into
something like this::
    @k1lib.viz.SliceablePlot.decorate
    def niceF(_slice):
        n = 100; r = k1lib.Range(-2, 2)
        x = np.linspace(*r, n)
        _slice = r.toRange(k1lib.Range(n), r.bound(_slice)).slice_
        plt.plot(x[_slice], (x**2)[_slice])
    # this works without a decorator too btw: k1lib.viz.SliceablePlot(niceF)
    niceF()[0.3:0.7] # plots x^2 equation with x in [0.3, 0.7]
    niceF()[0.3:] # plots x^2 equation with x in [0.3, 2]
The idea is to just take the input :class:`slice`, put some bounds on
its parts, then convert that slice from [-2, 2] to [0, 100]. Check
out :class:`k1lib.Range` if it's not obvious how this works.
A really cool feature of :class:`SliceablePlot` looks like this::
    niceF().legend(["A"])[-1:].grid(True).yscale("log")
This will plot :math:`x^2` with range in [-1, 2] with a nice grid, and
with y axis's scale set to log. Essentially, undefined method calls
on a :class:`SliceablePlot` will translate into ``plt`` calls. So the
above is roughly equivalent to this::
    x = np.linspace(-2, 2, 100)
    plt.plot(x, x**2)
    plt.legend(["A"])
    plt.grid(True)
    plt.yscale("log")
.. image:: images/SliceablePlot.png
This works even if you have multiple axes inside your figure. It's
wonderful, isn't it?"""
    def __init__(self, plotF:Callable[[slice], None], slices:Union[slice, List[slice]]=slice(None), plotDecorators:List[_PlotDecorator]=[], docs=""):
        """Creates a new SliceablePlot. Only use params listed below:
:param plotF: function that takes in a :class:`slice` or tuple of :class:`slice`s
:param docs: optional docs for the function that will be displayed in :meth:`__repr__`"""
        self.plotF = plotF
        self.slices = [slices] if isinstance(slices, slice) else slices
        self.docs = docs; self.plotDecorators = list(plotDecorators)
[docs]    @staticmethod
    def decorate(f):
        """Decorates a plotting function so that it becomes a
SliceablePlot."""
        answer = partial(SliceablePlot, plotF=f)
        update_wrapper(answer, f)
        return answer 
    @property
    def squeezedSlices(self) -> Union[List[slice], slice]:
        """If :attr:`slices` only has 1 element, then return that
element, else return the entire list."""
        return k1lib.squeeze(self.slices)
    def __getattr__(self, attr):
        if attr.startswith("_"): raise AttributeError()
        # automatically assume the attribute is a plt.attr method
        dec = _PlotDecorator(self, attr)
        self.plotDecorators.append(dec); return dec
    def __getitem__(self, idx):
        if type(idx) == slice:
            return SliceablePlot(self.plotF, [idx], self.plotDecorators, self.docs)
        if type(idx) == tuple and all([isinstance(elem, slice) for elem in idx]):
            return SliceablePlot(self.plotF, idx, self.plotDecorators, self.docs)
        raise Exception(f"Don't understand {idx}")
    def __repr__(self):
        self.plotF(self.squeezedSlices)
        for ax in plt.gcf().get_axes():
            plt.sca(ax)
            for decorator in self.plotDecorators: decorator.run()
        plt.show()
        return f"""Sliceable plot. Can...
- p[a:b]: to focus on a specific range of the plot
- p.yscale("log"): to perform operation as if you're using plt{self.docs}""" 
[docs]def plotSegments(x:List[float], y:List[float], states:List[int], colors:List[str]=None):
    """Plots a line graph, with multiple segments with different colors.
Idea is, you have a normal line graph, but you want to color parts of
the graph red, other parts blue. Then, you can pass a "state" array, with
the same length as your data, filled with ints, like this::
    y = np.array([ 460800,  921600,  921600, 1445888, 1970176, 1970176, 2301952,
           2633728, 2633728, 3043328, 3452928, 3452928, 3457024, 3461120,
           3463680, 3463680, 3470336, 3470336, 3467776, 3869184, 3865088,
           3865088, 3046400, 2972672, 2972672, 2309632, 2504192, 2504192,
           1456128, 1393664, 1393664,  472576])
    s = np.array([1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           1, 0, 0, 1, 0, 0, 1, 0, 0, 1])
    plotSegments(None, y, s, colors=["tab:blue", "tab:red"])
    
.. image:: images/plotSegments.png
:param x: (nullable) list of x coordinate at each point
:param y: list of y coordinates at each point
:param states: list of color at each point
:param colors: string colors (matplotlib color strings) to display for each states"""
    if x is None: x = range(len(y))
    if colors is None: colors = ["tab:blue", "tab:red", "tab:green", "tab:orange", "tab:purple", "tab:brown"][:len(x)]
    _x = []; _y = []; state = -1; count = -1 # stretchs, and bookkeeping nums
    lx = None; ly = None # last x and y from last stretch, for plot autocompletion
    while count + 1 < len(x):
        count += 1
        if state != states[count]:
            if len(_x) > 0 and state >= 0:
                if lx != None: _x = [lx] + _x; _y = [ly] + _y
                plt.plot(_x, _y, colors[state]); lx = _x[-1]; ly = _y[-1]
            _x = [x[count]]; _y = [y[count]]; state = states[count]
        else: _x.append(x[count]); _y.append(y[count])
    if len(_x) > 0 and state >= 0:
        if lx != None: _x = [lx] + _x; _y = [ly] + _y
        plt.plot(_x, _y, colors[state]) 
[docs]class Carousel:
    _idx = k1lib.AutoIncrement.random()
[docs]    def __init__(self):
        """Creates a new Carousel. You can then add images and whatnot.
Will even work even when you export the notebook as html. Example::
    import numpy as np, matplotlib.pyplot as plt, k1lib
    c = k1lib.viz.Carousel()
    x = np.linspace(-2, 2); plt.plot(x, x ** 2); c.savePlt()
    x = np.linspace(-1, 3); plt.plot(x, x ** 2); c.savePlt()
    c # displays in notebook cell
.. image:: images/carousel.png
"""
        self.imgs:List[Tuple[str, str]] = [] # Tuple[format, base64 img]
        self.defaultFormat = "jpeg" 
[docs]    def saveBytes(self, _bytes:bytes, fmt:str=None):
        """Saves bytes as another image.
:param fmt: format of image"""
        self.imgs.append((fmt or self.defaultFormat, base64.b64encode(_bytes).decode())) 
[docs]    def save(self, f:Callable[[io.BytesIO], None]):
        """Generic image save function. Treat :class:`io.BytesIO` as if it's
a file when you're doing this::
    
    with open("file.txt") as f:
        pass # "f" is similar to io.BytesIO
So, you can do stuff like::
    import matplotlib.pyplot as plt, numpy as np
    x = np.linspace(-2, 2)
    plt.plot(x, x**2)
    c = k1lib.viz.Carousel()
    c.save(lambda io: plt.savefig(io, format="png"))
:param f: lambda that provides a :class:`io.BytesIO` for you to write to
"""
        byteArr = io.BytesIO(); f(byteArr); byteArr.seek(0)
        self.saveBytes(byteArr.read()) 
[docs]    def savePlt(self):
        """Saves current plot from matplotlib"""
        self.save(lambda byteArr: plt.savefig(byteArr, format=self.defaultFormat))
        plt.clf() 
[docs]    def savePIL(self, image):
        """Saves a PIL image"""
        self.save(lambda byteArr: image.save(byteArr, format=self.defaultFormat)) 
[docs]    def saveFile(self, fileName:str, fmt:str=None):
        """Saves image from file.
:param fmt: format of the file. Will figure out from file extension
    automatically if left empty
"""
        with open(fileName, "rb") as f:
            if fmt is None: # automatically infer image format
                baseName = os.path.basename(fileName)
                if "." in baseName: fmt = baseName.split(".")[-1]
            self.saveBytes(f.read(), fmt) 
[docs]    def saveGraphviz(self, g):
        """Saves a graphviz graph"""
        import tempfile; a = tempfile.NamedTemporaryFile()
        g.render(a.name, format="jpeg"); self.saveFile(f"{a.name}.jpeg") 
[docs]    def pop(self):
        """Pops last image"""
        return self.imgs.pop() 
    def __getitem__(self, idx): return self.imgs[idx]
    def _repr_html_(self):
        imgs = [f"\"<img src='data:image/{fmt};base64, {img}' />\"" for fmt, img in self.imgs]
        idx = Carousel._idx.value
        pre = f"k1c_{idx}"
        html = f"""
<style>
    .{pre}_btn {{
        cursor: pointer;
        padding: 10px 15px;
        background: #9e9e9e;
        float: left;
        margin-right: 5px;
        color: #000;
        user-select: none
    }}
    .{pre}_btn:hover {{
        background: #4caf50;
        color: #fff;
    }}
</style>
<div>
    <div id="{pre}_prevBtn" class="{pre}_btn">Prev</div>
    <div id="{pre}_nextBtn" class="{pre}_btn">Next</div>
    <div style="clear:both"/>
    <div id="{pre}_status" style="padding: 10px"></div>
</div>
<div id="{pre}_imgContainer"></div>
<script>
    {pre}_imgs = [{','.join(imgs)}];
    {pre}_imgIdx = 0;
    function {pre}_display() {{
        document.querySelector("#{pre}_imgContainer").innerHTML = {pre}_imgs[{pre}_imgIdx];
        document.querySelector("#{pre}_status").innerHTML = "Image: " + ({pre}_imgIdx + 1) + "/" + {pre}_imgs.length;
    }};
    document.querySelector("#{pre}_prevBtn").onclick = () => {{
        {pre}_imgIdx -= 1;
        {pre}_imgIdx = Math.max({pre}_imgIdx, 0);
        {pre}_display();
    }};
    document.querySelector("#{pre}_nextBtn").onclick = () => {{
        {pre}_imgIdx += 1;
        {pre}_imgIdx = Math.min({pre}_imgIdx, {pre}_imgs.length - 1);
        {pre}_display();
    }};
    {pre}_display();
</script>
        """
        return html 
[docs]def confusionMatrix(matrix:torch.Tensor, categories:List[str]=None, **kwargs):
    """Plots a confusion matrix. Example::
    k1lib.viz.confusionMatrix(torch.rand(5, 5), ["a", "b", "c", "d", "e"])
.. image:: images/confusionMatrix.png
:param matrix: 2d matrix of shape (n, n)
:param categories: list of string categories
:param kwargs: keyword args passed into :meth:`plt.figure`"""
    if isinstance(matrix, torch.Tensor): matrix = matrix.numpy()
    if categories is None: categories = [f"{e}" for e in range(len(matrix))]
    fig = plt.figure(**{"dpi":100, **kwargs}); ax = fig.add_subplot(111)
    cax = ax.matshow(matrix); fig.colorbar(cax)
    with k1lib.ignoreWarnings():
        ax.set_xticklabels([''] + categories, rotation=90)
        ax.set_yticklabels([''] + categories)
    # Force label at every tick
    ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(1))
    ax.xaxis.set_label_position('top')
    plt.xlabel("Predictions"); plt.ylabel("Ground truth") 
[docs]def FAnim(fig, f, frames, *args, **kwargs):
    """Matplotlib function animation, 60fps. Example::
    # line below so that the animation is displayed in the notebook. Included in :mod:`k1lib.imports` already, so you don't really have to do this!
    plt.rcParams["animation.html"] = "jshtml"
    x = np.linspace(-2, 2); y = x**2
    fig, ax = plt.subplots()
    plt.close() # close cause it'll display 1 animation, 1 static if we don't do this
    def f(frame):
        ax.clear()
        ax.set_ylim(0, 4); ax.set_xlim(-2, 2)
        ax.plot(x[:frame], y[:frame])
    k1lib.FAnim(fig, f, len(x)) # plays animation in cell
:param fig: figure object from `plt.figure(...)` command
:param f: function that accepts 1 frame from `frames`.
:param frames: number of frames, or iterator, to pass into function"""
    return partial(mpl.animation.FuncAnimation, interval=1000/30)(fig, f, frames, *args, **kwargs) 
from torch import nn
from k1lib.cli import op
[docs]def mask(img:torch.Tensor, act:torch.Tensor) -> torch.Tensor:
    """Shows which part of the image the network is focusing on.
:param img: the image, expected to have dimension of (3, h, w)
:param act: the activation, expected to have dimension of (x, y), and with
    elements from 0 to 1."""
    *_, h, w = img.shape
    mask = act[None,] | nn.AdaptiveAvgPool2d([h//16, w//16]) | nn.AdaptiveAvgPool2d([h//8, w//8]) | nn.AdaptiveAvgPool2d([h, w])
    return mask * img | op().permute(1, 2, 0)