# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""
This module is for nice visualization tools. This is exposed automatically with::
   from k1lib.imports import *
   viz.mask # exposed
"""
import k1lib, base64, io, os, matplotlib as mpl, warnings
import k1lib.cli as cli
plt = k1lib.dep("matplotlib.pyplot"); import numpy as np
from typing import Callable, List, Union
from functools import partial, update_wrapper
try: import torch; import torch.nn as nn; hasTorch = True
except:
    torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
    nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
try: import PIL; hasPIL = True
except: hasPIL = False
__all__ = ["SliceablePlot", "plotSegments", "Carousel", "Toggle", "ToggleImage",
           "Scroll", "confusionMatrix", "FAnim", "mask", "PDF"]
class _PlotDecorator:                                                            # _PlotDecorator
    """The idea with decorators is that you can do something like this::
    sp = k1lib.viz.SliceablePlot()
    sp.yscale("log") # will format every plot as if ``plt.yscale("log")`` has been called
This class is not expected to be used by end users though."""                    # _PlotDecorator
    def __init__(self, sliceablePlot:"SliceablePlot", name:str):                 # _PlotDecorator
        """
:param sliceablePlot: the parent plot
:param name: the decorator's name, like "yscale" """                             # _PlotDecorator
        self.sliceablePlot = sliceablePlot                                       # _PlotDecorator
        self.name = name; self.args, self.kwargs = None, None                    # _PlotDecorator
    def __call__(self, *args, **kwargs):                                         # _PlotDecorator
        """Stores all args, then return the parent :class:`SliceablePlot`"""     # _PlotDecorator
        self.args = args; self.kwargs = kwargs; return self.sliceablePlot        # _PlotDecorator
    def run(self): getattr(plt, self.name)(*self.args, **self.kwargs)            # _PlotDecorator
[docs]class SliceablePlot:                                                             # SliceablePlot
    """This is a plot that is "sliceable", meaning you can focus into
a particular region of the plot quickly. A minimal example looks something
like this::
    import numpy as np, matplotlib.pyplot as plt, k1lib
    x = np.linspace(-2, 2, 100)
    def normalF():
        plt.plot(x, x**2)
    @k1lib.viz.SliceablePlot.decorate
    def plotF(_slice):
        plt.plot(x[_slice], (x**2)[_slice])
    plotF()[70:] # plots x^2 equation with x in [0.8, 2]
So, ``normalF`` plots the equation :math:`x^2` with x going from -2 to 2.
You can convert this into a :class:`SliceablePlot` by adding a term of
type :class:`slice` to the args, and decorate with :meth:`decorate`. Now,
every time you slice the :class:`SliceablePlot` with a specific range,
``plotF`` will receive it.
How intuitive everything is depends on how you slice your data. ``[70:]``
results in x in [0.8, 2] is rather unintuitive. You can change it into
something like this::
    @k1lib.viz.SliceablePlot.decorate
    def niceF(_slice):
        n = 100; r = k1lib.Range(-2, 2)
        x = np.linspace(*r, n)
        _slice = r.toRange(k1lib.Range(n), r.bound(_slice)).slice_
        plt.plot(x[_slice], (x**2)[_slice])
    # this works without a decorator too btw: k1lib.viz.SliceablePlot(niceF)
    niceF()[0.3:0.7] # plots x^2 equation with x in [0.3, 0.7]
    niceF()[0.3:] # plots x^2 equation with x in [0.3, 2]
The idea is to just take the input :class:`slice`, put some bounds on
its parts, then convert that slice from [-2, 2] to [0, 100]. Check
out :class:`k1lib.Range` if it's not obvious how this works.
A really cool feature of :class:`SliceablePlot` looks like this::
    niceF().legend(["A"])[-1:].grid(True).yscale("log")
This will plot :math:`x^2` with range in [-1, 2] with a nice grid, and
with y axis's scale set to log. Essentially, undefined method calls
on a :class:`SliceablePlot` will translate into ``plt`` calls. So the
above is roughly equivalent to this::
    x = np.linspace(-2, 2, 100)
    plt.plot(x, x**2)
    plt.legend(["A"])
    plt.grid(True)
    plt.yscale("log")
.. image:: images/SliceablePlot.png
This works even if you have multiple axes inside your figure. It's
wonderful, isn't it?"""                                                          # SliceablePlot
    def __init__(self, plotF:Callable[[slice], None], slices:Union[slice, List[slice]]=slice(None), plotDecorators:List[_PlotDecorator]=[], docs=""): # SliceablePlot
        """Creates a new SliceablePlot. Only use params listed below:
:param plotF: function that takes in a :class:`slice` or tuple of :class:`slice`s
:param docs: optional docs for the function that will be displayed in :meth:`__repr__`""" # SliceablePlot
        self.plotF = plotF                                                       # SliceablePlot
        self.slices = [slices] if isinstance(slices, slice) else slices          # SliceablePlot
        self.docs = docs; self.plotDecorators = list(plotDecorators)             # SliceablePlot
