k1lib.selector module¶
This module is for selecting a subnetwork using CSS so that you can do special things to them. Checkout the tutorial section for a walkthrough. This is exposed automatically with:
from k1lib.imports import *
selector.select # exposed
- 
class k1lib.selector.ModuleSelector(parent: k1lib.selector.ModuleSelector, name: str, nn: torch.nn.modules.module.Module)[source]¶
- Bases: - object- 
nn: torch.nn.Module¶
- The associated - torch.nn.Moduleof this- ModuleSelector
 - 
props: List[str]¶
- Properties of this - ModuleSelector
 - 
idx: int¶
- Unique id of this - ModuleSelectorin the entire script. May be useful for module recognition
 - 
property deepestDepth¶
- Deepest depth of the tree. If self doesn’t have any child, then depth is 0 
 - 
__call__(*args, **kwargs)[source]¶
- Calls the internal - torch.nn.Module
 - 
__contains__(prop: Optional[str] = None) → bool[source]¶
- Whether this - ModuleSelectorhas a specific prop. Example:- # returns True "b" in nn.Linear(3, 4).select("*:b") # returns False "h" in nn.Linear(3, 4).select("*:b") # returns True, "*" here means the ModuleSelector has any properties at all "*" in nn.Linear(3, 4).select("*:b") 
 - 
named_children(prop: Optional[str] = None) → Iterator[Tuple[str, k1lib.selector.ModuleSelector]][source]¶
- Get all named direct childs. - Parameters
- prop – Filter property. See also: - __contains__()
 
 - 
children(prop: Optional[str] = None) → Iterator[k1lib.selector.ModuleSelector][source]¶
- Get all direct childs. - Parameters
- prop – Filter property. See also: - __contains__()
 
 - 
named_modules(prop: Optional[str] = None) → Iterator[Tuple[str, k1lib.selector.ModuleSelector]][source]¶
- Get all named child recursively. Example: - modules = list(nn.Sequential(nn.Linear(3, 4), nn.ReLU()).select().named_modules()) # return 3 len(modules) # return tuple ('0', <ModuleSelector of Linear>) modules[1] - Parameters
- prop – Filter property. See also: - __contains__()
 
 - 
modules(prop: Optional[str] = None) → Iterator[k1lib.selector.ModuleSelector][source]¶
- Get all child recursively. - Parameters
- prop – Filter property. See also: - __contains__()
 
 - 
property directParams¶
- Dict params directly under this module 
 - 
parse(selectors: Union[List[str], str]) → k1lib.selector.ModuleSelector[source]¶
- Parses extra selectors. Clears all old selectors, but retain the props. Returns self. Example: - mS = selector.ModuleSelector.sample().parse("Conv2d:propA") # returns True "propA" in mS[1][0] - Parameters
- selectors – can be the preprocessed list, or the unprocessed css string 
 
 - 
apply(f: Callable[[k1lib.selector.ModuleSelector], None])[source]¶
- Applies a function to self and all child - ModuleSelector
 - 
clearProps() → k1lib.selector.ModuleSelector[source]¶
- Clears all existing props of this and all descendants - ModuleSelector. Example:- # returns False "b" in nn.Linear(3, 4).select("*:b").clearProps() 
 - 
property displayF¶
- Function to display each ModuleSelector’s lines. Default is just: - lambda mS: ", ".join(mS.props) 
 - 
static sample() → k1lib.selector.ModuleSelector[source]¶
- Create a new example - ModuleSelectorthat has a bit of hierarchy to them, with no css.
 - 
hookF(f: Callable[[k1lib.selector.ModuleSelector, torch.nn.modules.module.Module, Tuple[torch.Tensor], torch.Tensor], None] = None, prop: str = '*')[source]¶
- Context manager for applying forward hooks. Example: - def f(mS, i, o): print(i, o) m = nn.Linear(3, 4) with m.select().hookF(f): m(torch.randn(2, 3)) - Parameters
- f – hook callback, should accept - ModuleSelector, inputs and output
- prop – filter property of module to hook onto. If not specified, then it will print out input and output tensor shapes. 
 
 
 - 
hookFp(f=None, prop: str = '*')[source]¶
- Context manager for applying forward pre hooks. Example: - def f(mS, i): print(i) m = nn.Linear(3, 4) with m.select().hookFp(f): m(torch.randn(2, 3)) - Parameters
- f – hook callback, should accept - ModuleSelectorand inputs
- prop – filter property of module to hook onto. If not specified, then it will print out input tensor shapes. 
 
 
 - 
hookB(f=None, prop: str = '*')[source]¶
- Context manager for applying backward hooks. Example: - def f(mS, i, o): print(i, o) m = nn.Linear(3, 4) with m.select().hookB(f): m(torch.randn(2, 3)).sum().backward() - Parameters
- f – hook callback, should accept - ModuleSelector, grad inputs and outputs
- prop – filter property of module to hook onto. If not specified, then it will print out input tensor shapes. 
 
 
 - 
freeze(prop: str = '*')[source]¶
- Returns a context manager that freezes (set requires_grad to False) parts of the network. Example: - l = k1lib.Learner.sample() w = l.model.lin1.lin.weight.clone() # weights before with l.model.select("#lin1").freeze(): l.run(1) # returns True (l.model.lin1.lin.weight == w).all() 
 - 
unfreeze(prop: str = '*')[source]¶
- Returns a context manager that unfreezes (set requires_grad to True) parts of the network. Example: - l = k1lib.Learner.sample() w = l.model.lin1.lin.weight.clone() # weights before with l.model.select("#lin1").freeze(): with l.model.select("#lin1 > #lin").unfreeze(): l.run(1) # returns False (l.model.lin1.lin.weight == w).all() 
 - 
cutOff() → torch.nn.modules.module.Module¶
- Creates a new network that returns the selected layer’s output. Example: - xb = torch.randn(10, 2) m = nn.Sequential(nn.Linear(2, 5), nn.Linear(5, 4), nn.Linear(4, 6)) m0 = m.select("#0").cutOff(); m1 = m.select("#1").cutOff() # returns (10, 6) m(xb).shape # returns (10, 5) m0(xb).shape == torch.Size([10, 5]) # returns (10, 4) m1(xb).shape == torch.Size([10, 4]) 
 
- 
- 
k1lib.selector.preprocess(selectors: str, defaultProp='*') → List[str][source]¶
- Removes all quirkly features allowed by the css language, and outputs nice lines. Example: - # returns ["a:f", "a:g,h", "b:g,h", "t:*"] selector.preprocess("a:f; a, b: g,h; t") - Parameters
- selectors – single css selector string. Statements separated by “\n” or “;” 
- defaultProp – default property, if statement doesn’t have one 
 
 
- 
k1lib.selector.select(model: torch.nn.modules.module.Module, css: str = '*') → k1lib.selector.ModuleSelector[source]¶
- Creates a new ModuleSelector, in sync with a model. Example: - mS = selector.select(nn.Linear(3, 4), "#root:propA") - Or, you can do it the more direct way: - mS = nn.Linear(3, 4).select("#root:propA") - Parameters
- model – the - torch.nn.Moduleobject to select from
- css – the css selectors