2021-04-15 04:20:15 +02:00
|
|
|
"""
|
|
|
|
A ``Persistent`` class definition::
|
|
|
|
|
|
|
|
>>> class Movie(Persistent):
|
|
|
|
... title: str
|
|
|
|
... year: int
|
2021-05-21 23:56:12 +02:00
|
|
|
... box_office: float
|
2021-04-15 04:20:15 +02:00
|
|
|
|
|
|
|
Implemented behavior::
|
|
|
|
|
|
|
|
>>> Movie._connect() # doctest: +ELLIPSIS
|
|
|
|
<sqlite3.Connection object at 0x...>
|
2021-05-21 23:56:12 +02:00
|
|
|
>>> movie = Movie(title='The Godfather', year=1972, box_office=137)
|
2021-04-15 04:20:15 +02:00
|
|
|
>>> movie.title
|
|
|
|
'The Godfather'
|
2021-05-21 23:56:12 +02:00
|
|
|
>>> movie.box_office
|
2021-04-15 04:20:15 +02:00
|
|
|
137.0
|
|
|
|
|
2021-04-15 22:15:53 +02:00
|
|
|
Instances always have a ``._pk`` attribute, but it is ``None`` until the
|
2021-04-15 04:20:15 +02:00
|
|
|
object is saved::
|
|
|
|
|
2021-04-15 22:15:53 +02:00
|
|
|
>>> movie._pk is None
|
2021-04-15 04:20:15 +02:00
|
|
|
True
|
2021-04-15 22:15:53 +02:00
|
|
|
>>> movie._save()
|
|
|
|
1
|
|
|
|
>>> movie._pk
|
2021-04-15 04:20:15 +02:00
|
|
|
1
|
|
|
|
|
|
|
|
Delete the in-memory ``movie``, and fetch the record from the database,
|
|
|
|
using ``Movie[pk]``—item access on the class itself::
|
|
|
|
|
|
|
|
>>> del movie
|
|
|
|
>>> film = Movie[1]
|
|
|
|
>>> film
|
2021-05-21 23:56:12 +02:00
|
|
|
Movie(title='The Godfather', year=1972, box_office=137.0, _pk=1)
|
2021-04-15 04:20:15 +02:00
|
|
|
|
|
|
|
By default, the table name is the class name lowercased, with an appended
|
|
|
|
"s" for plural::
|
|
|
|
|
|
|
|
>>> Movie._TABLE_NAME
|
|
|
|
'movies'
|
|
|
|
|
|
|
|
If needed, a custom table name can be given as a keyword argument in the
|
|
|
|
class declaration::
|
|
|
|
|
|
|
|
>>> class Aircraft(Persistent, table='aircraft'):
|
|
|
|
... registration: str
|
|
|
|
... model: str
|
|
|
|
...
|
|
|
|
>>> Aircraft._TABLE_NAME
|
|
|
|
'aircraft'
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
2021-04-15 22:15:53 +02:00
|
|
|
from typing import Any, ClassVar, get_type_hints
|
2021-04-15 04:20:15 +02:00
|
|
|
|
|
|
|
import dblib as db
|
|
|
|
|
|
|
|
|
|
|
|
class Field:
|
2021-04-15 22:15:53 +02:00
|
|
|
def __init__(self, name: str, py_type: type) -> None:
|
2021-04-15 04:20:15 +02:00
|
|
|
self.name = name
|
|
|
|
self.type = py_type
|
|
|
|
|
2021-04-15 22:15:53 +02:00
|
|
|
def __set__(self, instance: 'Persistent', value: Any) -> None:
|
2021-04-15 04:20:15 +02:00
|
|
|
try:
|
|
|
|
value = self.type(value)
|
2021-04-15 22:15:53 +02:00
|
|
|
except (TypeError, ValueError) as e:
|
|
|
|
type_name = self.type.__name__
|
|
|
|
msg = f'{value!r} is not compatible with {self.name}:{type_name}.'
|
2021-04-15 04:20:15 +02:00
|
|
|
raise TypeError(msg) from e
|
|
|
|
instance.__dict__[self.name] = value
|
|
|
|
|
|
|
|
|
|
|
|
class Persistent:
|
2021-04-15 22:15:53 +02:00
|
|
|
_TABLE_NAME: ClassVar[str]
|
|
|
|
_TABLE_READY: ClassVar[bool] = False
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _fields(cls) -> dict[str, type]:
|
|
|
|
return {
|
|
|
|
name: py_type
|
|
|
|
for name, py_type in get_type_hints(cls).items()
|
|
|
|
if not name.startswith('_')
|
|
|
|
}
|
|
|
|
|
2021-05-21 23:56:12 +02:00
|
|
|
def __init_subclass__(cls, *, table: str = '', **kwargs: Any):
|
2021-04-15 22:15:53 +02:00
|
|
|
super().__init_subclass__(**kwargs) # type:ignore
|
2021-04-15 04:20:15 +02:00
|
|
|
cls._TABLE_NAME = table if table else cls.__name__.lower() + 's'
|
2021-04-15 22:15:53 +02:00
|
|
|
for name, py_type in cls._fields().items():
|
2021-04-15 04:20:15 +02:00
|
|
|
setattr(cls, name, Field(name, py_type))
|
|
|
|
|
2021-04-15 22:15:53 +02:00
|
|
|
def __init__(self, *, _pk=None, **kwargs):
|
|
|
|
field_names = self._asdict().keys()
|
|
|
|
for name, arg in kwargs.items():
|
|
|
|
if name not in field_names:
|
|
|
|
msg = f'{self.__class__.__name__!r} has no attribute {name!r}'
|
|
|
|
raise AttributeError(msg)
|
2021-04-15 04:20:15 +02:00
|
|
|
setattr(self, name, arg)
|
2021-04-15 22:15:53 +02:00
|
|
|
self._pk = _pk
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
kwargs = ', '.join(
|
|
|
|
f'{key}={value!r}' for key, value in self._asdict().items()
|
|
|
|
)
|
|
|
|
cls_name = self.__class__.__name__
|
|
|
|
if self._pk is None:
|
|
|
|
return f'{cls_name}({kwargs})'
|
|
|
|
return f'{cls_name}({kwargs}, _pk={self._pk})'
|
|
|
|
|
2021-05-21 23:56:12 +02:00
|
|
|
def _asdict(self) -> dict[str, Any]:
|
|
|
|
return {
|
|
|
|
name: getattr(self, name)
|
|
|
|
for name, attr in self.__class__.__dict__.items()
|
|
|
|
if isinstance(attr, Field)
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# database methods
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _connect(db_path: str = db.DEFAULT_DB_PATH):
|
|
|
|
return db.connect(db_path)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _ensure_table(cls) -> str:
|
|
|
|
if not cls._TABLE_READY:
|
|
|
|
db.ensure_table(cls._TABLE_NAME, cls._fields())
|
|
|
|
cls._TABLE_READY = True
|
|
|
|
return cls._TABLE_NAME
|
|
|
|
|
|
|
|
def __class_getitem__(cls, pk: int) -> 'Persistent':
|
|
|
|
field_names = ['_pk'] + list(cls._fields())
|
|
|
|
values = db.fetch_record(cls._TABLE_NAME, pk)
|
|
|
|
return cls(**dict(zip(field_names, values)))
|
|
|
|
|
2021-04-15 22:15:53 +02:00
|
|
|
def _save(self) -> int:
|
2021-04-15 04:20:15 +02:00
|
|
|
table = self.__class__._ensure_table()
|
2021-04-15 22:15:53 +02:00
|
|
|
if self._pk is None:
|
|
|
|
self._pk = db.insert_record(table, self._asdict())
|
2021-04-15 04:20:15 +02:00
|
|
|
else:
|
2021-04-15 22:15:53 +02:00
|
|
|
db.update_record(table, self._pk, self._asdict())
|
|
|
|
return self._pk
|