# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""
This module is for selecting a subnetwork using CSS so that you can do special
things to them. Checkout the tutorial section for a walkthrough. This is exposed
automatically with::
   from k1lib.imports import *
   selector.select # exposed
"""
import k1lib, re
from k1lib import cli
from typing import List, Tuple, Dict, Union, Any, Iterator, Callable
from contextlib import contextmanager; from functools import partial
try: import torch; from torch import nn; hasTorch = True
except:
    torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
    nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["ModuleSelector", "preprocess", "select"]
[docs]def preprocess(selectors:str, defaultProp="*") -> List[str]:                     # preprocess
    r"""Removes all quirkly features allowed by the css
language, and outputs nice lines. Example::
    # returns ["a:f", "a:g,h", "b:g,h", "t:*"]
    selector.preprocess("a:f; a, b: g,h; t")
:param selectors: single css selector string. Statements separated
    by "\\n" or ";"
:param defaultProp: default property, if statement doesn't have one"""           # preprocess
    # filtering unwanted characters and quirky spaces                            # preprocess
    lines = [e for l in selectors.split("\n") for e in l.split(";")]             # preprocess
    selectors = [re.sub("(^\s+)|(\s+$)", "", re.sub("\s\s+", " ", line)).replace(" >", ">").replace("> ", ">").replace(" :", ":").replace(": ", ":").replace(" ,", ",").replace(", ", ",").replace(";", "\n").replace(" \n", "\n").replace("\n ", "\n") for line in lines if line != ""] # preprocess
    # adding "*" to all selectors with no props specified                        # preprocess
    selectors = [selector if ":" in selector else f"{selector}:{defaultProp}" for selector in selectors] # preprocess
    # expanding comma-delimited selectors                                        # preprocess
    return [f"{segment}:{selector.split(':')[1]}" for selector in selectors for segment in selector.split(":")[0].split(",")] # preprocess 
def _getParts(s:str): return [a for elem in s.split(":")[0].split(">") if elem for a in elem.split(" ") if a] # _getParts
def _getProps(s:str): return [elem for elem in s.split(":")[1].split(",") if elem] # _getProps
_idxAuto = k1lib.AutoIncrement()                                                 # _getProps
[docs]class ModuleSelector: # empty methods so that Sphinx generates the docs in order # ModuleSelector
    props:List[str]                                                              # ModuleSelector
    """Properties of this :class:`ModuleSelector`"""                             # ModuleSelector
    idx:int                                                                      # ModuleSelector
    """Unique id of this :class:`ModuleSelector` in the entire script. May be useful
for module recognition"""                                                        # ModuleSelector
    nn:"torch.nn.Module"                                                         # ModuleSelector
    """The associated :class:`torch.nn.Module` of this :class:`ModuleSelector`""" # ModuleSelector
    def __init__(self, parent:"ModuleSelector", name:str, nn:"torch.nn.Module"): # ModuleSelector
        self.parent = parent; self.name = name; self.nn = nn                     # ModuleSelector
        self._children:Dict["ModuleSelector"] = {}                               # ModuleSelector
        self.props:List[str] = []; self.depth:int = 0                            # ModuleSelector
        self.directSelectors:List[str] = []                                      # ModuleSelector
        self.indirectSelectors:List[str] = []                                    # ModuleSelector
        self.displayF:Callable[["ModuleSelector"], str] = lambda mS: ', '.join(mS.props) # ModuleSelector
        self.idx = _idxAuto()                                                    # ModuleSelector
    def deepestDepth(self): pass                                                 # ModuleSelector
[docs]    def highlight(self, prop:str):                                               # ModuleSelector
        """Highlights the specified prop when displaying the object."""          # ModuleSelector
        self.displayF = lambda self: (k1lib.fmt.txt.red if prop in self else k1lib.fmt.txt.identity)(', '.join(self.props)) # ModuleSelector
        return self                                                              # ModuleSelector 
