Add persistent lib to Chapter 25

This commit is contained in:
Luciano Ramalho 2021-04-14 23:20:15 -03:00 committed by Leonardo Rochael Almeida
parent bbc664308a
commit 37d3e2b73c
4 changed files with 402 additions and 0 deletions

View File

@ -0,0 +1 @@
*.db

View File

@ -0,0 +1,151 @@
# SQLite3 does not support parameterized table and field names,
# for CREATE TABLE and PRAGMA so we must use Python string formatting.
# Applying `check_identifier` to parameters prevents SQL injection.
import sqlite3
from typing import NamedTuple
DEFAULT_DB_PATH = ':memory:'
CONNECTION = None
SQL_TYPES = {
int: 'INTEGER',
str: 'TEXT',
float: 'REAL',
bytes: 'BLOB',
}
class NoConnection(Exception):
"""Call connect() to open connection."""
class SchemaMismatch(ValueError):
"""The table's schema doesn't match the class."""
def __init__(self, table_name):
self.table_name = table_name
class NoSuchRecord(LookupError):
"""The given primary key does not exist."""
def __init__(self, pk):
self.pk = pk
class UnexpectedMultipleResults(Exception):
"""Query returned more than 1 row."""
class ColumnSchema(NamedTuple):
name: str
sql_type: str
def check_identifier(name):
if not name.isidentifier():
raise ValueError(f'{name!r} is not an identifier')
def connect(db_path=DEFAULT_DB_PATH):
global CONNECTION
CONNECTION = sqlite3.connect(db_path)
return CONNECTION
def get_connection():
if CONNECTION is None:
raise NoConnection()
return CONNECTION
def gen_columns_sql(fields):
for name, py_type in fields.items():
check_identifier(name)
try:
sql_type = SQL_TYPES[py_type]
except KeyError as e:
raise ValueError(f'type {py_type!r} is not supported') from e
yield ColumnSchema(name, sql_type)
def make_schema_sql(table_name, fields):
check_identifier(table_name)
pk = 'pk INTEGER PRIMARY KEY,'
spcs = ' ' * 4
columns = ',\n '.join(
f'{field_name} {sql_type}'
for field_name, sql_type in gen_columns_sql(fields)
)
return f'CREATE TABLE {table_name} (\n{spcs}{pk}\n{spcs}{columns}\n)'
def create_table(table_name, fields):
con = get_connection()
con.execute(make_schema_sql(table_name, fields))
def read_columns_sql(table_name):
con = get_connection()
check_identifier(table_name)
rows = con.execute(f'PRAGMA table_info({table_name!r})')
# row fields: cid name type notnull dflt_value pk
return [ColumnSchema(r[1], r[2]) for r in rows]
def valid_table(table_name, fields):
table_columns = read_columns_sql(table_name)
return set(gen_columns_sql(fields)) <= set(table_columns)
def ensure_table(table_name, fields):
table_columns = read_columns_sql(table_name)
if len(table_columns) == 0:
create_table(table_name, fields)
if not valid_table(table_name, fields):
raise SchemaMismatch(table_name)
def insert_record(table_name, fields):
con = get_connection()
check_identifier(table_name)
placeholders = ', '.join(['?'] * len(fields))
sql = f'INSERT INTO {table_name} VALUES (NULL, {placeholders})'
cursor = con.execute(sql, tuple(fields.values()))
pk = cursor.lastrowid
con.commit()
cursor.close()
return pk
def fetch_record(table_name, pk):
con = get_connection()
check_identifier(table_name)
sql = f'SELECT * FROM {table_name} WHERE pk = ? LIMIT 2'
result = list(con.execute(sql, (pk,)))
if len(result) == 0:
raise NoSuchRecord(pk)
elif len(result) == 1:
return result[0]
else:
raise UnexpectedMultipleResults()
def update_record(table_name, pk, fields):
con = get_connection()
check_identifier(table_name)
names = ', '.join(fields.keys())
placeholders = ', '.join(['?'] * len(fields))
values = tuple(fields.values()) + (pk,)
sql = f'UPDATE {table_name} SET ({names}) = ({placeholders}) WHERE pk = ?'
con.execute(sql, values)
con.commit()
return sql, values
def delete_record(table_name, pk):
con = get_connection()
check_identifier(table_name)
sql = f'DELETE FROM {table_name} WHERE pk = ?'
return con.execute(sql, (pk,))

View File

