From 37d3e2b73ca2877c7e2e3efd83ae34483e5ef094 Mon Sep 17 00:00:00 2001 From: Luciano Ramalho Date: Wed, 14 Apr 2021 23:20:15 -0300 Subject: [PATCH 1/3] Add persistent lib to Chapter 25 --- 25-class-metaprog/persistent/.gitignore | 1 + 25-class-metaprog/persistent/dblib.py | 151 +++++++++++++++++++++ 25-class-metaprog/persistent/dblib_test.py | 131 ++++++++++++++++++ 25-class-metaprog/persistent/persistlib.py | 119 ++++++++++++++++ 4 files changed, 402 insertions(+) create mode 100644 25-class-metaprog/persistent/.gitignore create mode 100644 25-class-metaprog/persistent/dblib.py create mode 100644 25-class-metaprog/persistent/dblib_test.py create mode 100644 25-class-metaprog/persistent/persistlib.py diff --git a/25-class-metaprog/persistent/.gitignore b/25-class-metaprog/persistent/.gitignore new file mode 100644 index 0000000..98e6ef6 --- /dev/null +++ b/25-class-metaprog/persistent/.gitignore @@ -0,0 +1 @@ +*.db diff --git a/25-class-metaprog/persistent/dblib.py b/25-class-metaprog/persistent/dblib.py new file mode 100644 index 0000000..68dfeab --- /dev/null +++ b/25-class-metaprog/persistent/dblib.py @@ -0,0 +1,151 @@ +# SQLite3 does not support parameterized table and field names, +# for CREATE TABLE and PRAGMA so we must use Python string formatting. +# Applying `check_identifier` to parameters prevents SQL injection. + +import sqlite3 +from typing import NamedTuple + +DEFAULT_DB_PATH = ':memory:' +CONNECTION = None + +SQL_TYPES = { + int: 'INTEGER', + str: 'TEXT', + float: 'REAL', + bytes: 'BLOB', +} + + +class NoConnection(Exception): + """Call connect() to open connection.""" + + +class SchemaMismatch(ValueError): + """The table's schema doesn't match the class.""" + + def __init__(self, table_name): + self.table_name = table_name + + +class NoSuchRecord(LookupError): + """The given primary key does not exist.""" + + def __init__(self, pk): + self.pk = pk + + +class UnexpectedMultipleResults(Exception): + """Query returned more than 1 row.""" + + +class ColumnSchema(NamedTuple): + name: str + sql_type: str + + +def check_identifier(name): + if not name.isidentifier(): + raise ValueError(f'{name!r} is not an identifier') + + +def connect(db_path=DEFAULT_DB_PATH): + global CONNECTION + CONNECTION = sqlite3.connect(db_path) + return CONNECTION + + +def get_connection(): + if CONNECTION is None: + raise NoConnection() + return CONNECTION + + +def gen_columns_sql(fields): + for name, py_type in fields.items(): + check_identifier(name) + try: + sql_type = SQL_TYPES[py_type] + except KeyError as e: + raise ValueError(f'type {py_type!r} is not supported') from e + yield ColumnSchema(name, sql_type) + + +def make_schema_sql(table_name, fields): + check_identifier(table_name) + pk = 'pk INTEGER PRIMARY KEY,' + spcs = ' ' * 4 + columns = ',\n '.join( + f'{field_name} {sql_type}' + for field_name, sql_type in gen_columns_sql(fields) + ) + return f'CREATE TABLE {table_name} (\n{spcs}{pk}\n{spcs}{columns}\n)' + + +def create_table(table_name, fields): + con = get_connection() + con.execute(make_schema_sql(table_name, fields)) + + +def read_columns_sql(table_name): + con = get_connection() + check_identifier(table_name) + rows = con.execute(f'PRAGMA table_info({table_name!r})') + # row fields: cid name type notnull dflt_value pk + return [ColumnSchema(r[1], r[2]) for r in rows] + + +def valid_table(table_name, fields): + table_columns = read_columns_sql(table_name) + return set(gen_columns_sql(fields)) <= set(table_columns) + + +def ensure_table(table_name, fields): + table_columns = read_columns_sql(table_name) + if len(table_columns) == 0: + create_table(table_name, fields) + if not valid_table(table_name, fields): + raise SchemaMismatch(table_name) + + +def insert_record(table_name, fields): + con = get_connection() + check_identifier(table_name) + placeholders = ', '.join(['?'] * len(fields)) + sql = f'INSERT INTO {table_name} VALUES (NULL, {placeholders})' + cursor = con.execute(sql, tuple(fields.values())) + pk = cursor.lastrowid + con.commit() + cursor.close() + return pk + + +def fetch_record(table_name, pk): + con = get_connection() + check_identifier(table_name) + sql = f'SELECT * FROM {table_name} WHERE pk = ? LIMIT 2' + result = list(con.execute(sql, (pk,))) + if len(result) == 0: + raise NoSuchRecord(pk) + elif len(result) == 1: + return result[0] + else: + raise UnexpectedMultipleResults() + + +def update_record(table_name, pk, fields): + con = get_connection() + check_identifier(table_name) + names = ', '.join(fields.keys()) + placeholders = ', '.join(['?'] * len(fields)) + values = tuple(fields.values()) + (pk,) + sql = f'UPDATE {table_name} SET ({names}) = ({placeholders}) WHERE pk = ?' + con.execute(sql, values) + con.commit() + return sql, values + + +def delete_record(table_name, pk): + con = get_connection() + check_identifier(table_name) + sql = f'DELETE FROM {table_name} WHERE pk = ?' + return con.execute(sql, (pk,)) diff --git a/25-class-metaprog/persistent/dblib_test.py b/25-class-metaprog/persistent/dblib_test.py new file mode 100644 index 0000000..9a0a93c --- /dev/null +++ b/25-class-metaprog/persistent/dblib_test.py @@ -0,0 +1,131 @@ +from textwrap import dedent + +import pytest + +from dblib import gen_columns_sql, make_schema_sql, connect, read_columns_sql +from dblib import ColumnSchema, insert_record, fetch_record, update_record +from dblib import NoSuchRecord, delete_record, valid_table + + +@pytest.fixture +def create_movies_sql(): + sql = ''' + CREATE TABLE movies ( + pk INTEGER PRIMARY KEY, + title TEXT, + revenue REAL + ) + ''' + return dedent(sql).strip() + + +@pytest.mark.parametrize( + 'fields, expected', + [ + ( + dict(title=str, awards=int), + [('title', 'TEXT'), ('awards', 'INTEGER')], + ), + ( + dict(picture=bytes, score=float), + [('picture', 'BLOB'), ('score', 'REAL')], + ), + ], +) +def test_gen_columns_sql(fields, expected): + result = list(gen_columns_sql(fields)) + assert result == expected + + +def test_make_schema_sql(create_movies_sql): + fields = dict(title=str, revenue=float) + result = make_schema_sql('movies', fields) + assert result == create_movies_sql + + +def test_read_columns_sql(create_movies_sql): + expected = [ + ColumnSchema(name='pk', sql_type='INTEGER'), + ColumnSchema(name='title', sql_type='TEXT'), + ColumnSchema(name='revenue', sql_type='REAL'), + ] + with connect() as con: + con.execute(create_movies_sql) + result = read_columns_sql('movies') + assert result == expected + + +def test_read_columns_sql_no_such_table(create_movies_sql): + with connect() as con: + con.execute(create_movies_sql) + result = read_columns_sql('no_such_table') + assert result == [] + + +def test_insert_record(create_movies_sql): + fields = dict(title='Frozen', revenue=1_290_000_000) + with connect() as con: + con.execute(create_movies_sql) + for _ in range(3): + result = insert_record('movies', fields) + assert result == 3 + + +def test_fetch_record(create_movies_sql): + fields = dict(title='Frozen', revenue=1_290_000_000) + with connect() as con: + con.execute(create_movies_sql) + pk = insert_record('movies', fields) + row = fetch_record('movies', pk) + assert row == (1, 'Frozen', 1_290_000_000.0) + + +def test_fetch_record_no_such_pk(create_movies_sql): + with connect() as con: + con.execute(create_movies_sql) + with pytest.raises(NoSuchRecord) as e: + fetch_record('movies', 42) + assert e.value.pk == 42 + + +def test_update_record(create_movies_sql): + fields = dict(title='Frozen', revenue=1_290_000_000) + with connect() as con: + con.execute(create_movies_sql) + pk = insert_record('movies', fields) + fields['revenue'] = 1_299_999_999 + sql, values = update_record('movies', pk, fields) + row = fetch_record('movies', pk) + assert sql == 'UPDATE movies SET (title, revenue) = (?, ?) WHERE pk = ?' + assert values == ('Frozen', 1_299_999_999, 1) + assert row == (1, 'Frozen', 1_299_999_999.0) + + +def test_delete_record(create_movies_sql): + fields = dict(title='Frozen', revenue=1_290_000_000) + with connect() as con: + con.execute(create_movies_sql) + pk = insert_record('movies', fields) + delete_record('movies', pk) + with pytest.raises(NoSuchRecord) as e: + fetch_record('movies', pk) + assert e.value.pk == pk + + +def test_persistent_valid_table(create_movies_sql): + fields = dict(title=str, revenue=float) + + with connect() as con: + con.execute(create_movies_sql) + con.commit() + assert valid_table('movies', fields) + + +def test_persistent_valid_table_false(create_movies_sql): + # year field not in movies_sql + fields = dict(title=str, revenue=float, year=int) + + with connect() as con: + con.execute(create_movies_sql) + con.commit() + assert not valid_table('movies', fields) diff --git a/25-class-metaprog/persistent/persistlib.py b/25-class-metaprog/persistent/persistlib.py new file mode 100644 index 0000000..633f48c --- /dev/null +++ b/25-class-metaprog/persistent/persistlib.py @@ -0,0 +1,119 @@ +""" +A ``Persistent`` class definition:: + + >>> class Movie(Persistent): + ... title: str + ... year: int + ... boxmega: float + +Implemented behavior:: + + >>> Movie._connect() # doctest: +ELLIPSIS + + >>> movie = Movie('The Godfather', 1972, 137) + >>> movie.title + 'The Godfather' + >>> movie.boxmega + 137.0 + +Instances always have a ``.pk`` attribute, but it is ``None`` until the +object is saved:: + + >>> movie.pk is None + True + >>> movie._persist() + >>> movie.pk + 1 + +Delete the in-memory ``movie``, and fetch the record from the database, +using ``Movie[pk]``—item access on the class itself:: + + >>> del movie + >>> film = Movie[1] + >>> film + Movie('The Godfather', 1972, 137.0, pk=1) + +By default, the table name is the class name lowercased, with an appended +"s" for plural:: + + >>> Movie._TABLE_NAME + 'movies' + +If needed, a custom table name can be given as a keyword argument in the +class declaration:: + + >>> class Aircraft(Persistent, table='aircraft'): + ... registration: str + ... model: str + ... + >>> Aircraft._TABLE_NAME + 'aircraft' + +""" + +from typing import get_type_hints + +import dblib as db + + +class Field: + def __init__(self, name, py_type): + self.name = name + self.type = py_type + + def __set__(self, instance, value): + try: + value = self.type(value) + except TypeError as e: + msg = f'{value!r} is not compatible with {self.name}:{self.type}.' + raise TypeError(msg) from e + instance.__dict__[self.name] = value + + +class Persistent: + def __init_subclass__( + cls, *, db_path=db.DEFAULT_DB_PATH, table='', **kwargs + ): + super().__init_subclass__(**kwargs) + cls._TABLE_NAME = table if table else cls.__name__.lower() + 's' + cls._TABLE_READY = False + for name, py_type in get_type_hints(cls).items(): + setattr(cls, name, Field(name, py_type)) + + @staticmethod + def _connect(db_path=db.DEFAULT_DB_PATH): + return db.connect(db_path) + + @classmethod + def _ensure_table(cls): + if not cls._TABLE_READY: + db.ensure_table(cls._TABLE_NAME, get_type_hints(cls)) + cls._TABLE_READY = True + return cls._TABLE_NAME + + def _fields(self): + return { + name: getattr(self, name) + for name, attr in self.__class__.__dict__.items() + if isinstance(attr, Field) + } + + def __init__(self, *args, pk=None): + for name, arg in zip(self._fields(), args): + setattr(self, name, arg) + self.pk = pk + + def __class_getitem__(cls, pk): + return cls(*db.fetch_record(cls._TABLE_NAME, pk)[1:], pk=pk) + + def __repr__(self): + args = ', '.join(repr(value) for value in self._fields().values()) + pk = '' if self.pk is None else f', pk={self.pk}' + return f'{self.__class__.__name__}({args}{pk})' + + def _persist(self): + table = self.__class__._ensure_table() + if self.pk is None: + self.pk = db.insert_record(table, self._fields()) + else: + db.update_record(table, self.pk, self._fields()) From f8a1268fb10c8d1f4fdc24caacce572035904561 Mon Sep 17 00:00:00 2001 From: Luciano Ramalho Date: Thu, 15 Apr 2021 12:50:50 -0300 Subject: [PATCH 2/3] dblib.update_record: Get connection only if identifiers are ok Co-authored-by: Leonardo Rochael Almeida --- 25-class-metaprog/persistent/dblib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/25-class-metaprog/persistent/dblib.py b/25-class-metaprog/persistent/dblib.py index 68dfeab..c858e3c 100644 --- a/25-class-metaprog/persistent/dblib.py +++ b/25-class-metaprog/persistent/dblib.py @@ -133,8 +133,8 @@ def fetch_record(table_name, pk): def update_record(table_name, pk, fields): - con = get_connection() check_identifier(table_name) + con = get_connection() names = ', '.join(fields.keys()) placeholders = ', '.join(['?'] * len(fields)) values = tuple(fields.values()) + (pk,) From ee418d7d972baf418c1f01f17736e80b214d3338 Mon Sep 17 00:00:00 2001 From: Luciano Ramalho Date: Thu, 15 Apr 2021 13:34:17 -0300 Subject: [PATCH 3/3] Fix unneeded posessive 's in docstring Co-authored-by: Leonardo Rochael Almeida --- 25-class-metaprog/persistent/dblib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/25-class-metaprog/persistent/dblib.py b/25-class-metaprog/persistent/dblib.py index c858e3c..e81edc4 100644 --- a/25-class-metaprog/persistent/dblib.py +++ b/25-class-metaprog/persistent/dblib.py @@ -21,7 +21,7 @@ class NoConnection(Exception): class SchemaMismatch(ValueError): - """The table's schema doesn't match the class.""" + """The table schema doesn't match the class.""" def __init__(self, table_name): self.table_name = table_name