# AUTOGENERATED FILE! PLEASE DON'T EDIT
import torch.nn as _nn
import k1lib as _k1lib, re as _re
from typing import List as _List, Tuple as _Tuple, Dict as _Dict, Union as _Union
from typing import Any as _Any, Iterator as _Iterator, Callable as _Callable
def _getParts(s:str): return [a for elem in s.split(":")[0].split(">") if elem for a in elem.split(" ") if a]
def _getProps(s:str): return [elem for elem in s.split(":")[1].split(",") if elem]
[docs]class ModuleSelector:
    _signature = 0
    selectedProps:_List[str]
    """Selected properties of this MS"""
    def __init__(self, parent:"ModuleSelector", name:str, nnModule:_nn.Module):
        self.parent = parent; self.name = name; self.nnModule = nnModule
        self._children:_Dict["ModuleSelector"] = {}
        self.selectedProps:_List[str] = []; self.depth:int = 0
        self.directSelectors:_List[str] = []
        self.indirectSelectors:_List[str] = []
        self.displayF:_Callable[["ModuleSelector"], str] = lambda mS: ', '.join(mS.selectedProps)
        self.signature = ModuleSelector._signature; ModuleSelector._signature += 1
    @property
    def displayF(self):
        """Function to display each ModuleSelector's lines.
Default is just::
    lambda mS: ", ".join(mS.selectedProps) """
        return self._displayF
    @displayF.setter
    def displayF(self, f):
        def applyF(self): self._displayF = f
        self.apply(applyF)
[docs]    def clearProps(self):
        def applyF(self): self.selectedProps = []
        self.apply(applyF); return self 
[docs]    def highlight(self, prop:str):
        self.displayF = lambda self: (_k1lib.format.red if self.selected(prop) else _k1lib.format.identity)(', '.join(self.selectedProps))
        return self 
[docs]    def selected(self, prop:str=None) -> bool:
        """Whether this ModuleSelector has a specific prop"""
        return "all" in self.selectedProps or prop in self.selectedProps 
[docs]    def named_children(self) -> _Iterator[_Tuple[str, "ModuleSelector"]]:
        """Get all named direct child"""
        return self._children.items() 
[docs]    def named_modules(self, prop:str=None) -> _Iterator[_Tuple[str, _nn.Module]]:
        """Get all named child recursively
:param prop: Filter property"""
        if prop != None:
            for name, m in self.named_modules():
                if m.selected(prop): yield name, m
            return
        yield self.name, self
        for child in self._children.values():
            for x in child.named_modules(): yield x 
[docs]    def children(self) -> _Iterator["ModuleSelector"]:
        """Get all direct child"""
        for name, x in self.named_children(): yield x 
[docs]    def modules(self, prop:str=None) -> _Iterator[_nn.Module]:
        """Get all child recursively. Optional filter prop"""
        for name, x in self.named_modules(prop): yield x 
    @property
    def directParams(self) -> _Dict[str, _nn.Parameter]:
        """Params directly under this module"""
        return {name: param for name, param in self.nnModule.named_parameters() if "." not in name}
[docs]    def parameters(self) -> _Iterator[_nn.Parameter]:
        """Get generator of parameters, all depths"""
        return self.nnModule.parameters() 
    def __getattr__(self, attr):
        if attr.startswith("_"): raise AttributeError(attr)
        if attr in self._children: return self._children[attr]
        return self.directParams[attr]
    def __getitem__(self, idx): return getattr(self, str(idx))
    @property
    def deepestDepth(self):
        """Deepest depth of the tree. If self doesn't
have any child, then depth is 0"""
        if len(self._children) == 0: return 0
        return 1 + max([child.deepestDepth for child in self._children.values()]) 
@_k1lib.patch(ModuleSelector, name="parse")
def _parse(self, selectors:_List[str]):
    """Parses extra selectors. Clears all old selectors, but retain the props"""
    self.directSelectors = []; self.indirectSelectors = []
    ogSelectors = selectors
    if self.parent != None:
        selectors = [] + selectors + self.parent.indirectSelectors + self.parent.directSelectors
        self.indirectSelectors += self.parent.indirectSelectors
        self.depth = self.parent.depth + 1
    for selector in selectors:
        parts = _getParts(selector)
        matches = parts[0] == self.nnModule.__class__.__name__ or parts[0] == "#" + self.name or parts[0] == "*"
        if len(parts) == 1:
            if matches: self.selectedProps += _getProps(selector)
        else:
            a = selector.find(">"); a = a if a > 0 else float("inf")
            b = selector.find(" "); b = b if b > 0 else float("inf")
            direct = a < b
            if matches:
                if direct: self.directSelectors.append(selector[a+1:])
                else: self.indirectSelectors.append(selector[b+1:])
    for name, mod in self.nnModule.named_children():
        if name not in self._children:
            self._children[name] = ModuleSelector(self, name, mod)
        self._children[name].parse(ogSelectors)
    self.selectedProps = list(set(self.selectedProps))
