k1lib.selector module¶
This module is for selecting a subnetwork using CSS so that you can do special things to them. Checkout the tutorial section for a walkthrough. This is exposed automatically with:
from k1lib.imports import *
selector.select # exposed
-
class
k1lib.selector.
ModuleSelector
(parent: k1lib.selector.ModuleSelector, name: str, nn: torch.nn.modules.module.Module)[source]¶ Bases:
object
-
nn
: torch.nn.Module¶ The associated
torch.nn.Module
of thisModuleSelector
-
props
: List[str]¶ Properties of this
ModuleSelector
-
idx
: int¶ Unique id of this
ModuleSelector
in the entire script. May be useful for module recognition
-
property
deepestDepth
¶ Deepest depth of the tree. If self doesn’t have any child, then depth is 0
-
__call__
(*args, **kwargs)[source]¶ Calls the internal
torch.nn.Module
-
__contains__
(prop: Optional[str] = None) → bool[source]¶ Whether this
ModuleSelector
has a specific prop. Example:# returns True "b" in nn.Linear(3, 4).select("*:b") # returns False "h" in nn.Linear(3, 4).select("*:b") # returns True, "*" here means the ModuleSelector has any properties at all "*" in nn.Linear(3, 4).select("*:b")
-
named_children
(prop: Optional[str] = None) → Iterator[Tuple[str, k1lib.selector.ModuleSelector]][source]¶ Get all named direct childs.
- Parameters
prop – Filter property. See also:
__contains__()
-
children
(prop: Optional[str] = None) → Iterator[k1lib.selector.ModuleSelector][source]¶ Get all direct childs.
- Parameters
prop – Filter property. See also:
__contains__()
-
named_modules
(prop: Optional[str] = None) → Iterator[Tuple[str, k1lib.selector.ModuleSelector]][source]¶ Get all named child recursively. Example:
modules = list(nn.Sequential(nn.Linear(3, 4), nn.ReLU()).select().named_modules()) # return 3 len(modules) # return tuple ('0', <ModuleSelector of Linear>) modules[1]
- Parameters
prop – Filter property. See also:
__contains__()
-
modules
(prop: Optional[str] = None) → Iterator[k1lib.selector.ModuleSelector][source]¶ Get all child recursively.
- Parameters
prop – Filter property. See also:
__contains__()
-
property
directParams
¶ Dict params directly under this module
-
parse
(selectors: Union[List[str], str]) → k1lib.selector.ModuleSelector[source]¶ Parses extra selectors. Clears all old selectors, but retain the props. Returns self. Example:
mS = selector.ModuleSelector.sample().parse("Conv2d:propA") # returns True "propA" in mS[1][0]
- Parameters
selectors – can be the preprocessed list, or the unprocessed css string
-
apply
(f: Callable[[k1lib.selector.ModuleSelector], None])[source]¶ Applies a function to self and all child
ModuleSelector
-
clearProps
() → k1lib.selector.ModuleSelector[source]¶ Clears all existing props of this and all descendants
ModuleSelector
. Example:# returns False "b" in nn.Linear(3, 4).select("*:b").clearProps()
-
property
displayF
¶ Function to display each ModuleSelector’s lines. Default is just:
lambda mS: ", ".join(mS.props)
-
static
sample
() → k1lib.selector.ModuleSelector[source]¶ Create a new example
ModuleSelector
that has a bit of hierarchy to them, with no css.
-
hookF
(f: Callable[[k1lib.selector.ModuleSelector, torch.nn.modules.module.Module, Tuple[torch.Tensor], torch.Tensor], None] = None, prop: str = '*')[source]¶ Context manager for applying forward hooks. Example:
def f(mS, i, o): print(i, o) m = nn.Linear(3, 4) with m.select().hookF(f): m(torch.randn(2, 3))
- Parameters
f – hook callback, should accept
ModuleSelector
, inputs and outputprop – filter property of module to hook onto. If not specified, then it will print out input and output tensor shapes.
-
hookFp
(f=None, prop: str = '*')[source]¶ Context manager for applying forward pre hooks. Example:
def f(mS, i): print(i) m = nn.Linear(3, 4) with m.select().hookFp(f): m(torch.randn(2, 3))
- Parameters
f – hook callback, should accept
ModuleSelector
and inputsprop – filter property of module to hook onto. If not specified, then it will print out input tensor shapes.
-
hookB
(f=None, prop: str = '*')[source]¶ Context manager for applying backward hooks. Example:
def f(mS, i, o): print(i, o) m = nn.Linear(3, 4) with m.select().hookB(f): m(torch.randn(2, 3)).sum().backward()
- Parameters
f – hook callback, should accept
ModuleSelector
, grad inputs and outputsprop – filter property of module to hook onto. If not specified, then it will print out input tensor shapes.
-
freeze
(prop: str = '*')[source]¶ Returns a context manager that freezes (set requires_grad to False) parts of the network. Example:
l = k1lib.Learner.sample() w = l.model.lin1.lin.weight.clone() # weights before with l.model.select("#lin1").freeze(): l.run(1) # returns True (l.model.lin1.lin.weight == w).all()
-
unfreeze
(prop: str = '*')[source]¶ Returns a context manager that unfreezes (set requires_grad to True) parts of the network. Example:
l = k1lib.Learner.sample() w = l.model.lin1.lin.weight.clone() # weights before with l.model.select("#lin1").freeze(): with l.model.select("#lin1 > #lin").unfreeze(): l.run(1) # returns False (l.model.lin1.lin.weight == w).all()
-
cutOff
() → torch.nn.modules.module.Module¶ Creates a new network that returns the selected layer’s output. Example:
xb = torch.randn(10, 2) m = nn.Sequential(nn.Linear(2, 5), nn.Linear(5, 4), nn.Linear(4, 6)) m0 = m.select("#0").cutOff(); m1 = m.select("#1").cutOff() # returns (10, 6) m(xb).shape # returns (10, 5) m0(xb).shape == torch.Size([10, 5]) # returns (10, 4) m1(xb).shape == torch.Size([10, 4])
-
-
k1lib.selector.
preprocess
(selectors: str, defaultProp='*') → List[str][source]¶ Removes all quirkly features allowed by the css language, and outputs nice lines. Example:
# returns ["a:f", "a:g,h", "b:g,h", "t:*"] selector.preprocess("a:f; a, b: g,h; t")
- Parameters
selectors – single css selector string. Statements separated by “\n” or “;”
defaultProp – default property, if statement doesn’t have one
-
k1lib.selector.
select
(model: torch.nn.modules.module.Module, css: str = '*') → k1lib.selector.ModuleSelector[source]¶ Creates a new ModuleSelector, in sync with a model. Example:
mS = selector.select(nn.Linear(3, 4), "#root:propA")
Or, you can do it the more direct way:
mS = nn.Linear(3, 4).select("#root:propA")
- Parameters
model – the
torch.nn.Module
object to select fromcss – the css selectors