# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
import k1lib, math, numpy as np, time
from k1lib import cli
from typing import List, Tuple, ContextManager
from contextlib import contextmanager
try: import torch; from torch import nn; hasTorch = True
except: hasTorch = False
try: import matplotlib.animation
except: pass
try: import matplotlib.pyplot as plt
except: pass
__all__ = ["dummy"]
[docs]def dummy(): # dummy
"""Does nothing. Only here so that you can read source code of this file and
see what’s up.""" # dummy
pass # dummy
settings = k1lib.settings.add("monkey", k1lib.Settings(), "monkey-patched settings").monkey # dummy
k1lib._settings.add("monkey", k1lib.Settings(), "monkey-patched settings") # dummy
if hasTorch: # dummy
@k1lib.patch(nn.Module) # dummy
def importParams(self:nn.Module, params:List[nn.Parameter]): # dummy
"""Given a list of :class:`torch.nn.parameter.Parameter`/:class:`torch.Tensor`,
update the current :class:`torch.nn.Module`'s parameters with it'""" # dummy
for oldParam, newParam in zip(self.parameters(), params): # dummy
oldParam.data = newParam.data.clone() # dummy
@k1lib.patch(nn.Module) # dummy
def exportParams(self:nn.Module) -> List[torch.Tensor]: # dummy
"""Gets the list of :class:`torch.Tensor` data""" # dummy
return [param.data.clone() for param in self.parameters()] # dummy
class ParamsContext: # dummy
def __init__(self, m:nn.Module): self.m = m # dummy
def __enter__(self): self.params = self.m.exportParams(); return self.params # dummy
def __exit__(self, *ignored): self.m.importParams(self.params) # dummy
if hasTorch: # dummy
@k1lib.patch(nn.Module) # dummy
@contextmanager # dummy
def paramsContext(self:nn.Module): # dummy
"""A nice context manager for :meth:`importParams` and :meth:`exportParams`.
Returns the old parameters on enter context. Example::
m = nn.Linear(2, 3)
with m.paramsContext() as oldParam:
pass # go wild, train, mutate `m` however much you like
# m automatically snaps back to the old param
Small reminder that this is not foolproof, as there are some :class:`~torch.nn.Module`
that stores extra information not accessible from the model itself, like
:class:`~torch.nn.BatchNorm2d`.""" # dummy
params = self.exportParams() # dummy
try: yield # dummy
finally: self.importParams(params) # dummy
if hasTorch: # dummy
@k1lib.patch(nn.Module) # dummy
def getParamsVector(model:nn.Module) -> List[torch.Tensor]: # dummy
"""For each parameter, returns a normal distributed random tensor
with the same standard deviation as the original parameter""" # dummy
answer = [] # dummy
for param in model.parameters(): # dummy
a = torch.randn(param.shape).to(param.device) # dummy
b = param.std() if param.numel() > 1 else 1 # dummy
answer.append(a * b) # dummy
return answer # dummy
from k1lib.cli import apply, deref, op, item # dummy
if hasTorch: # dummy
@k1lib.patch(nn.Module) # dummy
@contextmanager # dummy
def deviceContext(self:nn.Module, buffers:bool=True) -> ContextManager: # dummy
"""Preserves the device of whatever operation is inside this.
Example::
import torch.nn as nn
m = nn.Linear(3, 4)
with m.deviceContext():
m.cuda() # moves whole model to cuda
# automatically moves model to cpu
This is capable of preserving buffers' devices too. But it might be unstable.
:class:`~torch.nn.parameter.Parameter` are often updated inline, and they keep
their old identity, which makes it easy to keep track of which device the parameters
are on. However, buffers are rarely updated inline, so their identities change all
the time. To deal with this, this does something like this::
devices = [buf.device for buf in self.buffers()]
yield # entering context manager
for buffer, device in zip(self.buffers(), devices):
buffer.data = buffer.data.to(device=device)
This means that while inside the context, if you add a buffer anywhere to the
network, buffer-device alignment will be shifted and wrong. So, register all
your buffers (aka Tensors attached to :class:`~torch.nn.Module`) outside this context
to avoid headaches, or set ``buffers`` option to False.
If you don't know what I'm talking about, don't worry and just leave as default.
:param buffers: whether to preserve device of buffers (regular Tensors attached
to :class:`~torch.nn.Module`) or not.""" # dummy
pDevs = self.parameters() | apply(lambda t: (t, t.device)) | deref() # dummy
if buffers: pbDevs = pbDevs = self.modules() |\
apply(lambda m: (m, m | op().buffers(recurse=False) | op().device.all() | deref())) | deref(maxDepth=1) # dummy
try: yield # dummy
finally: # dummy
for p, dev in pDevs: p.data = p.data.to(device=dev) # dummy
if buffers: # dummy
for m, bDevs in pbDevs: # dummy
for buf, dev in zip(m.buffers(recurse=False), bDevs): # dummy
buf.data = buf.data.to(device=dev) # dummy
if hasTorch: # dummy
@k1lib.patch(nn.Module) # dummy
@contextmanager # dummy
def gradContext(self): # dummy
"""Preserves the requires_grad attribute.
Example::
m = nn.Linear(2, 3)
with m.gradContext():
m.weight.requires_grad = False
# returns True
m.weight.requires_grad
It's worth mentioning that this does not work with buffers (Tensors attached to
:class:`torch.nn.Module`), as buffers are not meant to track gradients!""" # dummy
grads = [(p, p.requires_grad) for p in self.parameters()] # dummy
try: yield # dummy
finally: # dummy
for p, grad in grads: p.requires_grad_(grad) # dummy
if hasTorch: # dummy
@k1lib.patch(nn.Module) # dummy
def __ror__(self, x): # dummy
"""Allows piping input to :class:`torch.nn.Module`, to match same style as
the module :mod:`k1lib.cli`. Example::
# returns torch.Size([5, 3])
torch.randn(5, 2) | nn.Linear(2, 3) | cli.shape()""" # dummy
return self(x) # dummy
if hasTorch: # dummy
@k1lib.patch(nn.Module, name="nParams") # dummy
@property # dummy
def nParams(self): # dummy
"""Get the number of parameters of this module.
Example::
# returns 9, because 6 (2*3) for weight, and 3 for bias
nn.Linear(2, 3).nParams""" # dummy
return sum([p.numel() for p in self.parameters()]) # dummy
if hasTorch: # dummy
@k1lib.patch(torch) # dummy
@k1lib.patch(torch.Tensor) # dummy
def crissCross(*others:Tuple[torch.Tensor]) -> torch.Tensor: # dummy
"""Concats multiple 1d tensors, sorts it, and get evenly-spaced values. Also
available as :meth:`torch.crissCross` and :meth:`~k1lib.cli.others.crissCross`.
Example::
a = torch.tensor([2, 2, 3, 6])
b = torch.tensor([4, 8, 10, 12, 18, 20, 30, 35])
# returns tensor([2, 3, 6, 10, 18, 30])
a.crissCross(b)
# returns tensor([ 2, 4, 8, 10, 18, 20, 30, 35])
a.crissCross(*([b]*10)) # 1 "a" and 10 "b"s
# returns tensor([ 2, 2, 3, 6, 18])
b.crissCross(*([a]*10)) # 1 "b" and 10 "a"s
Note how in the second case, the length is the same as tensor b, and the contents
are pretty close to b. In the third case, it's the opposite. Length is almost
the same as tensor a, and the contents are also pretty close to a.""" # dummy
return torch.cat([o.flatten() for o in others]).sort()[0][::len(others)] # dummy
if hasTorch: # dummy
@k1lib.patch(torch) # dummy
def sameStorage(a, b): # dummy
"""Check whether 2 (:class:`np.ndarray` or :class:`torch.Tensor`)
has the same storage or not. Example::
a = np.linspace(2, 3, 50)
# returns True
torch.sameStorage(a, a[:5])
# returns True
torch.sameStorage(a[:10], a[:5])
returns false
torch.sameStorage(a[:10], np.linspace(3, 4))
All examples above should work with PyTorch tensors as well.""" # dummy
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): # dummy
return a.data_ptr() == b.data_ptr() # dummy
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): # dummy
return a.base is b or b.base is a or a.base is b.base # dummy
return a is b # dummy
if hasTorch: # dummy
@k1lib.patch(torch.Tensor) # dummy
def histBounds(self:torch.Tensor, bins=100) -> torch.Tensor: # dummy
r"""Flattens and sorts the tensor, then get value of tensor at regular
linspace intervals. Does not guarantee bounds' uniqueness. Example::
# Tensor with lots of 2s and 5s
a = torch.Tensor([2]*5 + [3]*3 + [4] + [5]*4)
# returns torch.tensor([2., 3., 5.])
a.histBounds(3).unique()
The example result essentially shows 3 bins: :math:`[2, 3)`, :math:`[3, 5)` and
:math:`[5, \infty)`. This might be useful in scaling pixels so that networks handle
it nicely. Rough idea taken from fastai.medical.imaging.""" # dummy
sortedTensor = self.flatten().sort()[0] # dummy
ls = torch.linspace(0, 1, bins); ls[-1] = 1-1e-6 # dummy
bigLs = (ls * len(sortedTensor)).long() # dummy
return sortedTensor[bigLs] # dummy
if hasTorch: # dummy
@k1lib.patch(torch.Tensor) # dummy
def histScaled(self:torch.Tensor, bins=100, bounds=None) -> torch.Tensor: # dummy
"""Scale tensor's values so that the values are roughly spreaded out in range
:math:`[0, 1]` to ease neural networks' pain. Rough idea taken from
fastai.medical.imaging. Example::
# normal-distributed values
a = torch.randn(1000)
# plot #1 shows a normal distribution
plt.hist(a.numpy(), bins=30); plt.show()
# plot #2 shows almost-uniform distribution
plt.hist(a.histScaled().numpy()); plt.show()
Plot #1:
.. image:: images/histScaledNormal.png
Plot #2:
.. image:: images/histScaledUniform.png
:param bins: if ``bounds`` not specified, then will scale according to a hist
with this many bins
:param bounds: if specified, then ``bins`` is ignored and will scale according to
this. Expected this to be a sorted tensor going from ``min(self)`` to
``max(self)``.""" # dummy
if bounds is None: bounds = self.histBounds(bins).unique() # dummy
else: bounds = bounds.unique() # dummy
out = np.interp(self.numpy().flatten(), bounds, np.linspace(0, 1, len(bounds))) # dummy
return torch.tensor(out).reshape(self.shape) # dummy
if hasTorch: # dummy
@k1lib.patch(torch.Tensor) # dummy
def positionalEncode(t:torch.Tensor, richFactor:float=2) -> torch.Tensor: # dummy
r"""Position encode a tensor of shape :math:`(L, F)`, where :math:`L`
is the sequence length, :math:`F` is the encoded features. Will add the
encodings directly to the input tensor and return it.
This is a bit different from the standard implementations that ppl use.
This is exactly:
.. math:: p = \frac{i}{F\cdot richFactor}
.. math:: w = 1/10000^p
.. math:: pe = sin(w * L)
With ``i`` from range [0, F), and ``p`` the "progress". If ``richFactor`` is 1
(original algo), then ``p`` goes from 0% to 100% of the features. Example::
import matplotlib.pyplot as plt, torch, k1lib
plt.figure(dpi=150)
plt.imshow(torch.zeros(100, 10).positionalEncode().T)
.. image:: images/positionalEncoding.png
:param richFactor: the bigger, the richer the features are. A lot of
times, I observe that the features that are meant to cover huge scales
are pretty empty and don't really contribute anything useful. So this
is to bump up the usefulness of those features""" # dummy
seqN, featsN = t.shape # dummy
feats = torch.tensor(range(featsN)); w = (1/10000**(feats/featsN/richFactor))[None, :].expand(t.shape) # dummy
times = torch.tensor(range(seqN))[:, None].expand(t.shape) # dummy
t[:] = torch.sin(w * times); return t # dummy
if hasTorch: # dummy
@k1lib.patch(torch.Tensor) # dummy
def clearNan(self, value:float=0.0) -> torch.Tensor: # dummy
"""Sets all nan values to a specified value.
Example::
a = torch.randn(3, 3) * float("nan")
a.clearNan() # now full of zeros""" # dummy
self[self != self] = value # dummy
return self # dummy
if hasTorch: # dummy
@k1lib.patch(torch.Tensor) # dummy
def hasNan(self) -> bool: # dummy
"""Returns whether this Tensor has any nan values at all.""" # dummy
return (self != self).sum() > 0 # dummy
if hasTorch: # dummy
@k1lib.patch(torch.Tensor) # dummy
def stats(self) -> Tuple[float, float]: # dummy
return self.mean(), self.std() # dummy
inf = float("inf") # dummy
if hasTorch: # dummy
@k1lib.patch(torch.Tensor) # dummy
def hasInfs(self): # dummy
"""Whether this Tensor has negative or positive infinities.""" # dummy
return (self == inf).any() or (self == -inf).any() # dummy
if hasTorch: # dummy
@k1lib.patch(torch) # dummy
def loglinspace(a, b, n=100, **kwargs): # dummy
"""Like :meth:`torch.linspace`, but spread the values out in log space,
instead of linear space. Different from :meth:`torch.logspace`""" # dummy
return math.e**torch.linspace(math.log(a), math.log(b), n, **kwargs) # dummy
try: # dummy
import graphviz # dummy
@k1lib.patch(graphviz.Digraph, "__call__") # dummy
@k1lib.patch(graphviz.Graph, "__call__") # dummy
def _call(self, _from, *tos, **kwargs): # dummy
"""Convenience method to quickly construct graphs.
Example::
g = k1lib.graph()
g("a", "b", "c")
g # displays arrows from "a" to "b" and "a" to "c"
""" # dummy
for to in tos: self.edge(_from, to, **kwargs) # dummy
except: pass # dummy
try: # dummy
import matplotlib.pyplot as plt # dummy
from mpl_toolkits.mplot3d import Axes3D, art3d # dummy
@k1lib.patch(Axes3D) # dummy
def march(self, heatMap, level:float=0, facecolor=[0.45, 0.45, 0.75], edgecolor=None): # dummy
"""Use marching cubes to plot surface of a 3d heat map.
Example::
plt.k3d(6).march(k1lib.perlin3d(), 0.17)
.. image:: images/march.png
A more tangible example::
t = torch.zeros(100, 100, 100)
t[20:30,20:30,20:30] = 1
t[80:90,20:30,40:50] = 1
plt.k3d().march(t.numpy())
The function name is "march" because how it works internally is by using
something called marching cubes.
:param heatMap: 3d numpy array
:param level: array value to form the surface on""" # dummy
from skimage import measure # dummy
try: verts, faces, normals, values = measure.marching_cubes(heatMap, level) # dummy
except: verts, faces, normals, values = measure.marching_cubes_lewiner(heatMap, level) # dummy
mesh = art3d.Poly3DCollection(verts[faces]) # dummy
if facecolor is not None: mesh.set_facecolor(facecolor) # dummy
if edgecolor is not None: mesh.set_edgecolor(edgecolor) # dummy
self.add_collection3d(mesh) # dummy
self.set_xlim(0, heatMap.shape[0]) # dummy
self.set_ylim(0, heatMap.shape[1]) # dummy
self.set_zlim(0, heatMap.shape[2]); return self # dummy
@k1lib.patch(Axes3D) # dummy
def aspect(self): # dummy
"""Use the same aspect ratio for all axes.""" # dummy
self.set_box_aspect([ub - lb for lb, ub in (getattr(self, f'get_{a}lim')() for a in 'xyz')]) # dummy
@k1lib.patch(plt) # dummy
def k3d(size=8, labels=True, *args, **kwargs): # dummy
"""Convenience function to get an :class:`~mpl_toolkits.mplot3d.axes3d.Axes3D`.
:param labels: whether to include xyz labels or not
:param size: figure size""" # dummy
if isinstance(size, (int, float)): size = (size, size) # dummy
fig = plt.figure(figsize=size, constrained_layout=True, *args, **kwargs) # dummy
ax = fig.add_subplot(projection="3d") # dummy
if labels: # dummy
ax.set_xlabel('x') # dummy
ax.set_ylabel('y') # dummy
ax.set_zlabel('z') # dummy
return ax # dummy
@k1lib.patch(plt) # dummy
def animate(azimSpeed=3, azimStart=0, elevSpeed=0.9, elevStart=0, frames=20, close=True): # dummy
"""Animates the existing 3d axes.
Example::
plt.k3d().scatter(*np.random.randn(3, 10))
plt.animate()
:param frames: how many frames to render? Frame rate is 30 fps
:param close: whether to close the figure (to prevent the animation and
static plot showing at the same time) or not""" # dummy
fig = plt.gcf() # dummy
def f(frame): # dummy
for ax in fig.axes: # dummy
ax.view_init(elevStart+frame*elevSpeed, azimStart+frame*azimSpeed) # dummy
if close: plt.close() # dummy
return k1lib.viz.FAnim(fig, f, frames) # dummy
@k1lib.patch(plt) # dummy
def getFig(): # dummy
"""Grab figure of the current plot.
Example::
plt.plot() | plt.getFig() | toImg()
Internally, this just calls ``plt.gcf()`` and that's it, pretty simple.
But I usually plot things as a part of the cli pipeline, and it's very
annoying that I can't quite chain ``plt.gcf()`` operation, so I created
this""" # dummy
def inner(_): return plt.gcf() # dummy
return cli.aS(inner) # dummy
k1lib._settings.monkey.add("capturePlt", False, "whether to intercept matplotlib's show() and turn it into an image or not") # dummy
_oldShow = plt.show; _recentImg = [None] # dummy
@k1lib.patch(plt) # dummy
def show(*args, **kwargs): # dummy
try: # dummy
if k1lib._settings.monkey.capturePlt: _recentImg[0] = plt.gcf() | k1lib.cli.toImg() # dummy
except: return _oldShow(*args, **kwargs) # dummy
@k1lib.patch(plt) # dummy
def _k1_capturedImg(): return _recentImg[0] # dummy
except: pass # dummy
try: # dummy
@k1lib.patch(Axes3D) # dummy
def plane(self, origin, v1, v2=None, s1:float=1, s2:float=1, **kwargs): # dummy
"""Plots a 3d plane.
:param origin: origin vector, shape (3,)
:param v1: 1st vector, shape (3,)
:param v2: optional 2nd vector, shape(3,). If specified, plots a plane created
by 2 vectors. If not, plots a plane perpendicular to the 1st vector
:param s1: optional, how much to scale 1st vector by
:param s2: optional, how much to scale 2nd vector by
:param kwargs: keyword arguments passed to :meth:`~mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`""" # dummy
v1 = (v1 if isinstance(v1, torch.Tensor) else torch.tensor(v1)).float() # dummy
if v2 is None: # dummy
v = v1 # dummy
v1 = torch.tensor([1.0, 1, -(v[0]+v[1])/v[2]]) # dummy
v2 = torch.cross(v, v1) # dummy
v2 = (v2 if isinstance(v2, torch.Tensor) else torch.tensor(v2)).float() # dummy
origin = (origin if isinstance(origin, torch.Tensor) else torch.tensor(origin)).float() # dummy
x = torch.linspace(-1, 1, 50)[:,None] # dummy
v1 = (v1[None,:]*x*s1)[:,None] # dummy
v2 = (v2[None,:]*x*s2)[None,:] # dummy
origin = origin[None,None,:] # dummy
plane = (origin + v1 + v2).permute(2, 0, 1) # dummy
self.plot_surface(*plane.numpy(), **kwargs) # dummy
except: pass # dummy
try: # dummy
@k1lib.patch(Axes3D) # dummy
def point(self, v, **kwargs): # dummy
"""Plots a 3d point.
:param v: point location, shape (3,)
:param kwargs: keyword argument passed to :meth:`~mpl_toolkits.mplot3d.axes3d.Axes3D.scatter`""" # dummy
v = (v if hasTorch and isinstance(v, torch.Tensor) else torch.tensor(v)).float() # dummy
self.scatter(*v, **kwargs) # dummy
@k1lib.patch(Axes3D) # dummy
def line(self, v1, v2, **kwargs): # dummy
"""Plots a 3d line.
:param v1: 1st point location, shape (3,)
:param v2: 2nd point location, shape (3,)
:param kwargs: keyword argument passed to :meth:`~mpl_toolkits.mplot3d.axes3d.Axes3D.plot`""" # dummy
self.plot(*torch.tensor([list(v1), list(v2)]).float().T, **kwargs) # dummy
except: pass # dummy
try: # dummy
@k1lib.patch(Axes3D) # dummy
def surface(self, z, **kwargs): # dummy
"""Plots 2d surface in 3d. Pretty much exactly the same as
:meth:`~mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`, but fields x and y
are filled in automatically. Example::
x, y = np.meshgrid(np.linspace(-2, 2), np.linspace(-2, 2))
plt.k3d(6).surface(x**3 + y**3)
.. image:: images/surface.png
:param z: 2d numpy array for the heights
:param kwargs: keyword arguments passed to ``plot_surface``""" # dummy
if hasTorch and isinstance(z, torch.Tensor): z = z.numpy() # dummy
x, y = z.shape # dummy
x, y = np.meshgrid(np.array(range(y)), np.array(range(x))) # dummy
return self.plot_surface(x, y, z, **kwargs) # dummy
except: pass # dummy
try: # dummy
import pandas as pd # dummy
@k1lib.patch(pd.DataFrame) # dummy
def table(self): # dummy
"""Converts a :class:`pandas.core.frame.DataFrame` to a normal table made
from lists (with column headers), so that it can be more easily manipulated
with cli tools. Example::
pd.read_csv("abc.csv").table()""" # dummy
yield self.columns.to_list() # dummy
yield from self.values # dummy
except: pass # dummy
try: # dummy
import forbiddenfruit # dummy
def splitCamel(s): # dummy
"""Splits a string up based on camel case.
Example::
# returns ['IHave', 'No', 'Idea', 'What', 'To', 'Put', 'Here']
"IHaveNoIdeaWhatToPutHere".splitCamel()""" # dummy
words = [[s[0]]] # dummy
for c in s[1:]: # dummy
if words[-1][-1].islower() and c.isupper(): # dummy
words.append(list(c)) # dummy
else: words[-1].append(c) # dummy
return [''.join(word) for word in words] # dummy
forbiddenfruit.curse(str, "splitCamel", splitCamel) # dummy
except: pass # dummy
try: # dummy
import ray # dummy
@ray.remote # dummy
class RayProgress: # dummy
def __init__(self, n): self.values = [0]*n; self.thStop = False # dummy
def update(self, idx:int, val:float): self.values[idx] = val; return self.values[idx] # dummy
def stop(self): self.thStop = True # dummy
def content(self): return self.thStop, self.values | cli.apply(lambda x: f"{round(x*100)}%") | cli.join(" | ") # dummy
def startRayProgressThread(rp, title:str="Progress"): # dummy
def inner(x): # dummy
if x == 0: return # dummy
print("Starting...\r", end="") # dummy
beginTime = time.time() # dummy
while True: # dummy
stop, content = ray.get(rp.content.remote()) # dummy
print(f"{title}: {content}, {round(time.time()-beginTime)}s elapsed \r", end="") # dummy
if stop: break # dummy
time.sleep(0.01) # dummy
[0, 1] | cli.applyTh(inner, timeout=1e9, prefetch=10) | cli.item() # dummy
@k1lib.patch(ray) # dummy
@contextmanager # dummy
def progress(n:int, title:str="Progress"): # dummy
"""Manages multiple progress bars distributedly.
Example::
with ray.progress(5) as rp:
def process(idx:int):
for i in range(100):
time.sleep(0.05) # do some processing
rp.update.remote(idx, (i+1)/100) # update progress. Expect number between 0 and 1
range(5) | applyCl(process) | deref() # execute function in multiple nodes
This will print out a progress bar that looks like this::
Progress: 100% | 100% | 100% | 100% | 100%
:param n: number of progresses to keep track of
:param title: title of the progress to show""" # dummy
rp = RayProgress.remote(n); startRayProgressThread(rp, title); yield rp # dummy
ray.get(rp.stop.remote()); time.sleep(0.1) # dummy
except: pass # dummy
@k1lib.patch(np) # dummy
def gather(self, dim, index): # gather
"""Gathers values along an axis specified by ``dim``.
For a 3-D tensor the output is specified by::
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Not my code. All credits go to https://stackoverflow.com/questions/46868056/how-to-gather-elements-of-specific-indices-in-numpy
:param dim: the axis along which to index
:param index: A tensor of indices of elements to gather""" # gather
idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:] # gather
self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:] # gather
if idx_xsection_shape != self_xsection_shape: raise ValueError("Except for dimension " + str(dim) + ", all dimensions of index and self should be the same size") # gather
if index.dtype != np.dtype('int_'): raise TypeError("The values of index must be integers") # gather
data_swaped = np.swapaxes(self, 0, dim); index_swaped = np.swapaxes(index, 0, dim) # gather
gathered = np.choose(index_swaped, data_swaped); return np.swapaxes(gathered, 0, dim) # gather
try: # gather
import forbiddenfruit # gather
def expand(self, sh): return np.broadcast_to(self, sh) # gather
forbiddenfruit.curse(np.ndarray, "expand", expand) # gather
except: pass # gather