# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""
Everything related to data transformation and loading. This is exposed automatically
with::
   from k1lib.imports import *
   kdata.FunctionData # exposed
"""
import torch, numpy as np; from k1lib.cli import *
from typing import Callable, Union, Iterator
import matplotlib.pyplot as plt
__all__ = ["FunctionData", "tfImg", "tfFloat", "analyzeFloat"]
[docs]class FunctionData:
[docs]    @staticmethod
    def main(f:Callable, bs:int=32, epochs:int=300):
        """Constructs 2 dataloaders, train and valid, for a particular function.
Example::
    trainDl, validDl = kdata.FunctionData.main(torch.exp, 32, 300)
    for epoch in range(3):
        for xb, yb in trainDl:
            model(xb)"""
        x = torch.linspace(-5, 5, 1000)
        ds = [x, f(x)] | transpose() | randomize(None)
        return ds | splitList(8, 2) | (repeatFrom() | randomize() | batched(32)\
            
| (transpose() | toTensor()).all()).all()\
            
| (stagger(epochs*.8) + stagger(epochs*.2)) | toList() 
[docs]    @staticmethod
    def exp(bs, epochs): return FunctionData.main(torch.exp, bs, epochs) 
[docs]    @staticmethod
    def log(bs, epochs): return FunctionData.main(torch.log, bs, epochs) 
[docs]    @staticmethod
    def inverse(bs, epochs): return FunctionData.main(lambda x: 1/x, bs, epochs) 
[docs]    @staticmethod
    def linear(bs, epochs): return FunctionData.main(lambda x: 2*x+8, bs, epochs) 
[docs]    @staticmethod
    def sin(bs, epochs): return FunctionData.main(torch.sin, bs, epochs)  
aS = applyS
[docs]def tfImg(size:int=None, /, flip=True) -> BaseCli:
    """Get typical image transforms.
Example::
    "path/img.png" | toPIL() | kdata.tfImg(224)"""
    import torchvision.transforms as tf
    op = identity()
    if size: op |= aS(tf.Resize(size)) | aS(tf.CenterCrop(size))
    op |= aS(tf.ColorJitter(0.2, 0.2, 0.2)) | aS(tf.RandomAffine(5))
    if flip: op |= aS(tf.RandomHorizontalFlip())
    return op 
def tensorGuard(t, force:bool):
    if isinstance(t, np.ndarray): t = torch.tensor(t)
    if not isinstance(t, torch.Tensor): t = t | toFloat(force=force) | deref() | toTensor()
    return t
[docs]def tfFloat(t:Union[Iterator[float], torch.Tensor], force=True) -> BaseCli:
    """Suggested float input transformation function.
Example::
    # before training
    data = torch.randn(10, 20) * 100 + 20 # weird data with weird hist distribution
    f = kdata.tfFloat(data)
    # while training
    newData = torch.randn(10, 20) * 105 + 15
    newData | f # nicely formatted Tensor, going uniformly from -1 to 1
:param force: if True, forces weird values to 0.0, else filters out all weird rows."""
    t = tensorGuard(t, force); bounds = t.histBounds()
    return applyS(lambda t: tensorGuard(t, force).histScaled(0, bounds)*2 - 1) 
[docs]@applyS
def analyzeFloat(l:Iterator[float]):
    """Preliminary input float stream analysis.
Example::
    torch.linspace(-2, 2, 50) | kdata.analyzeFloat"""
    l = l | deref(False); lf = l | toFloat() | toTensor()
    nl = l | shape(0); nlf = len(lf)
    print(f"Percent of useful data: {nlf}/{nl} ({round(100*nlf/nl)}%)")
    print(f"- Mean: {lf.mean()}"); print(f"- Std: {lf.std()}")
    print(f"- Min: {lf.min()}");   print(f"- Max: {lf.max()}")
    plt.hist(lf.numpy(), bins=30); plt.title("Values histogram"); plt.ylabel("Frequency"); plt.show()
    plt.hist(lf.histScaled().numpy(), bins=30); plt.ylabel("Frequency"); plt.title("Scaled histogram")