[docs]    @staticmethod                                                                # SliceablePlot
    def decorate(f):                                                             # SliceablePlot
        """Decorates a plotting function so that it becomes a
SliceablePlot."""                                                                # SliceablePlot
        answer = partial(SliceablePlot, plotF=f)                                 # SliceablePlot
        update_wrapper(answer, f)                                                # SliceablePlot
        return answer                                                            # SliceablePlot 
    @property                                                                    # SliceablePlot
    def squeezedSlices(self) -> Union[List[slice], slice]:                       # SliceablePlot
        """If :attr:`slices` only has 1 element, then return that
element, else return the entire list."""                                         # SliceablePlot
        return k1lib.squeeze(self.slices)                                        # SliceablePlot
    def __getattr__(self, attr):                                                 # SliceablePlot
        if attr.startswith("_"): raise AttributeError()                          # SliceablePlot
        # automatically assume the attribute is a plt.attr method                # SliceablePlot
        dec = _PlotDecorator(self, attr)                                         # SliceablePlot
        self.plotDecorators.append(dec); return dec                              # SliceablePlot
    def __getitem__(self, idx):                                                  # SliceablePlot
        if type(idx) == slice:                                                   # SliceablePlot
            return SliceablePlot(self.plotF, [idx], self.plotDecorators, self.docs) # SliceablePlot
        if type(idx) == tuple and all([isinstance(elem, slice) for elem in idx]): # SliceablePlot
            return SliceablePlot(self.plotF, idx, self.plotDecorators, self.docs) # SliceablePlot
        raise Exception(f"Don't understand {idx}")                               # SliceablePlot
    def __repr__(self, show=True):                                               # SliceablePlot
        self.plotF(self.squeezedSlices)                                          # SliceablePlot
        for ax in plt.gcf().get_axes():                                          # SliceablePlot
            plt.sca(ax)                                                          # SliceablePlot
            for decorator in self.plotDecorators: decorator.run()                # SliceablePlot
        if show: plt.show()                                                      # SliceablePlot
        return f"""Sliceable plot. Can...
- p[a:b]: to focus on a specific range of the plot
- p.yscale("log"): to perform operation as if you're using plt{self.docs}"""     # SliceablePlot 
[docs]def plotSegments(x:List[float], y:List[float], states:List[int], colors:List[str]=None): # plotSegments
    """Plots a line graph, with multiple segments with different colors.
Idea is, you have a normal line graph, but you want to color parts of
the graph red, other parts blue. Then, you can pass a "state" array, with
the same length as your data, filled with ints, like this::
    y = np.array([ 460800,  921600,  921600, 1445888, 1970176, 1970176, 2301952,
           2633728, 2633728, 3043328, 3452928, 3452928, 3457024, 3461120,
           3463680, 3463680, 3470336, 3470336, 3467776, 3869184, 3865088,
           3865088, 3046400, 2972672, 2972672, 2309632, 2504192, 2504192,
           1456128, 1393664, 1393664,  472576])
    s = np.array([1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           1, 0, 0, 1, 0, 0, 1, 0, 0, 1])
    plotSegments(None, y, s, colors=["tab:blue", "tab:red"])
.. image:: images/plotSegments.png
:param x: (nullable) list of x coordinate at each point
:param y: list of y coordinates at each point
:param states: list of color at each point
:param colors: string colors (matplotlib color strings) to display for each states""" # plotSegments
    if x is None: x = range(len(y))                                              # plotSegments
    if colors is None: colors = ["tab:blue", "tab:red", "tab:green", "tab:orange", "tab:purple", "tab:brown"][:len(x)] # plotSegments
    _x = []; _y = []; state = -1; count = -1 # stretchs, and bookkeeping nums    # plotSegments
    lx = None; ly = None # last x and y from last stretch, for plot autocompletion # plotSegments
    while count + 1 < len(x):                                                    # plotSegments
        count += 1                                                               # plotSegments
        if state != states[count]:                                               # plotSegments
            if len(_x) > 0 and state >= 0:                                       # plotSegments
                if lx != None: _x = [lx] + _x; _y = [ly] + _y                    # plotSegments
                plt.plot(_x, _y, colors[state]); lx = _x[-1]; ly = _y[-1]        # plotSegments
            _x = [x[count]]; _y = [y[count]]; state = states[count]              # plotSegments
        else: _x.append(x[count]); _y.append(y[count])                           # plotSegments
    if len(_x) > 0 and state >= 0:                                               # plotSegments
        if lx != None: _x = [lx] + _x; _y = [ly] + _y                            # plotSegments
        plt.plot(_x, _y, colors[state])                                          # plotSegments 
