database
This commit is contained in:
parent
fd962a9e76
commit
cec1db209b
14
requirements.txt
Normal file
14
requirements.txt
Normal file
@ -0,0 +1,14 @@
|
||||
aiohttp==3.9.1
|
||||
aiosignal==1.3.1
|
||||
aiosqlite==0.19.0
|
||||
attrs==23.1.0
|
||||
discord.py==2.3.2
|
||||
frozenlist==1.4.0
|
||||
greenlet==3.0.2
|
||||
idna==3.6
|
||||
multidict==6.0.4
|
||||
psycopg2==2.9.9
|
||||
python-dotenv==1.0.0
|
||||
SQLAlchemy==2.0.23
|
||||
typing_extensions==4.9.0
|
||||
yarl==1.9.4
|
20
src/db/__init__.py
Normal file
20
src/db/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""
|
||||
Initialize the database modules, create the database tables and default data.
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from .models import Base, AuditModel
|
||||
from .db import DatabaseManager
|
||||
|
||||
# Initialise a database session
|
||||
engine = create_engine(DatabaseManager.get_database_url(use_async=False))
|
||||
session = sessionmaker(bind=engine)()
|
||||
|
||||
# Create tables if not exists
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
session.commit()
|
||||
session.close()
|
||||
|
59
src/db/db.py
Normal file
59
src/db/db.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""
|
||||
Database Manager
|
||||
"""
|
||||
|
||||
import logging
|
||||
from os import getenv
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
DB_TYPE = getenv("DB_TYPE", default="sqlite")
|
||||
DB_HOST = getenv("DB_HOST", default="db.sqlite")
|
||||
DB_PORT = getenv("DB_PORT")
|
||||
DB_USERNAME = getenv("DB_USERNAME")
|
||||
DB_PASSWORD = getenv("DB_PASSWORD")
|
||||
DB_DATABASE = getenv("DB_DATABASE")
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
Asynchronous database context manager.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
database_url = self.get_database_url()
|
||||
self.engine = create_async_engine(database_url, future=True)
|
||||
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
|
||||
self.session = None
|
||||
|
||||
@staticmethod
|
||||
def get_database_url(use_async=True):
|
||||
"""
|
||||
Returns a connection string for the database.
|
||||
"""
|
||||
|
||||
if DB_TYPE not in ("sqlite", "mariadb", "mysql", "postgresql"):
|
||||
raise ValueError(f"Unknown Database Type: {DB_TYPE}")
|
||||
|
||||
is_sqlite = DB_TYPE == "sqlite"
|
||||
|
||||
url = f"sqlite:///{DB_HOST}" if is_sqlite else f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}"
|
||||
url = url.replace(":/", "+aiosqlite:/" if is_sqlite else "+asyncpg:/") if use_async else url
|
||||
|
||||
return url
|
||||
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = self.session_maker()
|
||||
log.debug("Database connection open")
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_):
|
||||
await self.session.commit()
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
await self.engine.dispose()
|
||||
log.debug("Database connection closed")
|
27
src/db/models.py
Normal file
27
src/db/models.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""
|
||||
Models and Enums for the database.
|
||||
All table classes should be suffixed with `Model`.
|
||||
"""
|
||||
|
||||
from enum import Enum, auto
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, BigInteger
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class AuditModel(Base):
|
||||
"""
|
||||
Table for taking audits.
|
||||
"""
|
||||
|
||||
__tablename__ = "audit"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
discord_user_id = Column(BigInteger, nullable=False)
|
||||
message = Column(String, nullable=False)
|
||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
active = Column(Integer, default=True, nullable=False)
|
@ -7,6 +7,9 @@ import logging
|
||||
|
||||
from discord import app_commands, Interaction
|
||||
from discord.ext import commands, tasks
|
||||
from sqlalchemy import insert, select
|
||||
|
||||
from db import DatabaseManager, AuditModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -27,7 +30,13 @@ class Test(commands.Cog):
|
||||
|
||||
@app_commands.command(name="test-command")
|
||||
async def test_command(self, inter: Interaction):
|
||||
await inter.response.send_message("test")
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
message = f"Test command has been invoked successfully!"
|
||||
query = insert(AuditModel).values(discord_user_id=inter.user.id, message=message)
|
||||
await database.session.execute(query)
|
||||
|
||||
await inter.response.send_message("the audit log test was successful")
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
|
Loading…
x
Reference in New Issue
Block a user