@ -0,0 +1,131 @@
from textwrap import dedent
import pytest
from dblib import gen_columns_sql, make_schema_sql, connect, read_columns_sql
from dblib import ColumnSchema, insert_record, fetch_record, update_record
from dblib import NoSuchRecord, delete_record, valid_table
@pytest.fixture
def create_movies_sql():
sql = '''
CREATE TABLE movies (
pk INTEGER PRIMARY KEY,
title TEXT,
revenue REAL
)
'''
return dedent(sql).strip()
@pytest.mark.parametrize(
'fields, expected',
[
(
dict(title=str, awards=int),
[('title', 'TEXT'), ('awards', 'INTEGER')],
),
(
dict(picture=bytes, score=float),
[('picture', 'BLOB'), ('score', 'REAL')],
),
],
)
def test_gen_columns_sql(fields, expected):
result = list(gen_columns_sql(fields))
assert result == expected
def test_make_schema_sql(create_movies_sql):
fields = dict(title=str, revenue=float)
result = make_schema_sql('movies', fields)
assert result == create_movies_sql
def test_read_columns_sql(create_movies_sql):
expected = [
ColumnSchema(name='pk', sql_type='INTEGER'),
ColumnSchema(name='title', sql_type='TEXT'),
ColumnSchema(name='revenue', sql_type='REAL'),
]
with connect() as con:
con.execute(create_movies_sql)
result = read_columns_sql('movies')
assert result == expected
def test_read_columns_sql_no_such_table(create_movies_sql):
with connect() as con:
con.execute(create_movies_sql)
result = read_columns_sql('no_such_table')
assert result == []
def test_insert_record(create_movies_sql):
fields = dict(title='Frozen', revenue=1_290_000_000)
with connect() as con:
con.execute(create_movies_sql)
for _ in range(3):
result = insert_record('movies', fields)
assert result == 3
def test_fetch_record(create_movies_sql):
fields = dict(title='Frozen', revenue=1_290_000_000)
with connect() as con:
con.execute(create_movies_sql)
pk = insert_record('movies', fields)
row = fetch_record('movies', pk)
assert row == (1, 'Frozen', 1_290_000_000.0)
def test_fetch_record_no_such_pk(create_movies_sql):
with connect() as con:
con.execute(create_movies_sql)
with pytest.raises(NoSuchRecord) as e:
fetch_record('movies', 42)
assert e.value.pk == 42
def test_update_record(create_movies_sql):
fields = dict(title='Frozen', revenue=1_290_000_000)
with connect() as con:
con.execute(create_movies_sql)
pk = insert_record('movies', fields)
fields['revenue'] = 1_299_999_999
sql, values = update_record('movies', pk, fields)
row = fetch_record('movies', pk)
assert sql == 'UPDATE movies SET (title, revenue) = (?, ?) WHERE pk = ?'
assert values == ('Frozen', 1_299_999_999, 1)
assert row == (1, 'Frozen', 1_299_999_999.0)
def test_delete_record(create_movies_sql):
fields = dict(title='Frozen', revenue=1_290_000_000)
with connect() as con:
con.execute(create_movies_sql)
pk = insert_record('movies', fields)
delete_record('movies', pk)
with pytest.raises(NoSuchRecord) as e:
fetch_record('movies', pk)
assert e.value.pk == pk
def test_persistent_valid_table(create_movies_sql):
fields = dict(title=str, revenue=float)
with connect() as con:
con.execute(create_movies_sql)
con.commit()
assert valid_table('movies', fields)
def test_persistent_valid_table_false(create_movies_sql):
# year field not in movies_sql
fields = dict(title=str, revenue=float, year=int)
with connect() as con:
con.execute(create_movies_sql)
con.commit()
assert not valid_table('movies', fields)

View File

@ -0,0 +1,119 @@
"""
A ``Persistent`` class definition::
>>> class Movie(Persistent):
... title: str
... year: int
... boxmega: float
Implemented behavior::
>>> Movie._connect() # doctest: +ELLIPSIS
<sqlite3.Connection object at 0x...>
>>> movie = Movie('The Godfather', 1972, 137)
>>> movie.title
'The Godfather'
>>> movie.boxmega
137.0
Instances always have a ``.pk`` attribute, but it is ``None`` until the
object is saved::
>>> movie.pk is None
True
>>> movie._persist()
>>> movie.pk
1
Delete the in-memory ``movie``, and fetch the record from the database,
using ``Movie[pk]``item access on the class itself::
>>> del movie
>>> film = Movie[1]
>>> film
Movie('The Godfather', 1972, 137.0, pk=1)
By default, the table name is the class name lowercased, with an appended
"s" for plural::
>>> Movie._TABLE_NAME
'movies'
If needed, a custom table name can be given as a keyword argument in the
class declaration::
>>> class Aircraft(Persistent, table='aircraft'):
... registration: str
... model: str
...
>>> Aircraft._TABLE_NAME
'aircraft'
"""
from typing import get_type_hints
import dblib as db
class Field:
def __init__(self, name, py_type):
self.name = name
self.type = py_type
def __set__(self, instance, value):
try:
value = self.type(value)
except TypeError as e:
msg = f'{value!r} is not compatible with {self.name}:{self.type}.'
raise TypeError(msg) from e
instance.__dict__[self.name] = value
class Persistent:
def __init_subclass__(
cls, *, db_path=db.DEFAULT_DB_PATH, table='', **kwargs
):
super().__init_subclass__(**kwargs)
cls._TABLE_NAME = table if table else cls.__name__.lower() + 's'
cls._TABLE_READY = False
for name, py_type in get_type_hints(cls).items():
setattr(cls, name, Field(name, py_type))
@staticmethod
def _connect(db_path=db.DEFAULT_DB_PATH):
return db.connect(db_path)
@classmethod
def _ensure_table(cls):
if not cls._TABLE_READY:
db.ensure_table(cls._TABLE_NAME, get_type_hints(cls))
cls._TABLE_READY = True
return cls._TABLE_NAME
def _fields(self):
return {
name: getattr(self, name)
for name, attr in self.__class__.__dict__.items()
if isinstance(attr, Field)
}
def __init__(self, *args, pk=None):
for name, arg in zip(self._fields(), args):
setattr(self, name, arg)
self.pk = pk
def __class_getitem__(cls, pk):
return cls(*db.fetch_record(cls._TABLE_NAME, pk)[1:], pk=pk)
def __repr__(self):
args = ', '.join(repr(value) for value in self._fields().values())
pk = '' if self.pk is None else f', pk={self.pk}'
return f'{self.__class__.__name__}({args}{pk})'
def _persist(self):
table = self.__class__._ensure_table()
if self.pk is None:
self.pk = db.insert_record(table, self._fields())
else:
db.update_record(table, self.pk, self._fields())