# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
import k1lib, dill, traceback
from k1lib.callbacks import Cbs
from typing import Union
from time import time as _time
try: import torch; import torch.nn as nn; hasTorch = True
except:
    torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
    nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["CancelRunException", "CancelEpochException", "CancelBatchException",
           "Learner"]
[docs]class CancelRunException(Exception):                                             # CancelRunException
    """Used in core training loop, to skip the run entirely"""                   # CancelRunException
    pass                                                                         # CancelRunException 
[docs]class CancelEpochException(Exception):                                           # CancelEpochException
    """Used in core training loop, to skip to next epoch"""                      # CancelEpochException
    pass                                                                         # CancelEpochException 
[docs]class CancelBatchException(Exception):                                           # CancelBatchException
    """Used in core training loop, to skip to next batch"""                      # CancelBatchException
    pass                                                                         # CancelBatchException 
def _tab(text:Union[list, str], pad="    ") -> Union[list, str]:                 # _tab
    if isinstance(text, str): # this is old function that's replaced in main lib, but still useful # _tab
        return "\n".join([pad + line for line in text.split("\n")])              # _tab
    else: return [pad + line for line in text]                                   # _tab
[docs]class Learner:                                                                   # Learner
    def __init__(self):                                                          # Learner
        self._model = None; self._data = None; self._opt = None                  # Learner
        self._cbs = None; self.fileName = None                                   # Learner
        self.css = "*"; self.exceptionRaised = None # slowly pops                # Learner
        self.cbs = k1lib.Callbacks().withBasics().withQOL().withAdvanced()       # Learner
    @property                                                                    # Learner
    def model(self):                                                             # Learner
        """Set this to change the model to run"""                                # Learner
        return self._model                                                       # Learner
    @model.setter                                                                # Learner
    def model(self, model): self._model = model                                  # Learner
    @property                                                                    # Learner
    def data(self):                                                              # Learner
        """Set this to change the data (list of 2 dataloader) to run against.""" # Learner
        return self._data                                                        # Learner
    @data.setter                                                                 # Learner
    def data(self, data): self._data = data                                      # Learner
    @property                                                                    # Learner
    def opt(self):                                                               # Learner
        """Set this to change the optimizer. If you're making your own
optimizers, beware to follow the PyTorch's style guide as there are callbacks
that modifies optimizer internals while training like
:class:`k1lib.schedule.ParamScheduler`."""                                       # Learner
        return self._opt                                                         # Learner
    @opt.setter                                                                  # Learner
    def opt(self, opt): self._opt = opt                                          # Learner
    @property                                                                    # Learner
    def cbs(self):                                                               # Learner
        """The :class:`~k1lib.callbacks.callbacks.Callbacks` object. Initialized to
include all the common callbacks. You can set a new one if you want to."""       # Learner
        return self._cbs                                                         # Learner
    @cbs.setter                                                                  # Learner
    def cbs(self, cbs): cbs.l = self; self._cbs = cbs                            # Learner
    @property                                                                    # Learner
    def css(self) -> str:                                                        # Learner
        """The css selector string. Set this to select other parts of the network.
After setting, you can access the selector like this: :code:`l.selector`
See also: :class:`~k1lib.selector.ModuleSelector`"""                             # Learner
        return self._css                                                         # Learner
    @css.setter                                                                  # Learner
    def css(self, css:str):                                                      # Learner
        self._css = css                                                          # Learner
        if self.model != None: self.selector = k1lib.selector.select(self.model, self.css) # Learner
    @property                                                                    # Learner
    def lossF(self):                                                             # Learner
        """Set this to specify a loss function."""                               # Learner
        raise NotImplementedError("lossF actually doesn't really exist. Used to exist as a core part of Learner, but then has been converted to Cbs.LossF") # Learner
    @lossF.setter                                                                # Learner
    def lossF(self, lossF):                                                      # Learner
        if hasattr(self.cbs, "LossF"): self.cbs.LossF.lossF = lossF              # Learner
        else: self.cbs.add(Cbs.LossF(lossF))                                     # Learner
    def __getattr__(self, attr):                                                 # Learner
        if attr == "cbs": raise AttributeError()                                 # Learner
        return getattr(self.cbs, attr)                                           # Learner
    def __getstate__(self):                                                      # Learner
        answer = dict(self.__dict__); answer.pop("selector", None)               # Learner
        answer.pop("_data", None); return answer                                 # Learner
    def __setstate__(self, state):                                               # Learner
        self.__dict__.update(state)                                              # Learner
        self.__dict__["_data"] = None                                            # Learner
        self.css = self.css; self.cbs.l = self                                   # Learner