class _Carousel:                                                                 # _Carousel
    def __init__(self, searchMode, imgs, titles):                                # _Carousel
        self.searchMode = searchMode                                             # _Carousel
        self.imgs:List[Tuple[str, str]] = imgs # Tuple[format, base64 img]       # _Carousel
        self.titles = titles                                                     # _Carousel
    def _repr_html_(self):                                                       # _Carousel
        idx = Carousel._idx(); pre = f"k1c_{idx}"; searchMode = self.searchMode  # _Carousel
        imgs = self.imgs | cli.apply(lambda x: f"`{x}`") | cli.deref(); n = len(imgs) # _Carousel
        titles = self.titles | cli.apply(lambda x: f"`{x}`") | cli.deref()       # _Carousel
        if searchMode > 0: searchBar = f"<input type='text' value='' id='{pre}_search' placeholder='Search in {'content' if searchMode == 1 else 'header'}' style='padding: 4px 4px'>" # _Carousel
        else: searchBar = ""                                                     # _Carousel
        if n > 0: contents = imgs | cli.apply(k1lib.decode) | cli.insertIdColumn() | ~cli.apply(lambda idx, html: f"<div id='{pre}_content{idx}'>{html}</div>") | cli.deref() | cli.join('\n') # _Carousel
        else: contents = "(no pages or images are found)"                        # _Carousel
        #imgs = [f"\"<img alt='' src='data:image/{fmt};base64, {img}' />\"" for fmt, img in self.imgs] # _Carousel
        html = f"""<!-- k1lib.Carousel start -->
<style>
    .{pre}_btn {{
        cursor: pointer;
        padding: 6px 12px;
        /*background: #9e9e9e;*/
        background-color: #eee;
        margin-right: 8px;
        color: #000;
        box-shadow: 0 3px 5px rgb(0,0,0,0.3);
        border-radius: 18px;
        user-select: none;
        -webkit-user-select: none; /* Safari */
        -ms-user-select: none; /* IE 10+ */
    }}
    .{pre}_btn:hover {{
        box-shadow: box-shadow: 0 3px 10px rgb(0,0,0,0.6);
        background: #4caf50;
        color: #fff;
    }}
</style>
{searchBar}
<div>
    <div style="display: flex; flex-direction: row; padding: 8px">
        <div id="{pre}_prevBtn" class="{pre}_btn">Prev</div>
        <div id="{pre}_nextBtn" class="{pre}_btn">Next</div>
    </div>
    <div id="{pre}_status" style="padding: 10px"></div>
</div>
<div id="{pre}_imgContainer">
    {contents}
</div>
<script>
    const {pre}_allImgs = [{','.join(imgs)}];
    let {pre}_imgs = [...Array({pre}_allImgs.length).keys()]; // index of all available images. If searching for something then it will be a subset of allImgs
    const {pre}_titles = [{','.join(titles)}];
    {pre}_imgIdx = 0; // n-th element of pre_imgs, not of pre_allImgs
    {pre}_searchMode = {searchMode};
    function {pre}_show(i) {{ // i here is allImgs index, not of imgs
        document.querySelector(`#{pre}_content${{i}}`).style.display = "block";
    }}
    function {pre}_hide(i) {{ // i here is allImgs index, not of imgs
        document.querySelector(`#{pre}_content${{i}}`).style.display = "none";
    }}
    function {pre}_updatePageCount() {{
        let n = {pre}_imgs.length;
        if (n > 0) document.querySelector("#{pre}_status").innerHTML = "Page: " + ({pre}_imgIdx + 1) + "/" + n;
        else document.querySelector("#{pre}_status").innerHTML = "Page: 0/0"
    }}
    function {pre}_display() {{
        let n = {pre}_imgs.length;
        for (let i = 0; i < {n}; i++) {pre}_hide(i);
        if (n > 0) {pre}_show({pre}_imgs[{pre}_imgIdx]);
        {pre}_updatePageCount();
    }};
    document.querySelector("#{pre}_prevBtn").onclick = () => {{
        {pre}_imgIdx -= 1;
        {pre}_imgIdx = Math.max({pre}_imgIdx, 0);
        {pre}_display();
    }};
    document.querySelector("#{pre}_nextBtn").onclick = () => {{
        {pre}_imgIdx += 1;
        {pre}_imgIdx = Math.min({pre}_imgIdx, {pre}_imgs.length - 1);
        {pre}_display();
    }};
    if ({pre}_searchMode > 0) {{
        {pre}_searchInp = document.querySelector("#{pre}_search");
        {pre}_searchInp.oninput = (value) => {{
            const val = {pre}_searchInp.value;
            {pre}_imgs = ({pre}_searchMode === 1 ? {pre}_allImgs : {pre}_titles).map((e, i) => [window.atob(e).includes(val), i]).filter(e => e[0]).map(e => e[1]);
            {pre}_imgIdx = 0;; {pre}_display();
        }}
    }}
    {pre}_display();
</script>
<!-- k1lib.Carousel end -->"""                                                   # _Carousel
        return html                                                              # _Carousel
