Source code for k1lib.selector

# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""
This module is for selecting a subnetwork so that you can do special things to
them. Checkout the tutorial section for a walkthrough. This is exposed
automatically with::

   from k1lib.imports import *
   selector.select # exposed
"""
from torch import nn; import k1lib, re
from typing import List, Tuple, Dict, Union, Any, Iterator, Callable
__all__ = ["ModuleSelector", "filter", "select"]
def _getParts(s:str): return [a for elem in s.split(":")[0].split(">") if elem for a in elem.split(" ") if a]
def _getProps(s:str): return [elem for elem in s.split(":")[1].split(",") if elem]
_idxAuto = k1lib.AutoIncrement()
[docs]class ModuleSelector: selectedProps:List[str] """Selected properties of this MS""" def __init__(self, parent:"ModuleSelector", name:str, nnModule:"torch.nn.Module"): self.parent = parent; self.name = name; self.nnModule = nnModule self._children:Dict["ModuleSelector"] = {} self.selectedProps:List[str] = []; self.depth:int = 0 self.directSelectors:List[str] = [] self.indirectSelectors:List[str] = [] self.displayF:Callable[["ModuleSelector"], str] = lambda mS: ', '.join(mS.selectedProps) self.idx = _idxAuto() @property def displayF(self): """Function to display each ModuleSelector's lines. Default is just:: lambda mS: ", ".join(mS.selectedProps) """ return self._displayF @displayF.setter def displayF(self, f): def applyF(self): self._displayF = f self.apply(applyF)
[docs] def clearProps(self): def applyF(self): self.selectedProps = [] self.apply(applyF); return self
[docs] def highlight(self, prop:str): self.displayF = lambda self: (k1lib.fmt.txt.red if self.selected(prop) else k1lib.fmt.txt.identity)(', '.join(self.selectedProps)) return self
[docs] def selected(self, prop:str=None) -> bool: """Whether this ModuleSelector has a specific prop""" return "all" in self.selectedProps or prop in self.selectedProps
[docs] def named_children(self) -> Iterator[Tuple[str, "ModuleSelector"]]: """Get all named direct child""" return self._children.items()
[docs] def named_modules(self, prop:str=None) -> Iterator[Tuple[str, "ModuleSelector"]]: """Get all named child recursively :param prop: Filter property""" if prop != None: for name, m in self.named_modules(): if m.selected(prop): yield name, m return yield self.name, self for child in self._children.values(): for x in child.named_modules(): yield x
[docs] def children(self) -> Iterator["ModuleSelector"]: """Get all direct child""" for name, x in self.named_children(): yield x
[docs] def modules(self, prop:str=None) -> Iterator["ModuleSelector"]: """Get all child recursively. Optional filter prop""" for name, x in self.named_modules(prop): yield x
@property def directParams(self) -> Dict[str, nn.Parameter]: """Params directly under this module""" return {name: param for name, param in self.nnModule.named_parameters() if "." not in name}
[docs] def parameters(self) -> Iterator[nn.Parameter]: """Get generator of parameters, all depths""" return self.nnModule.parameters()
def __getattr__(self, attr): if attr.startswith("_"): raise AttributeError(attr) if attr in self._children: return self._children[attr] return self.directParams[attr] def __getitem__(self, idx): return getattr(self, str(idx)) @property def deepestDepth(self): """Deepest depth of the tree. If self doesn't have any child, then depth is 0""" if len(self._children) == 0: return 0 return 1 + max([child.deepestDepth for child in self._children.values()])
@k1lib.patch(ModuleSelector, name="parse") def _parse(self, selectors:List[str]): """Parses extra selectors. Clears all old selectors, but retain the props""" self.directSelectors = []; self.indirectSelectors = [] ogSelectors = selectors if self.parent != None: selectors = [] + selectors + self.parent.indirectSelectors + self.parent.directSelectors self.indirectSelectors += self.parent.indirectSelectors self.depth = self.parent.depth + 1 for selector in selectors: parts = _getParts(selector) matches = parts[0] == self.nnModule.__class__.__name__ or parts[0] == "#" + self.name or parts[0] == "*" if len(parts) == 1: if matches: self.selectedProps += _getProps(selector) else: a = selector.find(">"); a = a if a > 0 else float("inf") b = selector.find(" "); b = b if b > 0 else float("inf") direct = a < b if matches: if direct: self.directSelectors.append(selector[a+1:]) else: self.indirectSelectors.append(selector[b+1:]) for name, mod in self.nnModule.named_children(): if name not in self._children: self._children[name] = ModuleSelector(self, name, mod) self._children[name].parse(ogSelectors) self.selectedProps = list(set(self.selectedProps)) @k1lib.patch(ModuleSelector, name="apply") def _apply(self, f:Callable[[ModuleSelector], None]): """Applies a function to self and all child :class:`ModuleSelector`""" f(self) for child in self._children.values(): child.apply(f) @k1lib.patch(ModuleSelector, name="copy") def _copy(self): answer = ModuleSelector(self.parent, self.name, self.nnModule) answer.depth = self.depth answer.selectedProps = list(self.selectedProps) answer.displayF = self.displayF answer._children = {name:child.copy() for name, child in self._children.items()} for child in answer._children.values(): child.parent = answer return answer @k1lib.patch(ModuleSelector) def __repr__(self, intro:bool=True, header:Union[str, Tuple[str]]="", footer="", tabs:int=None): """ :param intro: whether to include a nice header and footer info :param header: str: include a header that starts where `displayF` will start Tuple[str, str]: first one in tree, second one in displayF section :param footer: same thing with header, but at the end :param header: include a header that starts where `displayF` will start :param tabs: number of tabs at the beginning. Best to leave this empty """ if tabs == None: tabs = 5 + self.deepestDepth answer = "ModuleSelector:\n" if intro else "" if header: h1, h2 = ("", header) if isinstance(header, str) else header answer += h1.ljust(tabs*4, " ") + h2 + "\n" answer += f"{self.name}: {self.nnModule.__class__.__name__}".ljust(tabs*4, " ") answer += self.displayF(self) + ("\n" if len(self._children) > 0 else "") answer += k1lib.tab("\n".join([child.__repr__(tabs=tabs-1, intro=False) for name, child in self._children.items()])) if footer: f1, f2 = ("", footer) if isinstance(footer, str) else footer answer += "\n" + f1.ljust(tabs*4, " ") + f2 if intro: answer += f"""\n\nCan... - mS.displayF = ...: sets a display function (mS -> str) for self and all descendants. Defaults to displaying all props - mS.deepestDepth: get deepest depth possible - mS.nnModule: get the underlying nn.Module object - mS.apply(f): apply to self and all descendants - mS.copy(): copy everything, including descendants - mS.selected("HookModule"): whether this module has a specified prop - mS.highlight(prop): highlights all modules with specified prop - mS.parse([..., ...]): parses extra css - mS.clearProps(): to clear all selected props, including descendants - mS.directParams(): get Dict[str, nn.Parameter] that are directly under this module - mS.named_children(), mS.children(): like PyTorch - mS.named_modules([prop]), mS.modules([prop]): like PyTorch. Optional filter prop - mS.parameters(): like PyTorch """ return answer
[docs]def filter(selectors:str, defaultProp="all") -> List[str]: r"""Removes all quirkly features allowed by the css language, and outputs nice lines. :param selectors: single css selector string. Statements separated by "\\n" or ";" :param defaultProp: default property, if statement doesn't have one""" # filtering unwanted characters and quirky spaces lines = [e for l in selectors.split("\n") for e in l.split(";")] selectors = [re.sub("(^\s+)|(\s+$)", "", re.sub("\s\s+", " ", line)).replace(" >", ">").replace("> ", ">").replace(" :", ":").replace(": ", ":").replace(" ,", ",").replace(", ", ",").replace(";", "\n").replace(" \n", "\n").replace("\n ", "\n") for line in lines if line != ""] # adding "all" to all selectors with no props specified selectors = [selector if ":" in selector else f"{selector}:{defaultProp}" for selector in selectors] # expanding comma-delimited selectors return [f"{segment}:{selector.split(':')[1]}" for selector in selectors for segment in selector.split(":")[0].split(",")]
[docs]def select(model:"torch.nn.Module", selectors:str) -> ModuleSelector: """Creates a new ModuleSelector, in sync with a model""" root = ModuleSelector(None, "root", model) root.parse(filter(selectors)) return root