# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
import k1lib, os, dill, time, inspect, json as _json, base64; k1 = k1lib
from typing import List
import k1lib.cli as cli; from k1lib.cli import *
from collections import defaultdict
pygments = k1.dep("pygments")
try: import PIL.Image, PIL; hasPIL = True
except: PIL = k1.dep("PIL"); hasPIL = False
__all__ = ["tag_serve",
           "FromNotebook", "FromPythonFile", "BuildPythonFile", "BuildDashFile", "StartServer", "GenerateHtml", "commonCbs", "serve",
           "text", "slider", "html", "json", "date", "serialized", "apiKey", "analyze", "webToPy", "pyToWeb"]
basePath = os.path.dirname(inspect.getabsfile(k1lib)) + os.sep + "serve" + os.sep
def pretty_py(code_string):                                                      # pretty_py
    s = pygments.highlight(code_string, pygments.lexers.PythonLexer(), pygments.formatters.HtmlFormatter()).replace('<pre', '<pre class="pre"') # pretty_py
    css = pygments.formatters.HtmlFormatter().get_style_defs('.highlight').replace(".highlight { background: #f8f8f8; }", "") # pretty_py
    return f'<style>{css}</style>{s}'                                            # pretty_py
def execPy(s:str, node:str, condaEnv:str=None, condaPath:str="~/miniconda3") -> "list[str]": # execPy
    fn = s | cli.file(); None | cli.cmd(f"scp {fn} {node}:{fn}") | cli.deref()   # execPy
    return None | cli.cmd(f"ssh {node} 'source {condaPath}/bin/activate {condaEnv} && python {fn}'") | cli.filt("x") | cli.deref() # execPy
[docs]def tag_serve(node:str=None, condaEnv:str=None, condaPath:str="~/miniconda3"):   # tag_serve
    """Tag that marks the cell that will be extracted to an independent
file and executed. Example::
    # serve(node="user@hostname", condaEnv="torch")
If a remote node is specified, internally, this will run commands on that node using
ssh, so make sure it's reachable via "ssh user@hostname" with default ssh identity,
or else it wouldn't work.
If .condaEnv is specified, then will activate conda before executing the script.
The activation command will be "{condaPath}/bin/activate {condaEnv}"
:param node: what node should the script be executed?
:param condaEnv: if specified, will activate that conda env before executing"""  # tag_serve
    return {"node": node, "condaEnv": condaEnv}                                  # tag_serve 
[docs]class FromNotebook(k1.Callback):                                                 # FromNotebook
[docs]    def __init__(self, fileName, tagName:str="serve", allTags:"list[str]"=("test", "notest", "donttest", "thumbnail", "export", "serve", "noserve", "dash")): # FromNotebook
        """Grabs source code from a Jupyter notebook. Will grab cells with the comment
like ``# serve`` in the first line.
See :meth:`tag_serve` to see more about its options
:param fileName: notebook path
:param tagName: which tag to extract out?
:param allTags: all possible custom tags that you use. It might complain if there's a tag in your
    notebook that it doesn't know about, so add all of your commonly used tags here""" # FromNotebook
        super().__init__(); self.fileName = fileName; self.tagName = tagName; self.allTags = allTags # FromNotebook 
[docs]    def fetchSource(self):                                                       # FromNotebook
        a = cli.nb.cells(self.fileName) | cli.filt(cli.op()["cell_type"] == "code") | cli.aS(list); self.l["sourceType"] = "notebook" # FromNotebook
        self.l["sourceCode"] = a | cli.nb.pretty(whitelist=[self.tagName]) | (cli.op()["source"] | ~cli.head(1)).all() | cli.joinStreams() | cli.deref() # FromNotebook
        self.l["tags"] = a | cli.op()["source"].all() | cli.filt("x") | cli.item().all() | cli.filt(cli.op().startswith("#")) | cli.deref() # FromNotebook
        # figures out build vars here, like node & condaEnv                      # FromNotebook
        self.l["node"] = None; self.l["condaEnv"] = None; self.l["condaPath"] = "~/miniconda3" # FromNotebook
        def serve(node:str=None, condaEnv:str=None, condaPath:str="~/miniconda3"): self.l["node"] = self.l["node"] or node; self.l["condaEnv"] = self.l["condaEnv"] or condaEnv; self.l["condaPath"] = self.l["condaPath"] or condaPath # FromNotebook
        for tag in self.l["tags"]: nb.executeTags(tag, defaultdict(lambda: 0, {**{x:0 for x in self.allTags}, "serve": serve})) # FromNotebook  
