# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""
I have several machine learning tools running on my own cluster that's hosted
on https://mlexps.com/#kapi, and this module contains functions, classes and
clis that will contact that service. This is so that if I want to use a language
model in multiple notebooks, I'd have to load the model into my GPU for each
notebook, which would waste a lot of resources. I can't run a lot of notebooks
at the same time as I'd just run out of VRAM. So, by having dedicated services/demos,
I can really focus on serving things well and make it performant. For example::
    "some text"           | kapi.embed()    # returns embedding numpy array
    "What is Python? "    | kapi.complete() # returns string, completes the sentence
    "image.png" | toImg() | kapi.ocr()      # returns `Ocr` object, with bounding boxes and text content of all possible texts
    "cute anime girl"     | kapi.txt2im()   # generates an image from some description
    "image.png" | toImg() | caption()       # generates a caption of an image
"""
__all__ = ["status", "segment", "demo", "embed", "embeds", "complete", "ocr", "txt2im", "caption", "speech", "summarize", "post"]
from k1lib.cli.init import BaseCli; import k1lib.cli.init as init
import k1lib.cli as cli, k1lib, base64, html, json
requests = k1lib.dep("requests"); k1 = k1lib
settings = k1lib.settings.cli
s = k1lib.Settings(); settings.add("kapi", s, "cli.kapi settings")
s.add("local", False, "whether to use local url instead of remote url. This only has relevance to me though, as the services are running on localhost")
def get(idx:str, json):                                                          # get
    """Sends a request to any service/demo on https://mlexps.com.
Example::
    # returns "13.0"
    kapi.get("demos/1-arith", {"a": 1, "b": 3, "c": True, "d": 2.5, "e": 10})
:param idx: index of the service, like "kapi/1-embed" """                        # get
    url = "http://localhost:9000" if s.local else "https://local.mlexps.com"     # get
    res = requests.post(f"{url}/routeServer/{idx.replace('/', '_')}", json=json).json() # get
    if not res["success"]: raise Exception(res["reason"])                        # get
    return res["data"]                                                           # get
def jsF_get(idx, dataIdx):                                                       # jsF_get
    url = "https://local.mlexps.com"                                             # jsF_get
    return f"""await (await fetch("{url}/routeServer/{idx.replace('/', '_')}", {{method: "POST", headers: {{ "Content-Type": "application/json" }}, body: JSON.stringify({dataIdx})}})).json()""" # jsF_get
[docs]def status():                                                                    # status
    """Displays a table of whether the services are online and available or not""" # status
    ["kapi/1-embed", "kapi/2-complete", "kapi/3-ocr", "kapi/4-txt2im", "kapi/5-caption", "kapi/6-speech"] | cli.apply(lambda x: [x, requests.get(f"https://local.mlexps.com/routeServer/{x.replace(*'/_')}/healthCheck").text == "ok"]) | cli.insert(["Service", "Online"]) | cli.display(None) # status 
[docs]class segment(BaseCli):                                                          # segment
[docs]    def __init__(self, limit:int=2000):                                          # segment
        """Segments the input string by sentences, such that each segment's length is lower than the specified limit.
Example::
    # returns ['some. Really', 'Long. String', 'Just. Monika']
    "some. Really. Long. String. Just. Monika" | segment(15)
So, this will split the input string by ". ", then incrementally joins the strings together into segments.
This is useful in breaking up text so that it fits within language model's context size""" # segment
        self.limit = limit                                                       # segment 
[docs]    def __ror__(self, text):                                                     # segment
        if not isinstance(text, str): raise Exception("Input is not a string!")  # segment
        data = [[]]; c = 0; limit = self.limit                                   # segment
        for line in text.split(". "):                                            # segment
            if c + len(line) > limit and c > 0: # if even a single sentence is too big, then just have a segment as that sentence, and don't push it to the next one # segment
                data.append([]); c = 0                                           # segment
            data[-1].append(line); c += len(line)+2                              # segment
        return data | cli.join(". ").all() | cli.deref()                         # segment  
