# AUTOGENERATED FILE! PLEASE DON'T EDIT
import k1lib, torch.nn as nn, torch, dill
[docs]class Learner:
    def __init__(self):
        self._model = None; self._data = None; self._opt = None
        self._cbs = None; self.fileName = None
        self.css = "*"; self.exceptionRaised = None # slowly pops
        self.cbs = k1lib.Callbacks().withBasics().withQOL().withAdvanced()
    @property
    def model(self):
        """Set this to change the model to run"""
        return self._model
    @model.setter
    def model(self, model): self._model = model
    @property
    def data(self):
        """Set this to change the data (of type :class:`k1lib.data.Data`) to
run against."""
        return self._data
    @data.setter
    def data(self, data): self._data = data
    @property
    def opt(self):
        """Set this to change the optimizer. If you're making your own
optimizers, beware to follow the PyTorch's style guide as there are callbacks
that modifies optimizer internals while training like
:class:`k1lib.schedule.ParamScheduler`."""
        return self._opt
    @opt.setter
    def opt(self, opt): self._opt = opt
    @property
    def cbs(self):
        """The :class:`~k1lib.callbacks.callbacks.Callbacks` object. Initialized to
include all the common callbacks. You can set a new one if you want to."""
        return self._cbs
    @cbs.setter
    def cbs(self, cbs): cbs.l = self; self._cbs = cbs
    @property
    def css(self) -> str:
        """The css selector string. Set this to select other parts of the network.
After setting, you can access the selector like this: :code:`l.selector`
See also: :class:`~k1lib.selector.ModuleSelector`"""
        return self._css
    @css.setter
    def css(self, css:str):
        self._css = css
        if self.model != None: self.selector = k1lib.selector.select(self.model, self.css)
    @property
    def lossF(self):
        """Set this to specify a loss function."""
        raise NotImplementedError("lossF actually doesn't really exist. Used to exist as a core part of Learner, but then has been converted to k1lib.callbacks.lossFunctions.LossLambda")
    @lossF.setter
    def lossF(self, lossF):
        if hasattr(self.cbs, "LossLambda"): self.cbs.LossLambda.lossF = lossF
        else: self.cbs.withLossLambda(lossF)
    def __getattr__(self, attr):
        if attr == "cbs": raise AttributeError()
        return getattr(self.cbs, attr)
    def __getstate__(self):
        answer = dict(self.__dict__); del answer["selector"]; return answer
    def __setstate__(self, state):
        self.__dict__.update(state)
        self.css = self.css; self.cbs.l = self
[docs]    def evaluate(self): pass # supposed to be overriden, to provide functionality here 
    @property
    def _warnings(self):
        warnings = "Warning: no model yet. Set using `l.model = ...`\n" if self.model == None else ""
        lossFnCbs = [True for cb in self.cbs if cb.__module__.startswith("k1lib.callbacks.lossFunctions.")]
        warnings += "Warning: no loss function callback detected! Set using `l.lossF = ...` or `l.cbs.withLossLambda(...)`\n" if len(lossFnCbs) == 0 else ""
        warnings += "Warning: no data yet. Set using `l.data = ...`\n" if self.data == None else ""
        warnings += "Warning: no optimizer yet. Set using `l.opt = ...`\n" if self.opt == None else ""
        if warnings != "": warnings += "\n\n"
        return warnings
    def __dir__(self):
        answer = list(super().__dir__())
        answer.extend(self.cbs.cbsDict.keys())
        return answer
    def __repr__(self):
        return f"""{self._warnings}l.model:\n{k1lib.tab(k1lib.limitLines(str(self.model)))}
l.opt:\n{k1lib.tab(k1lib.limitLines(str(self.opt)))}
l.cbs:\n{k1lib.tab(k1lib.limitLines(self.cbs.__repr__()))}
Use...
- l.model = ...: to specify a nn.Module object
- l.data = ...: to specify data object
- l.opt = ...: to specify an optimizer
- l.lossF = ...: to specify a loss function
- l.css = ...: to select modules using CSS. "#root" for root model
- l.cbs = ...: to use a custom `Callbacks` object
- l.selector: to get the modules selected by `l.css`
- l.run(epochs): to run the network
- l.Loss: to get a specific callback, this case "Loss"\n\n""" 
@k1lib.patch(Learner)
def save(self, fileName:str=None):
    """Saves this :class:`Learner` to file. See also: :meth:`load`
:param fileName: if empty, then will save as "learner-0.pth", with 0
    changeable to avoid conflicts. If resave this exact :class:`Learner`, then
    use the old name generated before"""
    self.fileName = fileName or self.fileName
    if self.fileName == None:
        files = [file for file in os.listdir() if file.startswith("learner") and file.endswith(".pth")]
        files = set([int(file.split(".pth")[0].split("learner-")[1]) for file in files])
        count = 0;
        while count in files: count += 1
        self.fileName = f"l-{count}.pth"
    torch.save(self, self.fileName, pickle_module=dill)
    print(f"Saved to {self.fileName}")
