mirror of
https://github.com/PiBrewing/craftbeerpi4.git
synced 2024-11-22 23:18:12 +01:00
391 lines
12 KiB
Python
391 lines
12 KiB
Python
|
# Copyright 2018 John Reese
|
||
|
# Licensed under the MIT license
|
||
|
|
||
|
"""
|
||
|
Core implementation of aiosqlite proxies
|
||
|
"""
|
||
|
|
||
|
import asyncio
|
||
|
import logging
|
||
|
import sqlite3
|
||
|
import sys
|
||
|
import warnings
|
||
|
from functools import partial
|
||
|
from pathlib import Path
|
||
|
from queue import Empty, Queue
|
||
|
from threading import Thread
|
||
|
from typing import (
|
||
|
Any,
|
||
|
AsyncIterator,
|
||
|
Callable,
|
||
|
Generator,
|
||
|
Iterable,
|
||
|
Optional,
|
||
|
Type,
|
||
|
Union,
|
||
|
)
|
||
|
from warnings import warn
|
||
|
|
||
|
from .context import contextmanager
|
||
|
from .cursor import Cursor
|
||
|
|
||
|
__all__ = ["connect", "Connection", "Cursor"]
|
||
|
|
||
|
LOG = logging.getLogger("aiosqlite")
|
||
|
|
||
|
|
||
|
def get_loop(future: asyncio.Future) -> asyncio.AbstractEventLoop:
|
||
|
if sys.version_info >= (3, 7):
|
||
|
return future.get_loop()
|
||
|
else:
|
||
|
return future._loop
|
||
|
|
||
|
|
||
|
class Connection(Thread):
|
||
|
def __init__(
|
||
|
self,
|
||
|
connector: Callable[[], sqlite3.Connection],
|
||
|
iter_chunk_size: int,
|
||
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
self._running = True
|
||
|
self._connection: Optional[sqlite3.Connection] = None
|
||
|
self._connector = connector
|
||
|
self._tx: Queue = Queue()
|
||
|
self._iter_chunk_size = iter_chunk_size
|
||
|
|
||
|
if loop is not None:
|
||
|
warn(
|
||
|
"aiosqlite.Connection no longer uses the `loop` parameter",
|
||
|
DeprecationWarning,
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def _conn(self) -> sqlite3.Connection:
|
||
|
if self._connection is None:
|
||
|
raise ValueError("no active connection")
|
||
|
|
||
|
return self._connection
|
||
|
|
||
|
def _execute_insert(
|
||
|
self, sql: str, parameters: Iterable[Any]
|
||
|
) -> Optional[sqlite3.Row]:
|
||
|
cursor = self._conn.execute(sql, parameters)
|
||
|
cursor.execute("SELECT last_insert_rowid()")
|
||
|
return cursor.fetchone()
|
||
|
|
||
|
def _execute_fetchall(
|
||
|
self, sql: str, parameters: Iterable[Any]
|
||
|
) -> Iterable[sqlite3.Row]:
|
||
|
cursor = self._conn.execute(sql, parameters)
|
||
|
return cursor.fetchall()
|
||
|
|
||
|
def run(self) -> None:
|
||
|
"""
|
||
|
Execute function calls on a separate thread.
|
||
|
|
||
|
:meta private:
|
||
|
"""
|
||
|
while self._running:
|
||
|
try:
|
||
|
future, function = self._tx.get(timeout=0.1)
|
||
|
except Empty:
|
||
|
continue
|
||
|
|
||
|
try:
|
||
|
LOG.debug("executing %s", function)
|
||
|
result = function()
|
||
|
LOG.debug("returning %s", result)
|
||
|
|
||
|
def set_result(fut, result):
|
||
|
if not fut.done():
|
||
|
fut.set_result(result)
|
||
|
|
||
|
get_loop(future).call_soon_threadsafe(set_result, future, result)
|
||
|
except BaseException as e:
|
||
|
LOG.info("returning exception %s", e)
|
||
|
|
||
|
def set_exception(fut, e):
|
||
|
if not fut.done():
|
||
|
fut.set_exception(e)
|
||
|
|
||
|
get_loop(future).call_soon_threadsafe(set_exception, future, e)
|
||
|
|
||
|
async def _execute(self, fn, *args, **kwargs):
|
||
|
"""Queue a function with the given arguments for execution."""
|
||
|
function = partial(fn, *args, **kwargs)
|
||
|
future = asyncio.get_event_loop().create_future()
|
||
|
|
||
|
self._tx.put_nowait((future, function))
|
||
|
|
||
|
return await future
|
||
|
|
||
|
async def _connect(self) -> "Connection":
|
||
|
"""Connect to the actual sqlite database."""
|
||
|
if self._connection is None:
|
||
|
try:
|
||
|
self._connection = await self._execute(self._connector)
|
||
|
except Exception:
|
||
|
self._running = False
|
||
|
self._connection = None
|
||
|
raise
|
||
|
|
||
|
return self
|
||
|
|
||
|
def __await__(self) -> Generator[Any, None, "Connection"]:
|
||
|
self.start()
|
||
|
return self._connect().__await__()
|
||
|
|
||
|
async def __aenter__(self) -> "Connection":
|
||
|
return await self
|
||
|
|
||
|
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||
|
await self.close()
|
||
|
|
||
|
@contextmanager
|
||
|
async def cursor(self) -> Cursor:
|
||
|
"""Create an aiosqlite cursor wrapping a sqlite3 cursor object."""
|
||
|
return Cursor(self, await self._execute(self._conn.cursor))
|
||
|
|
||
|
async def commit(self) -> None:
|
||
|
"""Commit the current transaction."""
|
||
|
await self._execute(self._conn.commit)
|
||
|
|
||
|
async def rollback(self) -> None:
|
||
|
"""Roll back the current transaction."""
|
||
|
await self._execute(self._conn.rollback)
|
||
|
|
||
|
async def close(self) -> None:
|
||
|
"""Complete queued queries/cursors and close the connection."""
|
||
|
try:
|
||
|
await self._execute(self._conn.close)
|
||
|
except Exception:
|
||
|
LOG.info("exception occurred while closing connection")
|
||
|
raise
|
||
|
finally:
|
||
|
self._running = False
|
||
|
self._connection = None
|
||
|
|
||
|
@contextmanager
|
||
|
async def execute(self, sql: str, parameters: Iterable[Any] = None) -> Cursor:
|
||
|
"""Helper to create a cursor and execute the given query."""
|
||
|
if parameters is None:
|
||
|
parameters = []
|
||
|
cursor = await self._execute(self._conn.execute, sql, parameters)
|
||
|
return Cursor(self, cursor)
|
||
|
|
||
|
@contextmanager
|
||
|
async def execute_insert(
|
||
|
self, sql: str, parameters: Iterable[Any] = None
|
||
|
) -> Optional[sqlite3.Row]:
|
||
|
"""Helper to insert and get the last_insert_rowid."""
|
||
|
if parameters is None:
|
||
|
parameters = []
|
||
|
return await self._execute(self._execute_insert, sql, parameters)
|
||
|
|
||
|
@contextmanager
|
||
|
async def execute_fetchall(
|
||
|
self, sql: str, parameters: Iterable[Any] = None
|
||
|
) -> Iterable[sqlite3.Row]:
|
||
|
"""Helper to execute a query and return all the data."""
|
||
|
if parameters is None:
|
||
|
parameters = []
|
||
|
return await self._execute(self._execute_fetchall, sql, parameters)
|
||
|
|
||
|
@contextmanager
|
||
|
async def executemany(
|
||
|
self, sql: str, parameters: Iterable[Iterable[Any]]
|
||
|
) -> Cursor:
|
||
|
"""Helper to create a cursor and execute the given multiquery."""
|
||
|
cursor = await self._execute(self._conn.executemany, sql, parameters)
|
||
|
return Cursor(self, cursor)
|
||
|
|
||
|
@contextmanager
|
||
|
async def executescript(self, sql_script: str) -> Cursor:
|
||
|
"""Helper to create a cursor and execute a user script."""
|
||
|
cursor = await self._execute(self._conn.executescript, sql_script)
|
||
|
return Cursor(self, cursor)
|
||
|
|
||
|
async def interrupt(self) -> None:
|
||
|
"""Interrupt pending queries."""
|
||
|
return self._conn.interrupt()
|
||
|
|
||
|
async def create_function(
|
||
|
self, name: str, num_params: int, func: Callable, deterministic: bool = False
|
||
|
) -> None:
|
||
|
"""
|
||
|
Create user-defined function that can be later used
|
||
|
within SQL statements. Must be run within the same thread
|
||
|
that query executions take place so instead of executing directly
|
||
|
against the connection, we defer this to `run` function.
|
||
|
|
||
|
In Python 3.8 and above, if *deterministic* is true, the created
|
||
|
function is marked as deterministic, which allows SQLite to perform
|
||
|
additional optimizations. This flag is supported by SQLite 3.8.3 or
|
||
|
higher, ``NotSupportedError`` will be raised if used with older
|
||
|
versions.
|
||
|
"""
|
||
|
if sys.version_info >= (3, 8):
|
||
|
await self._execute(
|
||
|
self._conn.create_function,
|
||
|
name,
|
||
|
num_params,
|
||
|
func,
|
||
|
deterministic=deterministic,
|
||
|
)
|
||
|
else:
|
||
|
if deterministic:
|
||
|
warnings.warn(
|
||
|
"Deterministic function support is only available on "
|
||
|
'Python 3.8+. Function "{}" will be registered as '
|
||
|
"non-deterministic as per SQLite defaults.".format(name)
|
||
|
)
|
||
|
|
||
|
await self._execute(self._conn.create_function, name, num_params, func)
|
||
|
|
||
|
@property
|
||
|
def in_transaction(self) -> bool:
|
||
|
return self._conn.in_transaction
|
||
|
|
||
|
@property
|
||
|
def isolation_level(self) -> str:
|
||
|
return self._conn.isolation_level
|
||
|
|
||
|
@isolation_level.setter
|
||
|
def isolation_level(self, value: str) -> None:
|
||
|
self._conn.isolation_level = value
|
||
|
|
||
|
@property
|
||
|
def row_factory(self) -> "Optional[Type]": # py3.5.2 compat (#24)
|
||
|
return self._conn.row_factory
|
||
|
|
||
|
@row_factory.setter
|
||
|
def row_factory(self, factory: "Optional[Type]") -> None: # py3.5.2 compat (#24)
|
||
|
self._conn.row_factory = factory
|
||
|
|
||
|
@property
|
||
|
def text_factory(self) -> Type:
|
||
|
return self._conn.text_factory
|
||
|
|
||
|
@text_factory.setter
|
||
|
def text_factory(self, factory: Type) -> None:
|
||
|
self._conn.text_factory = factory
|
||
|
|
||
|
@property
|
||
|
def total_changes(self) -> int:
|
||
|
return self._conn.total_changes
|
||
|
|
||
|
async def enable_load_extension(self, value: bool) -> None:
|
||
|
await self._execute(self._conn.enable_load_extension, value) # type: ignore
|
||
|
|
||
|
async def load_extension(self, path: str):
|
||
|
await self._execute(self._conn.load_extension, path) # type: ignore
|
||
|
|
||
|
async def set_progress_handler(
|
||
|
self, handler: Callable[[], Optional[int]], n: int
|
||
|
) -> None:
|
||
|
await self._execute(self._conn.set_progress_handler, handler, n)
|
||
|
|
||
|
async def set_trace_callback(self, handler: Callable) -> None:
|
||
|
await self._execute(self._conn.set_trace_callback, handler)
|
||
|
|
||
|
async def iterdump(self) -> AsyncIterator[str]:
|
||
|
"""
|
||
|
Return an async iterator to dump the database in SQL text format.
|
||
|
|
||
|
Example::
|
||
|
|
||
|
async for line in db.iterdump():
|
||
|
...
|
||
|
|
||
|
"""
|
||
|
dump_queue: Queue = Queue()
|
||
|
|
||
|
def dumper():
|
||
|
try:
|
||
|
for line in self._conn.iterdump():
|
||
|
dump_queue.put_nowait(line)
|
||
|
dump_queue.put_nowait(None)
|
||
|
|
||
|
except Exception:
|
||
|
LOG.exception("exception while dumping db")
|
||
|
dump_queue.put_nowait(None)
|
||
|
raise
|
||
|
|
||
|
fut = self._execute(dumper)
|
||
|
task = asyncio.ensure_future(fut)
|
||
|
|
||
|
while True:
|
||
|
try:
|
||
|
line: Optional[str] = dump_queue.get_nowait()
|
||
|
if line is None:
|
||
|
break
|
||
|
yield line
|
||
|
|
||
|
except Empty:
|
||
|
if task.done():
|
||
|
LOG.warning("iterdump completed unexpectedly")
|
||
|
break
|
||
|
|
||
|
await asyncio.sleep(0.01)
|
||
|
|
||
|
await task
|
||
|
|
||
|
async def backup(
|
||
|
self,
|
||
|
target: Union["Connection", sqlite3.Connection],
|
||
|
*,
|
||
|
pages: int = 0,
|
||
|
progress: Optional[Callable[[int, int, int], None]] = None,
|
||
|
name: str = "main",
|
||
|
sleep: float = 0.250
|
||
|
) -> None:
|
||
|
"""
|
||
|
Make a backup of the current database to the target database.
|
||
|
|
||
|
Takes either a standard sqlite3 or aiosqlite Connection object as the target.
|
||
|
"""
|
||
|
if sys.version_info < (3, 7):
|
||
|
raise RuntimeError("backup() method is only available on Python 3.7+")
|
||
|
|
||
|
if isinstance(target, Connection):
|
||
|
target = target._conn
|
||
|
|
||
|
await self._execute(
|
||
|
self._conn.backup,
|
||
|
target,
|
||
|
pages=pages,
|
||
|
progress=progress,
|
||
|
name=name,
|
||
|
sleep=sleep,
|
||
|
)
|
||
|
|
||
|
|
||
|
def connect(
|
||
|
database: Union[str, Path],
|
||
|
*,
|
||
|
iter_chunk_size=64,
|
||
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||
|
**kwargs: Any
|
||
|
) -> Connection:
|
||
|
"""Create and return a connection proxy to the sqlite database."""
|
||
|
|
||
|
if loop is not None:
|
||
|
warn(
|
||
|
"aiosqlite.connect() no longer uses the `loop` parameter",
|
||
|
DeprecationWarning,
|
||
|
)
|
||
|
|
||
|
def connector() -> sqlite3.Connection:
|
||
|
if isinstance(database, str):
|
||
|
loc = database
|
||
|
elif isinstance(database, bytes):
|
||
|
loc = database.decode("utf-8")
|
||
|
else:
|
||
|
loc = str(database)
|
||
|
|
||
|
return sqlite3.connect(loc, **kwargs)
|
||
|
|
||
|
return Connection(connector, iter_chunk_size)
|