# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""
Everything related to data transformation and loading. This is exposed automatically
with::
from k1lib.imports import *
kdata.FunctionData # exposed
"""
import torch, numpy as np; from k1lib.cli import *
from typing import Callable, Union, Iterator
import matplotlib.pyplot as plt
__all__ = ["FunctionData", "tfImg", "tfFloat", "analyzeFloat"]
[docs]class FunctionData:
[docs] @staticmethod
def main(f:Callable, bs:int=32, epochs:int=300):
"""Constructs 2 dataloaders, train and valid, for a particular function.
Example::
trainDl, validDl = kdata.FunctionData.main(torch.exp, 32, 300)
for epoch in range(3):
for xb, yb in trainDl:
model(xb)"""
x = torch.linspace(-5, 5, 1000)
ds = [x, f(x)] | transpose() | randomize(None)
return ds | splitList(8, 2) | (repeatFrom() | randomize() | batched(32)\
| (transpose() | toTensor()).all()).all()\
| (stagger(epochs*.8) + stagger(epochs*.2)) | toList()
[docs] @staticmethod
def exp(bs, epochs): return FunctionData.main(torch.exp, bs, epochs)
[docs] @staticmethod
def log(bs, epochs): return FunctionData.main(torch.log, bs, epochs)
[docs] @staticmethod
def inverse(bs, epochs): return FunctionData.main(lambda x: 1/x, bs, epochs)
[docs] @staticmethod
def linear(bs, epochs): return FunctionData.main(lambda x: 2*x+8, bs, epochs)
[docs] @staticmethod
def sin(bs, epochs): return FunctionData.main(torch.sin, bs, epochs)
aS = applyS
[docs]def tfImg(size:int=None, /, flip=True) -> BaseCli:
"""Get typical image transforms.
Example::
"path/img.png" | toPIL() | kdata.tfImg(224)"""
import torchvision.transforms as tf
op = identity()
if size: op |= aS(tf.Resize(size)) | aS(tf.CenterCrop(size))
op |= aS(tf.ColorJitter(0.2, 0.2, 0.2)) | aS(tf.RandomAffine(5))
if flip: op |= aS(tf.RandomHorizontalFlip())
return op
def tensorGuard(t, force:bool):
if isinstance(t, np.ndarray): t = torch.tensor(t)
if not isinstance(t, torch.Tensor): t = t | toFloat(force=force) | deref() | toTensor()
return t
[docs]def tfFloat(t:Union[Iterator[float], torch.Tensor], force=True) -> BaseCli:
"""Suggested float input transformation function.
Example::
# before training
data = torch.randn(10, 20) * 100 + 20 # weird data with weird hist distribution
f = kdata.tfFloat(data)
# while training
newData = torch.randn(10, 20) * 105 + 15
newData | f # nicely formatted Tensor, going uniformly from -1 to 1
:param force: if True, forces weird values to 0.0, else filters out all weird rows."""
t = tensorGuard(t, force); bounds = t.histBounds()
return applyS(lambda t: tensorGuard(t, force).histScaled(0, bounds)*2 - 1)
[docs]@applyS
def analyzeFloat(l:Iterator[float]):
"""Preliminary input float stream analysis.
Example::
torch.linspace(-2, 2, 50) | kdata.analyzeFloat"""
l = l | deref(False); lf = l | toFloat() | toTensor()
nl = l | shape(0); nlf = len(lf)
print(f"Percent of useful data: {nlf}/{nl} ({round(100*nlf/nl)}%)")
print(f"- Mean: {lf.mean()}"); print(f"- Std: {lf.std()}")
print(f"- Min: {lf.min()}"); print(f"- Max: {lf.max()}")
plt.hist(lf.numpy(), bins=30); plt.title("Values histogram"); plt.ylabel("Frequency"); plt.show()
plt.hist(lf.histScaled().numpy(), bins=30); plt.ylabel("Frequency"); plt.title("Scaled histogram")