# AUTOGENERATED FILE! PLEASE DON'T EDIT
import k1lib, torch.nn as nn, torch, dill, traceback
from k1lib.callbacks import Cbs
from typing import Union
from time import time as _time
__all__ = ["CancelRunException", "CancelEpochException", "CancelBatchException",
           "Learner"]
[docs]class CancelRunException(Exception):
    """Used in core training loop, to skip the run entirely"""
    pass 
[docs]class CancelEpochException(Exception):
    """Used in core training loop, to skip to next epoch"""
    pass 
[docs]class CancelBatchException(Exception):
    """Used in core training loop, to skip to next batch"""
    pass 
def _tab(text:Union[list, str], pad="    ") -> Union[list, str]:
    if isinstance(text, str): # this is old function that's replaced in main lib, but still useful
        return "\n".join([pad + line for line in text.split("\n")])
    else: return [pad + line for line in text]
[docs]class Learner:
    def __init__(self):
        self._model = None; self._data = None; self._opt = None
        self._cbs = None; self.fileName = None
        self.css = "*"; self.exceptionRaised = None # slowly pops
        self.cbs = k1lib.Callbacks().withBasics().withQOL().withAdvanced()
    @property
    def model(self):
        """Set this to change the model to run"""
        return self._model
    @model.setter
    def model(self, model): self._model = model
    @property
    def data(self):
        """Set this to change the data (list of 2 dataloader) to run against."""
        return self._data
    @data.setter
    def data(self, data): self._data = data
    @property
    def opt(self):
        """Set this to change the optimizer. If you're making your own
optimizers, beware to follow the PyTorch's style guide as there are callbacks
that modifies optimizer internals while training like
:class:`k1lib.schedule.ParamScheduler`."""
        return self._opt
    @opt.setter
    def opt(self, opt): self._opt = opt
    @property
    def cbs(self):
        """The :class:`~k1lib.callbacks.callbacks.Callbacks` object. Initialized to
include all the common callbacks. You can set a new one if you want to."""
        return self._cbs
    @cbs.setter
    def cbs(self, cbs): cbs.l = self; self._cbs = cbs
    @property
    def css(self) -> str:
        """The css selector string. Set this to select other parts of the network.
After setting, you can access the selector like this: :code:`l.selector`
See also: :class:`~k1lib.selector.ModuleSelector`"""
        return self._css
    @css.setter
    def css(self, css:str):
        self._css = css
        if self.model != None: self.selector = k1lib.selector.select(self.model, self.css)
    @property
    def lossF(self):
        """Set this to specify a loss function."""
        raise NotImplementedError("lossF actually doesn't really exist. Used to exist as a core part of Learner, but then has been converted to Cbs.LossF")
    @lossF.setter
    def lossF(self, lossF):
        if hasattr(self.cbs, "LossF"): self.cbs.LossF.lossF = lossF
        else: self.cbs.add(Cbs.LossF(lossF))
    def __getattr__(self, attr):
        if attr == "cbs": raise AttributeError()
        return getattr(self.cbs, attr)
    def __getstate__(self):
        answer = dict(self.__dict__); answer.pop("selector", None)
        answer.pop("_data", None); return answer
    def __setstate__(self, state):
        self.__dict__.update(state)
        self.__dict__["_data"] = None
        self.css = self.css; self.cbs.l = self
[docs]    def evaluate(self): pass # supposed to be overriden, to provide functionality here 
    @property
    def _warnings(self):
        warnings = "Warning: no model yet. Set using `l.model = ...`\n" if self.model == None else ""
        lossClasses = tuple([*k1lib.Callback.lossCls])
        lossFnCbs = [True for cb in self.cbs if isinstance(cb, lossClasses)]
        warnings += "Warning: no loss function callback detected (or you set `lossF` already but then erased all callbacks)! Set using `l.lossF = ...` or `l.cbs.add(Cbs.LossF(...))`\n" if len(lossFnCbs) == 0 else ""
        warnings += "Warning: no data yet. Set using `l.data = ...`\n" if self.data == None else ""
        warnings += "Warning: no optimizer yet. Set using `l.opt = ...`\n" if self.opt == None else ""
        if warnings != "": warnings += "\n\n"
        return warnings
    def __dir__(self):
        answer = list(super().__dir__())
        answer.extend(self.cbs.cbsDict.keys()); return answer
    def __repr__(self):
        return f"""{self._warnings}l.model:\n{_tab(k1lib.limitLines(str(self.model)))}
l.opt:\n{_tab(k1lib.limitLines(str(self.opt)))}
l.cbs:\n{_tab(k1lib.limitLines(self.cbs.__repr__()))}
Use...
- l.model = ...: to specify a nn.Module object
- l.data = ...: to specify data object
- l.opt = ...: to specify an optimizer
- l.lossF = ...: to specify a loss function
- l.css = ...: to select modules using CSS. "#root" for root model
- l.cbs = ...: to use a custom `Callbacks` object
- l.selector: to get the modules selected by `l.css`
- l.run(epochs): to run the network
- l.Loss: to get a specific callback, this case "Loss"\n\n""" 
@k1lib.patch(Learner)
def save(self, fileName:str=None):
    """Saves this :class:`Learner` to file. See also: :meth:`load`. Does not
save the ``data`` object, because that's potentially very big.
:param fileName: if empty, then will save as "learner-0.pth", with 0
    changeable to avoid conflicts. If resave this exact :class:`Learner`, then
    use the old name generated before"""
    self.fileName = fileName or self.fileName
    if self.fileName == None:
        files = [file for file in os.listdir() if file.startswith("learner") and file.endswith(".pth")]
        files = set([int(file.split(".pth")[0].split("learner-")[1]) for file in files])
        count = 0;
        while count in files: count += 1
        self.fileName = f"l-{count}.pth"
    torch.save(self, self.fileName, pickle_module=dill)
    print(f"Saved to {self.fileName}")
