Comments and Clean-up: Database
This commit is contained in:
parent
23a6da030e
commit
9876032289
@ -1,27 +1,23 @@
|
||||
"""
|
||||
Initialize the database modules.
|
||||
Initialize the database modules and create the database tables and default data.
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from .models import NewsCategories, Base, DefaultNewsCategories, get_or_create
|
||||
from .db import DATABASE_URL
|
||||
from .models import NewsCategories, Base, get_or_create, DefaultNewsCategories
|
||||
|
||||
# Initialise a database session, create the tables and insert the default data if needed
|
||||
engine = create_engine(DATABASE_URL)
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = sessionmaker(bind=engine)()
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
# Create tables if not exists
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
session = Session()
|
||||
|
||||
default_categories = [
|
||||
(category.name, category.value)
|
||||
for category in DefaultNewsCategories
|
||||
]
|
||||
|
||||
for category_name, category_id in default_categories:
|
||||
get_or_create(session, model=NewsCategories, id=category_id, name=category_name)
|
||||
# Gather the data for and create the default news category entries
|
||||
for category in DefaultNewsCategories:
|
||||
get_or_create(session, model=NewsCategories, id=category.id, name=category.name)
|
||||
|
||||
session.commit()
|
||||
session.close()
|
||||
|
125
src/db/db.py
125
src/db/db.py
@ -3,8 +3,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
import aiosqlite
|
||||
from os.path import isfile
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@ -17,6 +16,10 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
Asynchronous database context manager for sqlite3.
|
||||
"""
|
||||
|
||||
def __init__(self, database_url=DATABASE_ASYNC_URL):
|
||||
self.engine = create_async_engine(database_url, future=True)
|
||||
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
|
||||
@ -24,124 +27,12 @@ class DatabaseManager:
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = self.session_maker()
|
||||
log.debug("Database connection open")
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
async def __aexit__(self, *_):
|
||||
await self.session.close()
|
||||
await self.session.commit()
|
||||
self.session = None
|
||||
await self.engine.dispose()
|
||||
|
||||
|
||||
# class DBConnection:
|
||||
# """
|
||||
# Asynchronous context manager for database connections.
|
||||
# """
|
||||
|
||||
# def __init__(self):
|
||||
# self.conn = None
|
||||
|
||||
# async def __aenter__(self):
|
||||
# self.conn = await connect()
|
||||
# return self.conn
|
||||
|
||||
# async def __aexit__(self, *args):
|
||||
# await close(self.conn)
|
||||
|
||||
|
||||
# async def connect():
|
||||
# log.info("Opening database connection")
|
||||
# return await aiosqlite.connect(DB_PATH)
|
||||
|
||||
# async def with_commit(func):
|
||||
# """
|
||||
# Wrapper to commit changes to the database.
|
||||
# """
|
||||
# async def inner(*args, **kwargs):
|
||||
# await func(*args, **kwargs)
|
||||
# await commit()
|
||||
|
||||
# return inner
|
||||
|
||||
# async def build(conn):
|
||||
# """
|
||||
# Build the database from the build script.
|
||||
# """
|
||||
# log.info("Building database from build script")
|
||||
|
||||
# if isfile(BUILD_PATH):
|
||||
# await scriptexec(conn, BUILD_PATH)
|
||||
# return
|
||||
|
||||
# raise ValueError('Build script not found')
|
||||
|
||||
# async def commit(conn):
|
||||
# """
|
||||
# Commit changes to the database.
|
||||
# """
|
||||
# log.info("Committing database changes")
|
||||
# await conn.commit()
|
||||
|
||||
# async def close(conn):
|
||||
# """
|
||||
# Close the database connection.
|
||||
# """
|
||||
# log.debug("Closing database connection")
|
||||
# await conn.close()
|
||||
|
||||
# async def field(conn, cmd, *vals):
|
||||
# """
|
||||
# Return a single field.
|
||||
# """
|
||||
# log.debug("Executing command for field: %s, vals:%s", cmd, vals)
|
||||
# async with conn.execute(cmd, tuple(vals)) as cur:
|
||||
# if (fetch := await cur.fetchone()) is not None:
|
||||
# return fetch[0]
|
||||
|
||||
# async def record(conn, cmd, *vals):
|
||||
# """
|
||||
# Return a single record.
|
||||
# """
|
||||
# log.debug("Executing command for record: %s, vals: %s", cmd, vals)
|
||||
# async with conn.execute(cmd, tuple(vals)) as cur:
|
||||
# return await cur.fetchone()
|
||||
|
||||
# async def records(conn, cmd, *vals):
|
||||
# """
|
||||
# Return all records.
|
||||
# """
|
||||
# log.debug("Executing command for records: %s, vals: %s", cmd, vals)
|
||||
# async with conn.execute(cmd, tuple(vals)) as cur:
|
||||
# return await cur.fetchall()
|
||||
|
||||
# async def column(conn, cmd, *vals):
|
||||
# """
|
||||
# Return a single column.
|
||||
# """
|
||||
# log.debug("Executing command for column: %s, vals: %s", cmd, vals)
|
||||
# async with conn.execute(cmd, tuple(vals)) as cur:
|
||||
# return [item[0] for item in await cur.fetchall()]
|
||||
|
||||
# async def execute(conn, cmd, *vals):
|
||||
# """
|
||||
# Execute a command.
|
||||
# """
|
||||
# log.debug("Executing command: %s, vals: %s", cmd, vals)
|
||||
# async with conn.execute(cmd, tuple(vals)) as cur:
|
||||
# return cur
|
||||
|
||||
# async def multiexec(conn, cmd, valset):
|
||||
# """
|
||||
# Execute multiple commands.
|
||||
# """
|
||||
# log.debug("Executing multiple commands: %s, valset: %s", cmd, valset)
|
||||
# async with conn.executemany(cmd, valset):
|
||||
# pass
|
||||
|
||||
# async def scriptexec(conn, path):
|
||||
# """
|
||||
# Execute a script.
|
||||
# """
|
||||
# log.debug("Executing script: %s", path)
|
||||
# with open(path, 'r', encoding='utf-8') as script:
|
||||
# await conn.executescript(script.read())
|
||||
log.debug("Database connection closed")
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
|
||||
Models and Enums for the database.
|
||||
"""
|
||||
|
||||
from enum import Enum, auto
|
||||
@ -11,19 +11,29 @@ from sqlalchemy.ext.declarative import declarative_base
|
||||
Base = declarative_base()
|
||||
|
||||
def get_or_create(session, model, **kwargs):
|
||||
"""
|
||||
Returns an instance of a database model, creates one and returns that if not found.
|
||||
"""
|
||||
|
||||
instance = session.query(model).filter_by(**kwargs).first()
|
||||
if instance:
|
||||
return instance
|
||||
else:
|
||||
|
||||
if not instance:
|
||||
instance = model(**kwargs)
|
||||
session.add(instance)
|
||||
session.commit()
|
||||
return instance
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
class ServerChannels(Base):
|
||||
"""
|
||||
Database table for server channels.
|
||||
These are discord server channels to be included when sharing news from certain categories.
|
||||
"""
|
||||
|
||||
__tablename__ = 'server_channels'
|
||||
|
||||
# Columns
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
server_id = Column(Integer, nullable=False)
|
||||
channel_id = Column(Integer, nullable=False)
|
||||
@ -33,23 +43,37 @@ class ServerChannels(Base):
|
||||
|
||||
|
||||
class NewsArticles(Base):
|
||||
"""
|
||||
Database table for news articles.
|
||||
This is not for information about news articles, but rather for keeping track of where
|
||||
certain articles have been sent to, so that they arent shared twice in the same place.
|
||||
"""
|
||||
|
||||
__tablename__ = 'news_articles'
|
||||
|
||||
# Columns
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
url = Column(String, nullable=False)
|
||||
server_channel_id = Column(Integer, ForeignKey('server_channels.id'), nullable=False)
|
||||
|
||||
|
||||
class NewsCategories(Base):
|
||||
"""
|
||||
Database table for news categories.
|
||||
Used to categorise news articles to RSS feeds. Contains default entries that should not be
|
||||
modified or added to.
|
||||
"""
|
||||
|
||||
__tablename__ = 'news_categories'
|
||||
|
||||
# Columns
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String, unique=True, nullable=False)
|
||||
|
||||
|
||||
class DefaultNewsCategories(Enum):
|
||||
"""
|
||||
|
||||
Enum for news categories. Used to create default data in the database table of the same name.
|
||||
"""
|
||||
|
||||
ALL = auto()
|
||||
|
Loading…
x
Reference in New Issue
Block a user