diff --git a/src/db/__init__.py b/src/db/__init__.py index 580dd9b..ade3339 100644 --- a/src/db/__init__.py +++ b/src/db/__init__.py @@ -1,27 +1,23 @@ """ -Initialize the database modules. +Initialize the database modules and create the database tables and default data. """ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from .models import NewsCategories, Base, DefaultNewsCategories, get_or_create from .db import DATABASE_URL -from .models import NewsCategories, Base, get_or_create, DefaultNewsCategories +# Initialise a database session, create the tables and insert the default data if needed engine = create_engine(DATABASE_URL) -Session = sessionmaker(bind=engine) +session = sessionmaker(bind=engine)() -Base.metadata.create_all(engine) +# Create tables if not exists +Base.metadata.create_all(engine) -session = Session() - -default_categories = [ - (category.name, category.value) - for category in DefaultNewsCategories -] - -for category_name, category_id in default_categories: - get_or_create(session, model=NewsCategories, id=category_id, name=category_name) +# Gather the data for and create the default news category entries +for category in DefaultNewsCategories: + get_or_create(session, model=NewsCategories, id=category.id, name=category.name) session.commit() session.close() diff --git a/src/db/db.py b/src/db/db.py index 3983a3f..f34e936 100644 --- a/src/db/db.py +++ b/src/db/db.py @@ -3,8 +3,7 @@ """ import logging -import aiosqlite -from os.path import isfile + from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker @@ -17,6 +16,10 @@ log = logging.getLogger(__name__) class DatabaseManager: + """ + Asynchronous database context manager for sqlite3. + """ + def __init__(self, database_url=DATABASE_ASYNC_URL): self.engine = create_async_engine(database_url, future=True) self.session_maker = sessionmaker(self.engine, class_=AsyncSession) @@ -24,124 +27,12 @@ class DatabaseManager: async def __aenter__(self): self.session = self.session_maker() + log.debug("Database connection open") return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, *_): await self.session.close() await self.session.commit() self.session = None await self.engine.dispose() - - -# class DBConnection: -# """ -# Asynchronous context manager for database connections. -# """ - -# def __init__(self): -# self.conn = None - -# async def __aenter__(self): -# self.conn = await connect() -# return self.conn - -# async def __aexit__(self, *args): -# await close(self.conn) - - -# async def connect(): -# log.info("Opening database connection") -# return await aiosqlite.connect(DB_PATH) - -# async def with_commit(func): -# """ -# Wrapper to commit changes to the database. -# """ -# async def inner(*args, **kwargs): -# await func(*args, **kwargs) -# await commit() - -# return inner - -# async def build(conn): -# """ -# Build the database from the build script. -# """ -# log.info("Building database from build script") - -# if isfile(BUILD_PATH): -# await scriptexec(conn, BUILD_PATH) -# return - -# raise ValueError('Build script not found') - -# async def commit(conn): -# """ -# Commit changes to the database. -# """ -# log.info("Committing database changes") -# await conn.commit() - -# async def close(conn): -# """ -# Close the database connection. -# """ -# log.debug("Closing database connection") -# await conn.close() - -# async def field(conn, cmd, *vals): -# """ -# Return a single field. -# """ -# log.debug("Executing command for field: %s, vals:%s", cmd, vals) -# async with conn.execute(cmd, tuple(vals)) as cur: -# if (fetch := await cur.fetchone()) is not None: -# return fetch[0] - -# async def record(conn, cmd, *vals): -# """ -# Return a single record. -# """ -# log.debug("Executing command for record: %s, vals: %s", cmd, vals) -# async with conn.execute(cmd, tuple(vals)) as cur: -# return await cur.fetchone() - -# async def records(conn, cmd, *vals): -# """ -# Return all records. -# """ -# log.debug("Executing command for records: %s, vals: %s", cmd, vals) -# async with conn.execute(cmd, tuple(vals)) as cur: -# return await cur.fetchall() - -# async def column(conn, cmd, *vals): -# """ -# Return a single column. -# """ -# log.debug("Executing command for column: %s, vals: %s", cmd, vals) -# async with conn.execute(cmd, tuple(vals)) as cur: -# return [item[0] for item in await cur.fetchall()] - -# async def execute(conn, cmd, *vals): -# """ -# Execute a command. -# """ -# log.debug("Executing command: %s, vals: %s", cmd, vals) -# async with conn.execute(cmd, tuple(vals)) as cur: -# return cur - -# async def multiexec(conn, cmd, valset): -# """ -# Execute multiple commands. -# """ -# log.debug("Executing multiple commands: %s, valset: %s", cmd, valset) -# async with conn.executemany(cmd, valset): -# pass - -# async def scriptexec(conn, path): -# """ -# Execute a script. -# """ -# log.debug("Executing script: %s", path) -# with open(path, 'r', encoding='utf-8') as script: -# await conn.executescript(script.read()) + log.debug("Database connection closed") diff --git a/src/db/models.py b/src/db/models.py index 25c0992..86636aa 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -1,5 +1,5 @@ """ - +Models and Enums for the database. """ from enum import Enum, auto @@ -11,19 +11,29 @@ from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() def get_or_create(session, model, **kwargs): + """ + Returns an instance of a database model, creates one and returns that if not found. + """ + instance = session.query(model).filter_by(**kwargs).first() - if instance: - return instance - else: + + if not instance: instance = model(**kwargs) session.add(instance) session.commit() - return instance + + return instance class ServerChannels(Base): + """ + Database table for server channels. + These are discord server channels to be included when sharing news from certain categories. + """ + __tablename__ = 'server_channels' + # Columns id = Column(Integer, primary_key=True, autoincrement=True) server_id = Column(Integer, nullable=False) channel_id = Column(Integer, nullable=False) @@ -33,23 +43,37 @@ class ServerChannels(Base): class NewsArticles(Base): + """ + Database table for news articles. + This is not for information about news articles, but rather for keeping track of where + certain articles have been sent to, so that they arent shared twice in the same place. + """ + __tablename__ = 'news_articles' + # Columns id = Column(Integer, primary_key=True, autoincrement=True) url = Column(String, nullable=False) server_channel_id = Column(Integer, ForeignKey('server_channels.id'), nullable=False) class NewsCategories(Base): + """ + Database table for news categories. + Used to categorise news articles to RSS feeds. Contains default entries that should not be + modified or added to. + """ + __tablename__ = 'news_categories' + # Columns id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String, unique=True, nullable=False) class DefaultNewsCategories(Enum): """ - + Enum for news categories. Used to create default data in the database table of the same name. """ ALL = auto()