metas = {} # Dict[prefix -> demo meta]                                           # segment
[docs]class demo(BaseCli):                                                             # demo
[docs]    def __init__(self, prefix:str="demos_1-arith"):                              # demo
        """Sends a request to one of mlexps.com demos.
Example::
    # returns 21.0
    {"a": 3} | kapi.demo("demos/1-arith")
You don't have to specify all params, just the ones you want to deviate from the defaults
"""                                                                              # demo
        prefix = prefix.replace(*"/_"); self.prefix = prefix                     # demo
        if prefix not in metas: metas[prefix] = json.loads(requests.get(f"https://mlexps.com/{prefix.replace(*'_/')}/demo_meta.json").text) # demo 
[docs]    def __ror__(self, d):                                                        # demo
        prefix = self.prefix; meta = metas[prefix]; kw = {}                      # demo
        for arg in meta["args"]:                                                 # demo
            a = meta["defaults"][arg]; anno = meta["annos"][arg]                 # demo
            if anno in ("checkbox", "bytes", "image", "serialized"): a = a       # demo
            elif anno == "dropdown": a = a[1][a[0]]                              # demo
            elif anno == "apiKey": a = k1lib.apiKey if hasattr(k1lib, "apiKey") else a[0] # demo
            else: a = a[0]                                                       # demo
            kw[arg] = k1lib.serve.webToPy(a, anno)                               # demo
        for k, v in d.items(): kw[k] = v                                         # demo
        for k, v in kw.items(): kw[k] = k1lib.serve.pyToWeb(v, meta["annos"][k]) # demo
        url = "http://localhost:9003" if k1lib.settings.cli.kapi.local else "https://local.mlexps.com" # demo
        res = requests.post(f"{url}/routeServer/{prefix}", json=kw)              # demo
        if not res.ok: raise Exception(res.reason)                               # demo
        res = res.json()                                                         # demo
        if res["success"]: return k1lib.serve.webToPy(res["data"], meta["annos"]["return"]) # demo
        else: raise Exception(res["reason"])                                     # demo 
    def __repr__(self): return f"<demo prefix='{self.prefix}'>"                  # demo
    def _repr_html_(self): s = html.escape(f"{self}"); return f"{s}{metas[self.prefix]['mainDoc']}" # demo 
[docs]class embed(BaseCli):                                                            # embed
[docs]    def __init__(self):                                                          # embed
        """Gets an embedding vector for every sentence piped into this using `all-MiniLM-L6-v2`.
Example::
    # returns (384,)
    "abc" | kapi.embed() | shape()
    # returns (2, 384)
    ["abc", "def"] | kapi.embed().all() | shape()
- VRAM: 440MB
- Throughput: 512/s
See also: :class:`~k1lib.cli.models.embed`"""                                    # embed
        pass                                                                     # embed 
[docs]    def __ror__(self, it): return self._all_opt([it]) | cli.item()               # embed 
    def _all_opt(self, it:list[str]):                                            # embed
        for b in it | cli.batched(1024, True):                                   # embed
            yield from get("kapi/1-embed", {"lines": k1lib.encode(b)}) | cli.aS(k1lib.decode) # embed 
[docs]class embeds(BaseCli):                                                           # embeds
[docs]    def __init__(self):                                                          # embeds
        """Breaks up some text and grab the embedding vectors of each segment.
Example::
    "sone long text" | kapi.embeds() # returns list of (segment, numpy vector)
This is just a convenience cli. Internally, this splits the text up using :class:`segment`
and then embeds each segment using :class:`embed`
"""                                                                              # embeds
        pass                                                                     # embeds 
[docs]    def __ror__(self, it): return self._all_opt([it]) | cli.item()               # embeds 
    def _all_opt(self, it:list[str]): return it | cli.apply(segment(700) | cli.iden() & embed().all() | cli.transpose()) | cli.deref() # embeds 
