# AUTOGENERATED FILE! PLEASE DON'T EDIT
import math as _math, k1lib as _k1lib
import matplotlib.pyplot as _plt, numpy as _np
from itertools import accumulate as _accumulate
import typing as _typing
from typing import List as _List, Callable as _Callable,\
Union as _Union
ScheduleF = _typing.Callable[[float], float]
[docs]def combine(lambdas:_List[ScheduleF], ratios:_List[float]=None) -> ScheduleF:
    """Combine multiple different schedules.
:param ratios: weighting diferent functions. Does not have to
    add up to 1.
:param lambdas: functions with 1 float input in [0, 1]"""
    if ratios == None: ratios = [1] * len(lambdas)
    ratios = _np.array(ratios)
    ratios = ratios / ratios.sum()
    checkpoints = [0] + list(_accumulate(ratios))
    def f(x):
        for idx, checkpoint in enumerate(checkpoints):
            if checkpoint > x: break
        a = checkpoints[idx - 1]
        return lambdas[idx - 1]((x - a)/(checkpoints[idx] - a))
    return f 
[docs]def decorate(f:_Callable[[float, float, float], float]) -> _Callable[[float], float]:
    """Decorator, transforms f(low, high, x) to (low, high) -> f(x)."""
    def _f(low, high): return lambda x: f(low, high, x)
    _f.__doc__ = f.__doc__; return _f 
[docs]@decorate
def linear(low, high, x): return low + x * (high - low) 
[docs]@decorate
def cosine(low, high, x): return low + (high - low) * (1 + _math.cos(_math.pi * (1-x))) / 2 
[docs]def oneCycle(low, high): return combine([cosine(0.8 * low + 0.2 * high, high), cosine(high, low)], [0.3, 0.7]) 
_en4 = _math.e**-3
[docs]@decorate
def decay(low, high, x): return (_math.exp(-x*4+1) - _en4) / (_math.e - _en4) * (low - high) + high 
[docs]class Schedule:
    def __init__(self, param:str, scheduleF:ScheduleF):
        self.param = param; self.scheduleF = scheduleF; self.progress = None
    def __call__(self, x:float): return self.scheduleF(x)
[docs]    def startBatch(self, paramGroup:dict, progress:float):
        paramGroup[self.param] = self.scheduleF(progress)
        self.progress = progress 
    @property
    def value(self): return self.scheduleF(self.progress)
    def __repr__(self):
        _plt.figure(dpi=100); c = dict(color="tab:green")
        x = _np.linspace(0, 1, 1000); y = [self.scheduleF(x) for x in x]; _plt.plot(x, y)
        y = self(0); _plt.plot(0, y, "o", **c); _plt.annotate("(0, {:.1e})".format(y), (0, y))
        y = self(1); _plt.plot(1, y, "o", **c); _plt.annotate("(1, {:.1e})".format(y), (1, y))
        if (x := self.progress) is not None:
            blur = not (x in _k1lib.Range(0.1, 0.9))
            y = self(x); _plt.plot(x, y, "o", **c, alpha=(0.5 if blur else 1))
            if not blur: _plt.annotate("({:.1e}, {:.1e})".format(x, y), (x, y))
        _plt.show()
        return f"""'{self.param}' schedule. Can...
- s.progress: to get last recorded progress
- s.value: to get last recorded hyper parameter's value
- s(0.3): to get value of schedule at 30% progress"""
[docs]    @staticmethod
    def linear(param:str, low:float, high:float) -> _Callable[[float], float]:
        """Sharply goes from low to high"""
        return Schedule(param, linear(low, high)) 
[docs]    @staticmethod
    def cosine(param:str, low:float, high:float) -> _Callable[[float], float]:
        """Smoothly goes from low to high"""
        return Schedule(param, cosine(low, high)) 
[docs]    @staticmethod
    def oneCycle(param:str, low:float, high:float, upPortion:float=0.3) -> _Callable[[float], float]:
        """Goes from middle-low (0.8*low + 0.2*high) to high to low
:param upPortion: the percent of portion that goes up"""
        return Schedule(param, combine([cosine(0.8 * low + 0.2 * high, high), cosine(high, low)], [upPortion, 1-upPortion])) 
