# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""Higher order functions"""
__all__ = ["Func", "polyfit", "derivative", "optimize", "inverse", "integral", "batchify"]
from typing import Callable, List
import k1lib, numpy as np, warnings, threading, time, inspect
from functools import partial
plt = k1lib.dep("matplotlib.pyplot")
import k1lib.cli as cli
Func = Callable[[float], float]
[docs]def polyfit(x:List[float], y:List[float], deg:int=6) -> Func:                    # polyfit
    """Returns a function that approximate :math:`f(x) = y`.
Example::
    xs = [1, 2, 3]
    ys = [2, 3, 5]
    f = k1.polyfit(xs, ys, 1)
This will create a best-fit function. You can just use it as a regular,
normal function. You can even pass in :class:`numpy.ndarray`::
    # returns some float
    f(2)
    # plots fit function from 0 to 5
    xs = np.linspace(0, 5)
    plt.plot(xs, f(xs))
:param deg: degree of the polynomial of the returned function"""                 # polyfit
    params = np.polyfit(x, y, deg)                                               # polyfit
    def _inner(_x):                                                              # polyfit
        answer = np.zeros_like(_x, dtype=float)                                  # polyfit
        for expo, param in enumerate(params):                                    # polyfit
            answer += param * _x**(len(params)-expo-1)                           # polyfit
        return answer                                                            # polyfit
    return _inner                                                                # polyfit 
[docs]def derivative(f:Func, delta:float=1e-6) -> Func:                                # derivative
    """Returns the derivative of a function.
Example::
    f = lambda x: x**2
    df = k1lib.derivative(f)
    df(3) # returns roughly 6 """                                                # derivative
    return lambda x: (f(x + delta) - f(x)) / delta                               # derivative 
[docs]def optimize(f:Func, v:float=1, threshold:float=1e-6, **kwargs) -> float:        # optimize
    r"""Given :math:`f(x) = 0`, solves for x using Newton's method with initial value
`v`. Example::
    f = lambda x: x**2-2
    # returns 1.4142 (root 2)
    k1lib.optimize(f)
    # returns -1.4142 (negative root 2)
    k1lib.optimize(f, -1)
Interestingly, for some reason, result of this is more accurate than :meth:`derivative`.
"""                                                                              # optimize
    if len(kwargs) > 0: f = partial(f, **kwargs)                                 # optimize
    fD = derivative(f)                                                           # optimize
    for i in range(20):                                                          # optimize
        v = v - f(v)/fD(v)                                                       # optimize
    if abs(f(v)) > threshold: warnings.warn("k1lib.optimize not converging")     # optimize
    return v                                                                     # optimize 
[docs]def inverse(f:Func) -> Func:                                                     # inverse
    """Returns the inverse of a function.
Example::
    f = lambda x: x**2
    fInv = k1lib.inverse(f)
    # returns roughly 3
    fInv(9)
.. warning::
    The inverse function takes a long time to run, so don't use this
    where you need lots of speed. Also, as you might imagine, the
    inverse function isn't really airtight. Should work well with
    monotonic functions, but all bets are off with other functions."""           # inverse
    return lambda y: optimize(lambda x: f(x) - y)                                # inverse 
[docs]def integral(f:Func, _range:k1lib.Range) -> float:                               # integral
    """Integrates a function over a range.
Example::
    f = lambda x: x**2
    # returns roughly 9
    k1lib.integral(f, [0, 3])
There is also the cli :class:`~k1lib.cli.modifier.integrate`
which has a slightly different api."""                                           # integral
    _range = k1lib.Range(_range)                                                 # integral
    n = 1000; xs = np.linspace(*_range, n)                                       # integral
    return sum([f(x)*_range.delta/n for x in xs])                                # integral 
class Controller:                                                                # Controller
    def __init__(self):                                                          # Controller
        self.lock = threading.Lock(); self.data = {}; self.count = 0; self.event = threading.Event() # Controller
    def add(self, d):                                                            # Controller
        with self.lock: c = self.count; self.data[self.count] = d; self.count += 1; return c, self.event # Controller
    def prepare(self):                                                           # Controller
        with self.lock:                                                          # Controller
            data = {**self.data}; event = self.event                             # Controller
            self.data = {}; self.count = 0; self.event = threading.Event()       # Controller
            return data, event                                                   # Controller
