# AUTOGENERATED FILE! PLEASE DON'T EDIT
import k1lib
from .callbacks import Callback, Callbacks, Cbs
__all__ = ["Frozen"]
[docs]@k1lib.patch(Cbs)
class Frozen(Callback):
"""Freezes selected parts of the network"""
[docs] def __init__(self, css:str):
""":param css: css selectors for the parts you want to freeze"""
self.css = str
def startRun(self):
self.selector = self.l.selector.copy()
self.selector.clearProps()
self.selector.parse(k1lib.selector.filter(css, "_frozen_"))
self.params = []; self.oldParamValues = []
for m in self.selector.modules("_frozen_"):
self.params.extend(m.parameters())
for p in self.params:
self.oldParamValues.append(p.requires_grad)
p.requires_grad = False
def endRun(self):
for p, v in zip(self.params, self.oldParamValues):
p.requires_grad = v
self.params = []
def __repr__(self):
return f"""{self._reprHead}, can...
- f.selector: to get internal k1lib.ModuleSelector object
{self._reprCan}"""
@k1lib.patch(Callbacks, docs=Frozen)
def withFrozen(self, css:str): return self.append(Frozen(css))