refactoring after reviewers feedback + type hints

This commit is contained in:
Luciano Ramalho 2021-04-15 17:15:53 -03:00
parent 3a75f5ebf6
commit 4ff0a59608
4 changed files with 140 additions and 72 deletions

View File

@ -3,17 +3,10 @@
# Applying `check_identifier` to parameters prevents SQL injection. # Applying `check_identifier` to parameters prevents SQL injection.
import sqlite3 import sqlite3
from typing import NamedTuple from typing import NamedTuple, Optional, Iterator, Any
DEFAULT_DB_PATH = ':memory:' DEFAULT_DB_PATH = ':memory:'
CONNECTION = None CONNECTION: Optional[sqlite3.Connection] = None
SQL_TYPES = {
int: 'INTEGER',
str: 'TEXT',
float: 'REAL',
bytes: 'BLOB',
}
class NoConnection(Exception): class NoConnection(Exception):
@ -38,29 +31,45 @@ class UnexpectedMultipleResults(Exception):
"""Query returned more than 1 row.""" """Query returned more than 1 row."""
SQLType = str
TypeMap = dict[type, SQLType]
SQL_TYPES: TypeMap = {
int: 'INTEGER',
str: 'TEXT',
float: 'REAL',
bytes: 'BLOB',
}
class ColumnSchema(NamedTuple): class ColumnSchema(NamedTuple):
name: str name: str
sql_type: str sql_type: SQLType
def check_identifier(name): FieldMap = dict[str, type]
def check_identifier(name: str) -> None:
if not name.isidentifier(): if not name.isidentifier():
raise ValueError(f'{name!r} is not an identifier') raise ValueError(f'{name!r} is not an identifier')
def connect(db_path=DEFAULT_DB_PATH): def connect(db_path: str = DEFAULT_DB_PATH) -> sqlite3.Connection:
global CONNECTION global CONNECTION
CONNECTION = sqlite3.connect(db_path) CONNECTION = sqlite3.connect(db_path)
CONNECTION.row_factory = sqlite3.Row
return CONNECTION return CONNECTION
def get_connection(): def get_connection() -> sqlite3.Connection:
if CONNECTION is None: if CONNECTION is None:
raise NoConnection() raise NoConnection()
return CONNECTION return CONNECTION
def gen_columns_sql(fields): def gen_columns_sql(fields: FieldMap) -> Iterator[ColumnSchema]:
for name, py_type in fields.items(): for name, py_type in fields.items():
check_identifier(name) check_identifier(name)
try: try:
@ -70,7 +79,7 @@ def gen_columns_sql(fields):
yield ColumnSchema(name, sql_type) yield ColumnSchema(name, sql_type)
def make_schema_sql(table_name, fields): def make_schema_sql(table_name: str, fields: FieldMap) -> str:
check_identifier(table_name) check_identifier(table_name)
pk = 'pk INTEGER PRIMARY KEY,' pk = 'pk INTEGER PRIMARY KEY,'
spcs = ' ' * 4 spcs = ' ' * 4
@ -81,25 +90,24 @@ def make_schema_sql(table_name, fields):
return f'CREATE TABLE {table_name} (\n{spcs}{pk}\n{spcs}{columns}\n)' return f'CREATE TABLE {table_name} (\n{spcs}{pk}\n{spcs}{columns}\n)'
def create_table(table_name, fields): def create_table(table_name: str, fields: FieldMap) -> None:
con = get_connection() con = get_connection()
con.execute(make_schema_sql(table_name, fields)) con.execute(make_schema_sql(table_name, fields))
def read_columns_sql(table_name): def read_columns_sql(table_name: str) -> list[ColumnSchema]:
con = get_connection()
check_identifier(table_name) check_identifier(table_name)
con = get_connection()
rows = con.execute(f'PRAGMA table_info({table_name!r})') rows = con.execute(f'PRAGMA table_info({table_name!r})')
# row fields: cid name type notnull dflt_value pk return [ColumnSchema(r['name'], r['type']) for r in rows]
return [ColumnSchema(r[1], r[2]) for r in rows]
def valid_table(table_name, fields): def valid_table(table_name: str, fields: FieldMap) -> bool:
table_columns = read_columns_sql(table_name) table_columns = read_columns_sql(table_name)
return set(gen_columns_sql(fields)) <= set(table_columns) return set(gen_columns_sql(fields)) <= set(table_columns)
def ensure_table(table_name, fields): def ensure_table(table_name: str, fields: FieldMap) -> None:
table_columns = read_columns_sql(table_name) table_columns = read_columns_sql(table_name)
if len(table_columns) == 0: if len(table_columns) == 0:
create_table(table_name, fields) create_table(table_name, fields)
@ -107,21 +115,21 @@ def ensure_table(table_name, fields):
raise SchemaMismatch(table_name) raise SchemaMismatch(table_name)
def insert_record(table_name, fields): def insert_record(table_name: str, data: dict[str, Any]) -> int:
con = get_connection()
check_identifier(table_name) check_identifier(table_name)
placeholders = ', '.join(['?'] * len(fields)) con = get_connection()
placeholders = ', '.join(['?'] * len(data))
sql = f'INSERT INTO {table_name} VALUES (NULL, {placeholders})' sql = f'INSERT INTO {table_name} VALUES (NULL, {placeholders})'
cursor = con.execute(sql, tuple(fields.values())) cursor = con.execute(sql, tuple(data.values()))
pk = cursor.lastrowid pk = cursor.lastrowid
con.commit() con.commit()
cursor.close() cursor.close()
return pk return pk
def fetch_record(table_name, pk): def fetch_record(table_name: str, pk: int) -> sqlite3.Row:
con = get_connection()
check_identifier(table_name) check_identifier(table_name)
con = get_connection()
sql = f'SELECT * FROM {table_name} WHERE pk = ? LIMIT 2' sql = f'SELECT * FROM {table_name} WHERE pk = ? LIMIT 2'
result = list(con.execute(sql, (pk,))) result = list(con.execute(sql, (pk,)))
if len(result) == 0: if len(result) == 0:
@ -132,19 +140,21 @@ def fetch_record(table_name, pk):
raise UnexpectedMultipleResults() raise UnexpectedMultipleResults()
def update_record(table_name, pk, fields): def update_record(
table_name: str, pk: int, data: dict[str, Any]
) -> tuple[str, tuple[Any, ...]]:
check_identifier(table_name) check_identifier(table_name)
con = get_connection() con = get_connection()
names = ', '.join(fields.keys()) names = ', '.join(data.keys())
placeholders = ', '.join(['?'] * len(fields)) placeholders = ', '.join(['?'] * len(data))
values = tuple(fields.values()) + (pk,) values = tuple(data.values()) + (pk,)
sql = f'UPDATE {table_name} SET ({names}) = ({placeholders}) WHERE pk = ?' sql = f'UPDATE {table_name} SET ({names}) = ({placeholders}) WHERE pk = ?'
con.execute(sql, values) con.execute(sql, values)
con.commit() con.commit()
return sql, values return sql, values
def delete_record(table_name, pk): def delete_record(table_name: str, pk: int) -> sqlite3.Cursor:
con = get_connection() con = get_connection()
check_identifier(table_name) check_identifier(table_name)
sql = f'DELETE FROM {table_name} WHERE pk = ?' sql = f'DELETE FROM {table_name} WHERE pk = ?'

