Comments and Clean-up: Database

This commit is contained in:
Corban-Lee 2023-07-10 13:55:37 +01:00
parent 23a6da030e
commit 9876032289
3 changed files with 47 additions and 136 deletions

View File

@ -1,27 +1,23 @@
"""
Initialize the database modules.
Initialize the database modules and create the database tables and default data.
"""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from .models import NewsCategories, Base, DefaultNewsCategories, get_or_create
from .db import DATABASE_URL
from .models import NewsCategories, Base, get_or_create, DefaultNewsCategories
# Initialise a database session, create the tables and insert the default data if needed
engine = create_engine(DATABASE_URL)
Session = sessionmaker(bind=engine)
session = sessionmaker(bind=engine)()
Base.metadata.create_all(engine)
# Create tables if not exists
Base.metadata.create_all(engine)
session = Session()
default_categories = [
(category.name, category.value)
for category in DefaultNewsCategories
]
for category_name, category_id in default_categories:
get_or_create(session, model=NewsCategories, id=category_id, name=category_name)
# Gather the data for and create the default news category entries
for category in DefaultNewsCategories:
get_or_create(session, model=NewsCategories, id=category.id, name=category.name)
session.commit()
session.close()

View File

@ -3,8 +3,7 @@
"""
import logging
import aiosqlite
from os.path import isfile
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
@ -17,6 +16,10 @@ log = logging.getLogger(__name__)
class DatabaseManager:
"""
Asynchronous database context manager for sqlite3.
"""
def __init__(self, database_url=DATABASE_ASYNC_URL):
self.engine = create_async_engine(database_url, future=True)
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
@ -24,124 +27,12 @@ class DatabaseManager:
async def __aenter__(self):
self.session = self.session_maker()
log.debug("Database connection open")
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(self, *_):
await self.session.close()
await self.session.commit()
self.session = None
await self.engine.dispose()
# class DBConnection:
# """
# Asynchronous context manager for database connections.
# """
# def __init__(self):
# self.conn = None
# async def __aenter__(self):
# self.conn = await connect()
# return self.conn
# async def __aexit__(self, *args):
# await close(self.conn)
# async def connect():
# log.info("Opening database connection")
# return await aiosqlite.connect(DB_PATH)
# async def with_commit(func):
# """
# Wrapper to commit changes to the database.
# """
# async def inner(*args, **kwargs):
# await func(*args, **kwargs)
# await commit()
# return inner
# async def build(conn):
# """
# Build the database from the build script.
# """
# log.info("Building database from build script")
# if isfile(BUILD_PATH):
# await scriptexec(conn, BUILD_PATH)
# return
# raise ValueError('Build script not found')
# async def commit(conn):
# """
# Commit changes to the database.
# """
# log.info("Committing database changes")
# await conn.commit()
# async def close(conn):
# """
# Close the database connection.
# """
# log.debug("Closing database connection")
# await conn.close()
# async def field(conn, cmd, *vals):
# """
# Return a single field.
# """
# log.debug("Executing command for field: %s, vals:%s", cmd, vals)
# async with conn.execute(cmd, tuple(vals)) as cur:
# if (fetch := await cur.fetchone()) is not None:
# return fetch[0]
# async def record(conn, cmd, *vals):
# """
# Return a single record.
# """
# log.debug("Executing command for record: %s, vals: %s", cmd, vals)
# async with conn.execute(cmd, tuple(vals)) as cur:
# return await cur.fetchone()
# async def records(conn, cmd, *vals):
# """
# Return all records.
# """
# log.debug("Executing command for records: %s, vals: %s", cmd, vals)
# async with conn.execute(cmd, tuple(vals)) as cur:
# return await cur.fetchall()
# async def column(conn, cmd, *vals):
# """
# Return a single column.
# """
# log.debug("Executing command for column: %s, vals: %s", cmd, vals)
# async with conn.execute(cmd, tuple(vals)) as cur:
# return [item[0] for item in await cur.fetchall()]
# async def execute(conn, cmd, *vals):
# """
# Execute a command.
# """
# log.debug("Executing command: %s, vals: %s", cmd, vals)
# async with conn.execute(cmd, tuple(vals)) as cur:
# return cur
# async def multiexec(conn, cmd, valset):
# """
# Execute multiple commands.
# """
# log.debug("Executing multiple commands: %s, valset: %s", cmd, valset)
# async with conn.executemany(cmd, valset):
# pass
# async def scriptexec(conn, path):
# """
# Execute a script.
# """
# log.debug("Executing script: %s", path)
# with open(path, 'r', encoding='utf-8') as script:
# await conn.executescript(script.read())
log.debug("Database connection closed")

View File

@ -1,5 +1,5 @@
"""
Models and Enums for the database.
"""
from enum import Enum, auto
@ -11,19 +11,29 @@ from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
def get_or_create(session, model, **kwargs):
"""
Returns an instance of a database model, creates one and returns that if not found.
"""
instance = session.query(model).filter_by(**kwargs).first()
if instance:
return instance
else:
if not instance:
instance = model(**kwargs)
session.add(instance)
session.commit()
return instance
return instance
class ServerChannels(Base):
"""
Database table for server channels.
These are discord server channels to be included when sharing news from certain categories.
"""
__tablename__ = 'server_channels'
# Columns
id = Column(Integer, primary_key=True, autoincrement=True)
server_id = Column(Integer, nullable=False)
channel_id = Column(Integer, nullable=False)
@ -33,23 +43,37 @@ class ServerChannels(Base):
class NewsArticles(Base):
"""
Database table for news articles.
This is not for information about news articles, but rather for keeping track of where
certain articles have been sent to, so that they arent shared twice in the same place.
"""
__tablename__ = 'news_articles'
# Columns
id = Column(Integer, primary_key=True, autoincrement=True)
url = Column(String, nullable=False)
server_channel_id = Column(Integer, ForeignKey('server_channels.id'), nullable=False)
class NewsCategories(Base):
"""
Database table for news categories.
Used to categorise news articles to RSS feeds. Contains default entries that should not be
modified or added to.
"""
__tablename__ = 'news_categories'
# Columns
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String, unique=True, nullable=False)
class DefaultNewsCategories(Enum):
"""
Enum for news categories. Used to create default data in the database table of the same name.
"""
ALL = auto()