[docs]class complete(BaseCli):                                                         # complete
[docs]    def __init__(self, prompt:str=None, maxTokens:int=200):                      # complete
        """Generates text from predefined prompts using `Llama 2`.
Example::
    # returns string completion
    "What is Python?" | kapi.complete()
    # returns list of string completions
    ["What is Python?", "What is C++?"] | kapi.complete().all()
    # returns list of string completions. The prompts sent to the server are ["<paragraph 1>\\n\\n\\nPlease summarize the above paragraph", ...]
    ["<paragraph 1>", "<paragraph 2>"] | kapi.complete("Please summarize the above paragraph").all()
- VRAM: 22GB
- Throughput: 8/s
:param max_tokens: maximum amount of tokens
See :class:`~k1lib.cli.models.complete`. That one is an older version using Google Flan T5 instead of llama 2""" # complete
        self.prompt = prompt; self.maxTokens = maxTokens                         # complete 
[docs]    def __ror__(self, it): return self._all_opt([it]) | cli.item()               # complete 
    def _all_opt(self, it:list[str]):                                            # complete
        if self.prompt: it = it | cli.apply(lambda x: f"{x}\n\n\n{self.prompt}: ") | cli.deref() # complete
        if not (isinstance(it, (list, tuple)) and isinstance(it[0], str)):       # complete
            raise Exception("You might have forgot to use .all(), like ['str1', 'str2'] | kapi.complete().all()") # complete
        it = it | cli.apply(lambda x: [x, self.maxTokens]) | cli.deref()         # complete
        return get("kapi/2-complete", {"prompts": json.dumps(it)}) | cli.aS(json.loads) # complete
    def _jsF(self, meta):                                                        # complete
        fIdx = cli.init._jsFAuto(); dataIdx = cli.init._jsDAuto()                # complete
        body = f"{{ prompts: JSON.stringify([{dataIdx}].map((x) => [`${{x}}\\n\\n\\n{self.prompt or ''}`, {cli.kjs.v(self.maxTokens)}])) }}" # complete
        return f"""
const {fIdx} = async ({dataIdx}) => {{
    const res = {jsF_get('kapi/2-complete', body)}
    return res[0]
}}""", fIdx                                                                      # complete 
tf = k1.dep("torchvision.transforms")                                            # complete
[docs]class ocr(BaseCli):                                                              # ocr
[docs]    def __init__(self, paragraph:bool=False, resize=True):                       # ocr
        """Do OCR (optical character recognition) on some image.
Example::
    o = "some_image.png" | toImg() | kapi.ocr() # loads image and do OCR on them
    o.result
That returns something like this::
    [[[686, 718, 4, 12], 'palng', 0.037828799456428475],
     [[53, 89, 9, 29], '150', 0.9862767603969035],
     [[146, 208, 6, 30], '51,340', 0.8688367610346406],
     [[695, 723, 13, 33], '83', 0.9999892947172615],
     [[783, 855, 13, 29], 'UPGRADes', 0.6299456305919845],
     [[783, 855, 47, 61], 'Monkey Ace', 0.7461469463088448],
     [[827, 863, 117, 133], '5350', 0.9847457394951422],
     [[775, 809, 181, 195], '6325', 0.9660267233848572],
     [[827, 863, 181, 195], 's500', 0.24643410742282867],
     [[773, 811, 243, 259], '5800', 0.5125586986541748],
     [[823, 869, 243, 259], '01600', 0.22119118148432848],
     [[775, 809, 303, 321], '5750', 0.7384281754493713],
     [[827, 861, 305, 321], '5850', 0.6789041403197309]]
This is the main way to use this tool. But you might want to have a quick glance
to judge the performance of the OCR, then you can do ``img | kapi.ocr(True)``,
which returns a PIL image with highlighted bounding boxes.
- VRAM: 1GB
- Throughput: depends heavily on image resolution
:param paragraph: whether to try to combine boxes together or not
:param resize: whether to resize the images to a reasonable size before sending it over or not. Runs faster if true""" # ocr
        self.paragraph = paragraph; self.resize = resize                         # ocr 
