# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""
Some nice utils to complement :mod:`torch.nn`. This is exposed automatically
with::
   from k1lib.imports import *
   knn.Lambda # exposed
"""
from torch import nn
from typing import Callable, Any
__all__ = ["Lambda", "Identity", "LinBlock"]
[docs]class Lambda(nn.Module):
[docs]    def __init__(self, f:Callable[[Any], Any]):
        """Creates a simple module with a specified :meth:`forward`
function."""
        super().__init__(); self.f = f 
[docs]    def forward(self, x): return self.f(x)  
[docs]class Identity(Lambda):
    """Creates a module that returns the input in :meth:`forward`"""
    def __init__(self): super().__init__(lambda x: x) 
[docs]class LinBlock(nn.Module):
[docs]    def __init__(self, inC, outC):
        """Linear layer with relu behind it"""
        super().__init__(); self.lin = nn.Linear(inC, outC); self.relu = nn.ReLU() 
[docs]    def forward(self, x):
        return x | self.lin | self.relu