[docs]    def __call__(self, *args, **kwargs):                                         # ModuleSelector
        """Calls the internal :class:`torch.nn.Module`"""                        # ModuleSelector
        return self.nn(*args, **kwargs)                                          # ModuleSelector 
[docs]    def __contains__(self): pass                                                 # ModuleSelector 
[docs]    def named_children(self): pass                                               # ModuleSelector 
[docs]    def children(self): pass                                                     # ModuleSelector 
[docs]    def named_modules(self): pass                                                # ModuleSelector 
[docs]    def modules(self): pass                                                      # ModuleSelector 
    def directParams(self): pass                                                 # ModuleSelector
[docs]    def parse(self): pass                                                        # ModuleSelector 
[docs]    def apply(self): pass                                                        # ModuleSelector 
[docs]    def clearProps(self): pass                                                   # ModuleSelector 
    @property                                                                    # ModuleSelector
    def displayF(self):                                                          # ModuleSelector
        """Function to display each ModuleSelector's lines.
Default is just::
    lambda mS: ", ".join(mS.props) """                                           # ModuleSelector
        return self._displayF                                                    # ModuleSelector
    @displayF.setter                                                             # ModuleSelector
    def displayF(self, f):                                                       # ModuleSelector
        def applyF(self): self._displayF = f                                     # ModuleSelector
        self.apply(applyF)                                                       # ModuleSelector
    def __getattr__(self, attr):                                                 # ModuleSelector
        if attr.startswith("_"): raise AttributeError(attr)                      # ModuleSelector
        if attr in self._children: return self._children[attr]                   # ModuleSelector
        return self.directParams[attr]                                           # ModuleSelector
    def __getitem__(self, idx): return getattr(self, str(idx))                   # ModuleSelector
[docs]    @staticmethod                                                                # ModuleSelector
    def sample() -> "ModuleSelector":                                            # ModuleSelector
        """Create a new example :class:`ModuleSelector` that has a bit of
hierarchy to them, with no css."""                                               # ModuleSelector
        return nn.Sequential(nn.Linear(3, 4), nn.Sequential(nn.Conv2d(3, 8, 3, 2), nn.ReLU(), nn.Linear(5, 6)), nn.Linear(7, 8)).select("") # ModuleSelector 
[docs]    def hookF(self): pass                                                        # ModuleSelector 
[docs]    def hookFp(self): pass                                                       # ModuleSelector 
[docs]    def hookB(self): pass                                                        # ModuleSelector 
[docs]    def freeze(self): pass                                                       # ModuleSelector 
[docs]    def unfreeze(self): pass                                                     # ModuleSelector  
[docs]@k1lib.patch(nn.Module)                                                          # ModuleSelector
def select(model:"torch.nn.Module", css:str="*") -> "k1lib.selector.ModuleSelector": # select
    """Creates a new ModuleSelector, in sync with a model.
Example::
    mS = selector.select(nn.Linear(3, 4), "#root:propA")
Or, you can do it the more direct way::
    mS = nn.Linear(3, 4).select("#root:propA")
:param model: the :class:`torch.nn.Module` object to select from
:param css: the css selectors"""                                                 # select
    root = ModuleSelector(None, "root", model)                                   # select
    root.parse(preprocess(css)); return root                                     # select 
@k1lib.patch(ModuleSelector, name="apply")                                       # select
def _apply(self, f:Callable[[ModuleSelector], None]):                            # _apply
    """Applies a function to self and all child :class:`ModuleSelector`"""       # _apply
    f(self)                                                                      # _apply
    for child in self._children.values(): child.apply(f)                         # _apply