[docs]class Carousel(cli.BaseCli):                                                     # Carousel
    _idx = k1lib.AutoIncrement.random()                                          # Carousel
[docs]    def __init__(self, searchMode:int=0):                                        # Carousel
        """Creates a new Carousel that can flip through a list of images/html.
Will even work even when you export the notebook as html. Example::
    x = np.linspace(-2, 2); plt.plot(x, x ** 2); im1 = plt.gcf() | toImg()
    x = np.linspace(-1, 3); plt.plot(x, x ** 2); im2 - plt.gcf() | toImg()
    im3 = "<h1>abc</h1><div>Some content</div>" # can add html
    [im1, im2, im3] | viz.Carousel() # displays in notebook cell
.. image:: images/carousel.png
There's also a builtin search functionality that works like this::
    [
        "<h1>abc</h1><div>Some content 1</div>",
        "<h1>def</h1><div>Some other content 2</div>",
        "<h1>ghi</h1><div>Another content 3</div>",
    ] | Carousel(searchMode=1)
    [
        ["<h1>abc</h1>", "<div>Some content 1</div>"],
        ["<h1>def</h1>", "<div>Some other content 2</div>"],
        ["<h1>ghi</h1>", "<div>Another content 3</div>"],
    ] | Carousel(searchMode=2)
The first mode will search for some text inside the html content. The second mode
will search inside the title only, that means it's expecting to receive Iterator[title, html/img]
:param imgs: List of initial images. Can add more images later on by using :meth:`__ror__`
:param searchMode: 0 for no search, accepts Iterator[html/img],
    1 for search content, accepts Iterator[html/img],
    2 for search title, accepts Iterator[title, html/img]
"""                                                                              # Carousel
        self.searchMode = searchMode                                             # Carousel 
    def _process(self, e):                                                       # Carousel
        if isinstance(e, str): return f"{e}"                                     # Carousel
        elif hasPIL and isinstance(e, PIL.Image.Image):                          # Carousel
            return f"<img alt='' style='max-width: 100%' src='data:image/png;base64, {base64.b64encode(e | cli.toBytes()).decode()}' />" # Carousel
        else: raise Exception(f"Content is not a string nor a PIL image. Can't make a Carousel out of this unknown type: {type(e)}") # Carousel
