Source code for k1lib.viz

# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""
This module is for nice visualization tools. This is exposed automatically with::

   from k1lib.imports import *
   viz.mask # exposed
"""
import k1lib, base64, io, torch, os, matplotlib as mpl
import matplotlib.pyplot as plt, numpy as np
from typing import Callable, List, Union
from functools import partial, update_wrapper
__all__ = ["SliceablePlot", "plotSegments", "Carousel", "confusionMatrix", "FAnim",
           "mask"]
class _PlotDecorator:
    """The idea with decorators is that you can do something like this::

    sp = k1lib.viz.SliceablePlot()
    sp.yscale("log") # will format every plot as if ``plt.yscale("log")`` has been called

This class is not expected to be used by end users though."""
    def __init__(self, sliceablePlot:"SliceablePlot", name:str):
        """
:param sliceablePlot: the parent plot
:param name: the decorator's name, like "yscale" """
        self.sliceablePlot = sliceablePlot
        self.name = name; self.args, self.kwargs = None, None
    def __call__(self, *args, **kwargs):
        """Stores all args, then return the parent :class:`SliceablePlot`"""
        self.args = args; self.kwargs = kwargs; return self.sliceablePlot
    def run(self): getattr(plt, self.name)(*self.args, **self.kwargs)
[docs]class SliceablePlot: """This is a plot that is "sliceable", meaning you can focus into a particular region of the plot quickly. A minimal example looks something like this:: import numpy as np, matplotlib.pyplot as plt, k1lib x = np.linspace(-2, 2, 100) def normalF(): plt.plot(x, x**2) @k1lib.viz.SliceablePlot.decorate def plotF(_slice): plt.plot(x[_slice], (x**2)[_slice]) plotF()[70:] # plots x^2 equation with x in [0.8, 2] So, ``normalF`` plots the equation :math:`x^2` with x going from -2 to 2. You can convert this into a :class:`SliceablePlot` by adding a term of type :class:`slice` to the args, and decorate with :meth:`decorate`. Now, every time you slice the :class:`SliceablePlot` with a specific range, ``plotF`` will receive it. How intuitive everything is depends on how you slice your data. ``[70:]`` results in x in [0.8, 2] is rather unintuitive. You can change it into something like this:: @k1lib.viz.SliceablePlot.decorate def niceF(_slice): n = 100; r = k1lib.Range(-2, 2) x = np.linspace(*r, n) _slice = r.toRange(k1lib.Range(n), r.bound(_slice)).slice_ plt.plot(x[_slice], (x**2)[_slice]) # this works without a decorator too btw: k1lib.viz.SliceablePlot(niceF) niceF()[0.3:0.7] # plots x^2 equation with x in [0.3, 0.7] niceF()[0.3:] # plots x^2 equation with x in [0.3, 2] The idea is to just take the input :class:`slice`, put some bounds on its parts, then convert that slice from [-2, 2] to [0, 100]. Check out :class:`k1lib.Range` if it's not obvious how this works. A really cool feature of :class:`SliceablePlot` looks like this:: niceF().legend(["A"])[-1:].grid(True).yscale("log") This will plot :math:`x^2` with range in [-1, 2] with a nice grid, and with y axis's scale set to log. Essentially, undefined method calls on a :class:`SliceablePlot` will translate into ``plt`` calls. So the above is roughly equivalent to this:: x = np.linspace(-2, 2, 100) plt.plot(x, x**2) plt.legend(["A"]) plt.grid(True) plt.yscale("log") .. image:: images/SliceablePlot.png This works even if you have multiple axes inside your figure. It's wonderful, isn't it?""" def __init__(self, plotF:Callable[[slice], None], slices:Union[slice, List[slice]]=slice(None), plotDecorators:List[_PlotDecorator]=[], docs=""): """Creates a new SliceablePlot. Only use params listed below: :param plotF: function that takes in a :class:`slice` or tuple of :class:`slice`s :param docs: optional docs for the function that will be displayed in :meth:`__repr__`""" self.plotF = plotF self.slices = [slices] if isinstance(slices, slice) else slices self.docs = docs; self.plotDecorators = list(plotDecorators)
[docs] @staticmethod def decorate(f): """Decorates a plotting function so that it becomes a SliceablePlot.""" answer = partial(SliceablePlot, plotF=f) update_wrapper(answer, f) return answer
@property def squeezedSlices(self) -> Union[List[slice], slice]: """If :attr:`slices` only has 1 element, then return that element, else return the entire list.""" return k1lib.squeeze(self.slices) def __getattr__(self, attr): if attr.startswith("_"): raise AttributeError() # automatically assume the attribute is a plt.attr method dec = _PlotDecorator(self, attr) self.plotDecorators.append(dec); return dec def __getitem__(self, idx): if type(idx) == slice: return SliceablePlot(self.plotF, [idx], self.plotDecorators, self.docs) if type(idx) == tuple and all([isinstance(elem, slice) for elem in idx]): return SliceablePlot(self.plotF, idx, self.plotDecorators, self.docs) raise Exception(f"Don't understand {idx}") def __repr__(self): self.plotF(self.squeezedSlices) for ax in plt.gcf().get_axes(): plt.sca(ax) for decorator in self.plotDecorators: decorator.run() plt.show() return f"""Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt{self.docs}"""
[docs]def plotSegments(x:List[float], y:List[float], states:List[int], colors:List[str]=None): """Plots a line graph, with multiple segments with different colors. Idea is, you have a normal line graph, but you want to color parts of the graph red, other parts blue. Then, you can pass a "state" array, with the same length as your data, filled with ints, like this:: y = np.array([ 460800, 921600, 921600, 1445888, 1970176, 1970176, 2301952, 2633728, 2633728, 3043328, 3452928, 3452928, 3457024, 3461120, 3463680, 3463680, 3470336, 3470336, 3467776, 3869184, 3865088, 3865088, 3046400, 2972672, 2972672, 2309632, 2504192, 2504192, 1456128, 1393664, 1393664, 472576]) s = np.array([1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1]) plotSegments(None, y, s, colors=["tab:blue", "tab:red"]) .. image:: images/plotSegments.png :param x: (nullable) list of x coordinate at each point :param y: list of y coordinates at each point :param states: list of color at each point :param colors: string colors (matplotlib color strings) to display for each states""" if x is None: x = range(len(y)) if colors is None: colors = ["tab:blue", "tab:red", "tab:green", "tab:orange", "tab:purple", "tab:brown"][:len(x)] _x = []; _y = []; state = -1; count = -1 # stretchs, and bookkeeping nums lx = None; ly = None # last x and y from last stretch, for plot autocompletion while count + 1 < len(x): count += 1 if state != states[count]: if len(_x) > 0 and state >= 0: if lx != None: _x = [lx] + _x; _y = [ly] + _y plt.plot(_x, _y, colors[state]); lx = _x[-1]; ly = _y[-1] _x = [x[count]]; _y = [y[count]]; state = states[count] else: _x.append(x[count]); _y.append(y[count]) if len(_x) > 0 and state >= 0: if lx != None: _x = [lx] + _x; _y = [ly] + _y plt.plot(_x, _y, colors[state])
[docs]def confusionMatrix(matrix:torch.Tensor, categories:List[str]=None, **kwargs): """Plots a confusion matrix. Example:: k1lib.viz.confusionMatrix(torch.rand(5, 5), ["a", "b", "c", "d", "e"]) .. image:: images/confusionMatrix.png :param matrix: 2d matrix of shape (n, n) :param categories: list of string categories :param kwargs: keyword args passed into :meth:`plt.figure`""" if isinstance(matrix, torch.Tensor): matrix = matrix.numpy() if categories is None: categories = [f"{e}" for e in range(len(matrix))] fig = plt.figure(**{"dpi":100, **kwargs}); ax = fig.add_subplot(111) cax = ax.matshow(matrix); fig.colorbar(cax) with k1lib.ignoreWarnings(): ax.set_xticklabels([''] + categories, rotation=90) ax.set_yticklabels([''] + categories) # Force label at every tick ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(1)) ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(1)) ax.xaxis.set_label_position('top') plt.xlabel("Predictions"); plt.ylabel("Ground truth")
[docs]def FAnim(fig, f, frames, *args, **kwargs): """Matplotlib function animation, 60fps. Example:: # line below so that the animation is displayed in the notebook. Included in :mod:`k1lib.imports` already, so you don't really have to do this! plt.rcParams["animation.html"] = "jshtml" x = np.linspace(-2, 2); y = x**2 fig, ax = plt.subplots() plt.close() # close cause it'll display 1 animation, 1 static if we don't do this def f(frame): ax.clear() ax.set_ylim(0, 4); ax.set_xlim(-2, 2) ax.plot(x[:frame], y[:frame]) k1lib.FAnim(fig, f, len(x)) # plays animation in cell :param fig: figure object from `plt.figure(...)` command :param f: function that accepts 1 frame from `frames`. :param frames: number of frames, or iterator, to pass into function""" return partial(mpl.animation.FuncAnimation, interval=1000/30)(fig, f, frames, *args, **kwargs)
from torch import nn from k1lib.cli import op
[docs]def mask(img:torch.Tensor, act:torch.Tensor) -> torch.Tensor: """Shows which part of the image the network is focusing on. :param img: the image, expected to have dimension of (3, h, w) :param act: the activation, expected to have dimension of (x, y), and with elements from 0 to 1.""" *_, h, w = img.shape mask = act[None,] | nn.AdaptiveAvgPool2d([h//16, w//16]) | nn.AdaptiveAvgPool2d([h//8, w//8]) | nn.AdaptiveAvgPool2d([h, w]) return mask * img | op().permute(1, 2, 0)