@k1lib.patch(ModuleSelector, name="parse")                                       # _apply
def _parse(self, selectors:Union[List[str], str]) -> ModuleSelector:             # _parse
    """Parses extra selectors. Clears all old selectors, but retain
the props. Returns self. Example::
    mS = selector.ModuleSelector.sample().parse("Conv2d:propA")
    # returns True
    "propA" in mS[1][0]
:param selectors: can be the preprocessed list, or the unprocessed css string""" # _parse
    if isinstance(selectors, str): selectors = preprocess(selectors)             # _parse
    self.directSelectors = []; self.indirectSelectors = []                       # _parse
    ogSelectors = selectors                                                      # _parse
    if self.parent != None:                                                      # _parse
        selectors = [] + selectors + self.parent.indirectSelectors + self.parent.directSelectors # _parse
        self.indirectSelectors += self.parent.indirectSelectors                  # _parse
        self.depth = self.parent.depth + 1                                       # _parse
    for selector in selectors:                                                   # _parse
        parts = _getParts(selector)                                              # _parse
        matches = parts[0] == self.nn.__class__.__name__ or parts[0] == "#" + self.name or parts[0] == "*" # _parse
        if len(parts) == 1:                                                      # _parse
            if matches: self.props += _getProps(selector)                        # _parse
        else:                                                                    # _parse
            a = selector.find(">"); a = a if a > 0 else float("inf")             # _parse
            b = selector.find(" "); b = b if b > 0 else float("inf")             # _parse
            direct = a < b                                                       # _parse
            if matches:                                                          # _parse
                if direct: self.directSelectors.append(selector[a+1:])           # _parse
                else: self.indirectSelectors.append(selector[b+1:])              # _parse
    for name, mod in self.nn.named_children():                                   # _parse
        if name not in self._children:                                           # _parse
            self._children[name] = ModuleSelector(self, name, mod)               # _parse
        self._children[name].parse(ogSelectors)                                  # _parse
    self.props = list(set(self.props)); return self                              # _parse
@k1lib.patch(ModuleSelector)                                                     # _parse
def __contains__(self, prop:str=None) -> bool:                                   # __contains__
    """Whether this :class:`ModuleSelector` has a specific prop.
Example::
    # returns True
    "b" in nn.Linear(3, 4).select("*:b")
    # returns False
    "h" in nn.Linear(3, 4).select("*:b")
    # returns True, "*" here means the ModuleSelector has any properties at all
    "*" in nn.Linear(3, 4).select("*:b")"""                                      # __contains__
    if "*" in self.props: return True                                            # __contains__
    if prop in self.props: return True                                           # __contains__
    if prop == "*" and len(self.props) > 0: return True                          # __contains__
    return False                                                                 # __contains__
@k1lib.patch(ModuleSelector)                                                     # __contains__
def named_children(self, prop:str=None) -> Iterator[Tuple[str, ModuleSelector]]: # named_children
    """Get all named direct childs.
:param prop: Filter property. See also: :meth:`__contains__`"""                  # named_children
    if prop is None: return self._children.items()                               # named_children
    return ((k, v) for k, v in self._children.items() if prop in v)              # named_children
@k1lib.patch(ModuleSelector)                                                     # named_children
def children(self, prop:str=None) -> Iterator[ModuleSelector]:                   # children
    """Get all direct childs.
:param prop: Filter property. See also: :meth:`__contains__`"""                  # children
    return (x for _, x in self.named_children(prop))                             # children
@k1lib.patch(ModuleSelector, "directParams")                                     # children
@property                                                                        # children
def directParams(self) -> Dict[str, nn.Parameter]:                               # directParams
    """Dict params directly under this module"""                                 # directParams
    return {name: param for name, param in self.nn.named_parameters() if "." not in name} # directParams
@k1lib.patch(ModuleSelector)                                                     # directParams
def named_modules(self, prop:str=None) -> Iterator[Tuple[str, ModuleSelector]]:  # named_modules
    """Get all named child recursively.
Example::
    modules = list(nn.Sequential(nn.Linear(3, 4), nn.ReLU()).select().named_modules())
    # return 3
    len(modules)
    # return tuple ('0', <ModuleSelector of Linear>)
    modules[1]
:param prop: Filter property. See also: :meth:`__contains__`"""                  # named_modules
    if prop != None:                                                             # named_modules
        yield from ((name, m) for name, m in self.named_modules() if prop in m)  # named_modules
        return                                                                   # named_modules
    yield self.name, self                                                        # named_modules
    for child in self._children.values(): yield from child.named_modules()       # named_modules