[docs]    def __ror__(self, it):                                                       # Carousel
        imgs = []; titles = []                                                   # Carousel
        searchMode = self.searchMode                                             # Carousel
        if searchMode == 0 or searchMode == 1:                                   # Carousel
            for e in it: imgs.append(k1lib.encode(self._process(e)))             # Carousel
        elif searchMode == 2:                                                    # Carousel
            for title, e in it:                                                  # Carousel
                if not isinstance(title, str): raise Exception("Title is not a string. Can't perform search") # Carousel
                imgs.append(k1lib.encode(title+self._process(e)))                # Carousel
                titles.append(k1lib.encode(title))                               # Carousel
        else: raise Exception(f"Invalid searchMode: {searchMode}")               # Carousel
        return _Carousel(searchMode, imgs, titles)                               # Carousel 
    def _jsF(self, meta):                                                        # Carousel
        if self.searchMode != 0: raise Exception("viz.Carousel._jsF() does not support .searchMode!=0. You're using the JS transpiler anyway, you can trivially build your own, more complex search engine!") # Carousel
        fIdx = cli.init._jsFAuto(); dataIdx = cli.init._jsDAuto(); imgIdx = cli.init._jsDAuto(); pre = cli.init._jsDAuto() # Carousel
        return f"""const {fIdx} = ({dataIdx}) => {{
        return unescape(`<!-- k1lib.Carousel start -->
<style>
    .{pre}_btn {{
        cursor: pointer; padding: 6px 12px; background-color: #eee;
        margin-right: 8px; color: #000;
        box-shadow: 0 3px 5px rgb(0,0,0,0.3);
        border-radius: 18px; user-select: none;
        -webkit-user-select: none; /* Safari */
        -ms-user-select: none; /* IE 10+ */
    }}
    .{pre}_btn:hover {{
        box-shadow: box-shadow: 0 3px 10px rgb(0,0,0,0.6);
        background: #4caf50; color: #fff;
    }}
</style>
<div>
    <div style="display: flex; flex-direction: row; padding: 8px">
        <div id="{pre}_prevBtn" class="{pre}_btn">Prev</div>
        <div id="{pre}_nextBtn" class="{pre}_btn">Next</div>
    </div>
    <div id="{pre}_status" style="padding: 10px"></div>
</div>
<div id="{pre}_imgContainer">
    ${{{dataIdx}.map((x, i) => "<div id='{pre}_content" + i + "'>" + x + "</div>").join("")}}
</div>
%3Cscript%3E
    (async () => {{
        const {pre}_n = ${{{dataIdx}.length}}; {pre}_imgIdx = 0;
        function {pre}_updatePageCount() {{
            if ({pre}_n > 0) document.querySelector("#{pre}_status").innerHTML = "Page: " + ({pre}_imgIdx + 1) + "/" + {pre}_n;
            else document.querySelector("#{pre}_status").innerHTML = "Page: 0/0"
        }}
        function {pre}_display() {{
            for (let i = 0; i < {pre}_n; i++) document.querySelector("#{pre}_content" + i).style.display = "none";
            if ({pre}_n > 0) document.querySelector("#{pre}_content" + {pre}_imgIdx).style.display = "block";
            {pre}_updatePageCount();
        }};
        document.querySelector("#{pre}_prevBtn").onclick = () => {{
            {pre}_imgIdx -= 1;
            {pre}_imgIdx = Math.max({pre}_imgIdx, 0);
            {pre}_display();
        }};
        document.querySelector("#{pre}_nextBtn").onclick = () => {{
            {pre}_imgIdx += 1;
            {pre}_imgIdx = Math.min({pre}_imgIdx, {pre}_n - 1);
            {pre}_display();
        }};
        {pre}_display();
    }})();
%3C/script%3E`) }}
<!-- k1lib.Carousel end -->""", fIdx                                             # Carousel 
k1lib.settings.cli.atomic.deref = (*k1lib.settings.cli.atomic.deref, Carousel)   # Carousel
[docs]class Toggle(cli.BaseCli):                                                       # Toggle
    _idx = k1lib.AutoIncrement.random()                                          # Toggle
[docs]    def __init__(self):                                                          # Toggle
        """Button to toggle whether the content is displayed or
not. Useful if the html content is very big in size. Example::
    x = np.linspace(-2, 2); plt.plot(x, x ** 2)
    plt.gcf() | toImg() | toHtml() | viz.Toggle()
This will plot a graph, then create a button where you can toggle the image's visibility""" # Toggle
        self.content:str = "" # html string                                      # Toggle
        self._enteredRor = False                                                 # Toggle 
