Source code for k1lib._learner

# AUTOGENERATED FILE! PLEASE DON'T EDIT
import k1lib, dill, traceback
from k1lib.callbacks import Cbs
from typing import Union
from time import time as _time
try: import torch; import torch.nn as nn; hasTorch = True
except:
    torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
    nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["CancelRunException", "CancelEpochException", "CancelBatchException",
           "Learner"]
[docs]class CancelRunException(Exception): """Used in core training loop, to skip the run entirely""" pass
[docs]class CancelEpochException(Exception): """Used in core training loop, to skip to next epoch""" pass
[docs]class CancelBatchException(Exception): """Used in core training loop, to skip to next batch""" pass
def _tab(text:Union[list, str], pad=" ") -> Union[list, str]: if isinstance(text, str): # this is old function that's replaced in main lib, but still useful return "\n".join([pad + line for line in text.split("\n")]) else: return [pad + line for line in text]
[docs]class Learner: def __init__(self): self._model = None; self._data = None; self._opt = None self._cbs = None; self.fileName = None self.css = "*"; self.exceptionRaised = None # slowly pops self.cbs = k1lib.Callbacks().withBasics().withQOL().withAdvanced() @property def model(self): """Set this to change the model to run""" return self._model @model.setter def model(self, model): self._model = model @property def data(self): """Set this to change the data (list of 2 dataloader) to run against.""" return self._data @data.setter def data(self, data): self._data = data @property def opt(self): """Set this to change the optimizer. If you're making your own optimizers, beware to follow the PyTorch's style guide as there are callbacks that modifies optimizer internals while training like :class:`k1lib.schedule.ParamScheduler`.""" return self._opt @opt.setter def opt(self, opt): self._opt = opt @property def cbs(self): """The :class:`~k1lib.callbacks.callbacks.Callbacks` object. Initialized to include all the common callbacks. You can set a new one if you want to.""" return self._cbs @cbs.setter def cbs(self, cbs): cbs.l = self; self._cbs = cbs @property def css(self) -> str: """The css selector string. Set this to select other parts of the network. After setting, you can access the selector like this: :code:`l.selector` See also: :class:`~k1lib.selector.ModuleSelector`""" return self._css @css.setter def css(self, css:str): self._css = css if self.model != None: self.selector = k1lib.selector.select(self.model, self.css) @property def lossF(self): """Set this to specify a loss function.""" raise NotImplementedError("lossF actually doesn't really exist. Used to exist as a core part of Learner, but then has been converted to Cbs.LossF") @lossF.setter def lossF(self, lossF): if hasattr(self.cbs, "LossF"): self.cbs.LossF.lossF = lossF else: self.cbs.add(Cbs.LossF(lossF)) def __getattr__(self, attr): if attr == "cbs": raise AttributeError() return getattr(self.cbs, attr) def __getstate__(self): answer = dict(self.__dict__); answer.pop("selector", None) answer.pop("_data", None); return answer def __setstate__(self, state): self.__dict__.update(state) self.__dict__["_data"] = None self.css = self.css; self.cbs.l = self
[docs] def evaluate(self): pass # supposed to be overriden, to provide functionality here
@property def _warnings(self): warnings = "Warning: no model yet. Set using `l.model = ...`\n" if self.model == None else "" lossClasses = tuple([*k1lib.Callback.lossCls]) lossFnCbs = [True for cb in self.cbs if isinstance(cb, lossClasses)] warnings += "Warning: no loss function callback detected (or you set `lossF` already but then erased all callbacks)! Set using `l.lossF = ...` or `l.cbs.add(Cbs.LossF(...))`\n" if len(lossFnCbs) == 0 else "" warnings += "Warning: no data yet. Set using `l.data = ...`\n" if self.data == None else "" warnings += "Warning: no optimizer yet. Set using `l.opt = ...`\n" if self.opt == None else "" if warnings != "": warnings += "\n\n" return warnings def __dir__(self): answer = list(super().__dir__()) answer.extend(self.cbs.cbsDict.keys()); return answer def __repr__(self): return f"""{self._warnings}l.model:\n{_tab(k1lib.limitLines(str(self.model)))} l.opt:\n{_tab(k1lib.limitLines(str(self.opt)))} l.cbs:\n{_tab(k1lib.limitLines(self.cbs.__repr__()))} Use... - l.model = ...: to specify a nn.Module object - l.data = ...: to specify data object - l.opt = ...: to specify an optimizer - l.lossF = ...: to specify a loss function - l.css = ...: to select modules using CSS. "#root" for root model - l.cbs = ...: to use a custom `Callbacks` object - l.selector: to get the modules selected by `l.css` - l.run(epochs): to run the network - l.Loss: to get a specific callback, this case "Loss"\n\n"""
@k1lib.patch(Learner) def save(self, fileName:str=None): """Saves this :class:`Learner` to file. See also: :meth:`load`. Does not save the ``data`` object, because that's potentially very big. Example:: l = k1.Learner() # saves learner to "skip1_128bs.pth" and model to "skip1_128bs.model.pth" l.save("skip1_128bs") :param fileName: name to save file into""" torch.save(self, f"{fileName}.pth", pickle_module=dill) torch.save(self.model, f"{fileName}.model.pth", pickle_module=dill) print(f"Saved to {fileName}") @k1lib.patch(Learner, static=True) def load(fileName:str=None): """Loads a :class:`Learner` from a file. See also: :meth:`save`. Example:: # this will load up learner in file "skip1_128bs.pth" l = k1.Learner.load("skip1_128bs") :param fileName: if empty, then will prompt for file name""" f = fileName or input("Enter learner file name to load:") print(f"Loaded from {f}"); return torch.load(f"{f}.pth", pickle_module=dill) @k1lib.patch(Learner) def _run1Batch(self): self.cbs("startBatch") try: self.cbs("startPass", "inPass", "endPass") self.cbs("startLoss", "inLoss", "endLoss") if not self.cbs("startBackward"): self.lossG.backward() if not self.cbs("startStep"): self.opt.step() if not self.cbs("startZeroGrad"): self.opt.zero_grad(set_to_none=True) except k1lib.CancelBatchException as ex: self.cbs("cancelBatch"); print(f"Batch cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "") except (k1lib.CancelEpochException, k1lib.CancelRunException) as ex: # makes sure cancelBatch and endBatch gets called, for potential # cleanups, then reraise the exception self.cbs("cancelBatch", "endBatch"); raise ex self.cbs("endBatch") class DI: # data interceptor, just to record data loading times def __init__(self, l:Learner, data): self.l = l; self.data = data def __len__(self): return len(self.data) def __iter__(self): try: data = iter(self.data); timings = self.l.cbs.timings while True: beginTime = _time(); d = next(data) timings.loadData += _time() - beginTime; yield d except StopIteration: pass @k1lib.patch(Learner) def _run1Epoch(self): self.cbs("startEpoch") try: train, valid = self.data; train = DI(self, train); valid = DI(self, valid) try: self.batches = len(train) + len(valid) except: pass self.model.train() for self.batch, (self.xb, self.yb, *self.metab) in enumerate(train): self._run1Batch() trainLen = self.batch + 1 if not self.cbs("startValidBatches"): self.model.eval(); for self.batch, (self.xb, self.yb, *self.metab) in enumerate(valid): self.batch += trainLen; self._run1Batch() if self.batches is None: self.batches = self.batch + 1 except k1lib.CancelEpochException as ex: self.cbs("cancelEpoch"); print(f"Epoch cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "") except k1lib.CancelRunException as ex: self.cbs("cancelEpoch", "endEpoch"); raise ex self.cbs("endEpoch") @k1lib.patch(Learner) def run(self, epochs:int, batches:int=None): """Main run function. :param epochs: number of epochs to run. 1 epoch is the length of the dataset :param batches: if set, then cancels the epoch after reaching the specified batch""" if self._warnings != "": if not input(f"""You still have these warnings:\n\n{self._warnings} Do you want to continue? (y/n) """).lower().startswith("y"): print("Run ended"); return self.epochs = int(epochs); self.batches = None self.css = self.css # update module selector with self.cbs.context(): if batches is not None: self.cbs.add(Cbs.BatchLimit(int(batches))) self.cbs("startRun") try: for self.epoch in range(self.epochs): self._run1Epoch() except k1lib.CancelRunException as ex: self.cbs("cancelRun"); print(f"Run cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "") self.cbs("endRun"); return self @k1lib.patch(Learner) def __call__(self, xb, yb=None): """Executes just a small batch. Convenience method to query how the network is doing. :param xb: x batch :param yb: y batch. If specified, return (y, loss), else return y alone """ oldData = self.data; self.data = [[(xb, (yb or torch.tensor(0)))], []] with self.cbs.suspendEval(), self.cbs.context(): ex = lambda _: k1lib.raiseEx(k1lib.CancelBatchException) self.cbs.add(k1lib.Callback().withCheckpoint("startLoss" if yb is None else "startBackward", ex)) self.run(1, 1) self.data = oldData; return self.y if yb is None else (self.y, self.loss) @k1lib.patch(Learner) def evaluate(self): """Function to visualize quickly how the network is doing. Undefined by default, just placed here as a convention, so you have to do something like this:: l = k1lib.Learner() def evaluate(self): xbs, ybs, ys = self.Recorder.record(1, 3) plt.plot(torch.vstack(xbs), torch.vstack(ys)) l.evaluate = partial(evaluate(l)) """ raise NotImplementedError("You have to define evaluate() by yourself") from k1lib.cli import * @k1lib.patch(Learner, static=True) def sample() -> Learner: """Creates an example learner, just for simple testing stuff anywhere. The network tries to learn the function y=x. Only bare minimum callbacks are included.""" l = Learner(); x = torch.linspace(-5, 5, 1000) l.data = [x, x] | transpose() | randomize(None) | splitW() | (repeatFrom() | randomize() | batched(32) | (transpose() | toTensor()).all()).all() | stagger.tv(300) | toList() class Model(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = k1lib.knn.LinBlock(1, 3) self.lin2 = nn.Linear(3, 1) def forward(self, x): return ((x[:, None] + 2) | self.lin1 | self.lin2).squeeze() l.model = Model(); l.cbs = k1lib.Callbacks().add(Cbs.CoreNormal()).add(Cbs.Loss()).add(Cbs.ProgressBar()) l.lossF = lambda y, yb: ((y - yb) ** 2).sum() l.opt = torch.optim.Adam(l.model.parameters(), lr=3e-3); return l