Source code for k1lib.kdata

# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""
Everything related to data transformation and loading. This is exposed automatically
with::

   from k1lib.imports import *
   kdata.FunctionData # exposed
"""
import torch, numpy as np; from k1lib.cli import *
from typing import Callable, Union, Iterator
import matplotlib.pyplot as plt
__all__ = ["FunctionData", "tfImg", "tfFloat", "analyzeFloat"]
[docs]class FunctionData:
[docs] @staticmethod def main(f:Callable, bs:int=32, epochs:int=300): """Constructs 2 dataloaders, train and valid, for a particular function. Example:: trainDl, validDl = kdata.FunctionData.main(torch.exp, 32, 300) for epoch in range(3): for xb, yb in trainDl: model(xb)""" x = torch.linspace(-5, 5, 1000) ds = [x, f(x)] | transpose() | randomize(None) return ds | splitList(8, 2) | (repeatFrom() | randomize() | batched(32)\ | (transpose() | toTensor()).all()).all()\ | (stagger(epochs*.8) + stagger(epochs*.2)) | toList()
[docs] @staticmethod def exp(bs, epochs): return FunctionData.main(torch.exp, bs, epochs)
[docs] @staticmethod def log(bs, epochs): return FunctionData.main(torch.log, bs, epochs)
[docs] @staticmethod def inverse(bs, epochs): return FunctionData.main(lambda x: 1/x, bs, epochs)
[docs] @staticmethod def linear(bs, epochs): return FunctionData.main(lambda x: 2*x+8, bs, epochs)
[docs] @staticmethod def sin(bs, epochs): return FunctionData.main(torch.sin, bs, epochs)
aS = applyS
[docs]def tfImg(size:int=None, /, flip=True) -> BaseCli: """Get typical image transforms. Example:: "path/img.png" | toPIL() | kdata.imgTf(224)""" import torchvision.transforms as tf op = identity() if size: op |= aS(tf.Resize(size)) | aS(tf.CenterCrop(size)) op |= aS(tf.ColorJitter(0.2, 0.2, 0.2)) | aS(tf.RandomAffine(5)) if flip: op |= aS(tf.RandomHorizontalFlip()) return op
def tensorGuard(t, force:bool): if isinstance(t, np.ndarray): t = torch.tensor(t) if not isinstance(t, torch.Tensor): t = t | toFloat(force=force) | deref() | toTensor() return t
[docs]def tfFloat(t:Union[Iterator[float], torch.Tensor], force=True) -> BaseCli: """Suggested float input transformation function. Example:: # before training data = torch.randn(10, 20) * 100 + 20 # weird data with weird hist distribution f = kdata.tfFloat(data) # while training newData = torch.randn(10, 20) * 105 + 15 newData | f # nicely formatted Tensor, going uniformly from -1 to 1 :param force: if True, forces weird values to 0.0, else filters out all weird rows.""" t = tensorGuard(t, force); bounds = t.histBounds() return applyS(lambda t: tensorGuard(t, force).histScaled(0, bounds)*2 - 1)
[docs]@applyS def analyzeFloat(l:Iterator[float]): """Preliminary input float stream analysis. Example:: torch.linspace(-2, 2, 50) | kdata.analyzeFloat""" l = l | deref(False); lf = l | toFloat() | toTensor() nl = l | shape(0); nlf = len(lf) print(f"Percent of useful data: {nlf}/{nl} ({round(100*nlf/nl)}%)") print(f"- Mean: {lf.mean()}"); print(f"- Std: {lf.std()}") print(f"- Min: {lf.min()}"); print(f"- Max: {lf.max()}") plt.hist(lf.numpy(), bins=30); plt.title("Values histogram"); plt.ylabel("Frequency"); plt.show() plt.hist(lf.histScaled().numpy(), bins=30); plt.ylabel("Frequency"); plt.title("Scaled histogram")