View File

@ -77,7 +77,7 @@ def test_fetch_record(create_movies_sql):
con.execute(create_movies_sql) con.execute(create_movies_sql)
pk = insert_record('movies', fields) pk = insert_record('movies', fields)
row = fetch_record('movies', pk) row = fetch_record('movies', pk)
assert row == (1, 'Frozen', 1_290_000_000.0) assert tuple(row) == (1, 'Frozen', 1_290_000_000.0)
def test_fetch_record_no_such_pk(create_movies_sql): def test_fetch_record_no_such_pk(create_movies_sql):
@ -98,7 +98,7 @@ def test_update_record(create_movies_sql):
row = fetch_record('movies', pk) row = fetch_record('movies', pk)
assert sql == 'UPDATE movies SET (title, revenue) = (?, ?) WHERE pk = ?' assert sql == 'UPDATE movies SET (title, revenue) = (?, ?) WHERE pk = ?'
assert values == ('Frozen', 1_299_999_999, 1) assert values == ('Frozen', 1_299_999_999, 1)
assert row == (1, 'Frozen', 1_299_999_999.0) assert tuple(row) == (1, 'Frozen', 1_299_999_999.0)
def test_delete_record(create_movies_sql): def test_delete_record(create_movies_sql):

View File

@ -4,25 +4,26 @@ A ``Persistent`` class definition::
>>> class Movie(Persistent): >>> class Movie(Persistent):
... title: str ... title: str
... year: int ... year: int
... boxmega: float ... megabucks: float
Implemented behavior:: Implemented behavior::
>>> Movie._connect() # doctest: +ELLIPSIS >>> Movie._connect() # doctest: +ELLIPSIS
<sqlite3.Connection object at 0x...> <sqlite3.Connection object at 0x...>
>>> movie = Movie('The Godfather', 1972, 137) >>> movie = Movie(title='The Godfather', year=1972, megabucks=137)
>>> movie.title >>> movie.title
'The Godfather' 'The Godfather'
>>> movie.boxmega >>> movie.megabucks
137.0 137.0
Instances always have a ``.pk`` attribute, but it is ``None`` until the Instances always have a ``._pk`` attribute, but it is ``None`` until the
object is saved:: object is saved::
>>> movie.pk is None >>> movie._pk is None
True True
>>> movie._persist() >>> movie._save()
>>> movie.pk 1
>>> movie._pk
1 1
Delete the in-memory ``movie``, and fetch the record from the database, Delete the in-memory ``movie``, and fetch the record from the database,
@ -31,7 +32,7 @@ using ``Movie[pk]``—item access on the class itself::
>>> del movie >>> del movie
>>> film = Movie[1] >>> film = Movie[1]
>>> film >>> film
Movie('The Godfather', 1972, 137.0, pk=1) Movie(title='The Godfather', year=1972, megabucks=137.0, _pk=1)
By default, the table name is the class name lowercased, with an appended By default, the table name is the class name lowercased, with an appended
"s" for plural:: "s" for plural::
@ -51,69 +52,89 @@ class declaration::
""" """
from typing import get_type_hints from typing import Any, ClassVar, get_type_hints
import dblib as db import dblib as db
class Field: class Field:
def __init__(self, name, py_type): def __init__(self, name: str, py_type: type) -> None:
self.name = name self.name = name
self.type = py_type self.type = py_type
def __set__(self, instance, value): def __set__(self, instance: 'Persistent', value: Any) -> None:
try: try:
value = self.type(value) value = self.type(value)
except TypeError as e: except (TypeError, ValueError) as e:
msg = f'{value!r} is not compatible with {self.name}:{self.type}.' type_name = self.type.__name__
msg = f'{value!r} is not compatible with {self.name}:{type_name}.'
raise TypeError(msg) from e raise TypeError(msg) from e
instance.__dict__[self.name] = value instance.__dict__[self.name] = value
class Persistent: class Persistent:
def __init_subclass__( _TABLE_NAME: ClassVar[str]
cls, *, db_path=db.DEFAULT_DB_PATH, table='', **kwargs _TABLE_READY: ClassVar[bool] = False
):
super().__init_subclass__(**kwargs) @classmethod
def _fields(cls) -> dict[str, type]:
return {
name: py_type
for name, py_type in get_type_hints(cls).items()
if not name.startswith('_')
}
def __init_subclass__(cls, *, table: str = '', **kwargs: dict):
super().__init_subclass__(**kwargs) # type:ignore
cls._TABLE_NAME = table if table else cls.__name__.lower() + 's' cls._TABLE_NAME = table if table else cls.__name__.lower() + 's'
cls._TABLE_READY = False for name, py_type in cls._fields().items():
for name, py_type in get_type_hints(cls).items():
setattr(cls, name, Field(name, py_type)) setattr(cls, name, Field(name, py_type))
@staticmethod @staticmethod
def _connect(db_path=db.DEFAULT_DB_PATH): def _connect(db_path: str = db.DEFAULT_DB_PATH):
return db.connect(db_path) return db.connect(db_path)
@classmethod @classmethod
def _ensure_table(cls): def _ensure_table(cls) -> str:
if not cls._TABLE_READY: if not cls._TABLE_READY:
db.ensure_table(cls._TABLE_NAME, get_type_hints(cls)) db.ensure_table(cls._TABLE_NAME, cls._fields())
cls._TABLE_READY = True cls._TABLE_READY = True
return cls._TABLE_NAME return cls._TABLE_NAME
def _fields(self): def __class_getitem__(cls, pk: int) -> 'Persistent':
field_names = ['_pk'] + list(cls._fields())
values = db.fetch_record(cls._TABLE_NAME, pk)
return cls(**dict(zip(field_names, values)))
def _asdict(self) -> dict[str, Any]:
return { return {
name: getattr(self, name) name: getattr(self, name)
for name, attr in self.__class__.__dict__.items() for name, attr in self.__class__.__dict__.items()
if isinstance(attr, Field) if isinstance(attr, Field)
} }
def __init__(self, *args, pk=None): def __init__(self, *, _pk=None, **kwargs):
for name, arg in zip(self._fields(), args): field_names = self._asdict().keys()
for name, arg in kwargs.items():
if name not in field_names:
msg = f'{self.__class__.__name__!r} has no attribute {name!r}'
raise AttributeError(msg)
setattr(self, name, arg) setattr(self, name, arg)
self.pk = pk self._pk = _pk
def __class_getitem__(cls, pk): def __repr__(self) -> str:
return cls(*db.fetch_record(cls._TABLE_NAME, pk)[1:], pk=pk) kwargs = ', '.join(
f'{key}={value!r}' for key, value in self._asdict().items()
)
cls_name = self.__class__.__name__
if self._pk is None:
return f'{cls_name}({kwargs})'
return f'{cls_name}({kwargs}, _pk={self._pk})'
def __repr__(self): def _save(self) -> int:
args = ', '.join(repr(value) for value in self._fields().values())
pk = '' if self.pk is None else f', pk={self.pk}'
return f'{self.__class__.__name__}({args}{pk})'
def _persist(self):
table = self.__class__._ensure_table() table = self.__class__._ensure_table()
if self.pk is None: if self._pk is None:
self.pk = db.insert_record(table, self._fields()) self._pk = db.insert_record(table, self._asdict())
else: else:
db.update_record(table, self.pk, self._fields()) db.update_record(table, self._pk, self._asdict())
return self._pk

View File

@ -0,0 +1,37 @@
import pytest
from persistlib import Persistent
def test_field_descriptor_validation_type_error():
class Cat(Persistent):
name: str
weight: float
with pytest.raises(TypeError) as e:
felix = Cat(name='Felix', weight=None)
assert str(e.value) == 'None is not compatible with weight:float.'
def test_field_descriptor_validation_value_error():
class Cat(Persistent):
name: str
weight: float
with pytest.raises(TypeError) as e:
felix = Cat(name='Felix', weight='half stone')
assert str(e.value) == "'half stone' is not compatible with weight:float."
def test_constructor_attribute_error():
class Cat(Persistent):
name: str
weight: float
with pytest.raises(AttributeError) as e:
felix = Cat(name='Felix', weight=3.2, age=7)
assert str(e.value) == "'Cat' has no attribute 'age'"