@k1lib.patch(ModuleSelector)                                                     # named_modules
def modules(self, prop:str=None) -> Iterator[ModuleSelector]:                    # modules
    """Get all child recursively.
:param prop: Filter property. See also: :meth:`__contains__`"""                  # modules
    for name, x in self.named_modules(prop): yield x                             # modules
@k1lib.patch(ModuleSelector)                                                     # modules
def clearProps(self) -> "ModuleSelector":                                        # clearProps
    """Clears all existing props of this and all descendants
:class:`ModuleSelector`. Example::
    # returns False
    "b" in nn.Linear(3, 4).select("*:b").clearProps()"""                         # clearProps
    def applyF(self): self.props = []                                            # clearProps
    self.apply(applyF); return self                                              # clearProps
@k1lib.patch(ModuleSelector, name="deepestDepth")                                # clearProps
@property                                                                        # clearProps
def deepestDepth(self):                                                          # deepestDepth
    """Deepest depth of the tree. If self doesn't
have any child, then depth is 0"""                                               # deepestDepth
    if len(self._children) == 0: return 0                                        # deepestDepth
    return 1 + max([child.deepestDepth for child in self._children.values()])    # deepestDepth
@k1lib.patch(ModuleSelector)                                                     # deepestDepth
def __repr__(self, intro:bool=True, header:Union[str, Tuple[str]]="", footer="", tabs:int=None): # __repr__
    """
    :param intro: whether to include a nice header and footer info
    :param header:
        str: include a header that starts where `displayF` will start
        Tuple[str, str]: first one in tree, second one in displayF section
    :param footer: same thing with header, but at the end
    :param header: include a header that starts where `displayF` will start
    :param tabs: number of tabs at the beginning. Best to leave this empty
    """                                                                          # __repr__
    if tabs == None: tabs = 5 + self.deepestDepth                                # __repr__
    answer = "ModuleSelector:\n" if intro else ""                                # __repr__
    if header:                                                                   # __repr__
        h1, h2 = ("", header) if isinstance(header, str) else header             # __repr__
        answer += h1.ljust(tabs*4, " ") + h2 + "\n"                              # __repr__
    answer += f"{self.name}: {self.nn.__class__.__name__}".ljust(tabs*4, " ")    # __repr__
    answer += self.displayF(self) + ("\n" if len(self._children) > 0 else "")    # __repr__
    answer += self._children.values() | cli.apply(lambda child: child.__repr__(tabs=tabs-1, intro=False).split("\n")) | cli.joinStreams() | cli.tab() | cli.join("\n") # __repr__
    if footer:                                                                   # __repr__
        f1, f2 = ("", footer) if isinstance(footer, str) else footer             # __repr__
        answer += "\n" + f1.ljust(tabs*4, " ") + f2                              # __repr__
    if intro: answer += f"""\n\nCan...
- mS.deepestDepth: get deepest depth possible
- mS.nn: get the underlying nn.Module object
- mS.apply(f): apply to self and all descendants
- "HookModule" in mS: whether this module has a specified prop
- mS.highlight(prop): highlights all modules with specified prop
- mS.parse([..., ...]): parses extra css
- mS.directParams: get Dict[str, nn.Parameter] that are directly under this module""" # __repr__
    return answer                                                                # __repr__
def _strTensor(t): return "None" if t is None else f"{t.shape}"                  # _strTensor
def strTensorTuple(ts):                                                          # strTensorTuple
    if len(ts) > 1:                                                              # strTensorTuple
        shapes = "\n".join(f"- {_strTensor(t)}" for t in ts)                     # strTensorTuple
        return f"tensors ({len(ts)} total) shapes:\n{shapes}"                    # strTensorTuple
    else:                                                                        # strTensorTuple
        return f"tensor shape: {_strTensor(ts[0])}"                              # strTensorTuple
