Merge pull request #8 from leorochael/25-class-metaprog-persistent
Suggested improvements by @leorochael
This commit is contained in:
commit
3a75f5ebf6
1
25-class-metaprog/persistent/.gitignore
vendored
Normal file
1
25-class-metaprog/persistent/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
*.db
|
151
25-class-metaprog/persistent/dblib.py
Normal file
151
25-class-metaprog/persistent/dblib.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
# SQLite3 does not support parameterized table and field names,
|
||||||
|
# for CREATE TABLE and PRAGMA so we must use Python string formatting.
|
||||||
|
# Applying `check_identifier` to parameters prevents SQL injection.
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
DEFAULT_DB_PATH = ':memory:'
|
||||||
|
CONNECTION = None
|
||||||
|
|
||||||
|
SQL_TYPES = {
|
||||||
|
int: 'INTEGER',
|
||||||
|
str: 'TEXT',
|
||||||
|
float: 'REAL',
|
||||||
|
bytes: 'BLOB',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NoConnection(Exception):
|
||||||
|
"""Call connect() to open connection."""
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaMismatch(ValueError):
|
||||||
|
"""The table schema doesn't match the class."""
|
||||||
|
|
||||||
|
def __init__(self, table_name):
|
||||||
|
self.table_name = table_name
|
||||||
|
|
||||||
|
|
||||||
|
class NoSuchRecord(LookupError):
|
||||||
|
"""The given primary key does not exist."""
|
||||||
|
|
||||||
|
def __init__(self, pk):
|
||||||
|
self.pk = pk
|
||||||
|
|
||||||
|
|
||||||
|
class UnexpectedMultipleResults(Exception):
|
||||||
|
"""Query returned more than 1 row."""
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnSchema(NamedTuple):
|
||||||
|
name: str
|
||||||
|
sql_type: str
|
||||||
|
|
||||||
|
|
||||||
|
def check_identifier(name):
|
||||||
|
if not name.isidentifier():
|
||||||
|
raise ValueError(f'{name!r} is not an identifier')
|
||||||
|
|
||||||
|
|
||||||
|
def connect(db_path=DEFAULT_DB_PATH):
|
||||||
|
global CONNECTION
|
||||||
|
CONNECTION = sqlite3.connect(db_path)
|
||||||
|
return CONNECTION
|
||||||
|
|
||||||
|
|
||||||
|
def get_connection():
|
||||||
|
if CONNECTION is None:
|
||||||
|
raise NoConnection()
|
||||||
|
return CONNECTION
|
||||||
|
|
||||||
|
|
||||||
|
def gen_columns_sql(fields):
|
||||||
|
for name, py_type in fields.items():
|
||||||
|
check_identifier(name)
|
||||||
|
try:
|
||||||
|
sql_type = SQL_TYPES[py_type]
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(f'type {py_type!r} is not supported') from e
|
||||||
|
yield ColumnSchema(name, sql_type)
|
||||||
|
|
||||||
|
|
||||||
|
def make_schema_sql(table_name, fields):
|
||||||
|
check_identifier(table_name)
|
||||||
|
pk = 'pk INTEGER PRIMARY KEY,'
|
||||||
|
spcs = ' ' * 4
|
||||||
|
columns = ',\n '.join(
|
||||||
|
f'{field_name} {sql_type}'
|
||||||
|
for field_name, sql_type in gen_columns_sql(fields)
|
||||||
|
)
|
||||||
|
return f'CREATE TABLE {table_name} (\n{spcs}{pk}\n{spcs}{columns}\n)'
|
||||||
|
|
||||||
|
|
||||||
|
def create_table(table_name, fields):
|
||||||
|
con = get_connection()
|
||||||
|
con.execute(make_schema_sql(table_name, fields))
|
||||||
|
|
||||||
|
|
||||||
|
def read_columns_sql(table_name):
|
||||||
|
con = get_connection()
|
||||||
|
check_identifier(table_name)
|
||||||
|
rows = con.execute(f'PRAGMA table_info({table_name!r})')
|
||||||
|
# row fields: cid name type notnull dflt_value pk
|
||||||
|
return [ColumnSchema(r[1], r[2]) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def valid_table(table_name, fields):
|
||||||
|
table_columns = read_columns_sql(table_name)
|
||||||
|
return set(gen_columns_sql(fields)) <= set(table_columns)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_table(table_name, fields):
|
||||||
|
table_columns = read_columns_sql(table_name)
|
||||||
|
if len(table_columns) == 0:
|
||||||
|
create_table(table_name, fields)
|
||||||
|
if not valid_table(table_name, fields):
|
||||||
|
raise SchemaMismatch(table_name)
|
||||||
|
|
||||||
|
|
||||||
|
def insert_record(table_name, fields):
|
||||||
|
con = get_connection()
|
||||||
|
check_identifier(table_name)
|
||||||
|
placeholders = ', '.join(['?'] * len(fields))
|
||||||
|
sql = f'INSERT INTO {table_name} VALUES (NULL, {placeholders})'
|
||||||
|
cursor = con.execute(sql, tuple(fields.values()))
|
||||||
|
pk = cursor.lastrowid
|
||||||
|
con.commit()
|
||||||
|
cursor.close()
|
||||||
|
return pk
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_record(table_name, pk):
|
||||||
|
con = get_connection()
|
||||||
|
check_identifier(table_name)
|
||||||
|
sql = f'SELECT * FROM {table_name} WHERE pk = ? LIMIT 2'
|
||||||
|
result = list(con.execute(sql, (pk,)))
|
||||||
|
if len(result) == 0:
|
||||||
|
raise NoSuchRecord(pk)
|
||||||
|
elif len(result) == 1:
|
||||||
|
return result[0]
|
||||||
|
else:
|
||||||
|
raise UnexpectedMultipleResults()
|
||||||
|
|
||||||
|
|
||||||
|
def update_record(table_name, pk, fields):
|
||||||
|
check_identifier(table_name)
|
||||||
|
con = get_connection()
|
||||||
|
names = ', '.join(fields.keys())
|
||||||
|
placeholders = ', '.join(['?'] * len(fields))
|
||||||
|
values = tuple(fields.values()) + (pk,)
|
||||||
|
sql = f'UPDATE {table_name} SET ({names}) = ({placeholders}) WHERE pk = ?'
|
||||||
|
con.execute(sql, values)
|
||||||
|
con.commit()
|
||||||
|
return sql, values
|
||||||
|
|
||||||
|
|
||||||
|
def delete_record(table_name, pk):
|
||||||
|
con = get_connection()
|
||||||
|
check_identifier(table_name)
|
||||||
|
sql = f'DELETE FROM {table_name} WHERE pk = ?'
|
||||||
|
return con.execute(sql, (pk,))
|
131
25-class-metaprog/persistent/dblib_test.py
Normal file
131
25-class-metaprog/persistent/dblib_test.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
from textwrap import dedent
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from dblib import gen_columns_sql, make_schema_sql, connect, read_columns_sql
|
||||||
|
from dblib import ColumnSchema, insert_record, fetch_record, update_record
|
||||||
|
from dblib import NoSuchRecord, delete_record, valid_table
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_movies_sql():
|
||||||
|
sql = '''
|
||||||
|
CREATE TABLE movies (
|
||||||
|
pk INTEGER PRIMARY KEY,
|
||||||
|
title TEXT,
|
||||||
|
revenue REAL
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
return dedent(sql).strip()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'fields, expected',
|
||||||
|
[
|
||||||
|
(
|
||||||
|
dict(title=str, awards=int),
|
||||||
|
[('title', 'TEXT'), ('awards', 'INTEGER')],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
dict(picture=bytes, score=float),
|
||||||
|
[('picture', 'BLOB'), ('score', 'REAL')],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_gen_columns_sql(fields, expected):
|
||||||
|
result = list(gen_columns_sql(fields))
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_schema_sql(create_movies_sql):
|
||||||
|
fields = dict(title=str, revenue=float)
|
||||||
|
result = make_schema_sql('movies', fields)
|
||||||
|
assert result == create_movies_sql
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_columns_sql(create_movies_sql):
|
||||||
|
expected = [
|
||||||
|
ColumnSchema(name='pk', sql_type='INTEGER'),
|
||||||
|
ColumnSchema(name='title', sql_type='TEXT'),
|
||||||
|
ColumnSchema(name='revenue', sql_type='REAL'),
|
||||||
|
]
|
||||||
|
with connect() as con:
|
||||||
|
con.execute(create_movies_sql)
|
||||||
|
result = read_columns_sql('movies')
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_columns_sql_no_such_table(create_movies_sql):
|
||||||
|
with connect() as con:
|
||||||
|
con.execute(create_movies_sql)
|
||||||
|
result = read_columns_sql('no_such_table')
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_record(create_movies_sql):
|
||||||
|
fields = dict(title='Frozen', revenue=1_290_000_000)
|
||||||
|
with connect() as con:
|
||||||
|
con.execute(create_movies_sql)
|
||||||
|
for _ in range(3):
|
||||||
|
result = insert_record('movies', fields)
|
||||||
|
assert result == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_record(create_movies_sql):
|
||||||
|
fields = dict(title='Frozen', revenue=1_290_000_000)
|
||||||
|
with connect() as con:
|
||||||
|
con.execute(create_movies_sql)
|
||||||
|
pk = insert_record('movies', fields)
|
||||||
|
row = fetch_record('movies', pk)
|
||||||
|
assert row == (1, 'Frozen', 1_290_000_000.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_record_no_such_pk(create_movies_sql):
|
||||||
|
with connect() as con:
|
||||||
|
con.execute(create_movies_sql)
|
||||||
|
with pytest.raises(NoSuchRecord) as e:
|
||||||
|
fetch_record('movies', 42)
|
||||||
|
assert e.value.pk == 42
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_record(create_movies_sql):
|
||||||
|
fields = dict(title='Frozen', revenue=1_290_000_000)
|
||||||
|
with connect() as con:
|
||||||
|
con.execute(create_movies_sql)
|
||||||
|
pk = insert_record('movies', fields)
|
||||||
|
fields['revenue'] = 1_299_999_999
|
||||||
|
sql, values = update_record('movies', pk, fields)
|
||||||
|
row = fetch_record('movies', pk)
|
||||||
|
assert sql == 'UPDATE movies SET (title, revenue) = (?, ?) WHERE pk = ?'
|
||||||
|
assert values == ('Frozen', 1_299_999_999, 1)
|
||||||
|
assert row == (1, 'Frozen', 1_299_999_999.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_record(create_movies_sql):
|
||||||
|
fields = dict(title='Frozen', revenue=1_290_000_000)
|
||||||
|
with connect() as con:
|
||||||
|
con.execute(create_movies_sql)
|
||||||
|
pk = insert_record('movies', fields)
|
||||||
|
delete_record('movies', pk)
|
||||||
|
with pytest.raises(NoSuchRecord) as e:
|
||||||
|
fetch_record('movies', pk)
|
||||||
|
assert e.value.pk == pk
|
||||||
|
|
||||||
|
|
||||||
|
def test_persistent_valid_table(create_movies_sql):
|
||||||
|
fields = dict(title=str, revenue=float)
|
||||||
|
|
||||||
|
with connect() as con:
|
||||||
|
con.execute(create_movies_sql)
|
||||||
|
con.commit()
|
||||||
|
assert valid_table('movies', fields)
|
||||||
|
|
||||||
|
|
||||||
|
def test_persistent_valid_table_false(create_movies_sql):
|
||||||
|
# year field not in movies_sql
|
||||||
|
fields = dict(title=str, revenue=float, year=int)
|
||||||
|
|
||||||
|
with connect() as con:
|
||||||
|
con.execute(create_movies_sql)
|
||||||
|
con.commit()
|
||||||
|
assert not valid_table('movies', fields)
|
119
25-class-metaprog/persistent/persistlib.py
Normal file
119
25-class-metaprog/persistent/persistlib.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
"""
|
||||||
|
A ``Persistent`` class definition::
|
||||||
|
|
||||||
|
>>> class Movie(Persistent):
|
||||||
|
... title: str
|
||||||
|
... year: int
|
||||||
|
... boxmega: float
|
||||||
|
|
||||||
|
Implemented behavior::
|
||||||
|
|
||||||
|
>>> Movie._connect() # doctest: +ELLIPSIS
|
||||||
|
<sqlite3.Connection object at 0x...>
|
||||||
|
>>> movie = Movie('The Godfather', 1972, 137)
|
||||||
|
>>> movie.title
|
||||||
|
'The Godfather'
|
||||||
|
>>> movie.boxmega
|
||||||
|
137.0
|
||||||
|
|
||||||
|
Instances always have a ``.pk`` attribute, but it is ``None`` until the
|
||||||
|
object is saved::
|
||||||
|
|
||||||
|
>>> movie.pk is None
|
||||||
|
True
|
||||||
|
>>> movie._persist()
|
||||||
|
>>> movie.pk
|
||||||
|
1
|
||||||
|
|
||||||
|
Delete the in-memory ``movie``, and fetch the record from the database,
|
||||||
|
using ``Movie[pk]``—item access on the class itself::
|
||||||
|
|
||||||
|
>>> del movie
|
||||||
|
>>> film = Movie[1]
|
||||||
|
>>> film
|
||||||
|
Movie('The Godfather', 1972, 137.0, pk=1)
|
||||||
|
|
||||||
|
By default, the table name is the class name lowercased, with an appended
|
||||||
|
"s" for plural::
|
||||||
|
|
||||||
|
>>> Movie._TABLE_NAME
|
||||||
|
'movies'
|
||||||
|
|
||||||
|
If needed, a custom table name can be given as a keyword argument in the
|
||||||
|
class declaration::
|
||||||
|
|
||||||
|
>>> class Aircraft(Persistent, table='aircraft'):
|
||||||
|
... registration: str
|
||||||
|
... model: str
|
||||||
|
...
|
||||||
|
>>> Aircraft._TABLE_NAME
|
||||||
|
'aircraft'
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import get_type_hints
|
||||||
|
|
||||||
|
import dblib as db
|
||||||
|
|
||||||
|
|
||||||
|
class Field:
|
||||||
|
def __init__(self, name, py_type):
|
||||||
|
self.name = name
|
||||||
|
self.type = py_type
|
||||||
|
|
||||||
|
def __set__(self, instance, value):
|
||||||
|
try:
|
||||||
|
value = self.type(value)
|
||||||
|
except TypeError as e:
|
||||||
|
msg = f'{value!r} is not compatible with {self.name}:{self.type}.'
|
||||||
|
raise TypeError(msg) from e
|
||||||
|
instance.__dict__[self.name] = value
|
||||||
|
|
||||||
|
|
||||||
|
class Persistent:
|
||||||
|
def __init_subclass__(
|
||||||
|
cls, *, db_path=db.DEFAULT_DB_PATH, table='', **kwargs
|
||||||
|
):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
cls._TABLE_NAME = table if table else cls.__name__.lower() + 's'
|
||||||
|
cls._TABLE_READY = False
|
||||||
|
for name, py_type in get_type_hints(cls).items():
|
||||||
|
setattr(cls, name, Field(name, py_type))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _connect(db_path=db.DEFAULT_DB_PATH):
|
||||||
|
return db.connect(db_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _ensure_table(cls):
|
||||||
|
if not cls._TABLE_READY:
|
||||||
|
db.ensure_table(cls._TABLE_NAME, get_type_hints(cls))
|
||||||
|
cls._TABLE_READY = True
|
||||||
|
return cls._TABLE_NAME
|
||||||
|
|
||||||
|
def _fields(self):
|
||||||
|
return {
|
||||||
|
name: getattr(self, name)
|
||||||
|
for name, attr in self.__class__.__dict__.items()
|
||||||
|
if isinstance(attr, Field)
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *args, pk=None):
|
||||||
|
for name, arg in zip(self._fields(), args):
|
||||||
|
setattr(self, name, arg)
|
||||||
|
self.pk = pk
|
||||||
|
|
||||||
|
def __class_getitem__(cls, pk):
|
||||||
|
return cls(*db.fetch_record(cls._TABLE_NAME, pk)[1:], pk=pk)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
args = ', '.join(repr(value) for value in self._fields().values())
|
||||||
|
pk = '' if self.pk is None else f', pk={self.pk}'
|
||||||
|
return f'{self.__class__.__name__}({args}{pk})'
|
||||||
|
|
||||||
|
def _persist(self):
|
||||||
|
table = self.__class__._ensure_table()
|
||||||
|
if self.pk is None:
|
||||||
|
self.pk = db.insert_record(table, self._fields())
|
||||||
|
else:
|
||||||
|
db.update_record(table, self.pk, self._fields())
|
Loading…
x
Reference in New Issue
Block a user