diff --git a/25-class-metaprog/persistent/dblib.py b/25-class-metaprog/persistent/dblib.py index e81edc4..dd1fae4 100644 --- a/25-class-metaprog/persistent/dblib.py +++ b/25-class-metaprog/persistent/dblib.py @@ -3,17 +3,10 @@ # Applying `check_identifier` to parameters prevents SQL injection. import sqlite3 -from typing import NamedTuple +from typing import NamedTuple, Optional, Iterator, Any DEFAULT_DB_PATH = ':memory:' -CONNECTION = None - -SQL_TYPES = { - int: 'INTEGER', - str: 'TEXT', - float: 'REAL', - bytes: 'BLOB', -} +CONNECTION: Optional[sqlite3.Connection] = None class NoConnection(Exception): @@ -38,29 +31,45 @@ class UnexpectedMultipleResults(Exception): """Query returned more than 1 row.""" +SQLType = str + +TypeMap = dict[type, SQLType] + +SQL_TYPES: TypeMap = { + int: 'INTEGER', + str: 'TEXT', + float: 'REAL', + bytes: 'BLOB', +} + + class ColumnSchema(NamedTuple): name: str - sql_type: str + sql_type: SQLType -def check_identifier(name): +FieldMap = dict[str, type] + + +def check_identifier(name: str) -> None: if not name.isidentifier(): raise ValueError(f'{name!r} is not an identifier') -def connect(db_path=DEFAULT_DB_PATH): +def connect(db_path: str = DEFAULT_DB_PATH) -> sqlite3.Connection: global CONNECTION CONNECTION = sqlite3.connect(db_path) + CONNECTION.row_factory = sqlite3.Row return CONNECTION -def get_connection(): +def get_connection() -> sqlite3.Connection: if CONNECTION is None: raise NoConnection() return CONNECTION -def gen_columns_sql(fields): +def gen_columns_sql(fields: FieldMap) -> Iterator[ColumnSchema]: for name, py_type in fields.items(): check_identifier(name) try: @@ -70,7 +79,7 @@ def gen_columns_sql(fields): yield ColumnSchema(name, sql_type) -def make_schema_sql(table_name, fields): +def make_schema_sql(table_name: str, fields: FieldMap) -> str: check_identifier(table_name) pk = 'pk INTEGER PRIMARY KEY,' spcs = ' ' * 4 @@ -81,25 +90,24 @@ def make_schema_sql(table_name, fields): return f'CREATE TABLE {table_name} (\n{spcs}{pk}\n{spcs}{columns}\n)' -def create_table(table_name, fields): +def create_table(table_name: str, fields: FieldMap) -> None: con = get_connection() con.execute(make_schema_sql(table_name, fields)) -def read_columns_sql(table_name): - con = get_connection() +def read_columns_sql(table_name: str) -> list[ColumnSchema]: check_identifier(table_name) + con = get_connection() rows = con.execute(f'PRAGMA table_info({table_name!r})') - # row fields: cid name type notnull dflt_value pk - return [ColumnSchema(r[1], r[2]) for r in rows] + return [ColumnSchema(r['name'], r['type']) for r in rows] -def valid_table(table_name, fields): +def valid_table(table_name: str, fields: FieldMap) -> bool: table_columns = read_columns_sql(table_name) return set(gen_columns_sql(fields)) <= set(table_columns) -def ensure_table(table_name, fields): +def ensure_table(table_name: str, fields: FieldMap) -> None: table_columns = read_columns_sql(table_name) if len(table_columns) == 0: create_table(table_name, fields) @@ -107,21 +115,21 @@ def ensure_table(table_name, fields): raise SchemaMismatch(table_name) -def insert_record(table_name, fields): - con = get_connection() +def insert_record(table_name: str, data: dict[str, Any]) -> int: check_identifier(table_name) - placeholders = ', '.join(['?'] * len(fields)) + con = get_connection() + placeholders = ', '.join(['?'] * len(data)) sql = f'INSERT INTO {table_name} VALUES (NULL, {placeholders})' - cursor = con.execute(sql, tuple(fields.values())) + cursor = con.execute(sql, tuple(data.values())) pk = cursor.lastrowid con.commit() cursor.close() return pk -def fetch_record(table_name, pk): - con = get_connection() +def fetch_record(table_name: str, pk: int) -> sqlite3.Row: check_identifier(table_name) + con = get_connection() sql = f'SELECT * FROM {table_name} WHERE pk = ? LIMIT 2' result = list(con.execute(sql, (pk,))) if len(result) == 0: @@ -132,19 +140,21 @@ def fetch_record(table_name, pk): raise UnexpectedMultipleResults() -def update_record(table_name, pk, fields): +def update_record( + table_name: str, pk: int, data: dict[str, Any] +) -> tuple[str, tuple[Any, ...]]: check_identifier(table_name) con = get_connection() - names = ', '.join(fields.keys()) - placeholders = ', '.join(['?'] * len(fields)) - values = tuple(fields.values()) + (pk,) + names = ', '.join(data.keys()) + placeholders = ', '.join(['?'] * len(data)) + values = tuple(data.values()) + (pk,) sql = f'UPDATE {table_name} SET ({names}) = ({placeholders}) WHERE pk = ?' con.execute(sql, values) con.commit() return sql, values -def delete_record(table_name, pk): +def delete_record(table_name: str, pk: int) -> sqlite3.Cursor: con = get_connection() check_identifier(table_name) sql = f'DELETE FROM {table_name} WHERE pk = ?' diff --git a/25-class-metaprog/persistent/dblib_test.py b/25-class-metaprog/persistent/dblib_test.py index 9a0a93c..dcaf0bb 100644 --- a/25-class-metaprog/persistent/dblib_test.py +++ b/25-class-metaprog/persistent/dblib_test.py @@ -77,7 +77,7 @@ def test_fetch_record(create_movies_sql): con.execute(create_movies_sql) pk = insert_record('movies', fields) row = fetch_record('movies', pk) - assert row == (1, 'Frozen', 1_290_000_000.0) + assert tuple(row) == (1, 'Frozen', 1_290_000_000.0) def test_fetch_record_no_such_pk(create_movies_sql): @@ -98,7 +98,7 @@ def test_update_record(create_movies_sql): row = fetch_record('movies', pk) assert sql == 'UPDATE movies SET (title, revenue) = (?, ?) WHERE pk = ?' assert values == ('Frozen', 1_299_999_999, 1) - assert row == (1, 'Frozen', 1_299_999_999.0) + assert tuple(row) == (1, 'Frozen', 1_299_999_999.0) def test_delete_record(create_movies_sql): diff --git a/25-class-metaprog/persistent/persistlib.py b/25-class-metaprog/persistent/persistlib.py index 633f48c..15c9d38 100644 --- a/25-class-metaprog/persistent/persistlib.py +++ b/25-class-metaprog/persistent/persistlib.py @@ -4,25 +4,26 @@ A ``Persistent`` class definition:: >>> class Movie(Persistent): ... title: str ... year: int - ... boxmega: float + ... megabucks: float Implemented behavior:: >>> Movie._connect() # doctest: +ELLIPSIS - >>> movie = Movie('The Godfather', 1972, 137) + >>> movie = Movie(title='The Godfather', year=1972, megabucks=137) >>> movie.title 'The Godfather' - >>> movie.boxmega + >>> movie.megabucks 137.0 -Instances always have a ``.pk`` attribute, but it is ``None`` until the +Instances always have a ``._pk`` attribute, but it is ``None`` until the object is saved:: - >>> movie.pk is None + >>> movie._pk is None True - >>> movie._persist() - >>> movie.pk + >>> movie._save() + 1 + >>> movie._pk 1 Delete the in-memory ``movie``, and fetch the record from the database, @@ -31,7 +32,7 @@ using ``Movie[pk]``—item access on the class itself:: >>> del movie >>> film = Movie[1] >>> film - Movie('The Godfather', 1972, 137.0, pk=1) + Movie(title='The Godfather', year=1972, megabucks=137.0, _pk=1) By default, the table name is the class name lowercased, with an appended "s" for plural:: @@ -51,69 +52,89 @@ class declaration:: """ -from typing import get_type_hints +from typing import Any, ClassVar, get_type_hints import dblib as db class Field: - def __init__(self, name, py_type): + def __init__(self, name: str, py_type: type) -> None: self.name = name self.type = py_type - def __set__(self, instance, value): + def __set__(self, instance: 'Persistent', value: Any) -> None: try: value = self.type(value) - except TypeError as e: - msg = f'{value!r} is not compatible with {self.name}:{self.type}.' + except (TypeError, ValueError) as e: + type_name = self.type.__name__ + msg = f'{value!r} is not compatible with {self.name}:{type_name}.' raise TypeError(msg) from e instance.__dict__[self.name] = value class Persistent: - def __init_subclass__( - cls, *, db_path=db.DEFAULT_DB_PATH, table='', **kwargs - ): - super().__init_subclass__(**kwargs) + _TABLE_NAME: ClassVar[str] + _TABLE_READY: ClassVar[bool] = False + + @classmethod + def _fields(cls) -> dict[str, type]: + return { + name: py_type + for name, py_type in get_type_hints(cls).items() + if not name.startswith('_') + } + + def __init_subclass__(cls, *, table: str = '', **kwargs: dict): + super().__init_subclass__(**kwargs) # type:ignore cls._TABLE_NAME = table if table else cls.__name__.lower() + 's' - cls._TABLE_READY = False - for name, py_type in get_type_hints(cls).items(): + for name, py_type in cls._fields().items(): setattr(cls, name, Field(name, py_type)) @staticmethod - def _connect(db_path=db.DEFAULT_DB_PATH): + def _connect(db_path: str = db.DEFAULT_DB_PATH): return db.connect(db_path) @classmethod - def _ensure_table(cls): + def _ensure_table(cls) -> str: if not cls._TABLE_READY: - db.ensure_table(cls._TABLE_NAME, get_type_hints(cls)) + db.ensure_table(cls._TABLE_NAME, cls._fields()) cls._TABLE_READY = True return cls._TABLE_NAME - def _fields(self): + def __class_getitem__(cls, pk: int) -> 'Persistent': + field_names = ['_pk'] + list(cls._fields()) + values = db.fetch_record(cls._TABLE_NAME, pk) + return cls(**dict(zip(field_names, values))) + + def _asdict(self) -> dict[str, Any]: return { name: getattr(self, name) for name, attr in self.__class__.__dict__.items() if isinstance(attr, Field) } - def __init__(self, *args, pk=None): - for name, arg in zip(self._fields(), args): + def __init__(self, *, _pk=None, **kwargs): + field_names = self._asdict().keys() + for name, arg in kwargs.items(): + if name not in field_names: + msg = f'{self.__class__.__name__!r} has no attribute {name!r}' + raise AttributeError(msg) setattr(self, name, arg) - self.pk = pk + self._pk = _pk - def __class_getitem__(cls, pk): - return cls(*db.fetch_record(cls._TABLE_NAME, pk)[1:], pk=pk) + def __repr__(self) -> str: + kwargs = ', '.join( + f'{key}={value!r}' for key, value in self._asdict().items() + ) + cls_name = self.__class__.__name__ + if self._pk is None: + return f'{cls_name}({kwargs})' + return f'{cls_name}({kwargs}, _pk={self._pk})' - def __repr__(self): - args = ', '.join(repr(value) for value in self._fields().values()) - pk = '' if self.pk is None else f', pk={self.pk}' - return f'{self.__class__.__name__}({args}{pk})' - - def _persist(self): + def _save(self) -> int: table = self.__class__._ensure_table() - if self.pk is None: - self.pk = db.insert_record(table, self._fields()) + if self._pk is None: + self._pk = db.insert_record(table, self._asdict()) else: - db.update_record(table, self.pk, self._fields()) + db.update_record(table, self._pk, self._asdict()) + return self._pk diff --git a/25-class-metaprog/persistent/persistlib_test.py b/25-class-metaprog/persistent/persistlib_test.py new file mode 100644 index 0000000..1604ccb --- /dev/null +++ b/25-class-metaprog/persistent/persistlib_test.py @@ -0,0 +1,37 @@ +import pytest + + +from persistlib import Persistent + + +def test_field_descriptor_validation_type_error(): + class Cat(Persistent): + name: str + weight: float + + with pytest.raises(TypeError) as e: + felix = Cat(name='Felix', weight=None) + + assert str(e.value) == 'None is not compatible with weight:float.' + + +def test_field_descriptor_validation_value_error(): + class Cat(Persistent): + name: str + weight: float + + with pytest.raises(TypeError) as e: + felix = Cat(name='Felix', weight='half stone') + + assert str(e.value) == "'half stone' is not compatible with weight:float." + + +def test_constructor_attribute_error(): + class Cat(Persistent): + name: str + weight: float + + with pytest.raises(AttributeError) as e: + felix = Cat(name='Felix', weight=3.2, age=7) + + assert str(e.value) == "'Cat' has no attribute 'age'"