[docs]    def __ror__(self, it): self._enteredRor = True; self.content = it if isinstance(it, str) else it | cli.toHtml(); return self # Toggle 
    def __or__(self, it): # see discussion on Carousel()                         # Toggle
        if self._enteredRor: return it.__ror__(self)                             # Toggle
        else: return super().__or__(it)                                          # Toggle
    def _repr_html_(self):                                                       # Toggle
        pre = f"k1t_{Toggle._idx()}"                                             # Toggle
        html = f"""<!-- k1lib.Toggle start -->
<style>
    #{pre}_btn {{
        cursor: pointer;
        padding: 6px 12px;
        background: #eee;
        margin-right: 5px;
        color: #000;
        user-select: none;
        -webkit-user-select: none; /* Safari */
        -ms-user-select: none; /* IE 10+ */
        box-shadow: 0 3px 5px rgb(0,0,0,0.3);
        border-radius: 18px;
    }}
    #{pre}_btn:hover {{
        box-shadow: 0 3px 5px rgb(0,0,0,0.6);
        background: #4caf50;
        color: #fff;
    }}
</style>
<div>
    <div style="display: flex; flex-direction: row; padding: 4px">
        <div id="{pre}_btn">Show content</div>
        <div style="flex: 1"></div>
    </div>
    <div id="{pre}_content" style="display: none; margin-top: 12px">{self.content}</div>
</div>
<script>
    console.log("setup script ran for {pre}");
    {pre}_btn = document.querySelector("#{pre}_btn");
    {pre}_content = document.querySelector("#{pre}_content");
    {pre}_displayed = false;
    {pre}_btn.onclick = () => {{
        {pre}_displayed = !{pre}_displayed;
        {pre}_btn.innerHTML = {pre}_displayed ? "Hide content" : "Show content";
        {pre}_content.style.display = {pre}_displayed ? "block" : "none";
    }};
</script>
<!-- k1lib.Toggle end -->"""                                                     # Toggle
        return html                                                              # Toggle
    def _jsF(self, meta):                                                        # Toggle
        fIdx = cli.init._jsFAuto(); dataIdx = cli.init._jsDAuto(); pre = cli.init._jsDAuto() # Toggle
        return f"""const {fIdx} = ({dataIdx}) => {{
        return unescape(`
<!-- k1lib.Toggle start -->
<style>
    #{pre}_btn {{
        cursor: pointer;
        padding: 6px 12px;
        background: #eee;
        margin-right: 5px;
        color: #000;
        user-select: none;
        -webkit-user-select: none; /* Safari */
        -ms-user-select: none; /* IE 10+ */
        box-shadow: 0 3px 5px rgb(0,0,0,0.3);
        border-radius: 18px;
    }}
    #{pre}_btn:hover {{
        box-shadow: 0 3px 5px rgb(0,0,0,0.6);
        background: #4caf50;
        color: #fff;
    }}
</style>
<div>
    <div style="display: flex; flex-direction: row; padding: 4px">
        <div id="{pre}_btn">Show content</div>
        <div style="flex: 1"></div>
    </div>
    <div id="{pre}_content" style="display: none; margin-top: 12px">${{{dataIdx}}}</div>
</div>
%3Cscript%3E
    (async () => {{
        console.log("setup script ran for {pre}");
        {pre}_btn = document.querySelector("#{pre}_btn");
        {pre}_content = document.querySelector("#{pre}_content");
        {pre}_displayed = false;
        {pre}_btn.onclick = () => {{
            {pre}_displayed = !{pre}_displayed;
            {pre}_btn.innerHTML = {pre}_displayed ? "Hide content" : "Show content";
            {pre}_content.style.display = {pre}_displayed ? "block" : "none";
        }};
    }})();
%3C/script%3E`) }}
<!-- k1lib.Toggle end -->""", fIdx                                               # Toggle 
k1lib.settings.cli.atomic.deref = (*k1lib.settings.cli.atomic.deref, Toggle)     # Toggle
[docs]def ToggleImage():                                                               # ToggleImage
    """This function is sort of legacy. It's just ``img | toHtml() | viz.Toggle()`` really""" # ToggleImage
    return cli.toHtml() | Toggle()                                               # ToggleImage 
