Source code for k1lib._learner

# AUTOGENERATED FILE! PLEASE DON'T EDIT
import k1lib, torch.nn as nn, torch, dill
from time import time as _time
__all__ = ["CancelRunException", "CancelEpochException", "CancelBatchException",
           "Learner"]
[docs]class CancelRunException(Exception): """Used in core training loop, to skip the run entirely""" pass
[docs]class CancelEpochException(Exception): """Used in core training loop, to skip to next epoch""" pass
[docs]class CancelBatchException(Exception): """Used in core training loop, to skip to next batch""" pass
[docs]class Learner: def __init__(self): self._model = None; self._data = None; self._opt = None self._cbs = None; self.fileName = None self.css = "*"; self.exceptionRaised = None # slowly pops self.cbs = k1lib.Callbacks().withBasics().withQOL().withAdvanced() @property def model(self): """Set this to change the model to run""" return self._model @model.setter def model(self, model): self._model = model @property def data(self): """Set this to change the data (of type :class:`k1lib.data.Data`) to run against.""" return self._data @data.setter def data(self, data): self._data = data @property def opt(self): """Set this to change the optimizer. If you're making your own optimizers, beware to follow the PyTorch's style guide as there are callbacks that modifies optimizer internals while training like :class:`k1lib.schedule.ParamScheduler`.""" return self._opt @opt.setter def opt(self, opt): self._opt = opt @property def cbs(self): """The :class:`~k1lib.callbacks.callbacks.Callbacks` object. Initialized to include all the common callbacks. You can set a new one if you want to.""" return self._cbs @cbs.setter def cbs(self, cbs): cbs.l = self; self._cbs = cbs @property def css(self) -> str: """The css selector string. Set this to select other parts of the network. After setting, you can access the selector like this: :code:`l.selector` See also: :class:`~k1lib.selector.ModuleSelector`""" return self._css @css.setter def css(self, css:str): self._css = css if self.model != None: self.selector = k1lib.selector.select(self.model, self.css) @property def lossF(self): """Set this to specify a loss function.""" raise NotImplementedError("lossF actually doesn't really exist. Used to exist as a core part of Learner, but then has been converted to k1lib.callbacks.lossFunctions.LossLambda") @lossF.setter def lossF(self, lossF): if hasattr(self.cbs, "LossLambda"): self.cbs.LossLambda.lossF = lossF else: self.cbs.withLossLambda(lossF) def __getattr__(self, attr): if attr == "cbs": raise AttributeError() return getattr(self.cbs, attr) def __getstate__(self): answer = dict(self.__dict__); del answer["selector"]; return answer def __setstate__(self, state): self.__dict__.update(state) self.css = self.css; self.cbs.l = self
[docs] def evaluate(self): pass # supposed to be overriden, to provide functionality here
@property def _warnings(self): warnings = "Warning: no model yet. Set using `l.model = ...`\n" if self.model == None else "" lossClasses = tuple([*k1lib.Callback.lossCls]) lossFnCbs = [True for cb in self.cbs if isinstance(cb, lossClasses)] warnings += "Warning: no loss function callback detected (or you set `lossF` already but then erased all callbacks)! Set using `l.lossF = ...` or `l.cbs.withLossLambda(...)`\n" if len(lossFnCbs) == 0 else "" warnings += "Warning: no data yet. Set using `l.data = ...`\n" if self.data == None else "" warnings += "Warning: no optimizer yet. Set using `l.opt = ...`\n" if self.opt == None else "" if warnings != "": warnings += "\n\n" return warnings def __dir__(self): answer = list(super().__dir__()) answer.extend(self.cbs.cbsDict.keys()); return answer def __repr__(self): return f"""{self._warnings}l.model:\n{k1lib.tab(k1lib.limitLines(str(self.model)))} l.opt:\n{k1lib.tab(k1lib.limitLines(str(self.opt)))} l.cbs:\n{k1lib.tab(k1lib.limitLines(self.cbs.__repr__()))} Use... - l.model = ...: to specify a nn.Module object - l.data = ...: to specify data object - l.opt = ...: to specify an optimizer - l.lossF = ...: to specify a loss function - l.css = ...: to select modules using CSS. "#root" for root model - l.cbs = ...: to use a custom `Callbacks` object - l.selector: to get the modules selected by `l.css` - l.run(epochs): to run the network - l.Loss: to get a specific callback, this case "Loss"\n\n"""
@k1lib.patch(Learner) def save(self, fileName:str=None): """Saves this :class:`Learner` to file. See also: :meth:`load` :param fileName: if empty, then will save as "learner-0.pth", with 0 changeable to avoid conflicts. If resave this exact :class:`Learner`, then use the old name generated before""" self.fileName = fileName or self.fileName if self.fileName == None: files = [file for file in os.listdir() if file.startswith("learner") and file.endswith(".pth")] files = set([int(file.split(".pth")[0].split("learner-")[1]) for file in files]) count = 0; while count in files: count += 1 self.fileName = f"l-{count}.pth" torch.save(self, self.fileName, pickle_module=dill) print(f"Saved to {self.fileName}") @k1lib.patch(Learner, static=True) def load(fileName:str=None): """Loads a :class:`Learner` from a file. See also: :meth:`save` :param fileName: if empty, then will prompt for file name""" f = fileName or input("Enter learner file name to load:") print(f"Loaded from {f}"); return torch.load(f, pickle_module=dill) @k1lib.patch(Learner) def _run1Batch(self): self.cbs("startBatch") try: self.cbs("startPass", "inPass", "endPass") self.cbs("startLoss", "inLoss", "endLoss") if not self.cbs("startBackward"): self.lossG.backward() if not self.cbs("startStep"): self.opt.step() if not self.cbs("startZeroGrad"): self.opt.zero_grad(set_to_none=True) except k1lib.CancelBatchException as ex: self.cbs("cancelBatch"); print(f"Batch cancelled: {ex}.") except (k1lib.CancelEpochException, k1lib.CancelRunException) as ex: # makes sure cancelBatch and endBatch gets called, for potential # cleanups, then reraise the exception self.cbs("cancelBatch", "endBatch"); raise ex self.cbs("endBatch") class DI: # data interceptor, just to record data loading times def __init__(self, l:Learner, data): self.l = l; self.data = data def __len__(self): return len(self.data) def __iter__(self): try: data = iter(self.data); timings = self.l.cbs.timings while True: beginTime = _time(); d = next(data) timings.loadData += _time() - beginTime; yield d except StopIteration: pass @k1lib.patch(Learner) def _run1Epoch(self): self.cbs("startEpoch") try: train, valid = self.data; train = DI(self, train); valid = DI(self, valid) try: self.batches = len(train) + len(valid) except: self.batches = None self.model.train() for self.batch, (self.xb, self.yb) in enumerate(train): self._run1Batch() trainLen = self.batch + 1 if not self.cbs("startValidBatches"): self.model.eval(); for self.batch, (self.xb, self.yb) in enumerate(valid): self.batch += trainLen; self._run1Batch() except k1lib.CancelEpochException as ex: self.cbs("cancelEpoch"); print(f"Epoch cancelled: {ex}.") except k1lib.CancelRunException as ex: self.cbs("cancelEpoch", "endEpoch"); raise ex self.cbs("endEpoch") @k1lib.patch(Learner) def run(self, epochs:int, batches:int=None): """Main run function. :param epochs: number of epochs to run. 1 epoch is the length of the dataset :param batches: if set, then cancels the epoch after reaching the specified batch""" if self._warnings != "": if not input(f"""You still have these warnings:\n\n{self._warnings} Do you want to continue? (y/n) """).lower().startswith("y"): print("Run ended"); return self.epochs = epochs; self.css = self.css # update module selector with self.cbs.context(): if batches != None: self.cbs.withBatchLimit(batches) self.cbs("startRun") try: for self.epoch in range(epochs): self._run1Epoch() except k1lib.CancelRunException as ex: self.cbs("cancelRun"); print(f"Run cancelled: {ex}.") self.cbs("endRun"); return self @k1lib.patch(Learner) def __call__(self, xb, yb=None): """Executes just a small batch. Convenience method to query how the network is doing. :param xb: x batch :param yb: y batch. If specified, return (y, loss), else return y alone """ def gen(): yield xb, (yb or torch.tensor(0)) oldData = self.data; self.data = k1lib.data.Data(gen(), iter(range(0))) with self.cbs.suspendEval(), self.cbs.context(): ex = lambda _: k1lib.raiseEx(k1lib.CancelBatchException) self.cbs.append(k1lib.Callback().withCheckpoint("startLoss" if yb is None else "startBackward", ex)) self.run(1, 1) self.data = oldData; return self.y if yb is None else (self.y, self.loss) @k1lib.patch(Learner) def evaluate(self): """Function to visualize quickly how the network is doing. Undefined by default, just placed here as a convention, so you have to do something like this:: l = k1lib.Learner() def evaluate(self): xbs, ybs, ys = self.Recorder.record(1, 3) plt.plot(torch.vstack(xbs), torch.vstack(ys)) l.evaluate = partial(evaluate(l)) """ raise NotImplementedError("You have to define evaluate() by yourself") from k1lib.bioinfo.cli import * def getXbs(): return torch.linspace(0, 10, 50) | repeatFrom() | batched(32) | toTensor().all() @k1lib.patch(Learner, static=True) def sample() -> Learner: """Creates an example learner, just for simple testing stuff anywhere. The network tries to learn the function y=x.""" class DS: def __iter__(self): return [getXbs(), getXbs()] | transpose() | head(100) l = Learner(); l.data = k1lib.data.Data(DS(), range(0)) class Model(torch.nn.Module): def __init__(self): super().__init__(); self.linear = torch.nn.Linear(1, 1) def forward(self, x): x = x[:, None]; return self.linear(x + 2).squeeze() l.model = Model(); l.cbs = k1lib.Callbacks().withCoreNormal().withLoss().withProgressBar() l.lossF = lambda y, yb: ((y - yb) ** 2).sum() l.opt = torch.optim.Adam(l.model.parameters(), lr=3e-3); return l