# AUTOGENERATED FILE! PLEASE DON'T EDIT
"""
Lots of type hint mechanisms to be used by the `LLVM optimizer <llvm.html>`_
"""
import k1lib.cli as cli
import k1lib, itertools, copy, numbers; import numpy as np
from k1lib.cli.init import yieldT
from typing import List
from collections import defaultdict, deque
try: import torch; hasTorch = True
except: hasTorch = False; torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
__all__ = ["tBase", "tAny", "tList", "tIter", "tSet", "tCollection", "tExpand",
"tNpArray", "tTensor",
"tListIterSet", "tListSet", "tListIter", "tArrayTypes",
"inferType", "TypeHintException", "tLowest", "tCheck", "tOpt"]
settings = k1lib.settings.cli
settings.add("llvm", k1lib.Settings(), "settings related to LLVM-inspired optimizer `tOpt`. See more at module `k1lib.cli.typehint`")
settings.llvm.add("k1a", True, "utilize the supplementary C-compiled library automatically for optimizations")
[docs]class TypeHintException(Exception): pass
def klassName(self):
if isinstance(self, tBase): return self.__class__.__name__
try: return f"{self.__name__}"
except: return f"{self}"
def klassRepr(self): return f"{self}" if isinstance(self, tBase) else klassName(self)
[docs]class tBase:
def __init__(self, child=type(None)): self.child = child
def __repr__(self):
return f"<{klassName(self)} {klassRepr(self.child)}>"
def __eq__(self, v):
if not isinstance(v, tBase): return False
if self.__class__ != v.__class__: return False
if self.child != v.child: return False
return True
[docs] def check(self, v):
"""Checks whether a specific object adhears to this type hint or not.
Returns :attr:`yieldT` if object does not adhere. If it does, then return the object.
Note that in the case that the object is actually an iterator, it will return
a new iterator containing all elements from the old iterator."""
return NotImplemented
[docs] def item(self):
"""Gets the child type of this type. Basically what's the type if
it were to go through :class:`~k1lib.cli.utils.item`. Example::
# returns tTensor(torch.float32, 2)
tTensor(torch.float32, 3).item()
"""
return self.child if self.child is not type(None) else tAny()
[docs] def expand(self, n) -> List["tBase"]:
"""Expands the type to a list with ``n`` elements.
Example::
# returns [int, int, int, int]
tList(int).expand(4)
# returns [int, float, float, str]
tCollection(int, tExpand(float), str).expand(4)
"""
return [self.child if self.child is not type(None) else tAny()]*n
def __hash__(self):
return hash(f"{self.__class__} {self.child}")
def checkF(t):
#print(t, isinstance(t, (tBase, cli.typehint.tBase)))
if isinstance(t, (tBase, cli.typehint.tBase)): return t.check
else:
def inner(x):
try: return x if isinstance(x, t) else yieldT
except TypeError:
if hasTorch and isinstance(x, torch.Tensor):
return x if x.dtype == t else yieldT
return yieldT
except Exception as e:
print(x, t); raise e
return inner
[docs]class tAny(tBase):
def __init__(self): super().__init__()
def __repr__(self): return f"<{klassName(self)}>"
def __eq__(self, v): return isinstance(v, tAny)
[docs] def check(self, v): return v
[docs] def item(self): return tAny()
def __hash__(self): return hash(f"tAny")
[docs]class tIter(tBase):
[docs] def check(self, v):
l = []
for e in v:
x = checkF(self.child)(e); l.append(x)
if x == yieldT: return yieldT
return iter(l)
[docs]class tList(tBase):
[docs] def check(self, v):
if not isinstance(v, (list, tuple, range)): return yieldT
if tIter(self.child).check(v) is yieldT: return yieldT
return v
[docs]class tSet(tBase):
[docs] def check(self, v):
if not isinstance(v, set): return False
if tIter(self.child).check(v) is yieldT: return yieldT
return v
tListIterSet = (tList, tIter, tSet)
tListSet = (tList, tSet)
tListIter = (tList, tIter)
class tDict(tBase):
def __init__(self, keys, values):
"""Dictionary type.
Example::
d = tDict(tIter(str), tIter(int))
# returns {"a": 3} dict, so check passed
d.check({"a": 3})"""
super().__init__(); self.keys = keys; self.values = values
def check(self, v):
if not isinstance(v, dict): return yieldT
ks = self.keys.check(list(v.keys()))
vs = self.values.check(list(v.values()))
if ks is yieldT or vs is yieldT: return yieldT
return {k: v for k, v in zip(ks, vs)}
def __eq__(self, v):
if not isinstance(v, tDict): return False
if self.keys != v.keys: return False
if self.values != v.values: return False
return True
def __repr__(self):
return f"<{klassName(self)} {klassRepr(self.keys)} {klassRepr(self.values)}>"
[docs]class tNpArray(tBase):
[docs] def __init__(self, child=None, rank=None):
"""Numpy array type.
Example::
# returns np.array([2, 3])
tNpArray(np.int64, 1).check(np.array([2, 3]))
:param child: the dtype of the array
:param rank: the rank/dimension of the array"""
super().__init__(child); self.rank = rank
[docs] def check(self, v):
if not isinstance(v, np.ndarray): return yieldT
if self.rank is not None and self.rank != len(v.shape): return yieldT
return v
def __repr__(self): return f"<tNpArray {klassName(self.child)} rank={self.rank}>"
[docs] def item(self): return (tNpArray(self.child, self.rank - 1) if self.rank > 1 else self.child) if self.rank is not None else tNpArray(self.child, None)
def __eq__(self, v):
if not isinstance(v, tNpArray): return False
if self.child is not None and v.child is not None and self.child != v.child: return False
if self.rank is None or v.rank is None: return True
return self.rank == v.rank
def __hash__(self): return hash(f"{self.child} - {self.rank}")
[docs] def expand(self, n): return [self.item()]*n
if hasTorch:
class tTensor(tBase):
[docs] def __init__(self, child=None, rank=None):
"""PyTorch tensor type.
Example::
# returns torch.tensor([2.0, 3.0])
tTensor(torch.float32, 1).check(torch.tensor([2.0, 3.0]))
:param child: the dtype of the array
:param rank: the rank/dimension of the tensor"""
super().__init__(child); self.rank = rank
[docs] def check(self, v):
if not isinstance(v, torch.Tensor): return yieldT
if self.rank is not None and self.rank != len(v.shape): return yieldT
return v
def __repr__(self): return f"<tTensor {klassName(self.child)} rank={self.rank}>"
[docs] def item(self): return (tTensor(self.child, self.rank - 1) if self.rank > 1 else self.child) if self.rank is not None else tTensor(self.child, None)
def __eq__(self, v):
if not isinstance(v, tTensor): return False
if self.child is not None and v.child is not None and self.child != v.child: return False
if self.rank is None or v.rank is None: return True
return self.rank == v.rank
def __hash__(self): return hash(f"{self.child} - {self.rank}")
[docs] def expand(self, n): return [self.item()]*n
tArrayTypes = (tNpArray, tTensor)
else:
[docs] class tTensor(tBase): pass
tArrayTypes = (tNpArray,)
[docs]class tCollection(tBase):
[docs] def __init__(self, *children):
"""Fixed-length collection of things. Let's say you want a tuple with
5 values::
a = [3, [2, 3], "e", 2.0, b'3']
Then, this would be represented like this::
tCollection(int, tList(int), str, float, bytes)
This also works in conjunction with :class:`tExpand`, like this::
a = [3, [2, 3], "e", 2.0, 3.0]
tCollection(int, tList(int), str, tExpand(float))"""
super().__init__(None); self.children = list(children)
nExpands = sum(isinstance(e, tExpand) for e in children)
if nExpands > 1: raise TypeHintException("Can't have 2 `tExpand` in a `tCollection`")
self.nChildren = len(children) - nExpands # minimum number of children possible
self.expandIdx = -1
for i, e in enumerate(children):
if isinstance(e, tExpand): self.expandIdx = i
def __repr__(self):
a = ' '.join(klassRepr(c) for c in self.children)
return f"<{klassName(self)} {a}>"
def __eq__(self, v):
if isinstance(v, tCollection):
if len(self.children) != len(v.children): return False
for x, y in zip(self.children, v.children):
if x != y: return False
return True
return False
[docs] def check(self, v):
t = type(v) if isinstance(v, (list, tuple)) else None
v = list(v); l = []
if self.expandIdx >= 0:
n = len(self.children); nv = len(v)
nMatchExpand = nv-(n-1)
for i in range(self.expandIdx):
x = checkF(self.children[i])(v[i]); l.append(x)
if x is yieldT: return yieldT
for i in range(self.expandIdx, self.expandIdx + nMatchExpand):
x = checkF(self.children[self.expandIdx])(v[i]); l.append(x)
if x is yieldT: return yieldT
for i in range(self.expandIdx + nMatchExpand, nv):
x = checkF(self.children[i-nMatchExpand+1])(v[i]); l.append(x)
if x is yieldT: return yieldT
else:
l = []
for c, e in zip(self.children, v):
x = checkF(c)(e); l.append(x)
if x is yieldT: return yieldT
return t(l) if t else l
[docs] def reduce(self):
"""Tries to reduce ``tCollection(int, int)`` to ``tIter(int)`` if possible"""
s = self.children[0]
for e in self.children:
if s != e: return self
return tIter(s)
[docs] def item(self): return tLowest(*((t.child if isinstance(e, tExpand) else t) for t in self.children))
[docs] def expand(self, n:int) -> List[tBase]:
"""Expands out this collection so that it has a specified length"""
if self.expandIdx >= 0:
ts = []
for t in self.children:
if isinstance(t, tExpand):
for i in range(n - len(self.children) + 1): ts.append(t.child)
else: ts.append(t)
return ts
else:
if len(self.children) == n: return list(self.children)
else: # doesn't make sense, so default case should return to list of lowest child
return [self.item()]*n
[docs]class tExpand(tBase):
[docs] def __init__(self, child):
"""Supplement to :class:`tCollection`"""
super().__init__(child)
[docs] def check(self, v): return checkF(self.child)(v)
settings.atomic.add("typeHint", (numbers.Number, np.number, str, bool, bytes), "atomic types used for infering type of object for optimization passes")
[docs]def inferType(o):
"""Tries to infer the type of the input.
Example::
# returns tList(int)
inferType(range(10))
# returns tTensor(torch.float32, 2)
inferType(torch.randn(2, 3))
"""
if isinstance(o, range): return tList(int)
if isinstance(o, settings.atomic.typeHint): return type(o)
if isinstance(o, np.ndarray): return tNpArray(o.dtype, len(o.shape))
if hasTorch and isinstance(o, torch.Tensor): return tTensor(o.dtype, len(o.shape))
if isinstance(o, (list, tuple)):
arr = []; diff = False; a = None
for e in o:
t = inferType(e); arr.append(t)
if a is None: a = t
if a != t: diff = True
if diff:
if len(arr) < 100: return tCollection(*arr)
else: return tList(tLowest(*arr))
else: return tList(a)
if isinstance(o, dict): return tDict(inferType(list(o.keys())), inferType(list(o.values())))
return tAny()
def lowestChild(t):
if isinstance(t, tCollection): return tLowest(*t.children)
if isinstance(t, tListIterSet): return t.child
if isinstance(t, tArrayTypes):
if t.rank is None or t.rank == 1: return t.child
if t.rank is None: return t.__class__(t.child)
else: return t.__class__(t.child, t.rank - 1)
raise TypeHintException(f"Type {t} does not have a lowest child")
intTypes = {int, np.int8, np.int16, np.int32, np.int64, torch.int8, torch.int16, torch.int32, torch.int64}
floatTypes = {float, np.float16, np.float32, np.float64, torch.float16, torch.float32, torch.float64, torch.bfloat16}
try: floatTypes.add(np.float128) # some systems don't have float128
except: pass
intFloatTypes = {*intTypes, *floatTypes}
numericTypes = {*intTypes, *floatTypes, complex, numbers.Number}
def allSame(l): return all(t == l[0] for t in l)
[docs]def tLowest(*ts):
"""Grabs the lowest possible shared type of all the example types.
Example::
# returns tIter(float)
tLowest(tIter(float), tList(int))"""
# sort of like array types?
if all(isinstance(t, tArrayTypes) for t in ts):
lC = tLowest(*(lowestChild(t) for t in ts))
if all(isinstance(t, tTensor) for t in ts) or all(isinstance(t, tNpArray) for t in ts):
t = ts[0]; rank = t.rank if allSame([t.rank for t in ts]) else None
child = t.child if allSame([t.child for t in ts]) else None
return t.__class__(child, rank)
# sort of like list?
if all(isinstance(t, (tList, tIter, tSet, *tArrayTypes, tCollection)) for t in ts):
lC = tLowest(*(lowestChild(t) for t in ts))
if any(isinstance(t, (tIter, tCollection)) for t in ts): return tIter(lC)
return tList(lC)
# all numeric?
if all(t in numericTypes for t in ts):
if all(t in intTypes for t in ts): return int
if all(t in intFloatTypes for t in ts): return float
return numbers.Number
return tAny()
def _tCheck(inp, op):
a = inferType(inp); out = inp | op; b = inferType(out)
x = checkF(a)(inp); c1 = x is yieldT
y = checkF(b)(out); c2 = y is yieldT
z = checkF(op._typehint(a))(y); c3 = z is yieldT
if c1 or c2 or c3:
global tCheckData
tCheckData = [a, b, c1, c2, c3, inp, out]
raise TypeHintException(f"Type hints are wrong. Hints: inp type ({a}), out type ({b}). Checks: {c1}, {c2}, {c3}. Inp: {inp}, out: {out}")
return z
[docs]class tCheck(cli.BaseCli):
[docs] def __init__(self):
"""Tool similar to :class:`~k1lib.cli.trace.trace` to check whether
all type hint outputs of all clis are good or not. Example::
assert range(1, 3) | tCheck() | item() | op()*2 == 2
Mainly used in cli unit tests. Return type of statement will be :class:`tCheck`,
which might be undesirable, so you can pipe it to :data:`yieldT` like this::
# returns tCheck object
range(1, 3) | tCheck() | item() | op()*2
# returns number "2"
range(1, 3) | tCheck() | item() | op()*2 | yieldT"""
self.inp = None
[docs] def __ror__(self, v): self.inp = v; return self
def __or__(self, op):
if op is yieldT: return self.inp
self.inp = _tCheck(self.inp, op); return self
def __eq__(self, v): return self.inp == v
[docs]class tOpt(cli.BaseCli):
_passes = []; _serialPasses = []
_passStruct = {}; _serialStruct = {}
n = 10
[docs] def __init__(self):
"""Optimizes clis. Let's say you have something
like this::
range(1000) | toList() | head() | deref()
For whatever reason you forgot that you've dereferenced everything
in the middle, although you're only using 10 first elements, so the
code can't be lazy anymore. You can apply optimizations to it like this::
range(1000) | tOpt() | toList() | head() | deref()
This will effectively turn it into this::
range(1000) | tOpt() | head() | deref()
Normally, you'd use it in this form instead::
# returns the optimized cli
f = "file.txt" | tOpt() | cat() | shape(0) | tOpt
# then you can do this to pass it through as usual
"other file.txt" | f
Checkout the `llvm optimizer tutorial <llvm.html>` for a more in-depth explanation of this
More over, this combines nicely with :class:`~k1lib.cli.trace.trace` like this::
range(5) | tOpt() | trace() | apply(op()**2) | deref()"""
self.inp = None; self.clis = []
self._out = yieldT
@staticmethod
def _addBasePass(p, abstractness=1):
"""Adds an optimization pass that acts upon a single cli.
Example::
def o1(c:BaseCli, t:tBase):
if ...:
return aS(lambda x: x**2)
else:
return None
tOpt._addBasePass(o1, 6)
"""
tOpt._passes.append([p, round(max(min(abstractness, 2), 1))])
passStruct = {}
for a1 in range(2, 0, -1):
passStruct[a1] = []
for p, a2 in tOpt._passes:
if a2 == a1: passStruct[a1].append(p)
tOpt._passStruct = passStruct
[docs] @staticmethod
def addPass(p, klasses:List[cli.BaseCli]=[], abstractness=3):
"""Adds an optimization pass that acts upon multiple clis in series.
Example::
# cs: list of clis, ts: list of input type hints, 1 for each cli
def o1(cs:List[BaseCli], ts:List[tBase], metadata={}):
return [cs[1], cs[0]] # reorder the clis
tOpt.addPass(o1, [toList, head], 3)
Here, we're declaring an optimization pass ``o1``. You will be given a list of cli
objects, the cli's input type hints and some extra metadata. If you can optimize
it, then you should return a list of new clis, else you should return None
Also, ``abstractness`` has varying number of legal values:
- 1-5: generic optimizations
- 6-10: analysis passes. Passes must not return anything
Higher abstraction optimizations will be called first, and then lower abstraction
optimizations will be called later. So, the idea is, just like LLVM, you can do
some analysis which will compute metadata that you can use in your optimization
passes, which will return optimized clis if it can.
Within optimization passes, you can prioritize optimizations that look at the global
picture first, before breaking the code up into tiny fragments with more detailed
optimizations, at which point it's hard to look at the global picture.
:param p: the optimization pass
:param klasses: list of cli classes in series that will trigger the pass
:param abstractness: how abstract is this optimization"""
tOpt._serialPasses.append([p, tuple(klasses), round(max(min(abstractness, 15), 1))])
serialStruct = {}
for a1 in range(15, 0, -1):
serialStruct[a1] = defaultdict(lambda: [])
for p, klasses, a2 in tOpt._serialPasses:
if a2 == a1: serialStruct[a1][klasses].append(p)
tOpt._serialStruct = serialStruct
[docs] @staticmethod
def clearPasses():
"""Clears all passes"""
tOpt._passes = []; tOpt._serialPasses = []
tOpt._passStruct = {}; tOpt._serialStruct = {}
addSerialOpt()
@property
def out(self):
if self._out == yieldT:
if isinstance(self.inp, cli.BaseCli):
self.clis = [self.inp, *self.clis]; self.inp = None
# why wrap 2 times? We want passes to select klasses=[serial]
c = cli.serial(cli.serial(*self.clis)); t = inferType(self.inp)
# start optimization passes here
for i in range(tOpt.n):
atLeastOnce = False #print("-"*50)
for passes in tOpt._passStruct.values():
for p in passes:
repl = p(c, t)
if repl is not None: atLeastOnce = True; c = repl # optimized version
if not atLeastOnce: break
assert isinstance(c, cli.serial) and len(c.clis) == 1
self._optCli = c.clis[0]; self._out = self.inp | c
return self._out
@property
def optCli(self):
"""Grabs the optimized cli.
Example::
# returns optimized cli
(range(5) | tOpt() | apply(op()**2) | deref()).optCli
# you can also do it like this:
range(5) | tOpt() | apply(op()**2) | deref() | tOpt.optCli
# or even shorter like this:
range(5) | tOpt() | apply(op()**2) | deref() | tOpt
"""
self.out; return self._optCli
[docs] def __ror__(self, it): self.inp = it; return self
def __iter__(self): return iter(self.out)
def __or__(self, o):
if o is yieldT: return self.out
if o is tOpt.optCli or o is tOpt:
return self.optCli
self.clis.append(o); return self
def __repr__(self): return f"{self.out}"
def __eq__(self, v): return self.out == v
def __bool__(self): return self.out
class window(cli.BaseCli):
def __init__(self, n, newList=False):
self.n = n
def __ror__(self, it):
n = self.n; before = []; q = deque([], n)
it = iter(it)
for e in it:
q.append(e)
if len(q) == n:
yield before, tuple(q), it; before.append(q.popleft())
def grabTypes(cs, t):
ts = [t]
for c in cs: t = c._typehint(t); ts.append(t)
return ts
def grabKlasses(iKlasses): return [type(e) for e in iKlasses]
depth = 0; debug = False
def serialOpt(c, t, metadata=None):
"""Optimizes ``c``, which is supposed to be a :class:`~init.serial`
object, with the input type hint ``t``. If it can actually optimize it,
then it will return a new :class:`~init.serial` object, else it returns
None."""
if debug: global depth; depth += 1; print(f"serial depth: {depth}")
if metadata is None: metadata = {"route": []}
# returns None, or a new serial object
if not isinstance(c, cli.serial):
if debug: print(f"out depth, not serial: {depth}"); depth -= 1
return None
metadata["route"].append("serial")
cs = c.clis; ts = grabTypes(cs, t)
if debug: print(f"serialOpt: {[c.__class__.__name__ for c in cs]}, {ts}")
for windowSize in range(1, len(cs)+1):
for a, e, c in [cs, ts] | cli.transpose() | window(windowSize):
iKlasses, ths = e | cli.transpose()
klasses = tuple(type(e) for e in iKlasses)
#print(klasses)
for e in tOpt._serialStruct.values():
if klasses in e:
for p in e[klasses]:
res = p(iKlasses, ths, metadata)
#print(f"serial p: {p}, res: {0}, klasses: {klasses}")
if res is not None:
a = a | cli.toList(); c = c | cli.toList()
if debug: print(f"out depth new: {depth}"); depth -= 1
metadata["route"].pop()
return cli.serial(*(a | cli.transpose() | cli.item() if len(a) > 0 else []),
*res,
*(c | cli.transpose() | cli.item() if len(c) > 0 else []))
if debug: print(f"out depth none: {depth}"); depth -= 1
metadata["route"].pop()
def addSerialOpt():
tOpt._addBasePass(serialOpt, 5)
def inner(cs, ts, metadata):
res = serialOpt(cs[0], ts[0], metadata)
return None if res is None else [res]
tOpt.addPass(inner, [cli.serial], 15)
try: cli.optimizations.basics() # cyclic include, so mainly intended for regular use after first initialization
except: pass
tOpt.clearPasses();