# AUTOGENERATED FILE! PLEASE DON'T EDIT
import k1lib, torch.nn as nn, torch, dill
from time import time as _time
__all__ = ["CancelRunException", "CancelEpochException", "CancelBatchException",
"Learner"]
[docs]class CancelRunException(Exception):
"""Used in core training loop, to skip the run entirely"""
pass
[docs]class CancelEpochException(Exception):
"""Used in core training loop, to skip to next epoch"""
pass
[docs]class CancelBatchException(Exception):
"""Used in core training loop, to skip to next batch"""
pass
[docs]class Learner:
def __init__(self):
self._model = None; self._data = None; self._opt = None
self._cbs = None; self.fileName = None
self.css = "*"; self.exceptionRaised = None # slowly pops
self.cbs = k1lib.Callbacks().withBasics().withQOL().withAdvanced()
@property
def model(self):
"""Set this to change the model to run"""
return self._model
@model.setter
def model(self, model): self._model = model
@property
def data(self):
"""Set this to change the data (list of 2 dataloader) to run against."""
return self._data
@data.setter
def data(self, data): self._data = data
@property
def opt(self):
"""Set this to change the optimizer. If you're making your own
optimizers, beware to follow the PyTorch's style guide as there are callbacks
that modifies optimizer internals while training like
:class:`k1lib.schedule.ParamScheduler`."""
return self._opt
@opt.setter
def opt(self, opt): self._opt = opt
@property
def cbs(self):
"""The :class:`~k1lib.callbacks.callbacks.Callbacks` object. Initialized to
include all the common callbacks. You can set a new one if you want to."""
return self._cbs
@cbs.setter
def cbs(self, cbs): cbs.l = self; self._cbs = cbs
@property
def css(self) -> str:
"""The css selector string. Set this to select other parts of the network.
After setting, you can access the selector like this: :code:`l.selector`
See also: :class:`~k1lib.selector.ModuleSelector`"""
return self._css
@css.setter
def css(self, css:str):
self._css = css
if self.model != None: self.selector = k1lib.selector.select(self.model, self.css)
@property
def lossF(self):
"""Set this to specify a loss function."""
raise NotImplementedError("lossF actually doesn't really exist. Used to exist as a core part of Learner, but then has been converted to k1lib.callbacks.lossFunctions.LossLambda")
@lossF.setter
def lossF(self, lossF):
if hasattr(self.cbs, "LossLambda"): self.cbs.LossLambda.lossF = lossF
else: self.cbs.withLossLambda(lossF)
def __getattr__(self, attr):
if attr == "cbs": raise AttributeError()
return getattr(self.cbs, attr)
def __getstate__(self):
answer = dict(self.__dict__); del answer["selector"]; return answer
def __setstate__(self, state):
self.__dict__.update(state)
self.css = self.css; self.cbs.l = self
[docs] def evaluate(self): pass # supposed to be overriden, to provide functionality here
@property
def _warnings(self):
warnings = "Warning: no model yet. Set using `l.model = ...`\n" if self.model == None else ""
lossClasses = tuple([*k1lib.Callback.lossCls])
lossFnCbs = [True for cb in self.cbs if isinstance(cb, lossClasses)]
warnings += "Warning: no loss function callback detected (or you set `lossF` already but then erased all callbacks)! Set using `l.lossF = ...` or `l.cbs.withLossLambda(...)`\n" if len(lossFnCbs) == 0 else ""
warnings += "Warning: no data yet. Set using `l.data = ...`\n" if self.data == None else ""
warnings += "Warning: no optimizer yet. Set using `l.opt = ...`\n" if self.opt == None else ""
if warnings != "": warnings += "\n\n"
return warnings
def __dir__(self):
answer = list(super().__dir__())
answer.extend(self.cbs.cbsDict.keys()); return answer
def __repr__(self):
return f"""{self._warnings}l.model:\n{k1lib.tab(k1lib.limitLines(str(self.model)))}
l.opt:\n{k1lib.tab(k1lib.limitLines(str(self.opt)))}
l.cbs:\n{k1lib.tab(k1lib.limitLines(self.cbs.__repr__()))}
Use...
- l.model = ...: to specify a nn.Module object
- l.data = ...: to specify data object
- l.opt = ...: to specify an optimizer
- l.lossF = ...: to specify a loss function
- l.css = ...: to select modules using CSS. "#root" for root model
- l.cbs = ...: to use a custom `Callbacks` object
- l.selector: to get the modules selected by `l.css`
- l.run(epochs): to run the network
- l.Loss: to get a specific callback, this case "Loss"\n\n"""
@k1lib.patch(Learner)
def save(self, fileName:str=None):
"""Saves this :class:`Learner` to file. See also: :meth:`load`
:param fileName: if empty, then will save as "learner-0.pth", with 0
changeable to avoid conflicts. If resave this exact :class:`Learner`, then
use the old name generated before"""
self.fileName = fileName or self.fileName
if self.fileName == None:
files = [file for file in os.listdir() if file.startswith("learner") and file.endswith(".pth")]
files = set([int(file.split(".pth")[0].split("learner-")[1]) for file in files])
count = 0;
while count in files: count += 1
self.fileName = f"l-{count}.pth"
torch.save(self, self.fileName, pickle_module=dill)
print(f"Saved to {self.fileName}")
@k1lib.patch(Learner, static=True)
def load(fileName:str=None):
"""Loads a :class:`Learner` from a file. See also: :meth:`save`
:param fileName: if empty, then will prompt for file name"""
f = fileName or input("Enter learner file name to load:")
print(f"Loaded from {f}"); return torch.load(f, pickle_module=dill)
@k1lib.patch(Learner)
def _run1Batch(self):
self.cbs("startBatch")
try:
self.cbs("startPass", "inPass", "endPass")
self.cbs("startLoss", "inLoss", "endLoss")
if not self.cbs("startBackward"): self.lossG.backward()
if not self.cbs("startStep"): self.opt.step()
if not self.cbs("startZeroGrad"): self.opt.zero_grad(set_to_none=True)
except k1lib.CancelBatchException as ex:
self.cbs("cancelBatch"); print(f"Batch cancelled: {ex}.")
except (k1lib.CancelEpochException, k1lib.CancelRunException) as ex:
# makes sure cancelBatch and endBatch gets called, for potential
# cleanups, then reraise the exception
self.cbs("cancelBatch", "endBatch"); raise ex
self.cbs("endBatch")
class DI: # data interceptor, just to record data loading times
def __init__(self, l:Learner, data): self.l = l; self.data = data
def __len__(self): return len(self.data)
def __iter__(self):
try:
data = iter(self.data); timings = self.l.cbs.timings
while True:
beginTime = _time(); d = next(data)
timings.loadData += _time() - beginTime; yield d
except StopIteration: pass
@k1lib.patch(Learner)
def _run1Epoch(self):
self.cbs("startEpoch")
try:
train, valid = self.data; train = DI(self, train); valid = DI(self, valid)
try: self.batches = len(train) + len(valid)
except: pass
self.model.train()
for self.batch, (self.xb, self.yb) in enumerate(train):
self._run1Batch()
trainLen = self.batch + 1
if not self.cbs("startValidBatches"):
self.model.eval();
for self.batch, (self.xb, self.yb) in enumerate(valid):
self.batch += trainLen; self._run1Batch()
if self.batches is None: self.batches = self.batch + 1
except k1lib.CancelEpochException as ex:
self.cbs("cancelEpoch"); print(f"Epoch cancelled: {ex}.")
except k1lib.CancelRunException as ex:
self.cbs("cancelEpoch", "endEpoch"); raise ex
self.cbs("endEpoch")
@k1lib.patch(Learner)
def run(self, epochs:int, batches:int=None):
"""Main run function.
:param epochs: number of epochs to run. 1 epoch is the length of the dataset
:param batches: if set, then cancels the epoch after reaching the specified batch"""
if self._warnings != "":
if not input(f"""You still have these warnings:\n\n{self._warnings}
Do you want to continue? (y/n) """).lower().startswith("y"):
print("Run ended"); return
self.epochs = epochs; self.batches = None
self.css = self.css # update module selector
with self.cbs.context():
if batches != None: self.cbs.withBatchLimit(batches)
self.cbs("startRun")
try:
for self.epoch in range(epochs): self._run1Epoch()
except k1lib.CancelRunException as ex:
self.cbs("cancelRun"); print(f"Run cancelled: {ex}.")
self.cbs("endRun"); return self
@k1lib.patch(Learner)
def __call__(self, xb, yb=None):
"""Executes just a small batch. Convenience method to query how the network is
doing.
:param xb: x batch
:param yb: y batch. If specified, return (y, loss), else return y alone
"""
oldData = self.data; self.data = [[(xb, (yb or torch.tensor(0)))], []]
with self.cbs.suspendEval(), self.cbs.context():
ex = lambda _: k1lib.raiseEx(k1lib.CancelBatchException)
self.cbs.append(k1lib.Callback().withCheckpoint("startLoss" if yb is None else "startBackward", ex))
self.run(1, 1)
self.data = oldData; return self.y if yb is None else (self.y, self.loss)
@k1lib.patch(Learner)
def evaluate(self):
"""Function to visualize quickly how the network is doing. Undefined by default,
just placed here as a convention, so you have to do something like this::
l = k1lib.Learner()
def evaluate(self):
xbs, ybs, ys = self.Recorder.record(1, 3)
plt.plot(torch.vstack(xbs), torch.vstack(ys))
l.evaluate = partial(evaluate(l))
"""
raise NotImplementedError("You have to define evaluate() by yourself")
from k1lib.cli import *
@k1lib.patch(Learner, static=True)
def sample() -> Learner:
"""Creates an example learner, just for simple testing stuff anywhere. The
network tries to learn the function y=x."""
l = Learner(); l.data = k1lib.kdata.FunctionData.main(lambda x: x)
class Model(torch.nn.Module):
def __init__(self): super().__init__(); self.linear = torch.nn.Linear(1, 1)
def forward(self, x): x = x[:, None]; return self.linear(x + 2).squeeze()
l.model = Model(); l.cbs = k1lib.Callbacks().withCoreNormal().withLoss().withProgressBar()
l.lossF = lambda y, yb: ((y - yb) ** 2).sum()
l.opt = torch.optim.Adam(l.model.parameters(), lr=3e-3); return l