[docs]def batchify(period=0.1) -> "singleFn":                                          # batchify
    """Transforms a function taking in batches to taking in singles,
for performance reasons.
Say you have this function that does some heavy computation::
    def f1(x, y):
        time.sleep(1)     # simulating heavy load, like loading large libraries/binaries
        return x + y      # does not take lots of time
Let's also say that a lot of time, you want to execute that function over multiple samples::
    res = [] # will be filled with [3, 7, 11]
    for x, y in [[1, 2], [3, 4], [5, 6]]:
        res.append(f1(x, y))
This would take 3 seconds to complete. But a lot of time, it might be advantageous to merge
them together and execute everything all at once::
    def f2(xs, ys):
        res = []; time.sleep(1)                       # loading large libraries/binaries only once
        for x, y in zip(xs, ys): res.append(x+y)      # run through all samples quickly
        return res
    res = f2([1, 3, 5], [2, 4, 6]) # filled with [3, 7, 11], just like before
But, may be you're in a multithreaded application and desire the original function "f1(x, y)",
instead of the batched function "f2(xs, ys)", like running an LLM (large language model) on
requests submitted by people on the internet. Each request that comes in runs on different
threads, but it's still desirable to pool together all of those requests, run through the model
once, and then split up the results to each respective request. That's where this functionality
comes in::
    @k1.batchify(0.1)     # pool up all calls every 0.1 seconds
    def f2(xs, ys):
        ...               # same as previous example
    @k1.batchify          # can also do it like this. It'll default to a period of 0.1
    def f2(xs, ys):
        ...               # same as previous example
    res = []
    t1 = threading.Thread(target=lambda: res.append(f(1, 2)))
    t2 = threading.Thread(target=lambda: res.append(f(3, 4)))
    t3 = threading.Thread(target=lambda: res.append(f(5, 6)))
    ths = [t1, t2, t3]
    for th in ths: th.start()
    for th in ths: th.join() # after this point, res will have a permutation of [3, 7, 11] (because t1, t2, t3 execution order is not known)
This will take on average 1 + 0.1 seconds (heavy load execution time + refresh rate). You can
decorate this on any Flask endpoint you want and it will just work::
    @app.get("/run/<int:nums>")
    @k1.batchify(0.3)
    def run(nums):
        time.sleep(1) # long running process
        return nums | apply("x**2") | apply(str) | deref()
Now, if you send 10 requests within a window of 0.3s, then the total running time would only be
1.3s, instead of 10s like before"""                                              # batchify
    def wrap(batchedFn):                                                         # batchify
        data = []; con = Controller()                                            # batchify
        def bgProcess():                                                         # batchify
            while True:                                                          # batchify
                time.sleep(period); data, event = con.prepare()                  # batchify
                if len(data) > 0:                                                # batchify
                    d = data.values(); sentinel = object()                       # batchify
                    args = d | cli.cut(1) | cli.transpose(fill=sentinel) | cli.aS(lambda x: None if x is sentinel else x).all(2) | cli.deref() # batchify
                    kws = d | cli.cut(2) | cli.deref()                           # batchify
                    keys = kws | cli.op().keys().all() | cli.joinSt() | cli.aS(set) | cli.deref(); keys # batchify
                    kws = keys | cli.apply(lambda key: [key, kws | cli.apply(lambda kw: kw.get(key, None))]) | cli.toDict() | cli.deref() # batchify
                                                                                 # batchify
                    res = list(batchedFn(*args, **kws))                          # batchify
                    for k, out in zip(data.keys(), res): data[k][0][0] = out     # batchify
                    event.set()                                                  # batchify
        threading.Thread(target=bgProcess).start()                               # batchify
        def inner(*args, **kwargs):                                              # batchify
            output = [...]; idx, event = con.add([output, args, kwargs])         # batchify
            t = threading.Thread(target=lambda: event.wait()); t.start(); t.join() # batchify
            return output[0]                                                     # batchify
        inner.fullargspec = inspect.getfullargspec(batchedFn)                    # batchify
        return inner                                                             # batchify
    if isinstance(period, (int, float)): return wrap                             # batchify
    else: batchedFn = period; period = 0.1; return wrap(batchedFn)               # batchify