[docs]    @staticmethod
    def decay(param:str, low:float, high:float) -> _Callable[[float], float]:
        """Rises/drops quickly, then rate of change gets smaller and smaller"""
        return Schedule(param, decay(low, high))  
[docs]@_k1lib.patch(_k1lib.Callback.cls)
class ParamScheduler(_k1lib.Callback):
    """Schedules a param in parts of the network.
    
:param css: the selected parts of the network to schedule
:param schedules: (obvious)"""
    def __init__(self, css:str, schedules:_Union[_List[Schedule], Schedule]):
        super().__init__(); self.css = css; schedules = schedules or []
        schedules = [schedules] if isinstance(schedules, Schedule) else schedules
        self.schedules = {s.param:s for s in schedules}
        self.groupId = None; self.dependsOn = set("ProgressBar")
        self.initialized = False; self.prop = None
    def endRun(self):
        ":meta private:"
        self.initialized = False
    def __getstate__(self):
        answer = dict(self.__dict__)
        if "selector" in answer: del answer["selector"]
        return answer
    def startBatch(self):
        ":meta private:"
        if self.l.model.training:
            paramGroup = self.l.opt.param_groups[self.groupId]
            progress = self.l.ProgressBar.progress
            for schedule in self.schedules.values(): schedule.startBatch(paramGroup, progress)
    def __repr__(self):
        print(f"{self._reprHead}, css: \"{self.css}\", selector prop: \"{self.prop}\", schedules:")
        for schedule in self.schedules.values(): schedule.__repr__()
        return f"""Can...
- ps.schedules["lr"]: to get the schedule for a specific param
- ps.selector: to view the selected parameters
{self._reprCan}""" 
@_k1lib.patch(ParamScheduler, name="startRun")
def _startRun(self):
    "meta:private"
    if not self.initialized:
        # get all other ParamSchedulers
        cbs = [cb for cb in self.l.cbs if isinstance(cb, ParamScheduler) and not cb.suspended]
        # delete all old _ps_{i} selectors, and add new ones
        css = [line for line in self.l.css.split("\n") if "_ps_" not in line]
        for i, cb in enumerate(cbs):
            cb.prop = f"_ps_{i}"; css += _k1lib.selector.filter(cb.css, cb.prop)
            cb.initialized = True # make sure only 1 startRun is ran across all ParamSchedulers
        self.l.css = "\n".join(css)
        # sort cbs based on depth, so that deeper ones gets accounted for first
        for cb in cbs: cb._depth = next(self.l.selector.modules(cb.prop)).depth
        cbs = sorted(cbs, key=lambda cb: -cb._depth)
        # clear and add param groups
        self.l.opt.param_groups = []
        allParams = set(self.l.selector.parameters())
        for cb in cbs:
            params = set()
            for m in self.l.selector.modules(cb.prop):
                for p in m.parameters():
                    if p in allParams:
                        params.add(p); allParams.remove(p)
            if len(params) > 0:
                cb.groupId = len(self.l.opt.param_groups)
                self.l.opt.add_param_group({"prop": cb.prop, "css": cb.css, "params": list(params), **self.l.opt.defaults})
        self.l.opt.add_param_group({"prop": "rest", "css": "*", "params": list(allParams), **self.l.opt.defaults})
        for cb in cbs:
            params = set(self.l.opt.param_groups[cb.groupId]["params"])
            cb.selector = self.l.selector.copy()
            def applyF(mS):
                mS.displayF = lambda s: "*" if any([p in params for p in s.directParams.values()]) else ""
            cb.selector.apply(applyF)
@_k1lib.patch(_k1lib.Callbacks, docs=ParamScheduler, name="withParamScheduler")
def _withParamScheduler(self, css:str, schedules:list, name:str=None):
    return self.append(ParamScheduler(css, schedules), name=name)