@_k1lib.patch(ModuleSelector, name="apply")
def _apply(self, f:_Callable[[ModuleSelector], None]):
    """Applies a function to self and all child :class:`ModuleSelector`"""
    f(self)
    for child in self._children.values(): child.apply(f)
@_k1lib.patch(ModuleSelector, name="copy")
def _copy(self):
    answer = ModuleSelector(self.parent, self.name, self.nnModule)
    answer.depth = self.depth
    answer.selectedProps = list(self.selectedProps)
    answer.displayF = self.displayF
    answer._children = {name:child.copy() for name, child in self._children.items()}
    for child in answer._children.values(): child.parent = answer
    return answer
@_k1lib.patch(ModuleSelector)
def __repr__(self, intro:bool=True, header:_Union[str, _Tuple[str]]="", footer="", tabs:int=None):
    """
    :param intro: whether to include a nice header and footer info
    :param header:
        str: include a header that starts where `displayF` will start
        Tuple[str, str]: first one in tree, second one in displayF section
    :param footer: same thing with header, but at the end
    :param header: include a header that starts where `displayF` will start
    :param tabs: number of tabs at the beginning. Best to leave this empty
    """
    if tabs == None: tabs = 5 + self.deepestDepth
    answer = "ModuleSelector:\n" if intro else ""
    if header:
        h1, h2 = ("", header) if isinstance(header, str) else header
        answer += h1.ljust(tabs*4, " ") + h2 + "\n"
    answer += f"{self.name}: {self.nnModule.__class__.__name__}".ljust(tabs*4, " ")
    answer += self.displayF(self) + ("\n" if len(self._children) > 0 else "")
    answer += _k1lib.tab("\n".join([child.__repr__(tabs=tabs-1, intro=False) for name, child in self._children.items()]))
    if footer:
        f1, f2 = ("", footer) if isinstance(footer, str) else footer
        answer += "\n" + f1.ljust(tabs*4, " ") + f2
    if intro: answer += f"""\n\nCan...
- mS.displayF = ...: sets a display function (mS -> str) for self and all descendants. Defaults to displaying all props
- mS.deepestDepth: get deepest depth possible
- mS.nnModule: get the underlying nn.Module object
- mS.apply(f): apply to self and all descendants
- mS.copy(): copy everything, including descendants
- mS.selected("HookModule"): whether this module has a specified prop
- mS.highlight(prop): highlights all modules with specified prop
- mS.parse([..., ...]): parses extra css
- mS.clearProps(): to clear all selected props, including descendants
- mS.directParams(): get Dict[str, nn.Parameter] that are directly under this module
- mS.named_children(), mS.children(): like PyTorch
- mS.named_modules([prop]), mS.modules([prop]): like PyTorch. Optional filter prop
- mS.parameters(): like PyTorch
"""
    return answer
[docs]def filter(selectors:str, defaultProp="all") -> _List[str]:
    r"""Removes all quirkly features allowed by the css
language, and outputs nice lines.
:param selectors: single css selector string. Statements separated
    by "\\n" or ";"
:param defaultProp: default property, if statement doesn't have one"""
    # filtering unwanted characters and quirky spaces
    lines = [e for l in selectors.split("\n") for e in l.split(";")]
    selectors = [_re.sub("(^\s+)|(\s+$)", "", _re.sub("\s\s+", " ", line)).replace(" >", ">").replace("> ", ">").replace(" :", ":").replace(": ", ":").replace(" ,", ",").replace(", ", ",").replace(";", "\n").replace(" \n", "\n").replace("\n ", "\n") for line in lines if line != ""]
    # adding "all" to all selectors with no props specified
    selectors = [selector if ":" in selector else f"{selector}:{defaultProp}" for selector in selectors]
    # expanding comma-delimited selectors
    return [f"{segment}:{selector.split(':')[1]}" for selector in selectors for segment in selector.split(":")[0].split(",")] 
[docs]def select(model:_nn.Module, selectors:str) -> ModuleSelector:
    """Creates a new ModuleSelector, in sync with a model"""
    root = ModuleSelector(None, "root", model)
    root.parse(filter(selectors))
    return root