diff --git a/db.sqlite b/db.sqlite new file mode 100644 index 0000000..1ed4214 Binary files /dev/null and b/db.sqlite differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2b85b06 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +aiohttp==3.9.1 +aiosignal==1.3.1 +aiosqlite==0.19.0 +attrs==23.1.0 +discord.py==2.3.2 +frozenlist==1.4.0 +greenlet==3.0.2 +idna==3.6 +multidict==6.0.4 +psycopg2==2.9.9 +python-dotenv==1.0.0 +SQLAlchemy==2.0.23 +typing_extensions==4.9.0 +yarl==1.9.4 diff --git a/src/db/__init__.py b/src/db/__init__.py new file mode 100644 index 0000000..4e3ad91 --- /dev/null +++ b/src/db/__init__.py @@ -0,0 +1,20 @@ +""" +Initialize the database modules, create the database tables and default data. +""" + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from .models import Base, AuditModel +from .db import DatabaseManager + +# Initialise a database session +engine = create_engine(DatabaseManager.get_database_url(use_async=False)) +session = sessionmaker(bind=engine)() + +# Create tables if not exists +Base.metadata.create_all(engine) + +session.commit() +session.close() + diff --git a/src/db/db.py b/src/db/db.py new file mode 100644 index 0000000..a129616 --- /dev/null +++ b/src/db/db.py @@ -0,0 +1,59 @@ +""" +Database Manager +""" + +import logging +from os import getenv + +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker + +DB_TYPE = getenv("DB_TYPE", default="sqlite") +DB_HOST = getenv("DB_HOST", default="db.sqlite") +DB_PORT = getenv("DB_PORT") +DB_USERNAME = getenv("DB_USERNAME") +DB_PASSWORD = getenv("DB_PASSWORD") +DB_DATABASE = getenv("DB_DATABASE") + +log = logging.getLogger(__name__) + + +class DatabaseManager: + """ + Asynchronous database context manager. + """ + + def __init__(self): + database_url = self.get_database_url() + self.engine = create_async_engine(database_url, future=True) + self.session_maker = sessionmaker(self.engine, class_=AsyncSession) + self.session = None + + @staticmethod + def get_database_url(use_async=True): + """ + Returns a connection string for the database. + """ + + if DB_TYPE not in ("sqlite", "mariadb", "mysql", "postgresql"): + raise ValueError(f"Unknown Database Type: {DB_TYPE}") + + is_sqlite = DB_TYPE == "sqlite" + + url = f"sqlite:///{DB_HOST}" if is_sqlite else f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}" + url = url.replace(":/", "+aiosqlite:/" if is_sqlite else "+asyncpg:/") if use_async else url + + return url + + + async def __aenter__(self): + self.session = self.session_maker() + log.debug("Database connection open") + return self + + async def __aexit__(self, *_): + await self.session.commit() + await self.session.close() + self.session = None + await self.engine.dispose() + log.debug("Database connection closed") diff --git a/src/db/models.py b/src/db/models.py new file mode 100644 index 0000000..0cfc886 --- /dev/null +++ b/src/db/models.py @@ -0,0 +1,27 @@ +""" +Models and Enums for the database. +All table classes should be suffixed with `Model`. +""" + +from enum import Enum, auto + +from sqlalchemy import Column, Integer, String, DateTime, BigInteger +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + + +class AuditModel(Base): + """ + Table for taking audits. + """ + + __tablename__ = "audit" + + id = Column(Integer, primary_key=True, autoincrement=True) + discord_user_id = Column(BigInteger, nullable=False) + message = Column(String, nullable=False) + created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + active = Column(Integer, default=True, nullable=False) diff --git a/src/extensions/test.py b/src/extensions/test.py index 94d5c21..95fb2e7 100644 --- a/src/extensions/test.py +++ b/src/extensions/test.py @@ -7,6 +7,9 @@ import logging from discord import app_commands, Interaction from discord.ext import commands, tasks +from sqlalchemy import insert, select + +from db import DatabaseManager, AuditModel log = logging.getLogger(__name__) @@ -27,7 +30,13 @@ class Test(commands.Cog): @app_commands.command(name="test-command") async def test_command(self, inter: Interaction): - await inter.response.send_message("test") + + async with DatabaseManager() as database: + message = f"Test command has been invoked successfully!" + query = insert(AuditModel).values(discord_user_id=inter.user.id, message=message) + await database.session.execute(query) + + await inter.response.send_message("the audit log test was successful") async def setup(bot):