[docs]class FromPythonFile(k1.Callback):                                               # FromPythonFile
[docs]    def __init__(self, fileName):                                                # FromPythonFile
        """Grabs source code from a python file."""                              # FromPythonFile
        super().__init__(); self.fileName = fileName                             # FromPythonFile 
[docs]    def fetchSource(self): self.l["sourceCode"] = cli.cat(self.fileName) | cli.deref(); self.l["sourceType"] = "file" # FromPythonFile  
[docs]class BuildPythonFile(k1.Callback):                                              # BuildPythonFile
[docs]    def __init__(self, port=None):                                               # BuildPythonFile
        """Builds the output Python file, ready to be served on localhost.
:param port: which port to run on localhost. If not given, then a port will
    be picked at random, and will be available at ``cbs.l['port']``"""           # BuildPythonFile
        super().__init__(); self.port = port; self.suffix = "suffix.py"          # BuildPythonFile 
[docs]    def buildPythonFile(self):                                                   # BuildPythonFile
        self.l["pythonFile"] = ["from k1lib.imports import *", *self.l["sourceCode"]] | cli.file(); port = self.port # BuildPythonFile
        self.l["metaFile"] = metaFile = "" | cli.file(); os.remove(metaFile) # grabs temp meta file for communication, on localhost, not remote # BuildPythonFile
        if self.l["node"] is None: # grabs random free port if one is not available # BuildPythonFile
            if port is None: import socket; sock = socket.socket(); sock.bind(('', 0)); port = sock.getsockname()[1]; sock.close() # BuildPythonFile
        else:                                                                    # BuildPythonFile
            kw = {"node": self.l["node"], "condaEnv": self.l["condaEnv"], "condaPath": self.l["condaPath"]} # BuildPythonFile
            if port is None: port = execPy('import socket; sock = socket.socket(); sock.bind(("", 0)); port = sock.getsockname()[1]; sock.close(); print(port)', **kw)[0] # BuildPythonFile
        # actually has enough info to build the final file                       # BuildPythonFile
        self.l["port"] = port; node = self.l["node"]; (cli.cat(f"{basePath}{self.suffix}") | cli.op().replace("SOCKET_PORT", f"{port}").replace("META_FILE", metaFile).all()) >> cli.file(self.l["pythonFile"]) # BuildPythonFile
        if node: fn = self.l["pythonFile"]; None | cli.cmd(f"scp {fn} {node}:{fn}") | cli.deref() # BuildPythonFile  
[docs]class BuildDashFile(BuildPythonFile):                                            # BuildDashFile
[docs]    def __init__(self):                                                          # BuildDashFile
        """Builds the output Python file for a Dash app, ready to be served on localhost""" # BuildDashFile
        super().__init__()                                                       # BuildDashFile
        self.suffix = "suffix-dash.py"                                           # BuildDashFile  
[docs]class StartServer(k1.Callback):                                                  # StartServer
[docs]    def __init__(self, maxInitTime=10):                                          # StartServer
        """Starts the server, verify that it starts okay and dumps meta information (including
function signatures) to ``cbs.l``
:param maxInitTime: time to wait in seconds until the server is online before declaring it's unsuccessful""" # StartServer
        super().__init__(); self.maxInitTime = maxInitTime                       # StartServer 