@k1lib.patch(ModuleSelector)                                                     # strTensorTuple
@contextmanager                                                                  # strTensorTuple
def hookF(self, f:Callable[[ModuleSelector, "torch.nn.Module", Tuple[torch.Tensor], torch.Tensor], None]=None, prop:str="*"): # hookF
    """Context manager for applying forward hooks.
Example::
    def f(mS, i, o):
        print(i, o)
    m = nn.Linear(3, 4)
    with m.select().hookF(f):
        m(torch.randn(2, 3))
:param f: hook callback, should accept :class:`ModuleSelector`, inputs and output
:param prop: filter property of module to hook onto. If not specified, then it will print out input and output tensor shapes.""" # hookF
    if f is None: f = lambda mS, i, o: print(f"Forward hook {m}:\n" + ([f"Input  {strTensorTuple(i)}", f"Output tensor shape: {o.shape}"] | cli.tab() | cli.join("\n"))) # hookF
    g = lambda m, i, o: f(self, i, o)                                            # hookF
    handles = [m.nn.register_forward_hook(g) for m in self.modules(prop)]        # hookF
    try: yield                                                                   # hookF
    finally:                                                                     # hookF
        for h in handles: h.remove()                                             # hookF
@k1lib.patch(ModuleSelector)                                                     # hookF
@contextmanager                                                                  # hookF
def hookFp(self, f=None, prop:str="*"):                                          # hookFp
    """Context manager for applying forward pre hooks.
Example::
    def f(mS, i):
        print(i)
    m = nn.Linear(3, 4)
    with m.select().hookFp(f):
        m(torch.randn(2, 3))
:param f: hook callback, should accept :class:`ModuleSelector` and inputs
:param prop: filter property of module to hook onto. If not specified, then it will print out input tensor shapes.""" # hookFp
    if f is None: f = lambda mS, i: print(f"Forward pre hook {m}:\n" + ([f"Input {strTensorTuple(i)}"] | cli.tab() | cli.join("\n"))) # hookFp
    g = lambda m, i: f(self, i)                                                  # hookFp
    handles = [m.nn.register_forward_pre_hook(g) for m in self.modules(prop)]    # hookFp
    try: yield                                                                   # hookFp
    finally:                                                                     # hookFp
        for h in handles: h.remove()                                             # hookFp
@k1lib.patch(ModuleSelector)                                                     # hookFp
@contextmanager                                                                  # hookFp
def hookB(self, f=None, prop:str="*"):                                           # hookB
    """Context manager for applying backward hooks.
Example::
    def f(mS, i, o):
        print(i, o)
    m = nn.Linear(3, 4)
    with m.select().hookB(f):
        m(torch.randn(2, 3)).sum().backward()
:param f: hook callback, should accept :class:`ModuleSelector`, grad inputs and outputs
:param prop: filter property of module to hook onto. If not specified, then it will print out input tensor shapes.""" # hookB
    if f is None: f = lambda mS, i, o: print(f"Backward hook {m}:\n" + ([f"Input  {strTensorTuple(i)}", f"Output {strTensorTuple(o)}"] | cli.tab() | cli.join("\n"))) # hookB
    g = lambda m, i, o: f(self, i, o)                                            # hookB
    handles = [m.nn.register_full_backward_hook(g) for m in self.modules(prop)]  # hookB
    try: yield                                                                   # hookB
    finally:                                                                     # hookB
        for h in handles: h.remove()                                             # hookB
from contextlib import ExitStack                                                 # hookB
@contextmanager                                                                  # hookB
def _intercept(self, value:bool):                                                # _intercept
    handles = []                                                                 # _intercept
    try:                                                                         # _intercept
        data = []                                                                # _intercept
        f = lambda x: x.detach() if x is not None else None                      # _intercept
        for m in self.modules("*"):                                              # _intercept
            subData1 = []; subData2 = []; data.append([subData1, subData2])      # _intercept
            handles.append(m.nn.register_forward_hook(lambda _m, i, o: subData1.append([[f(e) for e in i], f(o)]))) # _intercept
            handles.append(m.nn.register_full_backward_hook(lambda _m, i, o: subData2.append([[f(e) for e in i], [f(e) for e in o]]))) # _intercept
        yield data                                                               # _intercept
    finally:                                                                     # _intercept
        for h in handles: h.remove()                                             # _intercept
