Source code for k1lib.callbacks.core

# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""This is for core callbacks, that defines how everything is going to go"""
from .callbacks import Callback, Callbacks, Cbs
import k1lib; from typing import List, Tuple, Dict, Iterator, Union, Any, Callable
try: import torch; hasTorch = True
except: torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["CoreNormal", "CoreRNN"]
[docs]@k1lib.patch(Cbs) class CoreNormal(Callback): """Just a normal, typical feed forward pass. Deposits variables into :class:`~k1lib.Learner` at checkpoint ``inPass``: - y: attached result tensor after passing through model""" def inPass(self): self.l.y = self.l.model(self.l.xb)
[docs]@k1lib.patch(Cbs) class CoreRNN(Callback): """RNN forward pass. Expected variables from :attr:`k1lib.Learner.model`: - initHidden: function takes in batch size, returns init hidden tensor Deposits variables into :class:`~k1lib.Learner` at checkpoint ``inPass``, more specifically ``rnnPass``: - y: attached result tensor after pass (``inPass``), after character pass (``rnnPass``) """ def startBatch(self): self.hx = self.l.model.initHidden(self.l.xb.shape[-2]) def inPass(self): self.hx = self.hx.to(self.l.xb.device) for item in self.l.xb: self.l.y, self.hx = self.l.model(item, self.hx) self.cbs("rnnPass")