complete rewrite
This commit is contained in:
parent
ae21eab7f1
commit
8dc97e9e70
16
.vscode/launch.json
vendored
Normal file
16
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Discord Bot",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/src/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
}
|
||||
]
|
||||
}
|
42
db/build.sql
Normal file
42
db/build.sql
Normal file
@ -0,0 +1,42 @@
|
||||
|
||||
/*
|
||||
Server Channels
|
||||
*/
|
||||
CREATE TABLE IF NOT EXISTS 'server_channels' (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id INTEGER NOT NULL,
|
||||
news_category_id INTEGER NOT NULL,
|
||||
active INTEGER NOT NULL,
|
||||
FOREIGN KEY (news_category_id) REFERENCES 'news_categories' (id)
|
||||
ON DELETE CASCADE
|
||||
);
|
||||
|
||||
|
||||
/*
|
||||
News Articles
|
||||
*/
|
||||
CREATE TABLE IF NOT EXISTS 'news_articles' (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
url TEXT NOT NULL,
|
||||
server_channel_id INTEGER NOT NULL,
|
||||
FOREIGN KEY (server_channel_id) REFERENCES 'server_channels' (id)
|
||||
ON DELETE CASCADE
|
||||
);
|
||||
|
||||
|
||||
/*
|
||||
News Categories
|
||||
*/
|
||||
CREATE TABLE IF NOT EXISTS 'news_categories' (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL
|
||||
);
|
||||
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('all');
|
||||
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('world');
|
||||
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('uk');
|
||||
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('north_america');
|
||||
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('entertainment');
|
||||
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('business');
|
||||
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('tech');
|
||||
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('science');
|
||||
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('top_stories');
|
BIN
db/db.sqlite
Normal file
BIN
db/db.sqlite
Normal file
Binary file not shown.
165
src/bot.py
165
src/bot.py
@ -1,6 +1,10 @@
|
||||
"""The discord bot for the application"""
|
||||
"""
|
||||
The discord bot for the application.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
@ -10,6 +14,9 @@ from discord import Intents, Interaction, app_commands
|
||||
from discord.ext import commands, tasks
|
||||
from bbc_feeds import news
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
EXTENSIONS_DIRECTORY = "src/extensions/"
|
||||
|
||||
|
||||
class DiscordBot(commands.Bot):
|
||||
|
||||
@ -17,24 +24,36 @@ class DiscordBot(commands.Bot):
|
||||
super().__init__(command_prefix="-", intents=Intents.all())
|
||||
|
||||
async def sync_app_commands(self):
|
||||
"""Sync application commands"""
|
||||
"""
|
||||
Sync application commands.
|
||||
"""
|
||||
|
||||
await self.wait_until_ready()
|
||||
await self.tree.sync()
|
||||
print("app commands synced")
|
||||
log.info("Application commands successfully synced")
|
||||
|
||||
async def on_ready(self):
|
||||
"""When the bot is ready"""
|
||||
"""
|
||||
When the bot is ready.
|
||||
"""
|
||||
|
||||
await self.add_cog(CommandsCog(self))
|
||||
await self.add_cog(ErrorCog(self))
|
||||
await self.sync_app_commands()
|
||||
|
||||
async def load_extensions(self):
|
||||
"""
|
||||
Load any extensions found in the extensions dictionary.
|
||||
"""
|
||||
|
||||
for filename in os.listdir(EXTENSIONS_DIRECTORY):
|
||||
if filename.endswith(".py"):
|
||||
await self.load_extension(f"extensions.{filename[:-3]}")
|
||||
|
||||
|
||||
class CommandsCog(commands.Cog):
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.news_task.start()
|
||||
# self.news_task.start()
|
||||
|
||||
async def story_to_embed(self, story) -> discord.Embed:
|
||||
"""
|
||||
@ -97,135 +116,3 @@ class CommandsCog(commands.Cog):
|
||||
embed = await self.story_to_embed(story)
|
||||
channel = self.bot.get_channel(1057004889458348042)
|
||||
await channel.send(embed=embed)
|
||||
|
||||
|
||||
class ErrorCog(commands.Cog):
|
||||
"""Error handling cog."""
|
||||
|
||||
__slots__ = ()
|
||||
default_err_msg = "I'm sorry, but I've encountered an " \
|
||||
"error while processing your command."
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
|
||||
# Register the error handler
|
||||
bot.tree.error(coro = self._dispatch_to_app_command_handler)
|
||||
|
||||
def trace_error(self, error: Exception):
|
||||
print(f"{type(error).__name__} {error}")
|
||||
raise error
|
||||
|
||||
async def _dispatch_to_app_command_handler(
|
||||
self,
|
||||
inter: Interaction,
|
||||
error: app_commands.AppCommandError
|
||||
):
|
||||
"""Dispatches the error to the app command handler"""
|
||||
|
||||
self.bot.dispatch("app_command_error", inter, error)
|
||||
|
||||
async def _respond_to_interaction(self, inter: Interaction) -> bool:
|
||||
"""Respond to an interaction with an error message"""
|
||||
|
||||
try:
|
||||
await inter.response.send_message(
|
||||
self.default_err_msg,
|
||||
ephemeral=True
|
||||
)
|
||||
except discord.InteractionResponded:
|
||||
return
|
||||
|
||||
@commands.Cog.listener("on_app_command_error")
|
||||
async def get_app_command_error(
|
||||
self,
|
||||
inter: Interaction,
|
||||
error: app_commands.AppCommandError
|
||||
):
|
||||
"""Handles the application command error.
|
||||
|
||||
Responds with the appropriate error message.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Send the default error message and create an edit
|
||||
# shorthand to add more details to the message once
|
||||
# we've figured out what the error is.
|
||||
print(error.with_traceback(None))
|
||||
await self._respond_to_interaction(inter)
|
||||
edit = lambda x: inter.edit_original_response(content=x)
|
||||
|
||||
raise error
|
||||
|
||||
except app_commands.CommandInvokeError as _err:
|
||||
|
||||
# The interaction has already been responded to.
|
||||
if isinstance(
|
||||
_err.original,
|
||||
discord.InteractionResponded
|
||||
):
|
||||
await edit(_err.original)
|
||||
return
|
||||
|
||||
# Some other error occurred while invoking the command.
|
||||
await edit(
|
||||
f"`{type(_err.original).__name__}` " \
|
||||
f": {_err.original}"
|
||||
)
|
||||
|
||||
except app_commands.CheckFailure as _err:
|
||||
|
||||
# The command is still on cooldown.
|
||||
if isinstance(
|
||||
_err,
|
||||
app_commands.CommandOnCooldown
|
||||
):
|
||||
await edit(
|
||||
f"Woah, slow down! This command is on cooldown, " \
|
||||
f"wait `{str(_err).split(' ')[7]}` !"
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(
|
||||
_err,
|
||||
app_commands.MissingPermissions
|
||||
):
|
||||
await edit(
|
||||
"You don't have the required permissions to " \
|
||||
"run this command!"
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(
|
||||
_err,
|
||||
app_commands.BotMissingPermissions
|
||||
):
|
||||
await edit(
|
||||
"I don't have the required permissions to " \
|
||||
"run this command! Please ask an admin to " \
|
||||
"grant me the required permissions."
|
||||
)
|
||||
return
|
||||
|
||||
# A different check has failed.
|
||||
await edit(f"`{type(_err).__name__}` : {_err}")
|
||||
|
||||
except app_commands.CommandNotFound:
|
||||
|
||||
# The command could not be found.
|
||||
await edit(
|
||||
f"I couldn't find the command you were looking for... "
|
||||
"\nThis is probably a discord bug related to " \
|
||||
"desynchronization between my commands and discord's " \
|
||||
"servers. Please try again later."
|
||||
)
|
||||
|
||||
except Exception as _err:
|
||||
# Caught here:
|
||||
# app_commands.TransformerError
|
||||
# app_commands.CommandLimitReached
|
||||
# app_commands.CommandAlreadyRegistered
|
||||
# app_commands.CommandSignatureMismatch
|
||||
|
||||
self.trace_error(_err)
|
||||
|
27
src/db/__init__.py
Normal file
27
src/db/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""
|
||||
Initialize the database modules.
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from .db import DATABASE_URL
|
||||
from .models import NewsCategories, Base, get_or_create, DefaultNewsCategories
|
||||
|
||||
engine = create_engine(DATABASE_URL)
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
session = Session()
|
||||
|
||||
default_categories = [
|
||||
(category.name, category.value)
|
||||
for category in DefaultNewsCategories
|
||||
]
|
||||
|
||||
for category_name, category_id in default_categories:
|
||||
get_or_create(session, model=NewsCategories, id=category_id, name=category_name)
|
||||
|
||||
session.commit()
|
||||
session.close()
|
147
src/db/db.py
Normal file
147
src/db/db.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import aiosqlite
|
||||
from os.path import isfile
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
DB_PATH = "db/db.sqlite"
|
||||
BUILD_PATH = "db/build.sql"
|
||||
DATABASE_URL = "sqlite:///db/db.sqlite"
|
||||
DATABASE_ASYNC_URL = "sqlite+aiosqlite:///db/db.sqlite"
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
def __init__(self, database_url=DATABASE_ASYNC_URL):
|
||||
self.engine = create_async_engine(database_url, future=True)
|
||||
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
|
||||
self.session = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = self.session_maker()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.session.close()
|
||||
await self.session.commit()
|
||||
self.session = None
|
||||
await self.engine.dispose()
|
||||
|
||||
|
||||
# class DBConnection:
|
||||
# """
|
||||
# Asynchronous context manager for database connections.
|
||||
# """
|
||||
|
||||
# def __init__(self):
|
||||
# self.conn = None
|
||||
|
||||
# async def __aenter__(self):
|
||||
# self.conn = await connect()
|
||||
# return self.conn
|
||||
|
||||
# async def __aexit__(self, *args):
|
||||
# await close(self.conn)
|
||||
|
||||
|
||||
# async def connect():
|
||||
# log.info("Opening database connection")
|
||||
# return await aiosqlite.connect(DB_PATH)
|
||||
|
||||
# async def with_commit(func):
|
||||
# """
|
||||
# Wrapper to commit changes to the database.
|
||||
# """
|
||||
# async def inner(*args, **kwargs):
|
||||
# await func(*args, **kwargs)
|
||||
# await commit()
|
||||
|
||||
# return inner
|
||||
|
||||
# async def build(conn):
|
||||
# """
|
||||
# Build the database from the build script.
|
||||
# """
|
||||
# log.info("Building database from build script")
|
||||
|
||||
# if isfile(BUILD_PATH):
|
||||
# await scriptexec(conn, BUILD_PATH)
|
||||
# return
|
||||
|
||||
# raise ValueError('Build script not found')
|
||||
|
||||
# async def commit(conn):
|
||||
# """
|
||||
# Commit changes to the database.
|
||||
# """
|
||||
# log.info("Committing database changes")
|
||||
# await conn.commit()
|
||||
|
||||
# async def close(conn):
|
||||
# """
|
||||
# Close the database connection.
|
||||
# """
|
||||
# log.debug("Closing database connection")
|
||||
# await conn.close()
|
||||
|
||||
# async def field(conn, cmd, *vals):
|
||||
# """
|
||||
# Return a single field.
|
||||
# """
|
||||
# log.debug("Executing command for field: %s, vals:%s", cmd, vals)
|
||||
# async with conn.execute(cmd, tuple(vals)) as cur:
|
||||
# if (fetch := await cur.fetchone()) is not None:
|
||||
# return fetch[0]
|
||||
|
||||
# async def record(conn, cmd, *vals):
|
||||
# """
|
||||
# Return a single record.
|
||||
# """
|
||||
# log.debug("Executing command for record: %s, vals: %s", cmd, vals)
|
||||
# async with conn.execute(cmd, tuple(vals)) as cur:
|
||||
# return await cur.fetchone()
|
||||
|
||||
# async def records(conn, cmd, *vals):
|
||||
# """
|
||||
# Return all records.
|
||||
# """
|
||||
# log.debug("Executing command for records: %s, vals: %s", cmd, vals)
|
||||
# async with conn.execute(cmd, tuple(vals)) as cur:
|
||||
# return await cur.fetchall()
|
||||
|
||||
# async def column(conn, cmd, *vals):
|
||||
# """
|
||||
# Return a single column.
|
||||
# """
|
||||
# log.debug("Executing command for column: %s, vals: %s", cmd, vals)
|
||||
# async with conn.execute(cmd, tuple(vals)) as cur:
|
||||
# return [item[0] for item in await cur.fetchall()]
|
||||
|
||||
# async def execute(conn, cmd, *vals):
|
||||
# """
|
||||
# Execute a command.
|
||||
# """
|
||||
# log.debug("Executing command: %s, vals: %s", cmd, vals)
|
||||
# async with conn.execute(cmd, tuple(vals)) as cur:
|
||||
# return cur
|
||||
|
||||
# async def multiexec(conn, cmd, valset):
|
||||
# """
|
||||
# Execute multiple commands.
|
||||
# """
|
||||
# log.debug("Executing multiple commands: %s, valset: %s", cmd, valset)
|
||||
# async with conn.executemany(cmd, valset):
|
||||
# pass
|
||||
|
||||
# async def scriptexec(conn, path):
|
||||
# """
|
||||
# Execute a script.
|
||||
# """
|
||||
# log.debug("Executing script: %s", path)
|
||||
# with open(path, 'r', encoding='utf-8') as script:
|
||||
# await conn.executescript(script.read())
|
63
src/db/models.py
Normal file
63
src/db/models.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
from enum import Enum, auto
|
||||
|
||||
from sqlalchemy import Column, Integer, String, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
def get_or_create(session, model, **kwargs):
|
||||
instance = session.query(model).filter_by(**kwargs).first()
|
||||
if instance:
|
||||
return instance
|
||||
else:
|
||||
instance = model(**kwargs)
|
||||
session.add(instance)
|
||||
session.commit()
|
||||
return instance
|
||||
|
||||
|
||||
class ServerChannels(Base):
|
||||
__tablename__ = 'server_channels'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
server_id = Column(Integer, nullable=False)
|
||||
channel_id = Column(Integer, nullable=False)
|
||||
news_category_id = Column(Integer, nullable=False)
|
||||
active = Column(Integer, nullable=False, default=True)
|
||||
news_articles = relationship('NewsArticles', cascade='all, delete')
|
||||
|
||||
|
||||
class NewsArticles(Base):
|
||||
__tablename__ = 'news_articles'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
url = Column(String, nullable=False)
|
||||
server_channel_id = Column(Integer, ForeignKey('server_channels.id'), nullable=False)
|
||||
|
||||
|
||||
class NewsCategories(Base):
|
||||
__tablename__ = 'news_categories'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String, unique=True, nullable=False)
|
||||
|
||||
|
||||
class DefaultNewsCategories(Enum):
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
ALL = auto()
|
||||
WORLD = auto()
|
||||
UK = auto()
|
||||
NORTH_AMERICA = auto()
|
||||
ENTERTAINMENT = auto()
|
||||
BUSINESS = auto()
|
||||
TECH = auto()
|
||||
SCIENCE = auto()
|
||||
TOP_STORIES = auto()
|
146
src/extensions/errors.py
Normal file
146
src/extensions/errors.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""
|
||||
Extension for the `ErrorCog` cog.
|
||||
Loading this file via `commands.Bot.load_extension` will add the `ErrorCog` cog to the bot.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from discord import app_commands, Interaction
|
||||
from discord.ext import commands
|
||||
from discord.errors import InteractionResponded
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ErrorCog(commands.Cog):
|
||||
"""
|
||||
Error handling cog.
|
||||
Discordpy has problems with error handling, this cog corrects
|
||||
"""
|
||||
|
||||
default_err_msg = "I'm sorry, but I've encountered an " \
|
||||
"error while processing your command."
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
|
||||
# Register the error handler
|
||||
bot.tree.error(coro=self._dispatch_to_app_command_handler)
|
||||
|
||||
def trace_error(self, error: Exception):
|
||||
log.error(f"{type(error).__name__} {error}")
|
||||
raise error
|
||||
|
||||
async def _dispatch_to_app_command_handler(
|
||||
self,
|
||||
inter: Interaction,
|
||||
error: app_commands.AppCommandError
|
||||
):
|
||||
"""
|
||||
Dispatches the error to the app command handler.
|
||||
"""
|
||||
|
||||
self.bot.dispatch("app_command_error", inter, error)
|
||||
|
||||
async def _respond_to_interaction(self, inter: Interaction) -> bool:
|
||||
"""
|
||||
Respond to an interaction with an error message.
|
||||
"""
|
||||
|
||||
try:
|
||||
await inter.response.send_message(
|
||||
self.default_err_msg,
|
||||
ephemeral=True
|
||||
)
|
||||
except InteractionResponded:
|
||||
log.debug("Interaction already responded to.")
|
||||
return
|
||||
|
||||
@commands.Cog.listener("on_app_command_error")
|
||||
async def get_app_command_error(
|
||||
self,
|
||||
inter: Interaction,
|
||||
error: app_commands.AppCommandError
|
||||
):
|
||||
"""
|
||||
Handles the application command error and responds with the appropriate error message.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Send the default error message and create an edit
|
||||
# shorthand to add more details to the message once
|
||||
# we've figured out what the error is.
|
||||
log.error(error.with_traceback(None))
|
||||
await self._respond_to_interaction(inter)
|
||||
edit = lambda x: inter.edit_original_response(content=x)
|
||||
|
||||
raise error
|
||||
|
||||
except app_commands.CommandInvokeError as _err:
|
||||
|
||||
# The interaction has already been responded to.
|
||||
if isinstance(_err.original, InteractionResponded):
|
||||
await edit(_err.original)
|
||||
return
|
||||
|
||||
# Some other error occurred while invoking the command.
|
||||
await edit(f"`{type(_err.original).__name__}` : {_err.original}")
|
||||
|
||||
except app_commands.CheckFailure as _err:
|
||||
|
||||
# The command is still on cooldown.
|
||||
if isinstance(_err, app_commands.CommandOnCooldown):
|
||||
await edit(
|
||||
f"Woah, slow down! This command is on cooldown, " \
|
||||
f"wait `{str(_err).split(' ')[7]}` !"
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(_err, app_commands.MissingPermissions):
|
||||
await edit(
|
||||
"You don't have the required permissions to " \
|
||||
"run this command!"
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(_err, app_commands.BotMissingPermissions):
|
||||
await edit(
|
||||
"I don't have the required permissions to " \
|
||||
"run this command! Please ask an admin to " \
|
||||
"grant me the required permissions."
|
||||
)
|
||||
return
|
||||
|
||||
# A different check has failed.
|
||||
await edit(f"`{type(_err).__name__}` : {_err}")
|
||||
|
||||
except app_commands.CommandNotFound:
|
||||
|
||||
# The command could not be found.
|
||||
await edit(
|
||||
f"I couldn't find the command you were looking for... "
|
||||
"\nThis is probably a discord bug related to " \
|
||||
"desynchronization between my commands and discord's " \
|
||||
"servers. Please try again later."
|
||||
)
|
||||
|
||||
except Exception as _err:
|
||||
# Caught here:
|
||||
# app_commands.TransformerError
|
||||
# app_commands.CommandLimitReached
|
||||
# app_commands.CommandAlreadyRegistered
|
||||
# app_commands.CommandSignatureMismatch
|
||||
|
||||
self.trace_error(_err)
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
"""
|
||||
Setup function for this extension.
|
||||
Adds the `ErrorCog` cog to the bot.
|
||||
"""
|
||||
|
||||
cog = ErrorCog(bot)
|
||||
await bot.add_cog(cog)
|
||||
log.info(f"Added {cog.__class__.__name__} cog")
|
335
src/extensions/news.py
Normal file
335
src/extensions/news.py
Normal file
@ -0,0 +1,335 @@
|
||||
"""
|
||||
Extension for the `NewsCog` cog.
|
||||
Loading this file via `commands.Bot.load_extension` will add the `NewsCog` cog to the bot.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from discord import app_commands, Interaction
|
||||
from discord.ext import commands, tasks
|
||||
from sqlalchemy import select, insert, delete, and_
|
||||
from bbc_feeds import news as news_api, feedparser
|
||||
from bs4 import BeautifulSoup as bs4
|
||||
|
||||
from db.db import DatabaseManager
|
||||
from db.models import ServerChannels, DefaultNewsCategories, NewsArticles
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
category_choices = [
|
||||
app_commands.Choice(name=category.name.replace("_", " "), value=category.value)
|
||||
for category in DefaultNewsCategories
|
||||
]
|
||||
|
||||
|
||||
class NewsStoryType:
|
||||
"""
|
||||
Type hinting class for news stories.
|
||||
"""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
title_detail: str
|
||||
summary: str
|
||||
summary_detail: str
|
||||
link: str
|
||||
links: list[str]
|
||||
guidislink: str
|
||||
published: str
|
||||
published_parsed: datetime
|
||||
|
||||
|
||||
class NewsCog(commands.Cog):
|
||||
"""
|
||||
News cog.
|
||||
Delivers embeds of news articles to discord channels.
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
self.news_task.start()
|
||||
|
||||
def fetch_articles(self, category_id: int) -> list[NewsStoryType]:
|
||||
"""
|
||||
Fetch the latest news articles from the category matching the given `category_id`.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
category_id : int
|
||||
The ID of the category to fetch news articles from.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
The given `category_id` doesn't match any known categories.
|
||||
"""
|
||||
|
||||
n = news_api()
|
||||
|
||||
# IMPORTANT:
|
||||
# The order of items in this list matters,
|
||||
# Items follow the order of items in the Enum: DefaultNewsCategories
|
||||
category_methods = [
|
||||
n.all,
|
||||
n.world,
|
||||
n.uk,
|
||||
n.north_america,
|
||||
n.entertainment,
|
||||
n.business,
|
||||
n.tech,
|
||||
n.science,
|
||||
n.top_stories
|
||||
]
|
||||
|
||||
try:
|
||||
return category_methods[category_id - 1](limit=1)
|
||||
except IndexError as err:
|
||||
raise ValueError(f"Invalid category_id: {category_id}") from err
|
||||
|
||||
async def story_to_embed(self, story: NewsStoryType, category_name: str) -> discord.Embed:
|
||||
"""
|
||||
Returns a discord.Embed object representing the given story.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
story : NewsStoryType
|
||||
Parsed details on the news story.
|
||||
|
||||
Returns
|
||||
-------
|
||||
discord.Embed
|
||||
A `discord.Embed` object updated with information on the news story.
|
||||
"""
|
||||
|
||||
# Fetch web data for the thumbnail
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(story.link) as response:
|
||||
html = await response.text()
|
||||
|
||||
# Parse the thumbnail for the news story
|
||||
soup = bs4(html, "html.parser")
|
||||
image_src = soup.select_one("meta[property='og:image']").get("content")
|
||||
|
||||
category_name = category_name.replace("_", " ").title()
|
||||
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.from_str("#FFFFFF"),
|
||||
title=story.title,
|
||||
description=story.summary,
|
||||
url=story.link,
|
||||
)
|
||||
embed.set_image(url=image_src)
|
||||
embed.set_author(name=f"BBC News • {category_name}")
|
||||
|
||||
return embed
|
||||
|
||||
async def followup_with_articles(self, inter: Interaction, category: app_commands.Choice[int]):
|
||||
"""
|
||||
Collects articles and follows up an interaction with embeds for these articles.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
inter : discord.Interaction
|
||||
The interaction between the bot and the user.
|
||||
|
||||
category : app_commands.Choice[int]
|
||||
The category chosen by the user, represented as a `Choice` object.
|
||||
"""
|
||||
|
||||
stories = self.fetch_articles(category.value)
|
||||
|
||||
if not stories:
|
||||
await inter.followup.send("No articles found")
|
||||
return
|
||||
|
||||
for story in stories:
|
||||
embed = await self.story_to_embed(story, category.name)
|
||||
await inter.followup.send(embed=embed)
|
||||
|
||||
async def _get_or_fetch_channel(self, channel_id: int) -> discord.TextChannel:
|
||||
"""
|
||||
Returns a `discord.TextChannel` object based on the given `channel_id`.
|
||||
Will try to return from cache if present, otherwise will make an API call for the channel.
|
||||
"""
|
||||
|
||||
channel = self.bot.get_channel(channel_id)
|
||||
return channel or await self.bot.fetch_channel(channel_id)
|
||||
|
||||
@tasks.loop(minutes=5)
|
||||
async def news_task(self):
|
||||
"""
|
||||
Task that checks for the latest news and shares it to discord.
|
||||
"""
|
||||
|
||||
log.info("Doing news task.")
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
for category in DefaultNewsCategories:
|
||||
|
||||
# Find channels that accept this category of news
|
||||
whereclause = and_(ServerChannels.news_category_id == category.value)
|
||||
query = select(ServerChannels).where(whereclause)
|
||||
result = await database.session.execute(query)
|
||||
server_channels = result.scalars().all()
|
||||
|
||||
if not server_channels:
|
||||
continue
|
||||
|
||||
stories = self.fetch_articles(category.value)
|
||||
if not stories:
|
||||
continue
|
||||
|
||||
story = stories[0]
|
||||
embed = None
|
||||
|
||||
for item in server_channels:
|
||||
|
||||
channel_id = item.channel_id
|
||||
|
||||
# Check the article hasn't already been shared with this channel.
|
||||
query = select(NewsArticles).where(and_(NewsArticles.server_channel_id == channel_id, NewsArticles.url == story.link))
|
||||
result = await database.session.execute(query)
|
||||
existing_items = result.scalars().all()
|
||||
|
||||
if existing_items:
|
||||
log_string = (
|
||||
"Existing items not sent"
|
||||
f"\nchannel id: {channel_id}"
|
||||
f"\narticle url: {story.link}"
|
||||
f"\ncategory: {category.name}\n"
|
||||
)
|
||||
log.info(log_string)
|
||||
continue
|
||||
|
||||
log_string = (
|
||||
"Items sent"
|
||||
f"\nchannel id: {channel_id}"
|
||||
f"\narticle url: {story.link}"
|
||||
f"\ncategory: {category.name}\n"
|
||||
)
|
||||
log.info(log_string)
|
||||
|
||||
# Add this article as shown to this channel, prevents it from being shown here again.
|
||||
query = insert(NewsArticles).values(server_channel_id=channel_id, url=story.link)
|
||||
await database.session.execute(query)
|
||||
|
||||
embed = embed or await self.story_to_embed(story, category.name)
|
||||
|
||||
channel = await self._get_or_fetch_channel(channel_id)
|
||||
await channel.send(embed=embed)
|
||||
|
||||
await database.session.commit() # commit will terminate the connection for some reason.
|
||||
|
||||
base_group = app_commands.Group(name="bbc", description="BBC News related commands")
|
||||
channels_group = app_commands.Group(parent=base_group, name="channels", description="Channel commands")
|
||||
|
||||
@base_group.command(name="news")
|
||||
@app_commands.choices(category=category_choices)
|
||||
async def get_news(self, inter: Interaction, category: app_commands.Choice[int]):
|
||||
"""
|
||||
Get the latest article from BBC news.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
await self.followup_with_articles(inter, category)
|
||||
|
||||
# @base_group.command(name="ping-me")
|
||||
# async def ping_me(self, inter: Interaction):
|
||||
# """
|
||||
# Ping you when a new news article has been posted. Use again to remove ping.
|
||||
# """
|
||||
|
||||
# await inter.response.send_message("response OK")
|
||||
|
||||
@channels_group.command(name="add-category")
|
||||
@app_commands.choices(category=category_choices)
|
||||
async def add_category(self, inter: Interaction, channel: discord.TextChannel, category: app_commands.Choice[int]):
|
||||
"""
|
||||
Assign a category to this channel. Articles under this category will be sent here.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = insert(ServerChannels).values(server_id=inter.guild_id, channel_id=channel.id, news_category_id=category.value)
|
||||
await database.session.execute(query)
|
||||
await database.session.commit()
|
||||
|
||||
category_name = category.name.replace("_", " ").title()
|
||||
embed = discord.Embed(
|
||||
title=f"Category Added • {category_name}",
|
||||
description=f"{channel.mention} will now receive news from **{category_name}**",
|
||||
colour=discord.Colour.from_str("#FFFFFF")
|
||||
)
|
||||
|
||||
await inter.followup.send(embed=embed)
|
||||
|
||||
@channels_group.command(name="del-category")
|
||||
@app_commands.choices(category=category_choices)
|
||||
async def delete_category(self, inter: Interaction, channel: discord.TextChannel, category: app_commands.Choice[int]):
|
||||
"""
|
||||
Remove a category from this channel.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
ServerChannels.channel_id == channel.id,
|
||||
ServerChannels.news_category_id == category.value
|
||||
)
|
||||
query = delete(ServerChannels).where(whereclause)
|
||||
await database.session.execute(query)
|
||||
await database.session.commit()
|
||||
|
||||
category_name = category.name.replace("_", " ").title()
|
||||
embed = discord.Embed(
|
||||
title=f"Category Removed • {category_name}",
|
||||
description=f"{channel.mention} will no longer receive news from **{category_name}**",
|
||||
colour=discord.Colour.from_str("#FFFFFF")
|
||||
)
|
||||
|
||||
await inter.followup.send(embed=embed)
|
||||
|
||||
@channels_group.command(name="lst-category")
|
||||
async def list_category(self, inter: Interaction):
|
||||
"""
|
||||
List the categories assigned to this channel.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = select(ServerChannels)
|
||||
result = await database.session.execute(query)
|
||||
server_channels = result.scalars().all()
|
||||
|
||||
output = ""
|
||||
|
||||
for item in server_channels:
|
||||
if item.server_id != inter.guild_id:
|
||||
continue
|
||||
|
||||
channel = inter.guild.get_channel(item.channel_id)
|
||||
category_name = DefaultNewsCategories(item.news_category_id).name.replace("_", " ")
|
||||
|
||||
output += f"{channel.mention} - {category_name}\n"
|
||||
|
||||
output = output or "No categories set"
|
||||
await inter.followup.send(output)
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
"""
|
||||
Setup function for this extension.
|
||||
Adds the `ErrorCog` cog to the bot.
|
||||
"""
|
||||
|
||||
cog = NewsCog(bot)
|
||||
await bot.add_cog(cog)
|
||||
log.info(f"Added {cog.__class__.__name__} cog")
|
95
src/logs.py
Normal file
95
src/logs.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""
|
||||
Handle async logging for the project.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import queue
|
||||
import logging
|
||||
from logging.handlers import QueueHandler, QueueListener
|
||||
from datetime import datetime, timedelta
|
||||
from itertools import count
|
||||
from typing import TextIO
|
||||
from pathlib import Path
|
||||
|
||||
LOGS_DIRECTORY = "logs/"
|
||||
LOG_FILENAME_FORMAT_PREFIX = "%Y-%m-%d %H-%M-%S"
|
||||
MAX_LOGFILE_AGE_DAYS = 7
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
def _open_file() -> TextIO:
|
||||
"""
|
||||
Returns a file object for the current log file.
|
||||
"""
|
||||
|
||||
# Create the logs directory if it doesnt exist
|
||||
Path(LOGS_DIRECTORY).mkdir(exist_ok=True)
|
||||
|
||||
# Create a generator to generate a unique filename
|
||||
timestamp = datetime.now().strftime(LOG_FILENAME_FORMAT_PREFIX)
|
||||
filenames = (f'{timestamp}.log' if i == 0 else f'{timestamp}_({i}).log' for i in count())
|
||||
|
||||
# Find a filename that doesn't already exist and return it
|
||||
for filename in filenames:
|
||||
try:
|
||||
return (Path(f'{LOGS_DIRECTORY}/{filename}').open('x', encoding='utf-8'))
|
||||
except FileExistsError:
|
||||
continue
|
||||
|
||||
def _delete_old_logs():
|
||||
"""
|
||||
Search through the logs directory and delete any expired log files.
|
||||
"""
|
||||
|
||||
for path in Path(LOGS_DIRECTORY).glob('*.txt'):
|
||||
prefix = path.stem.split('_')[0]
|
||||
try:
|
||||
log_date = datetime.strptime(prefix, LOG_FILENAME_FORMAT_PREFIX)
|
||||
except ValueError:
|
||||
log.warning(f'{path.parent} contains a problematic filename: {path.name}')
|
||||
continue
|
||||
|
||||
age = datetime.now() - log_date
|
||||
if age >= timedelta(days=MAX_LOGFILE_AGE_DAYS):
|
||||
log.info(f'Removing expired log file: {path.name}')
|
||||
path.unlink()
|
||||
|
||||
def update_log_levels(logger_names:tuple[str], level:int):
|
||||
"""
|
||||
Quick way to update the log level of multiple loggers at once.
|
||||
"""
|
||||
for name in logger_names:
|
||||
logger=logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
|
||||
def setup_logs(log_level:int=logging.DEBUG) -> str:
|
||||
"""
|
||||
Setup a logging queue handler and queue listener.
|
||||
Also creates a new log file for the current session and deletes old log files.
|
||||
"""
|
||||
|
||||
# Create a queue to pass log records to the listener
|
||||
log_queue = queue.Queue()
|
||||
queue_handler = QueueHandler(log_queue)
|
||||
|
||||
# Configure the root logger to use the queue
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
handlers=(queue_handler,),
|
||||
format='[%(asctime)s] [%(levelname)-8s] [%(name)-17s]: %(message)s'
|
||||
)
|
||||
|
||||
# Create a new log file
|
||||
file = _open_file()
|
||||
|
||||
file_handler = logging.StreamHandler(file) # Stream logs to the log file
|
||||
sys_handler = logging.StreamHandler(sys.stdout) # Stream logs to the console
|
||||
|
||||
# Create a listener to handle the queue
|
||||
queue_listener = QueueListener(log_queue, file_handler, sys_handler)
|
||||
queue_listener.start()
|
||||
|
||||
# Clear up old log files
|
||||
_delete_old_logs()
|
||||
|
||||
return file.name
|
28
src/main.py
28
src/main.py
@ -1,16 +1,38 @@
|
||||
"""Entry point for the application."""
|
||||
"""
|
||||
Entry point for the application.
|
||||
Run this file to get started.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from bot import DiscordBot
|
||||
|
||||
from bot import DiscordBot
|
||||
from logs import setup_logs, update_log_levels
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Entry point function for the application.
|
||||
Run this function to get started.
|
||||
"""
|
||||
|
||||
# Grab the token before anything else, because if there is no token
|
||||
# available then the bot cannot be started anyways.
|
||||
with open("TOKEN", "r") as token_file:
|
||||
token = token_file.read()
|
||||
|
||||
await DiscordBot().start(token)
|
||||
if not token:
|
||||
raise ValueError("Token file is empty")
|
||||
|
||||
setup_logs()
|
||||
update_log_levels(
|
||||
('discord', 'PIL', 'urllib3', 'aiosqlite', 'charset_normalizer'),
|
||||
level=logging.WARNING
|
||||
)
|
||||
|
||||
async with DiscordBot() as bot:
|
||||
await bot.load_extensions()
|
||||
await bot.start(token)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
asyncio.run(main())
|
||||
|
Loading…
x
Reference in New Issue
Block a user