[docs]    def startServer(self):                                                       # StartServer
        pythonFile = self.l["pythonFile"]; metaFile = self.l["metaFile"]; port = self.l["port"]; maxInitTime = self.maxInitTime # StartServer
        node = self.l["node"]; condaEnv = self.l["condaEnv"]; condaPath = self.l["condaPath"]; startTime = time.time() # StartServer
        # print(f"{pythonFile=} {metaFile=} {port=} {maxInitTime=} {node=} {condaEnv=} {condaPath=}") # StartServer
        if node is None:                                                         # StartServer
            None | cli.cmd(f"python {pythonFile} &"); count = 0                  # StartServer
            while not os.path.exists(metaFile):                                  # StartServer
                if time.time()-startTime > maxInitTime: raise Exception(f"Tried to start server up, but no responses yet. Port: {port}, pythonFile: {pythonFile}, metaFile: {metaFile}") # StartServer
                time.sleep(0.1)                                                  # StartServer
            self.l["meta"] = meta = metaFile | cli.cat(text=False) | cli.aS(dill.loads) # StartServer
        else:                                                                    # StartServer
            if condaEnv: None | cli.cmd(f"ssh {node} 'source {condaPath}/bin/activate {condaEnv} && nohup python {pythonFile}' &"); # StartServer
            else: None | cli.cmd(f"ssh {node} 'nohup python {pythonFile}' &");   # StartServer
            while not int(None | cli.cmd(f"ssh {node} 'if [ -e {metaFile} ]; then echo 1; else echo 0; fi'") | cli.item()): # StartServer
                if time.time()-startTime > maxInitTime: raise Exception(f"Tried to start server up, but no responses yet. Port: {port}, pythonFile: {pythonFile}, metaFile: {metaFile}") # StartServer
                time.sleep(0.5)                                                  # StartServer
            self.l["meta"] = meta = dill.loads(b"".join(None | cli.cmd(f"ssh {node} 'cat {metaFile}'", text=False))) # StartServer  
[docs]class GenerateHtml(k1.Callback):                                                 # GenerateHtml
[docs]    def __init__(self, serverPrefix=None, htmlFile=None, title="Interactive demo"): # GenerateHtml
        """Generates a html file that communicates with the server.
:param serverPrefix: prefix of server for back and forth requests, like "https://example.com/proj1". If
    empty, tries to grab ``cbs.l["serverPrefix"]``, which you can deposit from your own callback. If
    that's not available then it will fallback to ``localhost:port``
:param htmlFile: path of the target html file. If not specified then a temporary file
    will be created and made available in ``cbs.l["htmlFile"]``
:param title: title of html page"""                                              # GenerateHtml
        super().__init__(); self.serverPrefix = serverPrefix; self.htmlFile = htmlFile; self.title = title # GenerateHtml 
[docs]    def generateHtml(self):                                                      # GenerateHtml
        meta = dict(self.l["meta"])                                              # GenerateHtml
        replaceNewlineWithBr = op().split("<!-- k1lib.raw.start -->") | apply(op().split("<!-- k1lib.raw.end -->")) | head(1).split() | (op().replace("\n", "<br>").all(2)) + apply(op().replace("\n", "<br>"), 1) | joinSt(2) | join("") # GenerateHtml
        replaces = cli.op().replace("META_JSON", base64.b64encode(_json.dumps(meta).encode()).decode())\
            
.replace("SERVER_PREFIX", self.serverPrefix or self.l["serverPrefix"] or f"http://localhost:{self.l['port']}")\
            
.replace("TITLE", self.title).replace("INTRO", meta["mainDoc"] | replaceNewlineWithBr | op().replace("\ue157", "\n"))\
            
.replace("SOURCE_CODE", pretty_py(meta["source"]))                   # GenerateHtml
        self.l["htmlFile"] = cli.cat(f"{basePath}main.html") | replaces.all() | cli.file(self.htmlFile) # GenerateHtml  
[docs]def commonCbs():                                                                 # commonCbs
    """Grabs common callbacks, including :class:`BuildPythonFile` and :class:`StartServer`""" # commonCbs
    return k1.Callbacks().add(BuildPythonFile()).add(StartServer());             # commonCbs 
[docs]def serve(cbs):                                                                  # serve
    """Runs the serving pipeline."""                                             # serve
    import flask, flask_cors                                                     # serve
    cbs.l = defaultdict(lambda: None)                                            # serve
    cbs("begin")                                                                 # serve
    cbs("fetchSource") # fetches cells                                           # serve
    cbs("beforeBuildPythonFile"); cbs("buildPythonFile") # builds python server file # serve
    cbs("beforeStartServer"); cbs("startServer") # starts serving the model on localhost and add more meta info # serve
    cbs("beforeGenerateHtml"); cbs("generateHtml") # produces a standalone html file that provides interactive functionalities # serve
    cbs("end")                                                                   # serve
    return cbs                                                                   # serve 
class baseType:                                                                  # baseType
    def __init__(self):                                                          # baseType
        """Base type for all widget types"""                                     # baseType
        pass                                                                     # baseType
    def getConfig(self): return NotImplemented                                   # baseType
