# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""This is for core callbacks, that defines how everything is going to go"""
from k1lib.callbacks import Callback, Callbacks
import k1lib, torch
from typing import List, Tuple, Dict, Iterator, Union, Any, Callable
__all__ = ["CoreNormal", "CoreRNN"]
[docs]@k1lib.patch(Callback.cls)
class CoreNormal(Callback):
"""Just a normal, typical feed forward pass"""
def inPass(self):
self.l.y = self.l.model(self.l.xb)
@k1lib.patch(Callbacks, docs=CoreNormal)
def withCoreNormal(self, name:str=None):
return self.append(CoreNormal(), name=name)
[docs]@k1lib.patch(Callback.cls)
class CoreRNN(Callback):
"""RNN forward pass. Expected model to have the ``initHidden(bs) -> torch.Tensor``
method."""
def startBatch(self):
self.hx = self.l.model.initHidden(self.l.xb.shape[-2])
def inPass(self):
self.hx = self.hx.to(self.l.xb.device)
for item in self.l.xb:
self.l.y, self.hx = self.l.model(item, self.hx)
self.cbs("rnnPass")
@k1lib.patch(Callbacks, docs=CoreRNN)
def withCoreRNN(self, name:str=None):
return self.append(CoreRNN(), name=name)