# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""
This module allows you to make and combine a bunch of schedules, and setup the
optimizer so that it changes hyperparameter values based on the schedule. Highly
recommend you check out the `tutorials section <tutorials.html>`_ on this.
This is exposed automatically with::
from k1lib.imports import *
schedule.Fn # exposed
"""
import math, k1lib; import k1lib.cli as cli
import matplotlib.pyplot as plt, numpy as np
from itertools import accumulate
from k1lib.callbacks import Cbs, Callback
from typing import List, Callable, Union
__all__ = ["Fn", "linear", "smooth", "hump", "exp", "ParamScheduler"]
[docs]class Fn:
[docs] def __init__(self, f:Callable[[float], float], param:str=None):
"""Creates a new schedule based on some custom function.
Example::
s = schedule.Fn(lambda x: x**2)
s(0.2) # returns 0.04
# you can also use this as a decorator
@schedule.Fn
def s(x):
return x**2
:param f: function (domain should always in [0, 1]), can be :class:`~k1lib.cli.modifier.op`
:param param: (optional) Parameter to schedule (e.g "lr") if using :class:`ParamScheduler`"""
if isinstance(f, cli.op): f.op_solidify()
self.f = f; self.param = param; self.progress = None
self.domain = k1lib.Range(0, 1)
def __call__(self, x:float):
"""Get the current value."""
return self.f(x)
def _startBatch(self, paramGroup:dict, progress:float):
self.progress = paramGroup[self.param] = self(progress)
@property
def value(self): return self.f(self.progress)
def __mul__(self, x): self.domain *= x; return self
def __rmul__(self, x): self.domain *= x; return self
def __truediv__(self, x): self.domain /= x; return self
def __rtruediv__(self, x): self * (1.0/x); return self
def __add__(self, s:Union["Fn", str]) -> "Fn":
"""If given :class:`Fn`, then combines the 2 schedules together.
If it's a string, then sets the current param to it."""
if isinstance(s, Fn): return CombinedSchedule(self, s)
self.param = s; return self
[docs] def iter(self, n:int):
"""Returns an n-step iterator evenly divided in range [0, 1].
Example::
s = schedule.Fn(lambda x: x+2)
list(s.iter(6)) # returns [2.0, 2.2, 2.4, 2.6, 2.8, 3.0]"""
for e in np.linspace(0, 1, n): yield self(e)
[docs] def modifyOutput(self, f:Callable[[float], float]) -> "Fn":
"""Returns a new :class:`Fn` that has its output modified.
Example::
s = Fn(lambda x: x+2)
s.modifyOutput(lambda x: x**2) # now s's function is (x+2)**2"""
return Fn(lambda x: f(self.f(x)), self.param)
@k1lib.patch(Fn)
def __repr__(self):
plt.figure(dpi=100); c = dict(color="tab:green")
x = np.linspace(*self.domain, 1000); y = [self.f(x) for x in x]; plt.plot(x, y)
y = self(0); plt.plot(0, y, "o", **c); plt.annotate("(0, {:.1e})".format(y), (0, y))
y = self(1); plt.plot(1, y, "o", **c); plt.annotate("(1, {:.1e})".format(y), (1, y))
x = self.progress
if x is not None:
blur = not (x in k1lib.Range(0.1, 0.9))
y = self(x); plt.plot(x, y, "o", **c, alpha=(0.5 if blur else 1))
if not blur: plt.annotate("({:.1e}, {:.1e})".format(x, y), (x, y))
plt.show()
return f"""'{self.param}' schedule. Can...
- s.progress: to get last recorded progress
- s.value: to get last recorded hyper parameter's value
- s(0.3): to get value of schedule at 30% progress"""
class CombinedSchedule(Fn):
def __init__(self, s1, s2):
split = s1.domain.stop / (s1.domain.delta + s2.domain.delta)
s1r = k1lib.Range(0, split); s2r = k1lib.Range(split, 1)
def f(x):
if x < split: return s1.f(s1r.toUnit(x))
else: return s2.f(s2r.toUnit(x))
super().__init__(f, s1.param or s2.param)
def decorate(f:Callable[[float, float, float], float]) -> Fn:
"""Decorator, transforms f(low, high, x) to (low, high) -> f(x)."""
def _f(low, high, param:str=None):
return Fn(lambda x: f(low, high, x), param)
return k1lib.wraps(f)(_f)
[docs]@decorate
def linear(low, high, x):
"""Sharply goes from low to high"""
return low + x * (high - low)
[docs]@decorate
def smooth(low, high, x):
"""Smoothly goes from low to high"""
return low + (high - low) * (1 + math.cos(math.pi * (1-x))) / 2
[docs]def hump(low, high, param:str=None):
"""Smoothly rises up (30%), then down (70%)"""
return 0.3*smooth(0.8 * low + 0.2 * high, high) + 0.7*smooth(high, low, param)
_en4 = math.e**-3
[docs]@decorate
def exp(low, high, x):
"""Rises/drops quickly, then rate of change gets smaller and smaller"""
return (math.exp(-x*4+1) - _en4) / (math.e - _en4) * (low - high) + high
[docs]@k1lib.patch(Cbs)
class ParamScheduler(Callback):
"""Schedules a param in parts of the network.
:param css: the selected parts of the network to schedule
:param schedules: (obvious)"""
def __init__(self, css:str, *schedules:List[Fn]):
super().__init__(); self.css = css
for i, s in enumerate(schedules):
if s.param is None: raise RuntimeError(f"Schedule {i} does not have associated parameter! Set with `s.param = 'lr'`.")
self.schedules = {s.param:s for s in schedules}
self.groupId = None; self.dependsOn = set("ProgressBar")
self.initialized = False; self.prop = None
def endRun(self):
":meta private:"
self.initialized = False
def __getstate__(self):
answer = dict(self.__dict__)
if "selector" in answer: del answer["selector"]
return answer
[docs] def startBatch(self):
if self.l.model.training and self.groupId is not None:
paramGroup = self.l.opt.param_groups[self.groupId]
progress = self.l.ProgressBar.progress
for schedule in self.schedules.values():
schedule._startBatch(paramGroup, progress)
def __repr__(self):
print(f"{self._reprHead}, css: \"{self.css}\", selector prop: \"{self.prop}\", schedules:")
for schedule in self.schedules.values(): schedule.__repr__()
return f"""Can...
- ps.schedules["lr"]: to get the schedule for a specific param
- ps.selector: to view the selected parameters
{self._reprCan}"""
@k1lib.patch(ParamScheduler, name="startRun")
def _startRun(self):
if not self.initialized:
# get all other ParamSchedulers
pss = [cb for cb in self.l.cbs if isinstance(cb, ParamScheduler) and not cb.suspended]
for i, ps in enumerate(pss):
# make sure only 1 startRun is ran across all ParamSchedulers
ps.initialized = True; ps.prop = f"_ps_{i}"
ps.selector = k1lib.selector.select(self.l.model, ps.css)
# sort pss based on depth, so that deeper ones gets accounted for first
ps._depth = next(ps.selector.modules(ps.prop)).depth
pss = sorted(pss, key=lambda ps: -ps._depth)
# clear and add param groups
self.l.opt.param_groups = []
allParams = set(self.l.selector.nn.parameters())
for ps in pss:
params = set()
for m in ps.selector.modules(ps.prop):
for p in m.nn.parameters():
if p in allParams:
params.add(p); allParams.remove(p)
if len(params) > 0:
# so that we have a way to reference the group later on
ps.groupId = len(self.l.opt.param_groups)
self.l.opt.add_param_group({"prop": ps.prop, "css": ps.css, "params": list(params), **self.l.opt.defaults})
self.l.opt.add_param_group({"prop": "rest", "css": "*", "params": list(allParams), **self.l.opt.defaults})
for ps in pss:
if ps.groupId is None: continue
params = set(self.l.opt.param_groups[ps.groupId]["params"])
def applyF(mS):
mS.displayF = lambda s: "*" if any(p in params for p in s.directParams.values()) else ""
ps.selector.apply(applyF)
@k1lib.patch(k1lib.Callbacks, docs=ParamScheduler, name="withParamScheduler")
def _withParamScheduler(self, css:str, schedules:list, name:str=None):
return self.append(ParamScheduler(css, schedules), name=name)