# AUTOGENERATED FILE! PLEASE DON'T EDIT
import math as _math, k1lib as _k1lib
import matplotlib.pyplot as _plt, numpy as _np
from itertools import accumulate as _accumulate
import typing as _typing
from typing import List as _List, Callable as _Callable,\
Union as _Union
ScheduleF = _typing.Callable[[float], float]
[docs]def combine(lambdas:_List[ScheduleF], ratios:_List[float]=None) -> ScheduleF:
"""Combine multiple different schedules.
:param ratios: weighting diferent functions. Does not have to
add up to 1.
:param lambdas: functions with 1 float input in [0, 1]"""
if ratios == None: ratios = [1] * len(lambdas)
ratios = _np.array(ratios)
ratios = ratios / ratios.sum()
checkpoints = [0] + list(_accumulate(ratios))
def f(x):
for idx, checkpoint in enumerate(checkpoints):
if checkpoint > x: break
a = checkpoints[idx - 1]
return lambdas[idx - 1]((x - a)/(checkpoints[idx] - a))
return f
[docs]def decorate(f:_Callable[[float, float, float], float]) -> _Callable[[float], float]:
"""Decorator, transforms f(low, high, x) to (low, high) -> f(x)."""
def _f(low, high): return lambda x: f(low, high, x)
_f.__doc__ = f.__doc__; return _f
[docs]@decorate
def linear(low, high, x): return low + x * (high - low)
[docs]@decorate
def cosine(low, high, x): return low + (high - low) * (1 + _math.cos(_math.pi * (1-x))) / 2
[docs]def oneCycle(low, high): return combine([cosine(0.8 * low + 0.2 * high, high), cosine(high, low)], [0.3, 0.7])
_en4 = _math.e**-3
[docs]@decorate
def decay(low, high, x): return (_math.exp(-x*4+1) - _en4) / (_math.e - _en4) * (low - high) + high
[docs]class Schedule:
def __init__(self, param:str, scheduleF:ScheduleF):
self.param = param; self.scheduleF = scheduleF; self.progress = None
def __call__(self, x:float): return self.scheduleF(x)
[docs] def startBatch(self, paramGroup:dict, progress:float):
paramGroup[self.param] = self.scheduleF(progress)
self.progress = progress
@property
def value(self): return self.scheduleF(self.progress)
def __repr__(self):
_plt.figure(dpi=100); c = dict(color="tab:green")
x = _np.linspace(0, 1, 1000); y = [self.scheduleF(x) for x in x]; _plt.plot(x, y)
y = self(0); _plt.plot(0, y, "o", **c); _plt.annotate("(0, {:.1e})".format(y), (0, y))
y = self(1); _plt.plot(1, y, "o", **c); _plt.annotate("(1, {:.1e})".format(y), (1, y))
if (x := self.progress) is not None:
blur = not (x in _k1lib.Range(0.1, 0.9))
y = self(x); _plt.plot(x, y, "o", **c, alpha=(0.5 if blur else 1))
if not blur: _plt.annotate("({:.1e}, {:.1e})".format(x, y), (x, y))
_plt.show()
return f"""'{self.param}' schedule. Can...
- s.progress: to get last recorded progress
- s.value: to get last recorded hyper parameter's value
- s(0.3): to get value of schedule at 30% progress"""
[docs] @staticmethod
def linear(param:str, low:float, high:float) -> _Callable[[float], float]:
"""Sharply goes from low to high"""
return Schedule(param, linear(low, high))
[docs] @staticmethod
def cosine(param:str, low:float, high:float) -> _Callable[[float], float]:
"""Smoothly goes from low to high"""
return Schedule(param, cosine(low, high))
[docs] @staticmethod
def oneCycle(param:str, low:float, high:float, upPortion:float=0.3) -> _Callable[[float], float]:
"""Goes from middle-low (0.8*low + 0.2*high) to high to low
:param upPortion: the percent of portion that goes up"""
return Schedule(param, combine([cosine(0.8 * low + 0.2 * high, high), cosine(high, low)], [upPortion, 1-upPortion]))
[docs] @staticmethod
def decay(param:str, low:float, high:float) -> _Callable[[float], float]:
"""Rises/drops quickly, then rate of change gets smaller and smaller"""
return Schedule(param, decay(low, high))
[docs]@_k1lib.patch(_k1lib.Callback.cls)
class ParamScheduler(_k1lib.Callback):
"""Schedules a param in parts of the network.
:param css: the selected parts of the network to schedule
:param schedules: (obvious)"""
def __init__(self, css:str, schedules:_Union[_List[Schedule], Schedule]):
super().__init__(); self.css = css; schedules = schedules or []
schedules = [schedules] if isinstance(schedules, Schedule) else schedules
self.schedules = {s.param:s for s in schedules}
self.groupId = None; self.dependsOn = set("ProgressBar")
self.initialized = False; self.prop = None
def endRun(self):
":meta private:"
self.initialized = False
def __getstate__(self):
answer = dict(self.__dict__)
if "selector" in answer: del answer["selector"]
return answer
def startBatch(self):
":meta private:"
if self.l.model.training:
paramGroup = self.l.opt.param_groups[self.groupId]
progress = self.l.ProgressBar.progress
for schedule in self.schedules.values(): schedule.startBatch(paramGroup, progress)
def __repr__(self):
print(f"{self._reprHead}, css: \"{self.css}\", selector prop: \"{self.prop}\", schedules:")
for schedule in self.schedules.values(): schedule.__repr__()
return f"""Can...
- ps.schedules["lr"]: to get the schedule for a specific param
- ps.selector: to view the selected parameters
{self._reprCan}"""
@_k1lib.patch(ParamScheduler, name="startRun")
def _startRun(self):
"meta:private"
if not self.initialized:
# get all other ParamSchedulers
cbs = [cb for cb in self.l.cbs if isinstance(cb, ParamScheduler) and not cb.suspended]
# delete all old _ps_{i} selectors, and add new ones
css = [line for line in self.l.css.split("\n") if "_ps_" not in line]
for i, cb in enumerate(cbs):
cb.prop = f"_ps_{i}"; css += _k1lib.selector.filter(cb.css, cb.prop)
cb.initialized = True # make sure only 1 startRun is ran across all ParamSchedulers
self.l.css = "\n".join(css)
# sort cbs based on depth, so that deeper ones gets accounted for first
for cb in cbs: cb._depth = next(self.l.selector.modules(cb.prop)).depth
cbs = sorted(cbs, key=lambda cb: -cb._depth)
# clear and add param groups
self.l.opt.param_groups = []
allParams = set(self.l.selector.parameters())
for cb in cbs:
params = set()
for m in self.l.selector.modules(cb.prop):
for p in m.parameters():
if p in allParams:
params.add(p); allParams.remove(p)
if len(params) > 0:
cb.groupId = len(self.l.opt.param_groups)
self.l.opt.add_param_group({"prop": cb.prop, "css": cb.css, "params": list(params), **self.l.opt.defaults})
self.l.opt.add_param_group({"prop": "rest", "css": "*", "params": list(allParams), **self.l.opt.defaults})
for cb in cbs:
params = set(self.l.opt.param_groups[cb.groupId]["params"])
cb.selector = self.l.selector.copy()
def applyF(mS):
mS.displayF = lambda s: "*" if any([p in params for p in s.directParams.values()]) else ""
cb.selector.apply(applyF)
@_k1lib.patch(_k1lib.Callbacks, docs=ParamScheduler, name="withParamScheduler")
def _withParamScheduler(self, css:str, schedules:list, name:str=None):
return self.append(ParamScheduler(css, schedules), name=name)