k1lib.settings.cli.atomic.deref = (*k1lib.settings.cli.atomic.deref, Scroll)     # Scroll
[docs]def confusionMatrix(matrix:torch.Tensor, categories:List[str]=None, **kwargs):   # confusionMatrix
    """Plots a confusion matrix. Example::
    k1lib.viz.confusionMatrix(torch.rand(5, 5), ["a", "b", "c", "d", "e"])
.. image:: images/confusionMatrix.png
:param matrix: 2d matrix of shape (n, n)
:param categories: list of string categories
:param kwargs: keyword args passed into :meth:`plt.figure`"""                    # confusionMatrix
    if isinstance(matrix, torch.Tensor): matrix = matrix.numpy()                 # confusionMatrix
    if categories is None: categories = [f"{e}" for e in range(len(matrix))]     # confusionMatrix
    fig = plt.figure(**{"dpi":100, **kwargs}); ax = fig.add_subplot(111)         # confusionMatrix
    cax = ax.matshow(matrix); fig.colorbar(cax)                                  # confusionMatrix
                                                                                 # confusionMatrix
    with k1lib.ignoreWarnings():                                                 # confusionMatrix
        ax.set_xticklabels([''] + categories, rotation=90)                       # confusionMatrix
        ax.set_yticklabels([''] + categories)                                    # confusionMatrix
                                                                                 # confusionMatrix
    # Force label at every tick                                                  # confusionMatrix
    ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(1))                    # confusionMatrix
    ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(1))                    # confusionMatrix
    ax.xaxis.set_label_position('top')                                           # confusionMatrix
    plt.xlabel("Predictions"); plt.ylabel("Ground truth")                        # confusionMatrix 
[docs]def FAnim(fig, f, frames, *args, **kwargs):                                      # FAnim
    """Matplotlib function animation, 60fps. Example::
    # line below so that the animation is displayed in the notebook. Included in :mod:`k1lib.imports` already, so you don't really have to do this!
    plt.rcParams["animation.html"] = "jshtml"
    x = np.linspace(-2, 2); y = x**2
    fig, ax = plt.subplots()
    plt.close() # close cause it'll display 1 animation, 1 static if we don't do this
    def f(frame):
        ax.clear()
        ax.set_ylim(0, 4); ax.set_xlim(-2, 2)
        ax.plot(x[:frame], y[:frame])
    k1lib.FAnim(fig, f, len(x)) # plays animation in cell
:param fig: figure object from `plt.figure(...)` command
:param f: function that accepts 1 frame from `frames`.
:param frames: number of frames, or iterator, to pass into function"""           # FAnim
    return partial(mpl.animation.FuncAnimation, interval=1000/30)(fig, f, frames, *args, **kwargs) # FAnim 
from k1lib.cli import op                                                         # FAnim
[docs]def mask(img:torch.Tensor, act:torch.Tensor) -> torch.Tensor:                    # mask
    """Shows which part of the image the network is focusing on.
:param img: the image, expected to have dimension of (3, h, w)
:param act: the activation, expected to have dimension of (x, y), and with
    elements from 0 to 1."""                                                     # mask
    *_, h, w = img.shape                                                         # mask
    mask = act[None,] | nn.AdaptiveAvgPool2d([h//16, w//16]) | nn.AdaptiveAvgPool2d([h//8, w//8]) | nn.AdaptiveAvgPool2d([h, w]) # mask
    return mask * img | op().permute(1, 2, 0)                                    # mask 
class PDF(object):                                                               # PDF
    def __init__(self, pdf:str=None, size=(700,500)):                            # PDF
        """Displays pdf in the notebook.
Example::
    viz.PDF("a.pdf")
    "a.pdf" | viz.PDF()
    viz.PDF("a.pdf", (700, 500))
    "a.pdf" | viz.PDF(size=(700, 500))
If you're exporting this notebook as html, then you have to make sure
you place the generated html file in the correct directory so that it
can reference those pdf files.
:param pdf: relative path to pdf file"""                                         # PDF
        self.pdf = pdf; self.size = size                                         # PDF
    def __ror__(self, pdf): self.pdf = pdf; return self                          # PDF
    def _repr_html_(self): return '<iframe src={0} width={1[0]} height={1[1]}></iframe>'.format(self.pdf, self.size) # PDF
    def _repr_latex_(self): return r'\includegraphics[width=1.0\textwidth]{{{0}}}'.format(self.pdf) # PDF