@k1lib.patch(ModuleSelector)                                                     # _intercept
def intercept(self):                                                             # intercept
    """Returns a context manager that intercept forward and backward signals
to parts of the network. Example::
    l = k1lib.Learner.sample()
    with l.model.select("#lin1").intercept() as d:
        l.run(2)
    # returns (1, 2, 600, 2, 1, 32, 1), or (#selected modules, [forward, backward], #steps, [input, output], actual data)
    d | shape()
See also: :meth:`hookF`, :meth:`hookFp`, :meth:`hookB`"""                        # intercept
    return _intercept(self, False)                                               # intercept
from contextlib import ExitStack                                                 # intercept
@contextmanager                                                                  # intercept
def _freeze(self, value:bool, prop:str):                                         # _freeze
    with ExitStack() as stack:                                                   # _freeze
        for m in self.modules(prop):                                             # _freeze
            stack.enter_context(m.nn.gradContext())                              # _freeze
            m.nn.requires_grad_(value)                                           # _freeze
        try: yield                                                               # _freeze
        finally: pass                                                            # _freeze
@k1lib.patch(ModuleSelector)                                                     # _freeze
def freeze(self, prop:str="*"):                                                  # freeze
    """Returns a context manager that freezes (set requires_grad to False) parts of
the network. Example::
    l = k1lib.Learner.sample()
    w = l.model.lin1.lin.weight.clone() # weights before
    with l.model.select("#lin1").freeze():
        l.run(1)
    # returns True
    (l.model.lin1.lin.weight ==  w).all()"""                                     # freeze
    return _freeze(self, False, prop)                                            # freeze
@k1lib.patch(ModuleSelector)                                                     # freeze
def unfreeze(self, prop:str="*"):                                                # unfreeze
    """Returns a context manager that unfreezes (set requires_grad to True) parts of
the network. Example::
    l = k1lib.Learner.sample()
    w = l.model.lin1.lin.weight.clone() # weights before
    with l.model.select("#lin1").freeze():
        with l.model.select("#lin1 > #lin").unfreeze():
            l.run(1)
    # returns False
    (l.model.lin1.lin.weight ==  w).all()"""                                     # unfreeze
    return _freeze(self, True, prop)                                             # unfreeze
class CutOff(nn.Module):                                                         # CutOff
    def __init__(self, net, m):                                                  # CutOff
        super().__init__()                                                       # CutOff
        self.net = net; self.m = m; self._lastOutput = None                      # CutOff
        def f(m, i, o): self._lastOutput = o                                     # CutOff
        self.handle = self.m.register_forward_hook(f)                            # CutOff
    def forward(self, *args, **kwargs):                                          # CutOff
        self._lastOutput = None                                                  # CutOff
        self.net(*args, **kwargs)                                                # CutOff
        return self._lastOutput                                                  # CutOff
    def __del__(self): self.handle.remove()                                      # CutOff
@k1lib.patch(ModuleSelector)                                                     # CutOff
def cutOff(self) -> nn.Module:                                                   # cutOff
    """Creates a new network that returns the selected layer's output.
Example::
    xb = torch.randn(10, 2)
    m = nn.Sequential(nn.Linear(2, 5), nn.Linear(5, 4), nn.Linear(4, 6))
    m0 = m.select("#0").cutOff(); m1 = m.select("#1").cutOff()
    # returns (10, 6)
    m(xb).shape
    # returns (10, 5)
    m0(xb).shape == torch.Size([10, 5])
    # returns (10, 4)
    m1(xb).shape == torch.Size([10, 4])"""                                       # cutOff
    return CutOff(self.nn, self.modules("*") | cli.item() | cli.op().nn)         # cutOff