[docs]    def __ror__(self, it): return self._all_opt([it]) | cli.item()               # ocr 
    def _all_opt(self, it:list["PIL"]):                                          # ocr
        def resize(it): # resizing if they're too big                            # ocr
            for img in it:                                                       # ocr
                w, h = img | cli.shape()                                         # ocr
                if w > h:                                                        # ocr
                    if w > 1000: frac = 1000/w; img = img | tf.Resize([int(h*frac), int(w*frac)]) # ocr
                else:                                                            # ocr
                    if h > 1000: frac = 1000/h; img = img | tf.Resize([int(h*frac), int(w*frac)]) # ocr
                yield img, self.paragraph                                        # ocr
        return (resize(it) if self.resize else it | cli.apply(lambda img: [img, self.paragraph])) | cli.batched(10, True)\
            
| cli.apply(lambda imgParas: [imgParas, get("kapi/3-ocr", {"data": k1.encode(imgParas | cli.apply(cli.toBytes(), 0) | cli.deref())}) | cli.aS(k1.decode)] | cli.transpose()) | cli.joinSt() | ~cli.apply(Ocr) # ocr 
class Ocr:                                                                       # Ocr
    def __init__(self, imgPara, res):                                            # Ocr
        """Ocr result object. Stores raw results from model in ``.result`` field and has many more functionalities""" # Ocr
        self.img, self.para = imgPara; self.res = res                            # Ocr
    @property                                                                    # Ocr
    def result(self): return self.res                                            # Ocr
    def __repr__(self): return f"<Ocr shape={self.img | cli.shape()}>"           # Ocr
    def _overlay(self) -> "PIL":                                                 # Ocr
        img = self.img; res = self.res; p5 = k1.p5; w, h = img | cli.shape(); p5.newSketch(*img | cli.shape()); p5.background(255); p5.fill(255, 0) # Ocr
        res | cli.cut(0) | ~cli.apply(lambda x1,x2,y1,y2: [x1,h-y2,x2-x1,y2-y1]) | ~cli.apply(p5.rect) | cli.deref() # Ocr
        res | cli.cut(0, 1) | ~cli.apply(lambda x1,x2,y1,y2: [min(x1,x2), h-max(y1,y2)], 0) | ~cli.apply(lambda xy,s: [s,*xy]) | ~cli.apply(p5.text) | cli.deref() # Ocr
        im2 = p5.img(); alpha = 0.3; return [img, im2] | cli.apply(cli.toTensor() | cli.op()[:3]) | ~cli.aS(lambda x,y: x*alpha+y*(1-alpha)) | cli.op().to(int) | cli.op().permute(1, 2, 0) | cli.toImg() # Ocr
    def _repr_html_(self): s = html.escape(f"{self}"); return f"<pre>{s}</pre><img src='data:image/jpeg;base64, {base64.b64encode(self._overlay() | cli.toBytes()).decode()}' />" # Ocr
    def __getstate__(self): d = {**self.__dict__}; d["img"] = self.img | cli.toBytes(); return d # better compression due to converting to jpg # Ocr
    def __setstate__(self, d): self.__dict__.update(d); self.img = self.img | cli.toImg() # Ocr
[docs]class txt2im(BaseCli):                                                           # txt2im
[docs]    def __init__(self, num_inference_steps=10):                                  # txt2im
        """Generates images from text descriptions, using stable diffusion v2.
Example::
    "a bowl of apples" | kapi.txt2im() # returns PIL image
- VRAM: 5.42GB
- Throughput: 1/s
"""                                                                              # txt2im
        self.num_inference_steps = num_inference_steps                           # txt2im 
[docs]    def __ror__(self, it): return get("kapi/4-txt2im", {"prompt": it, "num_inference_steps": self.num_inference_steps}) | cli.aS(base64.b64decode) | cli.toImg() # txt2im  
[docs]class caption(BaseCli):                                                          # caption
[docs]    def __init__(self):                                                          # caption
        """Captions images using model `Salesforce/blip-image-captioning-large`.
Example::
    img = "some_image.png" | toImg() # loads PIL image
    img | kapi.caption()                  # returns string description
- VRAM: 2.5GB
- Throughput: 16/s
"""                                                                              # caption
        pass                                                                     # caption 
[docs]    def __ror__(self, it): return self._all_opt([it]) | cli.item()               # caption 
    def _all_opt(self, it:list["PIL"]): return get("kapi/5-caption", {"images": k1lib.encode(it)}) | cli.aS(k1lib.decode) # caption 
