Source code for k1lib.k1ui.main

# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""`k1ui <https://github.com/157239n/k1ui>`_ is another project made in Java that
aims to record and manipulate the screen, keyboard and mouse. The interface to
that project on its own is clunky, and this module is the Python interface to
ease its use.

Not quite developed yet tho, because I'm lazy."""
import k1lib, numpy as np, asyncio, time, inspect, json, threading, dill, math, base64, os, random, warnings
k1 = k1lib; cli = k1.cli; from k1lib.cli import *; knn = k1.knn; Cbs = k1.Cbs; viz = k1.viz; websockets = k1.dep("websockets")
nn = k1.dep("torch.nn"); optim = k1.dep("torch.optim"); tf = k1.dep("torchvision.transforms")
PIL = k1.dep("PIL"); k1.dep("graphviz"); requests = k1.dep("requests")
try: import torch; hasTorch = True
except: torch = k1.dep("torch"); hasTorch = False

from typing import Callable, List, Iterator, Tuple, Union, Dict; from collections import defaultdict, deque; from functools import lru_cache
import matplotlib as mpl; import matplotlib.pyplot as plt
__all__ = ["get", "WsSession", "selectArea", "record", "execute", "Recording",
           "Track", "CharTrack", "WordTrack", "ContourTrack", "ClickTrack", "WheelTrack", "StreamTrack",
           "distNet", "TrainScreen"]
