# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""This is for core callbacks, that defines how everything is going to go"""
from .callbacks import Callback, Callbacks, Cbs
import k1lib, torch
from typing import List, Tuple, Dict, Iterator, Union, Any, Callable
__all__ = ["CoreNormal", "CoreRNN"]
[docs]@k1lib.patch(Cbs)
class CoreNormal(Callback):
"""Just a normal, typical feed forward pass.
Deposits variables into :class:`~k1lib.Learner` at checkpoint ``inPass``:
- y: attached result tensor after passing through model"""
def inPass(self):
self.l.y = self.l.model(self.l.xb)
[docs]@k1lib.patch(Cbs)
class CoreRNN(Callback):
"""RNN forward pass.
Expected variables from :attr:`k1lib.Learner.model`:
- initHidden: function takes in batch size, returns init hidden tensor
Deposits variables into :class:`~k1lib.Learner` at checkpoint ``inPass``, more
specifically ``rnnPass``:
- y: attached result tensor after pass (``inPass``), after character pass (``rnnPass``)
"""
def startBatch(self):
self.hx = self.l.model.initHidden(self.l.xb.shape[-2])
def inPass(self):
self.hx = self.hx.to(self.l.xb.device)
for item in self.l.xb:
self.l.y, self.hx = self.l.model(item, self.hx)
self.cbs("rnnPass")