[docs]class text(baseType):                                                            # text
[docs]    def __init__(self, multiline:bool=True, password:bool=False):                # text
        """Represents text, either on single or multiple lines.
If `password` is true, then will set multiline to false automatically,
and creates a text box that blurs out the contents. Example::
    def endpoint(s:serve.text()="abc") -> str: pass
For inputs only. Use ``str`` for outputs"""                                      # text
        super().__init__(); self.multiline = multiline if not password else False; self.password = password # text 
    def __repr__(self): return f"<text multiline={self.multiline}>"              # text 
[docs]class slider(baseType):                                                          # slider
[docs]    def __init__(self, start:float, stop:float, intervals:int=100):              # slider
        """Represents a slider from `start` to `stop` with a bunch of
intervals in the middle. If `defValue` is not specified, uses the
middle point between start and stop. Example::
    def endpoint(a:serve.slider(2, 3.2)=2.3) -> str: pass
For inputs only"""                                                               # slider
        super().__init__(); self.start = start; self.stop = stop; self.intervals = intervals; self.dt = (stop-start)/intervals # slider 
    def __repr__(self): return f"<slider {self.start}---{self.intervals}-->{self.stop}>" # slider 
[docs]class html(baseType):                                                            # html
[docs]    def __init__(self):                                                          # html
        """Raw html.
Example::
    def endpoint() -> serve.html(): pass
For outputs only"""                                                              # html
        super().__init__()                                                       # html 
    def __repr__(self): return f"<html>"                                         # html 
[docs]class json(baseType):                                                            # json
[docs]    def __init__(self):                                                          # json
        """Raw json.
Example::
    def endpoint(a:serve.json()={"a": 3}) -> serve.json(): pass
For inputs and outputs"""                                                        # json
        super().__init__()                                                       # json 
    def __repr__(self): return f"<json>"                                         # json 
[docs]class date(baseType):                                                            # date
[docs]    def __init__(self, min=None, max=None):                                      # date
        """Local date time (no timezone information).
Example::
    def endpoint(d:serve.date()="2023-12-07T00:00") -> str: pass
:param min: min date, also in format '2023-12-07T00:00'"""                       # date
        super().__init__(); self.minDate = min; self.maxDate = max               # date 
    def __repr__(self): return f"<date>"                                         # date 
[docs]class serialized(baseType):                                                      # serialized
[docs]    def __init__(self):                                                          # serialized
        """For serialized objects using :mod:`dill`.
Example::
    def endpoint(a:serve.serialized()) -> serve.serialized():
        return {"any": "data structure", "you": "want", "even": np.random.randn(100)}
"""                                                                              # serialized
        super().__init__()                                                       # serialized 
    def __repr__(self): return f"<serialized>"                                   # serialized 
[docs]class apiKey(baseType):                                                          # apiKey
[docs]    def __init__(self, apiKey=str):                                              # apiKey
        """Protects your endpoint with an api key.
Example::
    def endpoint(apiKey:serve.apiKey("your api key here")="") -> str: pass
When compiled, your api key won't appear anywhere, not in the html, not in the meta
files, and someone calling the endpoint must specify it, else it will just errors out""" # apiKey
        super().__init__(); self.apiKey = apiKey                                 # apiKey 
    def __repr__(self): return f"<apiKey>"                                       # apiKey 
def refine(param:str, anno:baseType, default): # anno is not necessarily baseType, can be other types like "int" # refine
    if anno == int: return [param, "int", [default, False], None]                # refine
    if anno == float: return [param, "float", [default, False], None]            # refine
    multiline = lambda s: len(s.split("\n")) > 1 or len(s) > 100                 # refine
    if anno == bool: return [param, "checkbox", default, None]                   # refine
    if anno == str: return [param, "text", [default, multiline(default or "")], None] # refine
    if isinstance(anno, text): return [param, "text", [default, anno.multiline, anno.password], None] # refine
    if isinstance(anno, slider): return [param, "slider", [default, anno.start, anno.stop, anno.dt], None] # refine
    if isinstance(anno, range): return [param, "slider", [default, anno.start, anno.stop, anno.step], None] # refine
    byte2Str = aS(base64.b64encode) | op().decode("ascii")                       # refine
    if hasPIL and anno == PIL.Image.Image: return [param, "image", (default | toBytes() | byte2Str) if default is not None else None, None] # refine
    if anno == bytes: return [param, "bytes", (default | byte2Str) if default is not None else None, None] # refine
    if isinstance(anno, serialized): return [param, "serialized", (default | aS(dill.dumps) | byte2Str) if default is not None else None, None] # refine
    if isinstance(anno, list): anno | apply(str) | deref(); return [param, "dropdown", [anno.index(default), anno], None] # refine
    if isinstance(anno, html): return [param, "html", [default], None]           # refine
    if isinstance(anno, json): return [param, "json", [default, True], None]     # refine
    if isinstance(anno, date): return [param, "date", [default, anno.minDate, anno.maxDate], None] # refine
    if isinstance(anno, apiKey): return [param, "apiKey", [default], anno.apiKey] # refine
    raise Exception(f"Unknown type {anno}")                                      # refine
