diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f717486 --- /dev/null +++ b/.gitignore @@ -0,0 +1,125 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/prestodb/client.py b/prestodb/client.py index 7939f7b..7899270 100644 --- a/prestodb/client.py +++ b/prestodb/client.py @@ -423,43 +423,15 @@ def process(self, http_response): ) -class PrestoResult(object): +class PrestoQuery(object): """ - Represent the result of a Presto query as an iterator on rows. + Represent the execution of a SQL statement by Presto. - This class implements the iterator protocol as a generator type + Results of the query can be extracted by iterating over this class, since it + implements the iterator protocol as a generator type https://docs.python.org/3/library/stdtypes.html#generator-types """ - def __init__(self, query, rows=None): - self._query = query - self._rows = rows or [] - self._rownumber = 0 - - @property - def rownumber(self): - # type: () -> int - return self._rownumber - - def __iter__(self): - # Initial fetch from the first POST request - for row in self._rows: - self._rownumber += 1 - yield row - self._rows = None - - # Subsequent fetches from GET requests until next_uri is empty. - while not self._query.is_finished(): - rows = self._query.fetch() - for row in rows: - self._rownumber += 1 - logger.debug("row {}".format(row)) - yield row - - -class PrestoQuery(object): - """Represent the execution of a SQL statement by Presto.""" - def __init__( self, request, # type: PrestoRequest @@ -476,7 +448,9 @@ def __init__( self._cancelled = False self._request = request self._sql = sql - self._result = PrestoResult(self) + + self._rows = [] + self._rownumber = 0 @property def columns(self): @@ -490,10 +464,6 @@ def stats(self): def warnings(self): return self._warnings - @property - def result(self): - return self._result - def execute(self): # type: () -> PrestoResult """Initiate a Presto query by sending the SQL statement @@ -514,10 +484,10 @@ def execute(self): self._warnings = getattr(status, "warnings", []) if status.next_uri is None: self._finished = True - self._result = PrestoResult(self, status.rows) - return self._result + self._rows = status.rows + return self - def fetch(self): + def _fetch(self): # type: () -> List[List[Any]] """Continue fetching data for the current query_id""" response = self._request.get(self._request.next_uri) @@ -530,6 +500,14 @@ def fetch(self): self._finished = True return status.rows + def poll(self): + # type: () -> Dict + """Retrieve the current status of a presto query, caching any results.""" + if not self.query_id or self._finished: + return self.stats + self._rows.extend(self._fetch()) + return self.stats + def cancel(self): # type: () -> None """Cancel the current query""" @@ -549,3 +527,12 @@ def cancel(self): def is_finished(self): # type: () -> bool return self._finished + + def __iter__(self): + while self._rows or not self.is_finished(): + for row in self._rows: + self._rownumber += 1 + logger.debug('row {}'.format(row)) + yield row + self._rows = [] + self.poll() diff --git a/prestodb/dbapi.py b/prestodb/dbapi.py index 119c1b5..613fbda 100644 --- a/prestodb/dbapi.py +++ b/prestodb/dbapi.py @@ -225,6 +225,9 @@ def warnings(self): return self._query.warnings return None + def poll(self): + return self._query.poll() + def setinputsizes(self, sizes): raise prestodb.exceptions.NotSupportedError @@ -232,10 +235,8 @@ def setoutputsize(self, size, column): raise prestodb.exceptions.NotSupportedError def execute(self, operation, params=None): - self._query = prestodb.client.PrestoQuery(self._request, sql=operation) - result = self._query.execute() - self._iterator = iter(result) - return result + self._query = prestodb.client.PrestoQuery(self._request, sql=operation).execute() + return self._query def executemany(self, operation, seq_of_params): raise prestodb.exceptions.NotSupportedError @@ -250,13 +251,10 @@ def fetchone(self): An Error (or subclass) exception is raised if the previous call to .execute*() did not produce any result set or no call was issued yet. """ - - try: - return next(self._iterator) - except StopIteration: + result = self.fetchmany(1) + if len(result) != 1: return None - except prestodb.exceptions.HttpError as err: - raise prestodb.exceptions.OperationalError(str(err)) + return result[0] def fetchmany(self, size=None): # type: (Optional[int]) -> List[List[Any]] @@ -284,16 +282,20 @@ def fetchmany(self, size=None): size = self.arraysize result = [] + iterator = iter(self._query) + for _ in range(size): - row = self.fetchone() - if row is None: + try: + result.append(next(iterator)) + except StopIteration: break - result.append(row) + except prestodb.exceptions.HttpError as err: + raise prestodb.exceptions.OperationalError(str(err)) return result def genall(self): - return self._query.result + return self._query def fetchall(self): # type: () -> List[List[Any]]