k1lib.settings.add("k1ui", k1.Settings().add("server", k1.Settings().add("http", "http://localhost:9511", "normal http server").add("ws", "ws://localhost:9512", "websocket server"), "server urls"), "docs related to k1ui java library");
settings = k1lib.settings.k1ui
settings.add("draw", k1.Settings(), "drawing settings")
settings.draw.add("trackHeight", 30, "Track's height in Recording visualization")
settings.draw.add("pad", 10, "Padding between tracks");
[docs]def get(path): """Sends a get request to the Java server. Example:: k1ui.get("mouse/200/300") # move mouse to (200, 300)""" return requests.get(f"{settings.server.http}/{path}", timeout=60*10).text
def post(path, jsObj): """Sends a post request to the Java server. Example:: k1ui.post("mouse/200/300") # move mouse to (200, 300)""" return requests.post(f"{settings.server.http}/{path}", json=jsObj, timeout=60*10).text portAutoInc = k1.AutoIncrement(9520)
[docs]class WsSession:
[docs] def __init__(self, eventCb:Callable[["WsSession", dict], None], mainThreadCb:Callable[["WsSession"], None]): """Creates a websocket connection with the server, with some callback functions The callback functions (most are async btw) will be passed a WebSocket object as the first argument. You can use it to send messages like this:: # this will send a signal to the server to close the session sess.ws.send(json.dumps({"type": "close"})) # this will send a signal to the server requesting the current screenshot. Result will be deposited into eventCb sess.ws.send(json.dumps({"type": "screenshot"})) # this will execute a single event sess.ws.send(json.dumps({"type": "execute", "event": {"type": "keyTyped", "javaKeyCode": 0, ...}})) Complete, minimum example:: events = [] async def eventCb(sess, event): events.append(event) async def mainThreadCb(sess): sess.stream(300) # starts a stream with output screen width of 300px await asyncio.sleep(2) await sess.ws.send(json.dumps({"type": "execute", "event": {"type": "keyPressed", "javaKeyCode": 65, "timestamp": 0}})) await sess.ws.send(json.dumps({"type": "execute", "event": {"type": "keyReleased", "javaKeyCode": 65, "timestamp": 0}})) await asyncio.sleep(10); sess.close() await k1ui.WsSession(eventCb, mainThreadCb).run() What this code does is that it will communicate with the server continuously for 12 seconds, capturing all events in the mean time and save them into ``events`` list. It will start up a UDP stream to capture screenshots continuously, and after 2 seconds, it sends 2 events to the server, trying to type the letter "A". Finally, it waits for another 10 seconds and then terminates the connection. This interface is quite low-level, and is the basis for all other functionalities. Some of them include: * :meth:`record`: recording a session * :meth:`execute`: executes a list of events :param eventCb: (async) will be called whenever there's a new event :param mainThreadCb: (async) will be called after setting up everything :param streamWidth: specifies the width of the UDP stream, in pixels""" self.ws = None; self.eventCb = eventCb; self.mainThreadCb = mainThreadCb if not inspect.iscoroutinefunction(eventCb): raise Exception(f"eventCb has to be an async function") if not inspect.iscoroutinefunction(mainThreadCb): raise Exception(f"mainThreadCb has to be an async function") self.closed = False; self.streams = {} # width -> [width, lock, port]
async def _listenLoop(self): while True: res = await self.ws.recv() | cli.aS(json.loads); _type = res["type"] if _type == "close": break # python sends close signal to java, java then sends a close signal back, as an acknowledgement if _type == "screenshot": await self.eventCb(self, {"type": "screenshot", "bytes": base64.b64decode(res["screenshot"]), "timestamp": int(time.time()*1000)}) if _type == "newEvent": await self.eventCb(self, res["event"]) async def _pingLoop(self): while True: if self.closed: break try: await self.ws.send({"type": "ping"} | cli.aS(json.dumps)); await asyncio.sleep(1) except: break async def _streamLoop(self, width, locks, port): import cv2; streamRefresh = 100 # refreshes udp stream after this many seconds, so that it doesn't hang def threadLoop(lock, port): with lock, k1.captureStdout(False, True): get(f"startStream/{width}/{port}"); cap = cv2.VideoCapture(f'udp://0.0.0.0:{port}', cv2.CAP_FFMPEG); beginTime = time.time() while (cap.isOpened()): if self.closed: break res, frame = cap.read() if not res: break self.loop.create_task(self.eventCb(self, {"type": "stream", "width": width, "frame": frame[:,:,::-1], "timestamp": int(time.time()*1000)})) if time.time() - beginTime > streamRefresh + 10: break # there will be a short time (5s) where there're 2 udp streams simultaneously dumps events cap.release(); get(f"stopStream/{port}") ports = [port, port + 100]; sel = 0 while not self.closed: threading.Thread(target=threadLoop, args=(locks[sel], ports[sel])).start() await asyncio.sleep(streamRefresh); sel = 1-sel
[docs] def stream(self, width): """Starts a stream with a particular output width. The lower the width, the higher the fps and vice versa""" if width in self.streams: raise Exception(f"Can't start stream with width {width}. Just use the existing stream.") port = portAutoInc() self.streams[width] = [width, [threading.Lock(), threading.Lock()], port]; import cv2 # placed here so that users can see error message if cv2 is not imported asyncio.create_task(self._streamLoop(*self.streams[width]))
[docs] async def run(self): """Connects with Java server, set things up and runs ``mainThreadCb``""" async with websockets.connect(settings.server.ws, max_size=1_000_000_000) as ws: self.ws = ws; self.loop = asyncio.get_event_loop() _listenLoop = asyncio.create_task(self._listenLoop()) _pingLoop = asyncio.create_task(self._pingLoop()); try: await self.mainThreadCb(self) except asyncio.CancelledError: self.close() await _listenLoop
[docs] def close(self): """Closes the connection with the Java server""" if self.closed: print("Already closed"); return self.closed = True; asyncio.create_task(self.ws.send({"type": "close"} | cli.aS(json.dumps))) for width, locks, port in self.streams.values(): with locks[0]: # make sure all locks are freed. Also important to have the 2 locks be nested in each other, in case everything aligns just right that evades this mechanism with locks[1]: pass
[docs] async def execute(self, events): """Executes a series of events""" events = events | sortF(op()["timestamp"]) | aS(list) deltaT = int(time.time()*1000) - events[0]["timestamp"] for e in events | apply(lambda x: {**x, "timestamp": x["timestamp"]+deltaT}): st = e["timestamp"]/1000 - time.time() if st > 0: await asyncio.sleep(st) await self.ws.send(json.dumps({"type": "execute", "event": e}))
[docs]def selectArea(x, y, w, h): """Selects an area on the screen to focus into""" return get(f"selectArea/{x}/{y}/{w}/{h}")
[docs]async def record(t=None, keyCode=None, streamWidth=300, f=iden()): """Records activities. Examples:: events = await k1ui.record(t=5) # records for 5 seconds events = await k1ui.record(keyCode=5) # records until "Escape" is pressed events = await k1ui.record() # records until interrupt signal is sent to the process Note: these examples only work on jupyter notebooks. For regular Python processes, check out official Python docs (https://docs.python.org/3/library/asyncio-task.html) :param t: record duration :param keyCode: key to stop the recording :param streamWidth: whether to opens the UDP stream and capture screenshots at this width or not :param f: extra event post processing function""" events = [] async def eventCb(sess, event): res = f(event) if res is not None: events.append(res) if event["type"] == "keyReleased" and event["keyCode"] == keyCode: sess.close() async def mainThreadCb(sess): if streamWidth: sess.stream(streamWidth) if t is not None: await asyncio.sleep(t); sess.close() else: await asyncio.sleep(1e9) await WsSession(eventCb, mainThreadCb).run(); return events
[docs]async def execute(events:List[dict]): """Executes some events""" async def eventCb(sess, event): pass async def mainThreadCb(sess): await sess.execute(events); sess.close() await WsSession(eventCb, mainThreadCb).run()
uuid = k1.AutoIncrement(random.randint(0, int(1e9)), prefix="k1ui-") def escapeHtml(s): return s.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;") class HtmlView: def __init__(self, html): self.html = html def _repr_html_(self): return self.html
[docs]class Recording: def __init__(self, events): self.uuid = uuid(); self._tracks = [] if len(events) == 0: return # shortcut to initialize using cloned tracks rather than events events = events | sortF(op()["timestamp"]) | deref() self._tracks.extend(ContourTrack.parse(events)) self._tracks.extend(CharTrack.parse(events)) self._tracks.extend(ClickTrack.parse(events)) self._tracks.extend(WheelTrack.parse(events)) self._tracks.extend(StreamTrack.parse(events)) self._tracks = self._tracks | filt(op()) | apply(lambda x: x._rec(self)) | deref() self._resetTimes(); self._resetDis() def _resetTimes(self): self.startTime, self.endTime = self._tracks | op().timeUnix().all() | joinStreams() | filt(op()) | toMin() & toMax(); return self def _resetDis(self): self.dis1, self.dis2 = (self.startTime+self.endTime)/2 | aS(lambda x: [x-self.duration*0.53, x+self.duration*0.53]); return self # display times @property def duration(self): return self.endTime - self.startTime
[docs] def addTracks(self, *tracks) -> "Recording": """Adds tracks to the Recording""" if not isinstance(tracks[0], Track) and len(tracks) == 1: tracks = tracks[0] self._tracks.extend(tracks | apply(lambda tr: tr._rec(self))); self._resetTimes(); self._resetDis(); return self
[docs] def removeTracks(self, *tracks) -> "Recording": """Removes tracks from the Recording""" if not isinstance(tracks[0], Track) and len(tracks) == 1: tracks = tracks[0] tracks | apply(self._tracks.remove) | ignore(); self._resetTimes(); self._resetDis(); return self
def _normTime(self, t=None, default=None): return default if t is None else t + self.startTime
[docs] def zoom(self, t1=None, t2=None): """Zooms into a particular time range. If either bounds are not specified, they will default to the start and end of all events. :param t1: time values are relative to the recording's start time""" _dis1 = self.dis1; t1 = _dis1 if t1 is None else t1 + self.startTime _dis2 = self.dis2; t2 = _dis2 if t2 is None else t2 + self.startTime delta = t2-t1; t1-=delta*0.03; t2+=delta*0.03; self.dis1 = t1; self.dis2 = t2 html = self._repr_html_(); self.dis1 = _dis1; self.dis2 = _dis2; return HtmlView(html)
[docs] def sel(self, t1=None, t2=None, klass=None) -> List["Track"]: """Selects a subset of tracks using several filters. For selecting time, assuming we have a track that looks like this (x, y are t1, t2):: # |-1--| |-2-| # |---3---| # x y Then, tracks 1 and 3 are selected. Time values are relative to recording's start time :param t1: choose tracks that happen after this time :param t2: choose tracks that happen before this time :param klass: choose specific track class""" tracks = self._tracks if klass: tracks = tracks | instanceOf(klass) if t1 is not None or t2 is not None: t1 = self._normTime(t1, self.startTime); t2 = self._normTime(t2, self.endTime) tracks = tracks | apply(lambda o: [o.startTime or 0, o.endTime, o]) | ~filt(op()[1]<t1) | ~filt(op()[0]>=t2) | cut(2) return tracks | aS(list)
[docs] def sel1(self, **kwargs) -> List["Track"]: """Like :meth:`sel`, but this time gets the first element only.""" return self.sel(**kwargs) | item()
[docs] def time0(self) -> List[float]: """Start and end recording times. Start time is zero""" return [0, self.endTime - self.startTime]
[docs] def timeUnix(self) -> List[float]: """Start and end recording times. Both are absolute unix times""" return [self.startTime, self.endTime]
[docs] def events(self) -> List[dict]: """Reconstructs events from the Recording's internal data. The events are lossy though:: events = ... # events recorded r = k1ui.Recording(events) assert r.events() != events # this is the lossy part. Don't expect the produced events match exactly with each other""" return self._tracks | op().events().all() | joinStreams() | sortF(op()["timestamp"]) | deref(igT=False)
[docs] def copy(self) -> "Recording": """Creates a clone of this recording""" return Recording([]).addTracks(self._tracks | op().copy().all())._resetDis()
def _repr_html_(self): return self | aS(createTrackss) | aS(drawTrackss)
[docs]class Track:
[docs] def __init__(self, startTime, endTime): """Time values are absolute unix time.""" self.recording = None; self.startTime = startTime if startTime else None; self.endTime = endTime; self.uuid = uuid()
[docs] def time0(self) -> List[float]: """Start and end track times. Times are relative to track's start time""" return [0, self.endTime - self.startTime]
[docs] def time0Rec(self) -> List[float]: """Start and end track times. Times are relative to recording's start time""" return [self.startTime-self.recording.startTime if self.startTime else None, self.endTime-self.recording.startTime]
[docs] def timeUnix(self) -> List[float]: """Start and end track times. Times are absolute unix times""" return [self.startTime, self.endTime]
[docs] def concurrent(self) -> List["Track"]: """Grabs all tracks that are concurrent to this track""" return self.recording.sel(*self.time0Rec())
def _rec(self, recording): self.recording = recording; return self # inject dependency def _tooltip(self, ctx): return "" def _displayTimes(self): # shortcut func for displaying in __repr__ s = f"{self.startTime-self.recording.startTime:.2f}s" if self.startTime else None e = f"{self.endTime-self.recording.startTime:.2f}s"; return f"time ({s}->{e})"
[docs] def events(self) -> List[dict]: """Reconstructs events from the Track's internal data, to be implemented by subclasses.""" return NotImplemented
[docs] def copy(self): """Creates a clone of this Track, to be implemented by subclasses""" return NotImplemented
[docs] def move(self, deltaTime): """Moves the entire track left or right, to be implemented by subclasses. :param deltaTime: if negative, move left by this number of seconds, else move right""" self.startTime += deltaTime; self.endTime += deltaTime; self.recording._resetTimes(); self.recording._resetDis()
[docs]class CharTrack(Track):
[docs] def __init__(self, keyText:str, keyCode:int, mods:List[bool], times:List[float]): """Representing 1 key pressed and released. :param keyText: text to display to user, like "Enter" :param keyCode: event's "javaKeyCode" :param mods: list of 3 booleans, whether ctrl, shift or alt is pressed""" super().__init__(*times); self.keyText = keyText; self.keyCode = keyCode; self.mods = mods
[docs] @staticmethod def parse(events) -> List["CharTrack"]: stacks = {} # keyCode -> obj def process(e): _type, keyText, keyCode, mods, timestamp = e if _type == "keyPressed": if keyCode in stacks and stacks[keyCode]: a = stacks[keyCode]; stacks[keyCode] = e return [a, [_type, keyText, keyCode, mods, timestamp - 0.001]] #raise Exception("Strange case. Why would the same key be pressed twice without being released first") stacks[keyCode] = e if _type == "keyReleased": a = stacks[keyCode] if keyCode in stacks and stacks[keyCode] else None stacks[keyCode] = None; return [a, e] def makeTrack(x, y): if x is None: x = [0, y[1], y[2], y[3], None] return CharTrack(x[1], x[2], x[3], [x[4], y[4]]) return events | filt(op()["type"].startswith("key")) | filt(op()["type"] != "keyTyped") | apply(lambda x: [x["type"], x["keyText"], x["javaKeyCode"], [x["ctrl"], x["shift"], x["alt"]], x["timestamp"]/1000]) | apply(process) | filt(op()) | ~apply(makeTrack) | deref()
def _tooltip(self, ctx): return escapeHtml(self.__repr__()) def __repr__(self): return f"<CharTrack {self._displayTimes()} keyText ({self.keyText})>"
[docs] def events(self): d = []; t1, t2 = self.timeUnix() # does not care about mods because the mods will have a separate CharTrack already, so we don't have to repeat if t1: d.append({"type": "keyPressed", "keyText": self.keyText, "javaKeyCode": self.keyCode, "timestamp": int(t1*1000)}) if t2: d.append({"type": "keyReleased", "keyText": self.keyText, "javaKeyCode": self.keyCode, "timestamp": int(t2*1000)}) return d
[docs] def copy(self): return CharTrack(self.keyText, self.keyCode, self.mods, self.timeUnix())
[docs] def move(self, deltaTime): if self.startTime: self.startTime += deltaTime self.endTime += deltaTime; self.recording._resetTimes()
def _ord2(x): y = x | apply(ord) | deref() x2y = [x, y] | toDict(False) y2x = [y, x] | toDict(False) return [x, y, x2y, y2x] _upper, _upperCs, _upperD1, _upperD2 = _ord2("ABCDEFGHIJKLMNOPQRSTUVWXYZ"); _lower, _lowerCs, _lowerD1, _lowerD2 = _ord2("abcdefghijklmnopqrstuvwxyz") _num, _numCs, _numD1, _numD2 = _ord2("1234567890") _puncLower, _puncLowerCs, _puncLowerD1, _puncLowerD2 = _ord2("[];',./`-=\\") _puncUpper, _puncUpperCs, _puncUpperD1, _puncUpperD2 = _ord2("{}:\"<>?~_+|") # maps from numbers 12345 to punctuation like !@#$% _numPunc, _numPuncCs, _numPuncD1, _numPuncD2 = _ord2("!@#$%^&*()") _numPuncMap1 = [_numPuncCs, _numCs] | toDict(False); _numPuncMap2 = [_numCs, _numPuncCs] | toDict(False) _punc, _puncCs, _puncD1, _puncD2 = _ord2(_puncLower + _puncUpper + _numPunc + " ") # maps from lower case punctuation like ;',./ into upper case like :"<>? _puncMap = [_puncLower, _puncUpper] | toDict(False); _puncMapCs = [_puncLowerCs, _puncUpperCs] | toDict(False) _puncMap2 = [_puncUpper, _puncLower] | toDict(False) def _inferText(code:int, mods) -> str: if mods[0] or mods[2]: return None shift = mods[1] if shift: if code in _upperCs: return _upperD2[code] if code in _lowerCs: return _lowerD2[code].upper() if code in _numCs: return _numPuncD2[_numPuncMap2[code]] if code in _puncLowerCs: return _puncUpperD2[_puncMapCs[code]] if code in _puncCs: return _puncD2[code] return None else: if code in _upperCs: return _upperD2[code].lower() if code in _lowerCs: return _lowerD2[code] if code in _numCs: return _numD2[code] if code in _puncCs: return _puncD2[code] return None def _isUpper(x:str) -> bool: return x in _upper or x in _puncUpper or x in _numPunc def _canon(x:str) -> Union[int, str]: # returns canonical key to be pressed if x in _num: return _numD1[x] if x in _upper: return _upperD1[x] if x in _lower: return _upperD1[x.upper()] if x in _puncLower: return x if x in _puncUpper: return _puncMap2[x] if x in _numPunc: return _numPuncMap1[_numPuncD1[x]] if x in _punc: return x return None def _textToKeys(text:str): # opposite of _interText cap = False; d = []; sk = 16 # shift key for c in text: _cap = _isUpper(c) if _cap and not cap: d.append(["down", sk]); cap = True # change to upper elif not _cap and cap: d.append(["up", sk]); cap = False # change to lower d.append(["down", _canon(c)]); d.append(["up", _canon(c)]) if cap: d.append(["up", sk]) return d def _getTextBlocks(charTracks:List["CharTrack"]): # Get potential collection of CharTracks es = charTracks | filt(op().startTime) | sortF(op().startTime) | apply(lambda x: [_inferText(x.keyCode, x.mods), x]) | aS(list) d = []; _d = []; inBlock = False for c, obj in es: if c is None and inBlock: d.append(_d); inBlock = False # ends a block elif c is not None and not inBlock: _d = []; inBlock = True # starts a new block if inBlock: _d.append([c, obj]) if inBlock: d.append(_d) return d | apply(transpose() | join("") + iden())
[docs]class WordTrack(Track):
[docs] def __init__(self, text, times:List[float]): """Representing normal text input. This is not created from events directly. Rather, it's created from scanning over CharTracks and merging them together""" super().__init__(*times); self.text = text
def _tooltip(self, ctx): return escapeHtml(self.__repr__()) def __repr__(self): return f"<WordTrack {self._displayTimes()} text ({self.text}) >"
[docs] def events(self): es = _textToKeys(self.text); d = []; ts = np.linspace(*self.timeUnix(), len(es)) for t, (_type, code) in zip(ts, es): _type = "keyPressed" if _type == "down" else "keyReleased"; t = int(t*1000) if isinstance(code, str): d.append({"type": _type, "text": code, "timestamp": t}) else: d.append({"type": _type, "javaKeyCode": code, "timestamp": t}) return d
[docs] def copy(self): return WordTrack(self.text, self.timeUnix())
@k1.patch(Recording) def formWords(self) -> Recording: """Tries to merge nearby CharTracks together that looks like the user is trying to type something, if they make sense. Assuming the user types "a", then "b", then "c". This should be able to detect the intent that the user is trying to type "abc", and replace 3 CharTracks with a WordTrack. Example:: # example recording, run in notebook cell to see interactive interface r = k1ui.Recording.sample(); r # run in another notebook cell and compare difference r.formWords()""" for word, charTracks in _getTextBlocks(self.sel(klass=CharTrack)): if len(word) <= 0: continue ts = charTracks | op().timeUnix().all() | joinStreams() | toMin() & toMax() | deref() self.removeTracks(charTracks); self.addTracks(WordTrack(word, ts)) self.removeTracks(self.sel(*ts | apply(op()-self.startTime), klass=CharTrack) | filt(op().keyCode == 16)) # removing shift CharTracks return self
[docs]class ContourTrack(Track): # mouse movements
[docs] def __init__(self, coords): """Representing mouse trajectory ("mouseMoved" event). :param coords: numpy array with shape (#events, [x, y, unix time])""" super().__init__(*coords | cut(2) | toMin() & toMax()); self.coords = coords; self._cachedImg = None
[docs] @staticmethod def parse(events) -> List["ContourTrack"]: coords = events | filt(lambda x: x["type"] == "mouseMoved" or x["type"] == "mouseDragged") | apply(lambda x: [x["x"], x["y"], x["timestamp"]/1000]) | deref() | aS(np.array) return [] if coords | shape(0) == 0 else [ContourTrack(coords)]
def _img(self): if self._cachedImg: return self._cachedImg x, y, t = self.coords | transpose(); c = mpl.cm.rainbow(t - t[0] | aS(lambda x: x/x[-1])); plt.scatter(x, y, None, c, ".") plt.colorbar(mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(*self.time0Rec()), cmap=mpl.cm.rainbow)).ax.set_title("Time (s)") plt.title("ContourTrack"); plt.grid(True); plt.tight_layout(); self._cachedImg = plt.gcf() | toImg(); return self._cachedImg def __repr__(self): return f"<ContourTrack {self._displayTimes()} n ({self.coords.shape[0]})>" def _tooltip(self, ctx): return f"""<div><div style="margin-bottom:10px">{escapeHtml(self.__repr__())}</div>{self._imgHtml()}</div>""" def _imgHtml(self): return f"""<img src="data:image/png;base64,{self._img() | toBytes(imgType="png") | aS(base64.b64encode) | op().decode()}" alt="Mouse trajectory" />""" def _repr_html_(self): return f"""<!-- k1ui.ContourTrack --><div>{self._imgHtml()}</div>"""
[docs] def events(self): return self.coords | ~apply(lambda x, y, t: {"type": "mouseMoved", "x": x, "y": y, "timestamp": int(t*1000)}) | deref()
[docs] def copy(self): return ContourTrack(np.copy(self.coords))
[docs] def move(self, deltaTime): self.coords[:,2] += deltaTime; super().move(deltaTime)
[docs]class ClickTrack(Track): # mouse down, then up
[docs] def __init__(self, coords:np.ndarray, times:List[float]): """Representing a mouse pressed and released event""" super().__init__(*times); self.coords = coords # coords = [[x1, y1], [x2, y2]]
[docs] @staticmethod def parse(events) -> List["ClickTrack"]: tracks = []; pressedEvents = defaultdict(lambda: None) # haha, get it? def process(e): _type, x, y, button, t = e pe = pressedEvents[button] if _type == "mousePressed": if pe: raise Exception("Strange case. Why would inRange be true when mouse has just been pressed?") pressedEvents[button] = e if _type == "mouseReleased": if pe: tracks.append(ClickTrack(np.array([pe[1:4], e[1:4]]), [pe[4], e[4]])); pressedEvents[button] = None else: warnings.warn("Strange case. Why would mouse be released right at the start? Not strange enough to warrant an exception though") events | filt(lambda x: x["type"] == "mousePressed" or x["type"] == "mouseReleased") | apply(lambda x: [x["type"], x["x"], x["y"], x["button"], x["timestamp"]/1000]) | apply(process) | deref() return tracks
[docs] def isClick(self, threshold=1): """Whether this ClickTrack represents a single click. :param threshold: if Manhattan distance between start and end is less than this amount, then declare it a single click""" return abs(self.coords[0] - self.coords[1]).sum() <= threshold
def __repr__(self): return f"<ClickTrack {self._displayTimes()} coords ({self.coords[0]} -> {self.coords[1]})>" def _tooltip(self, ctx): return escapeHtml(f"{self}")
[docs] def events(self): xy1, xy2 = self.coords; t1, t2 = self.timeUnix() return [{"type": "mousePressed", "x": xy1[0], "y": xy1[1], "button": xy1[2], "timestamp": int(t1*1000)}, {"type": "mouseReleased", "x": xy2[0], "y": xy2[1], "button": xy2[2], "timestamp": int(t2*1000)}]
[docs] def copy(self): return ClickTrack(self.coords | deref(), self.timeUnix())
[docs]class WheelTrack(Track):
[docs] def __init__(self, coords:np.ndarray, times:List[float]): """Representing mouse wheel moved event""" super().__init__(*times); self.coords = coords
[docs] @staticmethod def parse(events) -> List["WheelTrack"]: d = []; _d = []; lastTime = 0 for rot, t in events | filt(op()["type"] == "mouseWheelMoved") | apply(lambda x: [x["wheelRotation"], x["timestamp"]/1000]): if t > lastTime + 2: d.append(_d); _d = [] _d.append([rot, t]); lastTime = t d.append(_d); return d | filt(lambda x: len(x)) | apply(aS(np.array) & (cut(1) | rows(0, -1)) | ~aS(WheelTrack)) | aS(list)
def __repr__(self): return f"<WheelTrack {self._displayTimes()} rotations (avg {self.coords[:,0].sum()}, {self.coords[:,0] | apply(lambda x: '+' if x > 0.5 else '0') | join('')})>" def _tooltip(self, ctx): return escapeHtml(f"{self}")
[docs] def events(self): rs = self.coords[:,0]; ts = np.linspace(*self.timeUnix(), self.coords.shape[0]) return [rs, ts] | transpose() | ~apply(lambda rot, t: {"type": "mouseWheelMoved", "wheelRotation": rot, "timestamp": int(t*1000)})
[docs] def copy(self): return WheelTrack(self.coords, self.timeUnix())
[docs]class StreamTrack(Track):
[docs] def __init__(self, frames:np.ndarray, times:np.ndarray): """Representing screenshots from the UDP stream""" super().__init__(times[0], times[-1]); self.frames = frames; self.times = times; self.aspect = self.frames.shape[2]/self.frames.shape[1]
[docs] @staticmethod def parse(events) -> List["StreamTrack"]: events = events | filt(op()["type"] == "stream") | aS(list) if len(events) == 0: return [] return [StreamTrack(*events | apply(lambda x: [x["frame"], x["timestamp"]/1000]) | transpose() | apply(np.array))]
def __repr__(self): return f"<StreamTrack {self._displayTimes()} #frames ({self.frames.shape[0]}) resolution {self.frames.shape[1:3][::-1]}>" def _frames(self, n, f=iden()): return [self.frames, self.times] | transpose() | insertIdColumn(True, False) | f | aS(list) | aS(lambda x: x | batched(len(x)//n)) | item().all() def _carousel(self): return self._frames(36) | cut(0) | toImg().all() | batched(9) | plotImgs(3, self.aspect, 3, im=True).all() | aS(k1.viz.Carousel) def _tooltip(self, ctx): metaId = ctx.metaId; streamId = autoId(); f = filt(ctx.dis1<op()<ctx.dis2, 1) data = self._frames(40, f) | apply(toImg() | aS(k1.viz.HtmlImage, style="width:800px") | aS(lambda x: x._repr_html_()), 0) | deref() | aS(json.dumps) ctx.scriptTags[streamId] = f""" data_{streamId} = {data}; meta_{metaId}.cbs[{streamId}] = (x) => {{ const stream_{streamId} = document.querySelector("#stream_{streamId}"); const streamText_{streamId} = document.querySelector("#streamText_{streamId}"); if (!stream_{streamId}) return; const fT = x/800*{ctx.dis2-ctx.dis1}+{ctx.dis1}; // frame time let minT = Infinity; let minIm = null; let minI = null for (const [imE, t, i] of data_{streamId}) {{ const dT = Math.abs(fT-t); if (dT < minT) {{ minIm = imE; minT = dT; minI = i }} else break; }} stream_{streamId}.innerHTML = minIm; streamText_{streamId}.innerHTML = "frame: " + minI; }};""" return f"""<div>{escapeHtml(str(self))} <div style="position:relative"> <div id="stream_{streamId}"></div> <div id="streamText_{streamId}" style="position:absolute;top:8px;left:12px;padding:4px 8px;background-color:white;border-radius:12px"></div> </div> </div>""" def _repr_html_(self): return f"""<div>{escapeHtml(str(self))}<div>{self._carousel()._repr_html_()}</div></div>"""
[docs] def events(self): return []
[docs] def copy(self): return StreamTrack(np.copy(self.frames), np.copy(self.times))
[docs] def move(self, deltaTime): self.times += deltaTime; super().move(deltaTime)
def createTrackss(rec:Recording): dis1 = rec.dis1; dis2 = rec.dis2; delta = dis2-dis1 # nTrack for "new track" def process(f=iden()): trackss = [] for nTrack in rec._tracks | f | apply(lambda x: [max(x.startTime or 0, dis1+delta*0.01), min(x.endTime, dis2-delta*0.01), x]) | filt(op()>dis1, 1) | filt(op()<dis2, 0) | deref(): cTracks = None # "chosen track" for eTracks in trackss: # "existing track" if eTracks["tracks"][-1][1] < nTrack[0]: cTracks = eTracks; break # can fit if cTracks: cTracks["tracks"].append(nTrack) else: trackss.append({"tracks": [nTrack], "type": nTrack[2].__class__.__name__.split(".")[-1]}) return trackss trackss = [ *process(instanceOf(CharTrack)), *process(instanceOf(WordTrack)), *process(instanceOf(ContourTrack)), *process(instanceOf(ClickTrack)), *process(instanceOf(WheelTrack)), *process(instanceOf(StreamTrack)) ]; return [trackss, rec] autoId = k1.AutoIncrement(random.randint(0, int(1e9))) def drawTrackss(obj) -> "html": h = settings.draw.trackHeight; pad = settings.draw.pad; trackss, rec = obj; sidebarW=120; # width infoId = autoId(); metaId = autoId(); timeId = autoId(); timeLId = autoId(); sketchId = autoId(); sketchLId = autoId() ctx = k1.Object.fromDict({"id2Tt": {}, "dis1": rec.dis1, "dis2": rec.dis2, "metaId": metaId, "scriptTags": {}}) children = enumerate(trackss) | permute(1, 0) | ~apply(drawTracks, ctx=ctx) | join("") trackNames = trackss | op()["type"].all() | insertIdColumn() | ~apply(lambda i, x: f"<div style='position:absolute;top:{pad+(pad+h)*i}px;left:12px;height:{h}px;text-align:center;line-height:{h}px'><div>{x}s</div></div>") | join("") st0 = rec.dis1 - rec.startTime; et0 = rec.dis2 - rec.startTime; ticks0 = k1.ticks(st0, et0) # 0-based ticksP = (ticks0+rec.startTime-rec.dis1)/(rec.dis2-rec.dis1)*800 # pixel scale ticks = [ticks0, ticksP] | transpose() | filt(op()>0, 1) | filt(op()<800, 1) | ~apply(lambda x, y: f"<div style='position:absolute;width:1px;height:10px;background-color:black;left:{y}px;bottom:4px'></div> <div style='position:absolute;left:{y-8}px;top:0px'>{x}</div>") | join("") sketchH = (pad+h)*len(trackss)+pad; extraScripts = "\n".join(ctx.scriptTags.values()) return f""" <div style="display:flex;flex-direction:column;align-items:flex-start"> <div style="display:flex;flex-direction:row"> <div style="width:{sidebarW}px;padding-right:10px;display:flex;justify-content:center;align-items:center"><div>Time (s)</div></div> <div id="time_{timeId}" style="background-color:red;height:{h}px;position:relative;height:34px"> {ticks} <div id="timeL_{timeLId}" style="position:absolute;top:0px;background-color:white;border:1px solid black;border-radius:8px;padding:0px 8px">&nbsp;&nbsp;</div> </div> </div> <div style="display:flex;flex-direction:row"> <div style="width:{sidebarW}px;padding-right:10px;position:relative">{trackNames}</div> <div id="sketch_{sketchId}" style="width:{800}px;height:{sketchH}px;background-color:grey;position:relative"> <div id="sketchL_{sketchLId}" style="position:absolute;width:1px;height:{sketchH}px;background-color:black;top:0px"></div> {children} </div> </div> <div id="info_{infoId}" style="min-height:30px;display:flex;flex-direction:column;justify-content:center;align-items:flex-start;padding:4px 12px"></div> </div> <script> id2Tt = {ctx.id2Tt | aS(json.dumps)} info_{infoId} = document.querySelector("#info_{infoId}"); time_{timeId} = document.querySelector("#time_{timeId}"); sketch_{sketchId} = document.querySelector("#sketch_{sketchId}"); sketchL_{sketchLId} = document.querySelector("#sketchL_{sketchLId}"); timeL_{timeLId} = document.querySelector("#timeL_{timeLId}"); meta_{metaId} = {{x: 0, y: 0, cbs: {{}}}}; for (const [k, v] of Object.entries(id2Tt)) {{ let elem = document.querySelector(`#track_${{k}}`); elem.onmouseover = () => {{info_{infoId}.innerHTML = atob(v[0]);elem.style.backgroundColor = "red";}}; elem.onmouseout = () => {{info_{infoId}.innerHTML = ""; elem.style.backgroundColor = "white";}}; }} sketch_{sketchId}.onmousemove = (event) => {{ const x = event.pageX-sketch_{sketchId}.getBoundingClientRect().x; meta_{metaId}.x = x; sketchL_{sketchLId}.style.left = x + "px"; timeL_{timeLId}.style.left = (x-timeL_{timeLId}.getBoundingClientRect().width/2) + "px"; timeL_{timeLId}.innerHTML = Number(x/800*{et0-st0}+{st0}).toFixed(2) + "s"; for (const cb of Object.values(meta_{metaId}.cbs)) cb(x); }} {extraScripts} </script>""" def drawTracks(tracks, rowId, ctx) -> "html": return tracks["tracks"] | apply(drawTrack, rowId=rowId, ctx=ctx) | join("") def drawTrack(track, rowId, ctx) -> "html": h = settings.draw.trackHeight; pad = settings.draw.pad; st, et, obj = track x1 = (st-ctx.dis1)/(ctx.dis2-ctx.dis1)*800; x2 = (et-ctx.dis1)/(ctx.dis2-ctx.dis1)*800 y = rowId*(h+pad)+pad; w = x2-x1; trackId = autoId() tooltip = obj._tooltip(ctx).encode() | aS(base64.b64encode) | op().decode() ctx.id2Tt[trackId] = [tooltip, x1, x2, y] return f"""<div id="track_{trackId}" style="top:{y}px;left:{x1}px;width:{w}px;height:{h}px;background-color:white;position:absolute"></div>""" basePath = os.path.dirname(inspect.getabsfile(k1lib)) + os.sep + "k1ui" + os.sep @k1.patch(Recording, static=True) def sampleEvents() -> List[dict]: """Grabs the built-in example events. Results will be really long, so beware, as it can crash your notebook if you try to display it.""" mouseE, keyE = cat(f"{basePath}mouseKey.pth", False) | aS(dill.loads) deltaT = keyE()[0]["timestamp"] - mouseE()[0]["timestamp"] ev = [*mouseE() | apply(lambda x: {**x, "timestamp": x["timestamp"]+deltaT}), *keyE()] try: # local comp has the k1ui-screen file, but it will not be bundled with the library, cause it's like 80MB! screenE = cat("screen.pth", False) | aS(dill.loads) deltaT = keyE()[0]["timestamp"] - screenE()[0]["timestamp"] return [*screenE() | apply(lambda x: {**x, "timestamp": x["timestamp"]+deltaT}), *ev] except: return ev @k1.patch(Recording, static=True) def sample() -> Recording: """Creates a Recording from :meth:`sampleEvents`""" return Recording(Recording.sampleEvents()) @k1.patch(ContourTrack) def split(self, times:List[float]): """Splits this contour track by multiple timestamps relative to recording's start time. Example:: r = k1ui.Recording.sample() r.sel1(klass=k1ui.ContourTrack).split([5])""" rec = self.recording; c = self.coords; i = 0; x = 0; y = 0; d = []; cps = np.array(times) + rec.startTime while True: if cps[i] > c[y,2]: y += 1 else: if y > x: d.append(c[x:y]) x = y; i += 1 if y >= len(c): d.append(c[x:y]); break if i >= len(cps): d.append(c[x:]); break rec.removeTracks(self) rec.addTracks(d | apply(ContourTrack)) @k1.patch(ContourTrack) def splitClick(self, clickTracks:List["ClickTrack"]=None): """Splits this contour track by click events. Essentially, the click events chops this contour into multiple segments. Example:: r = k1ui.Recording.sample() r.sel1(klass=k1ui.ContourTrack).splitClick() :param clickTracks: if not specified, use all ClickTracks from the recording""" rec = self.recording; c = self.coords; i = 0; x = 0; y = 0; d = [] if clickTracks is None: clickTracks = rec.sel(*self.time0Rec()) | instanceOf(ClickTrack) self.split(clickTracks | ~filt(op().isClick(-1)) | op().timeUnix().all() | joinStreams() | sort(None) | apply(op()-rec.startTime) | deref()) @k1.patch(Recording) def addTime(self, t:float, duration:float) -> Recording: """Inserts a specific duration into a specific point in time. More clearly, this transfroms this:: # |-1--| |-2-| # |---3---| # ^ insert duration=3 here Into this:: # |-1--| |-2-| # |---3------| Tracks that partly overlaps with the range will have their start/end times modified, and potentially delete some of the Track's internal data: - Tracks whose only start and end times are modified: Char, Word, Click, Wheel - Tracks whose internal data are also modified: Contour, Stream :param t: where to insert the duration, relative to Recording's start time :param duration: how long (in seconds) to insert?""" at = self.sel(t,t); after = self.sel(t) # tracks at or after the specified time unix = t + self.startTime for track in at: after.remove(track) for track in at: track.endTime += duration if isinstance(track, ContourTrack): c = track.coords; idx = (c[:,2] > unix).argmax(); track._cachedImg = None if c[idx,2] > unix: c[idx:,2] += duration # index is valid if isinstance(track, StreamTrack): c = track.times; idx = (c > unix).argmax() if c[idx] > unix: c[idx:] += duration # index is valid for track in after: track.startTime += duration; track.endTime += duration if isinstance(track, ContourTrack): track.coords[2] += duration if isinstance(track, StreamTrack): track.times += duration self.endTime += duration; self._resetDis(); return self @k1.patch(Recording) def removeTime(self, t1:float, t2:float) -> Recording: """Deletes time from t1 to t2 (relative to Recording's start time). All tracks lying completely inside this range will be deleted. More clearly, it transforms this:: # |-1--| |-2-| |-3-| # |---4---| |-5-| # ^ ^ delete between these carets Into this:: # |-1--| |-3-| # |-4-||5-| Tracks that partly overlaps with the range will have their start/end times modified, and potentially delete some of the Track's internal data: - Tracks whose only start and end times are modified: Char, Word, Click, Wheel - Tracks whose internal data are also modified: Contour, Stream""" duration = t2 - t1; t1U = t1 + self.startTime; t2U = t2 + self.startTime self.removeTracks(self.sel(t1, t2) | filt(op().startTime >= t1U) | filt(op().endTime < t2U)) # removing everything that's completely inside overlap = self.sel(t1, t2) | aS(list); after = self.sel(t2) | filt(op().startTime >= t2U) | aS(list) for track in overlap: # handling left overhang if isinstance(track, ContourTrack): c = track.coords; idx1 = (c[:,2] > t1U).argmax(); idx2 = (c[:,2] > t2U).argmax() if c[idx2,2] <= t2U: idx2 = len(c) a = c[:idx1]; b = c[idx2:]; b[:,2] -= duration track.coords = np.concatenate([a, b]); track._cachedImg = None if isinstance(track, StreamTrack): c = track.times; idx1 = (c > t1U).argmax(); idx2 = (c > t2U).argmax() if c[idx2] <= t2U: idx2 = len(c) # special case if idx2 is not valid track.times = np.concatenate([track.times[:idx1], track.times[idx2:]-duration]) track.frames = np.concatenate([track.frames[:idx1], track.frames[idx2:]]) track.endTime = max(t1U, track.endTime - duration); track.startTime = min(t1U, track.startTime) for track in after: if isinstance(track, ContourTrack): track.coords[:,2] -= duration if isinstance(track, StreamTrack): track.times -= duration track.startTime -= duration; track.endTime -= duration self._resetTimes(); self._resetDis(); return self def _move(cs, e1, e2): det = e1[0]*e2[1] - e1[1]*e2[0]; dot = e1@e2; angle = math.atan2(det, dot) s = math.sin(angle); c = math.cos(angle); rot = np.array([[c, -s], [s, c]]) scale = (e2**2).sum()**0.5/(e1**2).sum()**0.5; return (rot @ cs.T)*scale | transpose() @k1.patch(ContourTrack) def movePoint(self, x, y, start=True): """Move contour's start/end to another location, smoothly scaling all intermediary points along. :param start: if True, move the start point, else move the end point""" c = self.coords; e2 = np.array([x, y]) if start: s = c[-1,:2]; e1 = c[0,:2] - s else: s = c[0,:2]; e1 = c[-1,:2] - s e2 = e2 - s; c[:,:2] = _move(c[:,:2]-s, e1, e2)+s @k1.patch(Track) def nextTrack(self) -> Track: """Grabs the next track (ordered by start time) in the recording""" return self.recording._tracks | filt(op().startTime) | filt(op().startTime > (self.startTime or 0)) | sortF(op().startTime) | item() @k1.patch(Recording) def refine(self, enabled:List[int]=[1,1,0]) -> Recording: """Perform sensible default operations to refine the Recording. This currently includes: - (0) Splitting ContourTracks into multiple smaller tracks using click events - (1) Forming words from nearby CharTracks - (2) Removing open-close CharTracks. Basically, CharTracks that don't have a begin or end time :param enabled: list of integers, whether to turn on or off certain features. 1 to turn on, 0 to turn off""" if enabled[0]: self.formWords() if enabled[1]: self.sel(klass=ContourTrack) | op().splitClick().all() | ignore() if enabled[2]: self.removeTracks(self.sel(klass=CharTrack) | ~filt(op().startTime)) return self def convBlock(inC, outC, kernel=3, stride=2, padding=1): return torch.nn.Sequential(torch.nn.Conv2d(inC, outC, kernel, stride, padding), torch.nn.ReLU(), torch.nn.BatchNorm2d(outC)) if hasTorch: class skipBlock(torch.nn.Module): def __init__(self, inC): super().__init__(); self.conv1 = convBlock(inC, inC, stride=1) self.conv2 = convBlock(inC, inC*2) def forward(self, x): return ((x | self.conv1) + x) | self.conv2 class Net(torch.nn.Module): def __init__(self, skips:int=5): super().__init__() self.skips = torch.nn.Sequential(convBlock(3, 8), *[skipBlock(8*2**i) for i in range(skips)]) self.avgPool = torch.nn.AdaptiveAvgPool2d([1, 1]); self.lin1 = knn.LinBlock(8 * 2**skips, 50) self.lin2 = torch.nn.Linear(50, 10); self.softmax = torch.nn.Softmax(dim=1) self.distThreshold = torch.nn.Parameter(torch.tensor(-0.5)); self.sigmoid = torch.nn.Sigmoid() self.headOnly = True def forward(self, x): x = x | self.skips | self.avgPool | op().squeeze() | self.lin1 return x if self.headOnly else x | self.lin2 x = ((x[None] - x[:,None])**2).sum(dim=-1) x = (x + 1e-7)**0.5 + self.distThreshold | self.sigmoid return x
[docs]def distNet() -> "torch.nn.Module": """Grabs a pretrained network that might be useful in distinguishing between screens. Example:: net = k1ui.distNet() net(torch.randn(16, 3, 192, 192)) # returns tensor of shape (16, 10)""" net = Net(); net.load_state_dict(cat(f"{basePath}256.model.state_dict.pth", False) | aS(dill.loads)) net.parameters() | op().requires_grad_(False).all() | ignore(); net.eval(); return net
def discardTransients(it, col=None, countThres=7, regular=False): # consistent for 7 consecutive frames, then output the results lastRow = None; lastE = None yielded = False; count = 0 for row in it: e = row[col] if col else row if e == lastE: count += 1 else: count = 0; lastE = e; lastRow = row; yielded = False if count > countThres-2 and not yielded: yielded = True; yield lastRow elif regular: yield None class Buffer: def __init__(self): self.l = deque() def append(self, x): self.l.append(x) def __next__(self): return self.l.popleft() if hasTorch: np2Tensor = toImg() | aS(tf.Resize([192, 192])) | toTensor() class MLP(nn.Module): def __init__(self, nClasses, **kwargs): super().__init__(); self.l1 = knn.LinBlock(50, nClasses); self.l2 = nn.Linear(nClasses, nClasses) def forward(self, xb): return xb | self.l1 | self.l2 whatever = object()
[docs]class TrainScreen: data: List[Tuple[int, str]] """Core dataset of TrainScreen. Essentially just a list of (frameId, screen name)"""
[docs] def __init__(self, r:Recording): """Creates a screen training system that will train a small neural network to recognize different screens using a small amount of feedback from the user. Overview on how it's supposed to look like: Setting up:: r = k1ui.Recording(await k1ui.record(30)) # record everything for 30 seconds, and creates a recording out of it ts = k1ui.TrainScreen(r) # creates the TrainScreen object r # run this in a cell to display the recording, including StreamTrack ts.addRule("home", "settings", "home") # add expected screen transition dynamics (home -> settings -> home) Training with user's feedback:: ts.registerFrames({"home": [100, 590, 4000, 4503], "settings": [1200, 2438]}) # label some frames of the recording. Network will train for ~6 seconds next(ts) # display 20 images that confuses the network the most ts.register({"home": [2, 6], "settings": [1, 16]}) # label some frames from the last line. Notice the frame numbers are much smaller and are <20 next(ts); ts.register({}); next(ts); ts.register({}) # repeat the last 2 lines for a few times (3-5 times is probably good enough for ~7 screens) Evaluating the performance:: ts.graphs() # displays 2 graphs: network's prediction graph and the actual rule graph. Best way to judge performance ts.l.Accuracy.plot() # actual accuracy metric while training. Network could have bad accuracy here while still able to construct a perfect graph, so don't rely much on this Using the model:: ts.predict(torch.randn(2, 3, 192, 192) | k1ui.distNet()) # returns list of ints. Can use ts.idx2Name dict to convert to screen names Saving the model:: ts | aS(dill.dumps) | file("ts.pth") .. warning:: This won't actually save the associated recording, because recordings are very heavy objects (several GB). It is expected that you manually manage the lifecycle of the recording.""" self.r = r; self.data = []; # [(frame id, screen name)] self._aspect = self.frames | item() | op().shape[:2] | ~aS(lambda x, y: y/x) self._distNet = distNet(); self._rules = set(); self._trainParams = {"joinAlpha": 0, "epochs": 300} self._lastScreenName = None; self._screenDump = Buffer(); self._screenTransients = discardTransients(self._screenDump, regular=True)
@property def _coldStart(self): return len(self.data) == 0 # whether there are any data at all to work with def _coldGuard(self): if self._coldStart: raise Exception("TrainScreen has not started yet. Run `next(ts)`, choose a few frames using `ts.register()` to access this functionality") def _learner(self): self._coldGuard(); l = k1.Learner(); l.data = self._dataF() l.model = MLP(len(self.name2Idx)) l.opt = optim.AdamW([l.model.parameters(), self._distNet.parameters()] | joinStreams(), lr=3e-3) l.cbs.add(Cbs.LossCrossEntropy()); l.css = "none" l.ConfusionMatrix.categories = deref()(self.name2Idx.items()) | sort(1) | cut(0) | deref() l.cbs.remove("AccuracyTop5", "AccF0"); return l
[docs] def train(self, restart=True): """Trains the network for a while (300 epochs/6 seconds). Will be called automatically when you register new frames to the system :param restart: whether to restart the small network or not""" if restart: self.l = self._learner(); self.l.run(self._trainParams["epochs"])
[docs] def trainParams(self, joinAlpha:float=None, epochs:int=None): """Sets training parameters. :param joinAlpha: (default 0) alpha used in joinStreamsRandom component for each screen categories. Read more at :class:`~k1lib.cli.structural.joinStreamsRandom` :param epochs: (default 300) number of epochs for each training session""" if joinAlpha: self._trainParams["joinAlpha"] = joinAlpha if epochs: self._trainParams["epochs"] = epochs
@property def frames(self) -> np.ndarray: """Grab the frames from the first :class:`StreamTrack` from the :class:`Recording`""" return self.r.sel1(klass=StreamTrack).frames @property @lru_cache def feats(self) -> List[np.ndarray]: """Gets the feature array of shape (N, 10) by passing the frames through :meth:`distNet`. This returns a list of arrays, not a giant, stacked array for memory performance""" self._coldGuard(); print("Converting all frames to features using `distNet`..."); a = k1.AutoIncrement() res = self.frames | tee(lambda x: f"{a()}/{len(self)}").crt() | np2Tensor.all() | batched(16, True) | apply(aS(list) | aS(torch.stack) | aS(self._distNet)) | joinStreams() | aS(list) print(); return res def __len__(self): return len(self.frames) def _randomConsidering(self): return range(len(self)) | splitW(1, 1, 1, 1, 1) | apply(randomize(None, 42) | head(4)) | joinStreams() | aS(list) def __next__(self) -> "PIL.Image.Image": # show frames if self._coldStart: self._considering = self._randomConsidering() else: a = self.transitionScreens(False) | randomize(None, 42) | cut(0) | aS(iter) b = self._randomConsidering() | aS(iter); c = self._midBoundaryConsidering() | aS(iter) self._considering = [a, a, b, c, c, c] | apply(wrapList() | insert(yieldT | repeat(), False) | joinStreams() | randomize()) | joinStreamsRandom() | head(20) | deref() return self._considering | lookup(self.frames) | insertIdColumn(begin=False) | plotImgs(5, self._aspect-0.2, im=True) def _refreshIdx(self): self.idx2Name, self.name2Idx = self.data | cut(1) | aS(set) | insertIdColumn() | toDict() & (permute(1, 0) | toDict()) | deref()
[docs] def register(self, d): """Tells the object which images previously displayed by :meth:`__next__` associate with what screen name. Example:: next(ts) # displays the images out to a notebook cell ts.register({"home": [3, 4, 7], "settings": [5, 19, 2], "monkeys": [15, 11], "guns": []}) This will also quickly (around 6 seconds) train a small neural network on all available frames based on the new information you provided. See also: :meth:`registerFrames`""" self.data = [self.data, deref()(d.items()) | apply(repeat(), 0) | transpose().all() | joinStreams() | permute(1, 0) | lookup(self._considering, 0)] | joinStreams() | sort(0) | unique(0) | deref() self._refreshIdx(); self.train()
[docs] def registerFrames(self, data:Dict[str, List[int]]): """Tells the object which frames should have which labels. Example:: ts.registerFrames({"home": [328, 609], "settings": [12029], "monkeys": [1238]}) This differs from :meth:`register` in that the frame id here is the absolute frame index in the recording, while in :meth:`register`, it's the frame displayed by :meth:`__next__`.""" self.data = [self.data, deref()(data.items()) | apply(repeat(), 0) | transpose().all() | joinStreams() | permute(1, 0)] | joinStreams() | sort(0) | unique(0) | deref(); self._refreshIdx(); self.train()
[docs] def addRule(self, *screenNames:List[str]) -> "TrainScreen": """Adds a screen transition rule. Let's say that the transition dynamic looks like this: .. code-block:: text home <---> settings <---> account ^ | v shortcuts You can represent it like this:: ts.addRule("home", "settings", "account", "settings", "home") ts.addRule("settings", "shortcuts", "settings")""" screenNames | window(2) | apply(tuple) | apply(self._rules.add) | ignore(); return self
[docs] def transitionScreens(self, obeyRule:bool=whatever) -> List[Tuple[int, str]]: """Get the list of screens (list of (frameId, screen name) tuple) that the network deems to be transitions between screen states. :param obeyRule: if not specified, then don't filter. If True, returns only screens that are part of the specified rule and vice versa""" self._coldGuard() with torch.no_grad(): transitions = self.predict(self.feats) | insertIdColumn() | aS(discardTransients, 1) | window(2, True, None) | filt(~aS(lambda x, y: not y or x[1] != y[1])) | cut(0) | lookup(self.idx2Name, 1) | deref() if obeyRule is whatever: return transitions f = inSet(self._rules, 1) if obeyRule else ~inSet(self._rules, 1) return transitions | window(2) | apply(transpose() | iden() + aS(tuple)) | f | transpose().all() | joinStreams() | unique(0) | deref()
[docs] def newEvent(self, sess:WsSession, event:dict): if event["type"] == "stream": with torch.no_grad(): name = event["frame"] | np2Tensor | op().reshape(-1, 3, 192, 192) | aS(self._distNet) | op().view(1, -1)\ | self.l.model | op().argmax().item() | aS(lambda x: self.idx2Name[x]) sess.loop.create_task(sess.eventCb(sess, {"type": "screenName", "name": name})) if event["type"] == "screenName": self._screenDump.append(event["name"]); res = next(self._screenTransients) if res: sess.loop.create_task(sess.eventCb(sess, {"type": "screenTransition", "transition": (self._lastScreenName, res)})) self._lastScreenName = res
[docs] def predict(self, feats:"torch.Tensor") -> List[int]: """Using the built-in network, tries to predict the screen name for a bunch of features of shape (N, 10). Example:: r = ...; ts = k1ui.TrainScreen(r); next(ts) ts.register({"bg": [9, 10, 11, 12, 17, 19], "docs": [5, 6, 7, 8, 0, 1, 4], "jupyter": [2, 3]}) # returns list of 2 integers ts.predict(torch.randn(2, 3, 192, 192) | aS(k1ui.distNet()))""" self._coldGuard(); return feats | batched(128, True) | apply(aS(list) | aS(torch.stack) | aS(self.l.model) | op().argmax(1).numpy()) | joinStreams()
[docs] def transitionGraph(self) -> "graphviz.dot.Digraph": """Gets a screen transition graph of the entire recording. See also: :meth:`graphs`""" g = k1.digraph(); self.transitionScreens() | cut(1) | window(2) | apply(tuple) | count() | cut(0, 1) | ~apply(lambda c, xy: g(*xy, label=f" {c}")) | ignore(); return g
[docs] def ruleGraph(self) -> "graphviz.dot.Digraph": """Gets a screen transition graph based on the specified rules. Rules are added using :meth:`addRule`. See also: :meth:`graphs`""" g = k1.digraph(); self._rules | ~apply(g) | ignore(); return g
[docs] def graphs(self) -> viz.Carousel: """Combines both graphs from :meth:`transitionGraph` and :meth:`ruleGraph`""" return [self.transitionGraph(), self.ruleGraph()] | toImg().all() | aS(viz.Carousel)
[docs] def labeledData(self) -> viz.Carousel: """Visualizes labeled data""" return self.data | groupBy(1) | apply(randomize(None) | head(5) | lookup(self.frames, 0)) | batched(5) | plotImgs(5, self._aspect-0.2, table=True, im=True).all() | aS(viz.Carousel)
def __getstate__(self): d = dict(self.__dict__); del d["r"]; del d["_lastScreenName"]; del d["_screenDump"]; del d["_screenTransients"]; return d def __setstate__(self, d): self.__dict__.update(d); self._lastScreenName = None; self._screenDump = Buffer() self._screenTransients = discardTransients(self._screenDump, regular=True)
[docs] def correctRatio(self): """Ratio between the number of screens that is in a valid transition and ones that isn't in a valid transition. Just a quick metric to see how well the network is doing. The higher the number, the better it is""" return len(self.transitionScreens(True))/len(self.transitionScreens(False))
def fillIn(n, states): iS = 0 # index of states state = None; nextI, nextS = states[iS] for i in range(n): if i >= nextI: iS += 1; state = nextS if iS < len(states): nextI, nextS = states[iS] yield [i, state] def blocks(it): lastIs = []; lastE = None for i, e in it: if e != lastE: if lastE is not None: yield min(lastIs), max(lastIs), lastE lastIs = []; lastE = e lastIs.append(i) yield min(lastIs), max(lastIs), lastE @k1.patch(TrainScreen) def _dataF(self, bs=64): self._coldGuard(); v1 = fillIn(len(self), self.data) | filt(op(), 1) # old version. A bit more liveral than v2, and will accidentally auto label wrongly from time to time v2 = blocks(self.data) | ~apply(lambda x, y, z: [range(x, y+1), z | repeat()] | transpose()) | joinStreams() js = deref() | aS(lambda xs: xs | apply(repeatFrom() | randomize()) | joinStreamsRandom(self._trainParams["joinAlpha"], xs | apply(len) | deref())) # joinStreams return v2 | randomize(None) | groupBy(1) | filt(lambda x: len(x) > 1) | splitW().all() | transpose()\ | apply(js | lookup(self.feats, 0) | lookup(self.name2Idx, 1) | batched(bs)\ | apply(transpose() | (aS(list) | aS(torch.stack)) + toTensor(int)) ) | stagger.tv(1024/bs) | aS(list) def midBounds(it): # to grab data samples that's in between blocks. Aka the really confusing case in-between transitions, so that the user can guide it effectively lastI = 0; lastE = None for i, e in it: if e != lastE: yield (i + lastI)//2, i-lastI, lastE; lastE = e lastI = i return it @k1.patch(TrainScreen) def _midBoundaryConsidering(self): return midBounds(self.data) | ~head(1) | ~sort(1) | cut(0)