[docs]class speech(BaseCli):                                                           # speech
[docs]    def __init__(self, sep=False):                                               # speech
        """Converts English speech to text using whisper-large-v2.
Example::
    "audio.mp3" | toAudio() | kapi.speech() # returns string transcript
- VRAM: 4GB
- Throughput: 20min video finish transcribing in ~25s, so around 60x faster than real time
If the input audio is too long (>25 minutes), then it will be broken up
into multiple smaller pieces around 20 min each and sent to the server,
so at the bounds, it might go wrong a little bit
:param sep: if True, separate transcripts of each segment (returns List[transcript]),
    if False (default), joins segment's transcripts together into a single string""" # speech
        self.sep = sep                                                           # speech 
[docs]    def __ror__(self, audio:"conv.Audio"):                                       # speech
        nSplits = int(audio.raw.duration_seconds/60/25)+1                        # speech
        res = audio | cli.splitW(*[1]*nSplits) | cli.apply(lambda piece: get("kapi/6-speech", {"audio": base64.b64encode(piece | cli.toBytes()).decode()})) # speech
        return list(res) if self.sep else res | cli.join(". ")                   # speech  
def _summarize(text:str) -> str:                                                 # _summarize
    return text | segment(2000) | complete("<|end of transcript|>\n\nPlease summarize the above transcript using 1-3 sentences: ").all()\
        | cli.op().strip().all() | cli.deref() | cli.join(". ")                  # _summarize
[docs]class summarize(BaseCli):                                                        # summarize
[docs]    def __init__(self, length=1000):                                             # summarize
        """Summarizes text in multiple stages until it's shorter than ``length`` in
characters or until further compression is not possible. Example::
    url = "https://www.youtube.com/watch?v=NfmSjGbnEWk"
    audio = url   | toAudio()     # downloads audio from youtube
    text  = audio | kapi.speech() # does speech recognition
    text | summarize()            # summarize the text. For a 23 minute video/22k characters text, it should take around 23s to summarize everything
This will return an array of strings::
    [
        "shortened text final stage",
        "shortened text stage 2",
        "shortened text stage 1",
        "original text",
    ]
So in each stage, the original text is split up into multiple pieces, then
each piece is summarized using :class:`complete` and then all summary will
be joined together, creating the "shortened text stage 1". This continues
until it the text's length does not decrease any further, or it's shorter
than the desired length.
:param length: desired summary string length"""                                  # summarize
        self.length = length                                                     # summarize 
[docs]    def __ror__(self, text:str):                                                 # summarize
        stages = [text]; l = len(text)                                           # summarize
        while True:                                                              # summarize
            if len(text) < self.length: return stages | cli.reverse() | cli.deref() # summarize
            l = len(text); text = _summarize(text); stages.append(text)          # summarize
            if len(text)/l > 0.8: return stages | cli.reverse() | cli.deref() # if length not shrinking, then just return early # summarize  
[docs]class post(BaseCli):                                                             # post
[docs]    def __init__(self, url):                                                     # post
        """Creates a post request from a URL that can be read using :meth:`~k1lib.cli.inp.cat`.
Example::
    # returns str of the results
    {"some": "json data"} | kapi.post("https://some.url/some/path")
Notice how there isn't a get request counterpart, because you can always just cat() them
directly, as get requests don't have a body::
    cat("https://some.url/some/path")
"""                                                                              # post
        self.url = url                                                           # post 
[docs]    def __ror__(self, d): return requests.post(self.url, json=d).text            # post 
    def _jsF(self, meta):                                                        # post
        fIdx = init._jsFAuto(); dataIdx = init._jsDAuto()                        # post
        return f"""\
const {fIdx} = async ({dataIdx}) => {{
    const res = await fetch({json.dumps(self.url)}, {{ method: "POST", headers: {{ "Content-Type": "application/json" }}, body: JSON.stringify({dataIdx}) }});
    if (res.ok) return await res.text();
    throw new Error(`Can't send POST request to '{self.url}': ${{res.status}} - ${{res.statusText}}`);
}}""", fIdx                                                                      # post