Source code for k1lib.schedule

# AUTOGENERATED FILE! PLEASE DON'T EDIT
import math, k1lib
import matplotlib.pyplot as plt, numpy as np
from itertools import accumulate
from typing import List, Callable, Union
__all__ = ["Fn", "linear", "smooth", "hump", "exp", "ParamScheduler"]
[docs]class Fn:
[docs] def __init__(self, f:Callable[[float], float], param:str=None): """Creates a new schedule based on some custom function. :param f: domain should always in [0, 1] :param param: (optional) Parameter to schedule (e.g "lr") if using :class:`ParamScheduler`""" self.f = f; self.param = param; self.progress = None self.domain = k1lib.Range(0, 1)
def __call__(self, x:float): """Get the current value.""" return self.f(x) def _startBatch(self, paramGroup:dict, progress:float): self.progress = paramGroup[self.param] = self(progress) @property def value(self): return self.f(self.progress) def __mul__(self, x): self.domain *= x; return self def __rmul__(self, x): self.domain *= x; return self def __truediv__(self, x): self.domain /= x; return self def __rtruediv__(self, x): self * (1.0/x); return self def __add__(self, s:Union["Fn", str]) -> "Fn": """If given :class:`Fn`, then combines the 2 schedules together. If it's a string, then sets the current param to it.""" if isinstance(s, Fn): return CombinedSchedule(self, s) self.param = s; return self
[docs] def iter(self, n:int): """Returns an n-step iterator evenly divided in range [0, 1]. Example:: s = schedule.Fn(lambda x: x+2) list(s.iter(6)) # returns [2.0, 2.2, 2.4, 2.6, 2.8, 3.0]""" for e in np.linspace(0, 1, n): yield self(e)
[docs] def modifyOutput(self, f:Callable[[float], float]) -> "Fn": """Returns a new :class:`Fn` that has its output modified. Example:: s = Fn(lambda x: x+2) s.modifyOutput(lambda x: x**2) # now s's function is (x+2)**2""" return Fn(lambda x: f(self.f(x)), self.param)
@k1lib.patch(Fn) def __repr__(self): plt.figure(dpi=100); c = dict(color="tab:green") x = np.linspace(*self.domain, 1000); y = [self.f(x) for x in x]; plt.plot(x, y) y = self(0); plt.plot(0, y, "o", **c); plt.annotate("(0, {:.1e})".format(y), (0, y)) y = self(1); plt.plot(1, y, "o", **c); plt.annotate("(1, {:.1e})".format(y), (1, y)) x = self.progress if x is not None: blur = not (x in k1lib.Range(0.1, 0.9)) y = self(x); plt.plot(x, y, "o", **c, alpha=(0.5 if blur else 1)) if not blur: plt.annotate("({:.1e}, {:.1e})".format(x, y), (x, y)) plt.show() return f"""'{self.param}' schedule. Can... - s.progress: to get last recorded progress - s.value: to get last recorded hyper parameter's value - s(0.3): to get value of schedule at 30% progress""" class CombinedSchedule(Fn): def __init__(self, s1, s2): split = s1.domain.stop / (s1.domain.delta + s2.domain.delta) s1r = k1lib.Range(0, split); s2r = k1lib.Range(split, 1) def f(x): if x < split: return s1.f(s1r.toUnit(x)) else: return s2.f(s2r.toUnit(x)) super().__init__(f, s1.param or s2.param) def decorate(f:Callable[[float, float, float], float]) -> Fn: """Decorator, transforms f(low, high, x) to (low, high) -> f(x).""" def _f(low, high, param:str=None): return Fn(lambda x: f(low, high, x), param) return k1lib.wraps(f)(_f)
[docs]@decorate def linear(low, high, x): """Sharply goes from low to high""" return low + x * (high - low)
[docs]@decorate def smooth(low, high, x): """Smoothly goes from low to high""" return low + (high - low) * (1 + math.cos(math.pi * (1-x))) / 2
[docs]def hump(low, high, param:str=None): """Smoothly rises up (30%), then down (70%)""" return 0.3*smooth(0.8 * low + 0.2 * high, high) + 0.7*smooth(high, low, param)
_en4 = math.e**-3
[docs]@decorate def exp(low, high, x): """Rises/drops quickly, then rate of change gets smaller and smaller""" return (math.exp(-x*4+1) - _en4) / (math.e - _en4) * (low - high) + high
[docs]@k1lib.patch(k1lib.Callback.cls) class ParamScheduler(k1lib.Callback): """Schedules a param in parts of the network. :param css: the selected parts of the network to schedule :param schedules: (obvious)""" def __init__(self, css:str, *schedules:List[Fn]): super().__init__(); self.css = css for i, s in enumerate(schedules): if s.param is None: raise RuntimeError(f"Schedule {i} does not have associated parameter! Set with `s.param = 'lr'`.") self.schedules = {s.param:s for s in schedules} self.groupId = None; self.dependsOn = set("ProgressBar") self.initialized = False; self.prop = None def endRun(self): ":meta private:" self.initialized = False def __getstate__(self): answer = dict(self.__dict__) if "selector" in answer: del answer["selector"] return answer
[docs] def startBatch(self): if self.l.model.training and self.groupId is not None: paramGroup = self.l.opt.param_groups[self.groupId] progress = self.l.ProgressBar.progress for schedule in self.schedules.values(): schedule._startBatch(paramGroup, progress)
def __repr__(self): print(f"{self._reprHead}, css: \"{self.css}\", selector prop: \"{self.prop}\", schedules:") for schedule in self.schedules.values(): schedule.__repr__() return f"""Can... - ps.schedules["lr"]: to get the schedule for a specific param - ps.selector: to view the selected parameters {self._reprCan}"""
@k1lib.patch(ParamScheduler, name="startRun") def _startRun(self): if not self.initialized: # get all other ParamSchedulers pss = [cb for cb in self.l.cbs if isinstance(cb, ParamScheduler) and not cb.suspended] for i, ps in enumerate(pss): # make sure only 1 startRun is ran across all ParamSchedulers ps.initialized = True; ps.prop = f"_ps_{i}" ps.selector = k1lib.selector.select(self.l.model, ps.css) # sort pss based on depth, so that deeper ones gets accounted for first ps._depth = next(ps.selector.modules(ps.prop)).depth pss = sorted(pss, key=lambda ps: -ps._depth) # clear and add param groups self.l.opt.param_groups = [] allParams = set(self.l.selector.parameters()) for ps in pss: params = set() for m in ps.selector.modules(ps.prop): for p in m.parameters(): if p in allParams: params.add(p); allParams.remove(p) if len(params) > 0: # so that we have a way to reference the group later on ps.groupId = len(self.l.opt.param_groups) self.l.opt.add_param_group({"prop": ps.prop, "css": ps.css, "params": list(params), **self.l.opt.defaults}) self.l.opt.add_param_group({"prop": "rest", "css": "*", "params": list(allParams), **self.l.opt.defaults}) for ps in pss: if ps.groupId is None: continue params = set(self.l.opt.param_groups[ps.groupId]["params"]) def applyF(mS): mS.displayF = lambda s: "*" if any(p in params for p in s.directParams.values()) else "" ps.selector.apply(applyF) @k1lib.patch(k1lib.Callbacks, docs=ParamScheduler, name="withParamScheduler") def _withParamScheduler(self, css:str, schedules:list, name:str=None): return self.append(ParamScheduler(css, schedules), name=name)