[docs]    def evaluate(self): pass # supposed to be overriden, to provide functionality here # Learner 
    @property                                                                    # Learner
    def _warnings(self):                                                         # Learner
        warnings = "Warning: no model yet. Set using `l.model = ...`\n" if self.model == None else "" # Learner
        lossClasses = tuple([*k1lib.Callback.lossCls])                           # Learner
        lossFnCbs = [True for cb in self.cbs if isinstance(cb, lossClasses)]     # Learner
        warnings += "Warning: no loss function callback detected (or you set `lossF` already but then erased all callbacks)! Set using `l.lossF = ...` or `l.cbs.add(Cbs.LossF(...))`\n" if len(lossFnCbs) == 0 else "" # Learner
        warnings += "Warning: no data yet. Set using `l.data = ...`\n" if self.data == None else "" # Learner
        warnings += "Warning: no optimizer yet. Set using `l.opt = ...`\n" if self.opt == None else "" # Learner
        if warnings != "": warnings += "\n\n"                                    # Learner
        return warnings                                                          # Learner
    def __dir__(self):                                                           # Learner
        answer = list(super().__dir__())                                         # Learner
        answer.extend(self.cbs.cbsDict.keys()); return answer                    # Learner
    def __repr__(self):                                                          # Learner
        return f"""{self._warnings}l.model:\n{_tab(k1lib.limitLines(str(self.model)))}
l.opt:\n{_tab(k1lib.limitLines(str(self.opt)))}
l.cbs:\n{_tab(k1lib.limitLines(self.cbs.__repr__()))}
Use...
- l.model = ...: to specify a nn.Module object
- l.data = ...: to specify data object
- l.opt = ...: to specify an optimizer
- l.lossF = ...: to specify a loss function
- l.css = ...: to select modules using CSS. "#root" for root model
- l.cbs = ...: to use a custom `Callbacks` object
- l.selector: to get the modules selected by `l.css`
- l.run(epochs): to run the network
- l.Loss: to get a specific callback, this case "Loss"\n\n"""                    # Learner 
@k1lib.patch(Learner)                                                            # Learner
def save(self, fileName:str=None):                                               # save
    """Saves this :class:`Learner` to file. See also: :meth:`load`. Does not
save the ``data`` object, because that's potentially very big. Example::
    l = k1.Learner()
    # saves learner to "skip1_128bs.pth" and model to "skip1_128bs.model.pth"
    l.save("skip1_128bs")
:param fileName: name to save file into"""                                       # save
    torch.save(self, f"{fileName}.pth", pickle_module=dill)                      # save
    torch.save(self.model, f"{fileName}.model.pth", pickle_module=dill)          # save
    print(f"Saved to {fileName}")                                                # save
@k1lib.patch(Learner, static=True)                                               # save
def load(fileName:str=None):                                                     # load
    """Loads a :class:`Learner` from a file. See also: :meth:`save`.
Example::
    # this will load up learner in file "skip1_128bs.pth"
    l = k1.Learner.load("skip1_128bs")
:param fileName: if empty, then will prompt for file name"""                     # load
    f = fileName or input("Enter learner file name to load:")                    # load
    print(f"Loaded from {f}"); return torch.load(f"{f}.pth", pickle_module=dill) # load
