# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""This is for core callbacks, that defines how everything is going to go"""
from .callbacks import Callback, Callbacks, Cbs
import k1lib; from typing import List, Tuple, Dict, Iterator, Union, Any, Callable
try: import torch; hasTorch = True
except: torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["CoreNormal", "CoreRNN"]
[docs]@k1lib.patch(Cbs)
class CoreNormal(Callback):
"""Just a normal, typical feed forward pass.
Deposits variables into :class:`~k1lib.Learner` at checkpoint ``inPass``:
- y: attached result tensor after passing through model"""
def inPass(self):
self.l.y = self.l.model(self.l.xb)
[docs]@k1lib.patch(Cbs)
class CoreRNN(Callback):
"""RNN forward pass.
Expected variables from :attr:`k1lib.Learner.model`:
- initHidden: function takes in batch size, returns init hidden tensor
Deposits variables into :class:`~k1lib.Learner` at checkpoint ``inPass``, more
specifically ``rnnPass``:
- y: attached result tensor after pass (``inPass``), after character pass (``rnnPass``)
"""
def startBatch(self):
self.hx = self.l.model.initHidden(self.l.xb.shape[-2])
def inPass(self):
self.hx = self.hx.to(self.l.xb.device)
for item in self.l.xb:
self.l.y, self.hx = self.l.model(item, self.hx)
self.cbs("rnnPass")