Spaces:
Sleeping
Sleeping
| import ast | |
| import base64 | |
| import duckdb | |
| import json | |
| import re | |
| import textwrap | |
| from ulid import ULID | |
| HISTORY_FILE = "history.json" | |
| MAX_ROWS = 10000 | |
| class SQLError(Exception): | |
| pass | |
| class NotFoundError(Exception): | |
| pass | |
| class Q(str): | |
| UNSAFE = ["CREATE", "DELETE", "DROP", "INSERT", "UPDATE"] | |
| rows=None | |
| def __new__(cls, template: str, **kwargs): | |
| """Create a new Q-string.""" | |
| _template = textwrap.dedent(template).strip() | |
| try: | |
| instance = str.__new__(cls, _template.format(**kwargs)) | |
| except KeyError: | |
| instance = str.__new__(cls, _template) | |
| instance.id = str(ULID()) | |
| instance.alias = kwargs.pop("alias") if kwargs.get("alias") else None | |
| instance.template = _template | |
| instance.kwargs = kwargs | |
| instance.definitions = "\n".join([f"{k} = {repr(v)}" for k, v in kwargs.items()]) | |
| for attr in ("rows", "cols", "source_id", "start", "end"): | |
| setattr(instance, attr, None) | |
| return instance | |
| def __repr__(self): | |
| """Neat repr for inspecting Q objects.""" | |
| strings = [] | |
| for k, v in self.__dict__.items(): | |
| value_repr = "\n" + textwrap.indent(v, " ") if "\n" in str(v) else v | |
| strings.append(f"{k}: {value_repr}") | |
| return "\n".join(strings) | |
| def run(self, sql_engine=None, save=False, _raise=False): | |
| self.start = ULID() | |
| try: | |
| if sql_engine is None: | |
| res = self.run_duckdb() | |
| else: | |
| res = self.run_sql(sql_engine) | |
| self.rows, self.cols = res.shape | |
| return res | |
| except Exception as e: | |
| if _raise: | |
| raise e | |
| return str(e) | |
| finally: | |
| self.end = ULID() | |
| if save: | |
| self.save() | |
| def run_duckdb(self): | |
| if MAX_ROWS: | |
| return duckdb.sql(f"WITH x AS ({self}) SELECT * FROM x LIMIT {MAX_ROWS}") | |
| else: | |
| return duckdb.sql(self) | |
| def df(self, sql_engine=None, save=False, _raise=False): | |
| res = self.run(sql_engine=sql_engine, save=save, _raise=_raise) | |
| if not getattr(self, "rows", None): | |
| return | |
| else: | |
| result_df = res.df() | |
| result_df.q = self | |
| return result_df | |
| def save(self, file=HISTORY_FILE): | |
| with open(file, "a") as f: | |
| f.write(self.json) | |
| f.write("\n") | |
| def json(self): | |
| serialized = {"id": self.id, "q": self} | |
| serialized.update(self.__dict__) | |
| return json.dumps(serialized, default=lambda x: x.datetime.strftime("%F %T.%f")[:-3]) | |
| def is_safe(self): | |
| return not any(cmd in self.template.upper() for cmd in self.UNSAFE) | |
| def from_dict(cls, query_dict: dict): | |
| q = query_dict.pop("q") | |
| return cls(q, **query_dict) | |
| def from_template_and_definitions(cls, template: str, definitions: str, alias: str|None = None): | |
| query_dict = {"q": template, "alias": alias} | |
| query_dict.update(parse_definitions(definitions)) | |
| instance = Q.from_dict(query_dict) | |
| instance.definitions = definitions | |
| return instance | |
| def from_history(cls, query_id=None, alias=None): | |
| search_query = Q(f""" | |
| SELECT id, template, kwargs | |
| FROM '{HISTORY_FILE}' | |
| WHERE id='{query_id}' OR alias='{alias}' | |
| LIMIT 1 | |
| """) | |
| query = search_query.run() | |
| if search_query.rows == 1: | |
| source_id, template, kwargs = query.fetchall()[0] | |
| kwargs = {k: v for k, v in kwargs.items() if v is not None} | |
| instance = cls(template, **kwargs) | |
| instance.source_id = source_id | |
| return instance | |
| elif search_query.rows == 0: | |
| raise NotFoundError(f"id '{query_id}' / alias '{alias}' not found") | |
| else: | |
| raise SQLError(query) | |
| # @property | |
| # def definitions(self): | |
| # return "\n".join([""]+[f"{k} = {v}" for k, v in self.kwargs.items()]) | |
| def base64(self): | |
| return base64.b64encode(self.encode()).decode() | |
| def from_base64(cls, b64): | |
| """Initializing from base64-encoded URL paths.""" | |
| return cls(base64.b64decode(b64).decode()) | |
| def parse_definitions(definitions) -> dict: | |
| """Parse a string literal of "key=value" pairs, one per line, into kwargs.""" | |
| kwargs = {} | |
| lines = definitions.split("\n") | |
| for _line in lines: | |
| line = re.sub("\s+", "", _line) | |
| if line == "" or line.startswith("#"): | |
| continue | |
| if "=" in line: | |
| key, value = line.split("=", maxsplit=1) | |
| kwargs[key] = ast.literal_eval(value) | |
| return kwargs | |
| EX1 = Q.from_template_and_definitions( | |
| template="SELECT {x} AS {colname}", | |
| definitions="\n".join([ | |
| "# Define variables: one '=' per line", | |
| "x=42", | |
| "colname='answer'", | |
| ]), | |
| alias="example1", | |
| ) | |
| EX2 = Q( | |
| """ | |
| SELECT | |
| Symbol, | |
| Number, | |
| Mass, | |
| Abundance | |
| FROM '{url}' | |
| """, | |
| url="https://raw.githubusercontent.com/ekwan/cctk/master/cctk/data/isotopes.csv", | |
| alias="example2", | |
| ) | |
| EX3 = Q( | |
| """ | |
| SELECT * | |
| FROM 'history.json' | |
| ORDER BY id DESC | |
| """, | |
| alias="example3", | |
| ) | |
| EX4 = Q("SELECT nothing", alias="bad_example") |