# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""
Some nice utils to complement :mod:`torch.nn`. This is exposed automatically
with::
   from k1lib.imports import *
   knn.Lambda # exposed
"""
from typing import Callable, Any; import k1lib
try: import torch; from torch import nn; hasTorch = True
except: nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["Lambda", "Identity", "LinBlock", "MultiheadAttention"]
[docs]class Lambda(nn.Module):                                                         # Lambda
[docs]    def __init__(self, f:Callable[[Any], Any]):                                  # Lambda
        """Creates a simple module with a specified :meth:`forward`
function."""                                                                     # Lambda
        super().__init__(); self.f = f                                           # Lambda 
[docs]    def forward(self, x): return self.f(x)                                       # Lambda  
[docs]class Identity(Lambda):                                                          # Identity
    """Creates a module that returns the input in :meth:`forward`"""             # Identity
    def __init__(self): super().__init__(lambda x: x)                            # Identity 
[docs]class LinBlock(nn.Module):                                                       # LinBlock
[docs]    def __init__(self, inC, outC):                                               # LinBlock
        """Linear layer with relu behind it"""                                   # LinBlock
        super().__init__(); self.lin = nn.Linear(inC, outC); self.relu = nn.LeakyReLU() # LinBlock 
[docs]    def forward(self, x):                                                        # LinBlock
        return x | self.lin | self.relu                                          # LinBlock  
[docs]class MultiheadAttention(nn.Module):                                             # MultiheadAttention
[docs]    def __init__(self, qdim, kdim, vdim, embed, head=4, outdim=None):            # MultiheadAttention
        """Kinda like :class:`torch.nn.MultiheadAttention`, just simpler, shorter, and clearer.
Probably not as fast as the official version, and doesn't have masks and whatnot, but easy to read!
Example::
    xb = torch.randn(14, 32, 35) # (S, N, ), or sequence size 14, batch size 32, feature size 35
    # returns torch.Size([14, 32, 50])
    MultiheadAttention(35, 35, 35, 9, 4, 50)(xb).shape
Although you can use this right away with no mods, I really encourage you to copy and paste the
source code of this and modify it to your needs.
:param qdim: Basic query, key and value dimensions
:param embed: a little different from :class:`torch.nn.MultiheadAttention`, as this is after splitting into heads
:param outdim: if not specified, then equals to ``embed * head``"""              # MultiheadAttention
        super().__init__()                                                       # MultiheadAttention
        self.embed = embed; self.head = head; outdim = outdim or embed*head      # MultiheadAttention
        self.qdim = qdim; self.wq = nn.Linear(qdim, head*embed)                  # MultiheadAttention
        self.kdim = kdim; self.wk = nn.Linear(kdim, head*embed)                  # MultiheadAttention
        self.vdim = vdim; self.wv = nn.Linear(vdim, head*embed)                  # MultiheadAttention
        self.outLin = nn.Linear(head*embed, outdim)                              # MultiheadAttention
        self.softmax = nn.Softmax(-1)                                            # MultiheadAttention 
[docs]    def forward(self, query, key=None, value=None):                              # MultiheadAttention
        """If ``key`` or ``value`` is not specified, just default to ``query``.""" # MultiheadAttention
        if key is None: key = query                                              # MultiheadAttention
        if value is None: value = query                                          # MultiheadAttention
        S, N, *_ = key.shape; F = self.embed; head = self.head                   # MultiheadAttention
        q = self.wq(query); k = self.wk(key); v = self.wv(value)                 # MultiheadAttention
        S1 = q.shape[0]                                                          # MultiheadAttention
        if q.shape[1] != k.shape[1]: q = q.expand(-1, k.shape[1], -1).contiguous() # MultiheadAttention
        q = q.view(S1, -1, F).transpose(0, 1)                                    # MultiheadAttention
        k = k.view(S, -1, F).transpose(0, 1)                                     # MultiheadAttention
        v = v.view(S, -1, F).transpose(0, 1)                                     # MultiheadAttention
        mat = self.softmax((q / math.sqrt(F)) @ k.transpose(1, 2))               # MultiheadAttention
        return self.outLin((mat @ v).transpose(0, 1).reshape(S1, N, head*F))     # MultiheadAttention