@k1lib.patch(Learner, static=True)
def load(fileName:str=None):
    """Loads a :class:`Learner` from a file. See also: :meth:`save`
:param fileName: if empty, then will prompt for file name"""
    f = fileName or input("Enter learner file name to load:")
    print(f"Loaded from {f}"); return torch.load(f, pickle_module=dill)
@k1lib.patch(Learner)
def _run1Batch(self):
    self.cbs("startBatch")
    try:
        self.cbs("startPass", "inPass", "endPass")
        self.cbs("startLoss", "inLoss", "endLoss")
        if not self.cbs("startBackward"): self.lossG.backward()
        if not self.cbs("startStep"):  self.opt.step()
        if not self.cbs("startZeroGrad"): self.opt.zero_grad(set_to_none=True)
    except k1lib.CancelBatchException as ex:
        self.cbs("cancelBatch"); print(f"Batch cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "")
    except (k1lib.CancelEpochException, k1lib.CancelRunException) as ex:
        # makes sure cancelBatch and endBatch gets called, for potential
        # cleanups, then reraise the exception
        self.cbs("cancelBatch", "endBatch"); raise ex
    self.cbs("endBatch")
class DI: # data interceptor, just to record data loading times
    def __init__(self, l:Learner, data): self.l = l; self.data = data
    def __len__(self): return len(self.data)
    def __iter__(self):
        try:
            data = iter(self.data); timings = self.l.cbs.timings
            while True:
                beginTime = _time(); d = next(data)
                timings.loadData += _time() - beginTime; yield d
        except StopIteration: pass
@k1lib.patch(Learner)
def _run1Epoch(self):
    self.cbs("startEpoch")
    try:
        train, valid = self.data; train = DI(self, train); valid = DI(self, valid)
        try: self.batches = len(train) + len(valid)
        except: pass
        self.model.train()
        for self.batch, (self.xb, self.yb, *self.metab) in enumerate(train):
            self._run1Batch()
        trainLen = self.batch + 1
        if not self.cbs("startValidBatches"):
            self.model.eval();
            for self.batch, (self.xb, self.yb, *self.metab) in enumerate(valid):
                self.batch += trainLen; self._run1Batch()
        if self.batches is None: self.batches = self.batch + 1
    except k1lib.CancelEpochException as ex:
        self.cbs("cancelEpoch"); print(f"Epoch cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "")
    except k1lib.CancelRunException as ex:
        self.cbs("cancelEpoch", "endEpoch"); raise ex
    self.cbs("endEpoch")
@k1lib.patch(Learner)
def run(self, epochs:int, batches:int=None):
    """Main run function.
:param epochs: number of epochs to run. 1 epoch is the length of the dataset
:param batches: if set, then cancels the epoch after reaching the specified batch"""
    if self._warnings != "":
        if not input(f"""You still have these warnings:\n\n{self._warnings}
Do you want to continue? (y/n) """).lower().startswith("y"):
            print("Run ended"); return
    self.epochs = epochs; self.batches = None
    self.css = self.css # update module selector
    with self.cbs.context():
        if batches != None: self.cbs.add(Cbs.BatchLimit(batches))
        self.cbs("startRun")
        try:
            for self.epoch in range(epochs): self._run1Epoch()
        except k1lib.CancelRunException as ex:
            self.cbs("cancelRun"); print(f"Run cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "")
        self.cbs("endRun"); return self
@k1lib.patch(Learner)
def __call__(self, xb, yb=None):
    """Executes just a small batch. Convenience method to query how the network is
doing.
:param xb: x batch
:param yb: y batch. If specified, return (y, loss), else return y alone
"""
    oldData = self.data; self.data = [[(xb, (yb or torch.tensor(0)))], []]
    with self.cbs.suspendEval(), self.cbs.context():
        ex = lambda _: k1lib.raiseEx(k1lib.CancelBatchException)
        self.cbs.add(k1lib.Callback().withCheckpoint("startLoss" if yb is None else "startBackward", ex))
        self.run(1, 1)
    self.data = oldData; return self.y if yb is None else (self.y, self.loss)
@k1lib.patch(Learner)
def evaluate(self):
    """Function to visualize quickly how the network is doing. Undefined by default,
just placed here as a convention, so you have to do something like this::
    l = k1lib.Learner()
    def evaluate(self):
        xbs, ybs, ys = self.Recorder.record(1, 3)
        plt.plot(torch.vstack(xbs), torch.vstack(ys))
    l.evaluate = partial(evaluate(l))
"""
    raise NotImplementedError("You have to define evaluate() by yourself")
from k1lib.cli import *
@k1lib.patch(Learner, static=True)
def sample() -> Learner:
    """Creates an example learner, just for simple testing stuff anywhere. The
network tries to learn the function y=x. Only bare minimum callbacks are included."""
    l = Learner(); l.data = k1lib.kdata.FunctionData.main(lambda x: x)
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.lin1 = k1lib.knn.LinBlock(1, 3)
            self.lin2 = nn.Linear(3, 1)
        def forward(self, x):
            return ((x[:, None] + 2) | self.lin1 | self.lin2).squeeze()
    l.model = Model(); l.cbs = k1lib.Callbacks().add(Cbs.CoreNormal()).add(Cbs.Loss()).add(Cbs.ProgressBar())
    l.lossF = lambda y, yb: ((y - yb) ** 2).sum()
    l.opt = torch.optim.Adam(l.model.parameters(), lr=3e-3); return l