Source code for k1lib.knn
# AUTOGENERATED FILE! PLEASE DON'T EDIT
from torch import nn
from typing import Callable, Any
__all__ = ["Lambda", "Identity"]
[docs]class Lambda(nn.Module):
[docs] def __init__(self, f:Callable[[Any], Any]):
"""Creates a simple module with a specified :meth:`forward`
function."""
super().__init__(); self.f = f
def forward(self, x): return self.f(x)
[docs]class Identity(Lambda):
"""Creates a module that returns the input in :meth:`forward`"""
def __init__(self): super().__init__(lambda x: x)