# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""This is for core callbacks, that defines how everything is going to go"""
from .callbacks import Callback, Callbacks, Cbs
import k1lib, torch
from typing import List, Tuple, Dict, Iterator, Union, Any, Callable
__all__ = ["CoreNormal", "CoreRNN"]
[docs]@k1lib.patch(Cbs)
class CoreNormal(Callback):
"""Just a normal, typical feed forward pass"""
def inPass(self):
self.l.y = self.l.model(self.l.xb)
@k1lib.patch(Callbacks, docs=CoreNormal)
def withCoreNormal(self, name:str=None):
return self.append(CoreNormal(), name=name)
[docs]@k1lib.patch(Cbs)
class CoreRNN(Callback):
"""RNN forward pass. Expected model to have the ``initHidden(bs) -> torch.Tensor``
method."""
def startBatch(self):
self.hx = self.l.model.initHidden(self.l.xb.shape[-2])
def inPass(self):
self.hx = self.hx.to(self.l.xb.device)
for item in self.l.xb:
self.l.y, self.hx = self.l.model(item, self.hx)
self.cbs("rnnPass")
@k1lib.patch(Callbacks, docs=CoreRNN)
def withCoreRNN(self, name:str=None):
return self.append(CoreRNN(), name=name)