@k1lib.patch(Learner, static=True)
def load(fileName:str=None):
    """Loads a :class:`Learner` from a file. See also: :meth:`save`
:param fileName: if empty, then will prompt for file name"""
    f = fileName or input("Enter learner file name to load:")
    print(f"Loaded from {f}"); return torch.load(f, pickle_module=dill)
@k1lib.patch(Learner)
def _run1Batch(self):
    self.cbs("startBatch")
    try:
        self.cbs("startPass", "inPass", "endPass")
        self.cbs("startLoss", "inLoss", "endLoss")
        if not self.cbs("startBackward"): self.lossG.backward()
        if not self.cbs("startStep"):  self.opt.step()
        if not self.cbs("startZeroGrad"): self.opt.zero_grad(set_to_none=True)
    except k1lib.CancelBatchException as ex:
        self.cbs("cancelBatch"); print(f"Batch cancelled: {ex}.")
    except (k1lib.CancelEpochException, k1lib.CancelRunException) as ex:
        # makes sure cancelBatch and endBatch gets called, for potential
        # cleanups, then reraise the exception
        self.cbs("cancelBatch", "endBatch"); raise ex
    self.cbs("endBatch")
@k1lib.patch(Learner)
def _run1Epoch(self):
    self.cbs("startEpoch")
    try:
        try: self.batches = len(self.data.train) + len(self.data.valid)
        except: self.batches = None
        self.model.train()
        for self.batch, (self.xb, self.yb) in enumerate(self.data.train):
            self._run1Batch()
        trainLen = self.batch + 1
        if not self.cbs("startValidBatches"):
            self.model.eval();
            for self.batch, (self.xb, self.yb) in enumerate(self.data.valid):
                self.batch += trainLen; self._run1Batch()
    except k1lib.CancelEpochException as ex:
        self.cbs("cancelEpoch"); print(f"Epoch cancelled: {ex}.")
    except k1lib.CancelRunException as ex:
        self.cbs("cancelEpoch", "endEpoch"); raise ex
    self.cbs("endEpoch")
@k1lib.patch(Learner)
def run(self, epochs:int, batches:int=None):
    """Main run function.
:param epochs: number of epochs to run. 1 epoch is the length of the dataset
:param batches: if set, then cancels the epoch after reaching the specified batch"""
    if self._warnings != "":
        if not input(f"""You still have these warnings:\n\n{self._warnings}
Do you want to continue? (y/n) """).lower().startswith("y"):
            print("Run ended"); return
    self.epochs = epochs; self.css = self.css # update module selector
    with self.cbs.context():
        if batches != None: self.cbs.withBatchLimit(batches)
        self.cbs("startRun")
        try:
            for self.epoch in range(epochs): self._run1Epoch()
        except k1lib.CancelRunException as ex:
            self.cbs("cancelRun"); print(f"Run cancelled: {ex}.")
        self.cbs("endRun"); return self
@k1lib.patch(Learner)
def __call__(self, xb, yb=None):
    """Executes just a small batch. Convenience method to query how the network is
doing.
:param xb: x batch
:param yb: y batch. If specified, return (y, loss), else return y alone
"""
    def gen(): yield xb, (yb or torch.tensor(0))
    oldData = self.data; self.data = k1lib.data.Data(gen(), iter(range(0)))
    with self.cbs.suspendEval(), self.cbs.context():
        ex = lambda _: k1lib.raiseEx(k1lib.CancelBatchException)
        self.cbs.append(k1lib.Callback().withCheckpoint("startLoss" if yb is None else "startBackward", ex))
        self.run(1, 1)
    self.data = oldData; return self.y if yb is None else (self.y, self.loss)
@k1lib.patch(Learner)
def evaluate(self):
    """Function to visualize quickly how the network is doing. Undefined by default,
just placed here as a convention, so you have to do something like this::
    l = k1lib.Learner()
    def evaluate(self):
        xbs, ybs, ys = self.Recorder.record(1, 3)
        plt.plot(torch.vstack(xbs), torch.vstack(ys))
    l.evaluate = partial(evaluate(l))
"""
    raise NotImplementedError("You have to define evaluate() by yourself")