@k1lib.patch(Learner)                                                            # load
def _run1Batch(self):                                                            # _run1Batch
    self.cbs("startBatch")                                                       # _run1Batch
    try:                                                                         # _run1Batch
        self.cbs("startPass", "inPass", "endPass")                               # _run1Batch
        self.cbs("startLoss", "inLoss", "endLoss")                               # _run1Batch
        if not self.cbs("startBackward"): self.lossG.backward()                  # _run1Batch
        if not self.cbs("startStep"):  self.opt.step()                           # _run1Batch
        if not self.cbs("startZeroGrad"): self.opt.zero_grad(set_to_none=True)   # _run1Batch
    except k1lib.CancelBatchException as ex:                                     # _run1Batch
        self.cbs("cancelBatch"); print(f"Batch cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "") # _run1Batch
    except (k1lib.CancelEpochException, k1lib.CancelRunException) as ex:         # _run1Batch
        # makes sure cancelBatch and endBatch gets called, for potential         # _run1Batch
        # cleanups, then reraise the exception                                   # _run1Batch
        self.cbs("cancelBatch", "endBatch"); raise ex                            # _run1Batch
    self.cbs("endBatch")                                                         # _run1Batch
class DI: # data interceptor, just to record data loading times                  # DI
    def __init__(self, l:Learner, data): self.l = l; self.data = data            # DI
    def __len__(self): return len(self.data)                                     # DI
    def __iter__(self):                                                          # DI
        try:                                                                     # DI
            data = iter(self.data); timings = self.l.cbs.timings                 # DI
            while True:                                                          # DI
                beginTime = _time(); d = next(data)                              # DI
                timings.loadData += _time() - beginTime; yield d                 # DI
        except StopIteration: pass                                               # DI
@k1lib.patch(Learner)                                                            # DI
def _run1Epoch(self):                                                            # _run1Epoch
    self.cbs("startEpoch")                                                       # _run1Epoch
    try:                                                                         # _run1Epoch
        train, valid = self.data; train = DI(self, train); valid = DI(self, valid) # _run1Epoch
        try: self.batches = len(train) + len(valid)                              # _run1Epoch
        except: pass                                                             # _run1Epoch
        self.model.train()                                                       # _run1Epoch
        for self.batch, (self.xb, self.yb, *self.metab) in enumerate(train):     # _run1Epoch
            self._run1Batch()                                                    # _run1Epoch
        trainLen = self.batch + 1                                                # _run1Epoch
        if not self.cbs("startValidBatches"):                                    # _run1Epoch
            self.model.eval();                                                   # _run1Epoch
            for self.batch, (self.xb, self.yb, *self.metab) in enumerate(valid): # _run1Epoch
                self.batch += trainLen; self._run1Batch()                        # _run1Epoch
        if self.batches is None: self.batches = self.batch + 1                   # _run1Epoch
    except k1lib.CancelEpochException as ex:                                     # _run1Epoch
        self.cbs("cancelEpoch"); print(f"Epoch cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "") # _run1Epoch
    except k1lib.CancelRunException as ex:                                       # _run1Epoch
        self.cbs("cancelEpoch", "endEpoch"); raise ex                            # _run1Epoch
    self.cbs("endEpoch")                                                         # _run1Epoch
@k1lib.patch(Learner)                                                            # _run1Epoch
def run(self, epochs:int, batches:int=None):                                     # run
    """Main run function.
:param epochs: number of epochs to run. 1 epoch is the length of the dataset
:param batches: if set, then cancels the epoch after reaching the specified batch""" # run
    if self._warnings != "":                                                     # run
        if not input(f"""You still have these warnings:\n\n{self._warnings}
Do you want to continue? (y/n) """).lower().startswith("y"):                     # run
            print("Run ended"); return                                           # run
    self.epochs = int(epochs); self.batches = None                               # run
    self.css = self.css # update module selector                                 # run
    with self.cbs.context():                                                     # run
        if batches is not None: self.cbs.add(Cbs.BatchLimit(int(batches)))       # run
        self.cbs("startRun")                                                     # run
        try:                                                                     # run
            for self.epoch in range(self.epochs): self._run1Epoch()              # run
        except k1lib.CancelRunException as ex:                                   # run
            self.cbs("cancelRun"); print(f"Run cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "") # run
        self.cbs("endRun"); return self                                          # run
@k1lib.patch(Learner)                                                            # run
def __call__(self, xb, yb=None):                                                 # __call__
    """Executes just a small batch. Convenience method to query how the network is
doing.
:param xb: x batch
:param yb: y batch. If specified, return (y, loss), else return y alone
"""                                                                              # __call__
    oldData = self.data; self.data = [[(xb, (yb or torch.tensor(0)))], []]       # __call__
    with self.cbs.suspendEval(), self.cbs.context():                             # __call__
        ex = lambda _: k1lib.raiseEx(k1lib.CancelBatchException)                 # __call__
        self.cbs.add(k1lib.Callback().withCheckpoint("startLoss" if yb is None else "startBackward", ex)) # __call__
        self.run(1, 1)                                                           # __call__
    self.data = oldData; return self.y if yb is None else (self.y, self.loss)    # __call__
@k1lib.patch(Learner)                                                            # __call__
def evaluate(self):                                                              # evaluate
    """Function to visualize quickly how the network is doing. Undefined by default,
just placed here as a convention, so you have to do something like this::
    l = k1lib.Learner()
    def evaluate(self):
        xbs, ybs, ys = self.Recorder.record(1, 3)
        plt.plot(torch.vstack(xbs), torch.vstack(ys))
    l.evaluate = partial(evaluate(l))
"""                                                                              # evaluate
    raise NotImplementedError("You have to define evaluate() by yourself")       # evaluate
from k1lib.cli import *                                                          # evaluate
@k1lib.patch(Learner, static=True)                                               # evaluate
def sample() -> Learner:                                                         # sample
    """Creates an example learner, just for simple testing stuff anywhere. The
network tries to learn the function y=x. Only bare minimum callbacks are included.""" # sample
    l = Learner(); x = torch.linspace(-5, 5, 1000)                               # sample
    l.data = [x, x] | transpose() | randomize(None) | splitW() | (repeatFrom() | randomize() | batched(32) | (transpose() | toTensor()).all()).all() | stagger.tv(300) | toList() # sample
    class Model(torch.nn.Module):                                                # sample
        def __init__(self):                                                      # sample
            super().__init__()                                                   # sample
            self.lin1 = k1lib.knn.LinBlock(1, 3)                                 # sample
            self.lin2 = nn.Linear(3, 1)                                          # sample
        def forward(self, x):                                                    # sample
            return ((x[:, None] + 2) | self.lin1 | self.lin2).squeeze()          # sample
    l.model = Model(); l.cbs = k1lib.Callbacks().add(Cbs.CoreNormal()).add(Cbs.Loss()).add(Cbs.ProgressBar()) # sample
    l.lossF = lambda y, yb: ((y - yb) ** 2).sum()                                # sample
    l.opt = torch.optim.Adam(l.model.parameters(), lr=3e-3); return l            # sample