[docs]def analyze(f):                                                                  # analyze
    spec = getattr(f, "fullargspec", inspect.getfullargspec(f)); args = spec.args; n = len(args) # analyze
    annos = spec.annotations; defaults = spec.defaults or ()                     # analyze
    docs = (f.__doc__ or "").split("\n") | grep(":param", sep=True).till() | filt(op().ab_len() > 0) | op().strip().all(2) | (join(" ") | op().split(":") | ~aS(lambda x, y, *z: [y.split(" ")[1], ":".join(z).strip()])).all() | toDict() # analyze
    mainDoc = (f.__doc__ or " ").split("\n") | grep(".", sep=True).till(":param") | breakIf(op()[0].startswith(":param")) | join("\n").all() | join("\n") # analyze
                                                                                 # analyze
    if len(annos) != n + 1: raise Exception(f"Please annotate all of your arguments ({n} args + 1 return != {len(annos)} annos). Args: {args}, annos: {annos}") # analyze
    if len(defaults) != n: raise Exception(f"Please specify default values for all of your arguments ({n} args != {len(defaults)} default values)") # analyze
    a = [args, args | lookup(annos), defaults] | transpose() | ~apply(refine) | deref() # analyze
    ret = refine("return", annos["return"], None)[1]; defaults = a | cut(0, 2) | toDict() # analyze
    annos = a | cut(0, 1) | toDict(); annos["return"] = ret; privates = a | cut(0, 3) | toDict() # analyze
    if ret == "slider": raise Exception(f"Return value is a slider, which doesn't really make sense. Return float, str or sth like that") # analyze
                                                                                 # analyze
    # args:list, annos:dict, defaults:list, docs:dict                            # analyze
                                                                                 # analyze
    return {"args": args, "annos": annos, "defaults": defaults, "privates": privates, "docs": docs, # analyze
     "mainDoc": mainDoc, "source": inspect.getsource(f), "pid": os.getpid()}     # analyze
    return args, annos, defaults, docs, mainDoc, d                               # analyze 
class Html(str):                                                                 # Html
    def _repr_html_(self): return self                                           # Html
[docs]def webToPy(o:str, klass:baseType):                                              # webToPy
    if klass == "json": return o                                                 # webToPy
    o = str(o)                                                                   # webToPy
    if klass == "int": return int(float(o))                                      # webToPy
    if klass == "float": return float(o)                                         # webToPy
    if klass == "slider": o = float(o); return int(o) if round(o) == o else o    # webToPy
    if klass == "text" or klass == "dropdown" or klass == "apiKey" or klass == "date": return o # webToPy
    if klass == "checkbox": return o.lower() == "true"                           # webToPy
    if klass == "bytes": return base64.b64decode(o)                              # webToPy
    if klass == "serialized": return dill.loads(base64.b64decode(o))             # webToPy
    if klass == "image": return o | aS(base64.b64decode) | toImg()               # webToPy
    if klass == "html": return Html(base64.b64decode(o).decode())                # webToPy
    return NotImplemented                                                        # webToPy 
[docs]def pyToWeb(o, klass:baseType) -> str:                                           # pyToWeb
    if klass in ("int", "float", "text", "checkbox", "slider", "apiKey", "date"): return f"{o}" # pyToWeb
    if klass == "bytes": return base64.b64encode(o).decode()                     # pyToWeb
    if klass == "serialized": return base64.b64encode(dill.dumps(o)).decode()    # pyToWeb
    if klass == "image": return base64.b64encode(o | toBytes()).decode()         # pyToWeb
    if klass == "dropdown": return o;                                            # pyToWeb
    if klass == "html": return o.encode() | aS(base64.b64encode) | op().decode() # pyToWeb
    if klass == "json": return o # ---------------------------------------------- that one case where it returns an object instead of a string # pyToWeb
    return NotImplemented                                                        # pyToWeb