help commands, and tasks function

This commit is contained in:
Corban-Lee Jones 2023-12-20 08:22:37 +00:00
parent df60652da6
commit 8a1f623c6f
6 changed files with 201 additions and 37 deletions

View File

@ -24,7 +24,7 @@ class DatabaseManager:
""" """
def __init__(self, no_commit: bool = False): def __init__(self, no_commit: bool = False):
database_url = self.get_database_url() # TODO: This is called every time a connection is established, maybe make it once and reference it? database_url = self.get_database_url() # This is called every time a connection is established, maybe make it once and reference it?
self.engine = create_async_engine(database_url, future=True) self.engine = create_async_engine(database_url, future=True)
self.session_maker = sessionmaker(self.engine, class_=AsyncSession) self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
self.session = None self.session = None
@ -36,13 +36,10 @@ class DatabaseManager:
Returns a connection string for the database. Returns a connection string for the database.
""" """
# TODO finish support for mysql, mariadb, etc
url = f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}" url = f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}"
url_addon = "" url_addon = ""
# This looks fucking ugly # This looks fucking ugly
#
match use_async, DB_TYPE: match use_async, DB_TYPE:
case True, "sqlite": case True, "sqlite":
url_addon = "aiosqlite" url_addon = "aiosqlite"

View File

@ -3,12 +3,18 @@ Models and Enums for the database.
All table classes should be suffixed with `Model`. All table classes should be suffixed with `Model`.
""" """
from enum import Enum, auto
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint, ForeignKey
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
BigInteger,
UniqueConstraint,
ForeignKey
)
Base = declarative_base() Base = declarative_base()
@ -23,7 +29,7 @@ class AuditModel(Base):
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
discord_user_id = Column(BigInteger, nullable=False) discord_user_id = Column(BigInteger, nullable=False)
message = Column(String, nullable=False) message = Column(String, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
class SentArticleModel(Base): class SentArticleModel(Base):
@ -38,7 +44,7 @@ class SentArticleModel(Base):
discord_channel_id = Column(BigInteger, nullable=False) discord_channel_id = Column(BigInteger, nullable=False)
discord_server_id = Column(BigInteger, nullable=False) discord_server_id = Column(BigInteger, nullable=False)
article_url = Column(String, nullable=False) article_url = Column(String, nullable=False)
when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
class RssSourceModel(Base): class RssSourceModel(Base):
@ -52,7 +58,7 @@ class RssSourceModel(Base):
nick = Column(String, nullable=False) nick = Column(String, nullable=False)
discord_server_id = Column(BigInteger, nullable=False) discord_server_id = Column(BigInteger, nullable=False)
rss_url = Column(String, nullable=False) rss_url = Column(String, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
feed_channels = relationship("FeedChannelModel", cascade="all, delete") feed_channels = relationship("FeedChannelModel", cascade="all, delete")

View File

@ -1,6 +1,6 @@
""" """
Extension for the `RssCog`. Extension for the `FeedCog`.
Loading this file via `commands.Bot.load_extension` will add `RssCog` to the bot. Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bot.
""" """
import logging import logging
@ -73,7 +73,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed
return None, feed return None, feed
class RssCog(commands.Cog): class FeedCog(commands.Cog):
""" """
Command cog. Command cog.
""" """
@ -226,7 +226,9 @@ class RssCog(commands.Cog):
@feed_group.command(name="list") @feed_group.command(name="list")
@choices(sort=rss_list_sort_choices) @choices(sort=rss_list_sort_choices)
async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False): async def list_rss_sources(
self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False
):
"""Provides a with a list of RSS sources available for the current server. """Provides a with a list of RSS sources available for the current server.
Parameters Parameters
@ -238,7 +240,7 @@ class RssCog(commands.Cog):
await inter.response.defer() await inter.response.defer()
# Default to the first choice if not specified. # Default to the first choice if not specified.
if type(sort) is Choice: if isinstance(sort, Choice):
description = "Sort by " description = "Sort by "
description += "Nickname " if sort.value == 0 else "Date Added " description += "Nickname " if sort.value == 0 else "Date Added "
description += '\U000025BC' if sort_reverse else '\U000025B2' description += '\U000025BC' if sort_reverse else '\U000025B2'
@ -246,19 +248,17 @@ class RssCog(commands.Cog):
sort = rss_list_sort_choices[0] sort = rss_list_sort_choices[0]
description = "" description = ""
sort = sort if type(sort) == Choice else rss_list_sort_choices[0]
match sort.value, sort_reverse: match sort.value, sort_reverse:
case 0, False: case 0, False:
order_by = RssSourceModel.nick.asc() order_by = RssSourceModel.nick.asc()
case 0, True: case 0, True:
order_by = RssSourceModel.nick.desc() order_by = RssSourceModel.nick.desc()
case 1, False: # NOTE: case 1, False:
order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest order_by = RssSourceModel.created.desc()
case 1, True: # date first, not the oldest as it would sort otherwise. case 1, True:
order_by = RssSourceModel.created.asc() order_by = RssSourceModel.created.asc()
case _, _: case _, _:
raise ValueError("Unknown sort: %s" % sort) raise ValueError(f"Unknown sort: {sort}")
async with DatabaseManager() as database: async with DatabaseManager() as database:
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
@ -272,7 +272,10 @@ class RssCog(commands.Cog):
await followup(inter, "It looks like you have no rss sources.") await followup(inter, "It looks like you have no rss sources.")
return return
output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)]) output = "\n".join([
f"{i}. **[{rss.nick}]({rss.rss_url})** "
for i, rss in enumerate(rss_sources)
])
embed = Embed( embed = Embed(
title="Saved RSS Feeds", title="Saved RSS Feeds",
@ -330,16 +333,79 @@ class RssCog(commands.Cog):
for article in articles for article in articles
]) ])
await database.session.execute(query) await database.session.execute(query)
await audit(self, f"User is requesting {max_} articles from {source.name}", inter.user.id, database=database) await audit(self,
f"User is requesting {max_} articles from {source.name}",
inter.user.id, database=database
)
await followup(inter, embeds=embeds) await followup(inter, embeds=embeds)
# Help ---- ---- ----
@feed_group.command(name="help")
async def get_help(self, inter: Interaction):
"""Get help on how to use my commands.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.response.defer()
description = (
"`/feed add <nickname> <url>` \n\n"
"Save a new RSS feed to the bot. This can be referred to later, when assigning "
"channels to receive content from these RSS feeds."
"\n\n\n`/feed remove <option>` \n\n"
"Remove a previously saved RSS feed. Select the nickname from the shown options "
"if any. You can re-add the RSS feed to the bot using the `/feeds add` command."
"\n\n\n`/feed list <sort> <sort_reverse>` \n\n"
"List all saved RSS feeds numerically. Use the `<sort>` option to order "
"the results by either nickname or date & time added. Use the `<sort_reverse>` "
"option to order by ascending or descending in conjunction with the `<sort>` option."
"\n\n\n`/feed assign <rss> <textchannel>` \n\n"
"Assign a channel to an RSS feed. Previously saved RSS feeds will be selectable "
"under the `<rss>` option. The channel will be assumed as the current channel, "
"unless specified otherwise using the `<textchannel` option."
"\n\n\n`/feed unassign <option>` \n\n"
"Unassigned channel from an RSS feed. Previously assigned channels will be shown "
"as an `<option>`, select one to remove it."
"\n\n\n`/feed channels` \n\n"
"List all channels assigned to an RSS feed numerically."
)
embed = Embed(
title="Help",
description=description,
colour=Colour.blue(),
)
await followup(inter, embed=embed)
# Channels ---- ---- ---- # Channels ---- ---- ----
async def autocomplete_rss_sources(self, inter: Interaction, nickname: str): async def autocomplete_rss_sources(self, inter: Interaction, nickname: str):
"""""" """_summary_
Parameters
----------
inter : Interaction
_description_
nickname : str
_description_
Returns
-------
_type_
_description_
"""
async with DatabaseManager() as database: async with DatabaseManager() as database:
whereclause = and_( whereclause = and_(
@ -371,7 +437,7 @@ class RssCog(commands.Cog):
async with DatabaseManager() as database: async with DatabaseManager() as database:
whereclause = and_( whereclause = and_(
FeedChannelModel.discord_server_id == inter.guild_id, FeedChannelModel.discord_server_id == inter.guild_id,
FeedChannelModel.search_name.ilike(f"%{current}%"), # is this secure from SQL Injection atk ? FeedChannelModel.search_name.ilike(f"%{current}%"),
RssSourceModel.id == FeedChannelModel.rss_source_id RssSourceModel.id == FeedChannelModel.rss_source_id
) )
query = ( query = (
@ -384,7 +450,10 @@ class RssCog(commands.Cog):
feeds = [] feeds = []
for feed in result.scalars().all(): for feed in result.scalars().all():
channel = inter.guild.get_channel(feed.discord_channel_id) channel = inter.guild.get_channel(feed.discord_channel_id)
feeds.append(Choice(name=f"# {channel.name} | {feed.rss_source.nick}", value=feed.id)) feeds.append(Choice(
name=f"# {channel.name} | {feed.rss_source.nick}",
value=feed.id
))
log.debug("Autocomplete existing_feeds returned %s results", len(feeds)) log.debug("Autocomplete existing_feeds returned %s results", len(feeds))
@ -501,11 +570,14 @@ class RssCog(commands.Cog):
rowcount = len(feed_channels) rowcount = len(feed_channels)
if not feed_channels: if not feed_channels:
await followup(inter, "It looks like there are no feed channels available.") await followup(inter,
"It looks like there are no feed channels available."
)
return return
output = "\n".join([ output = "\n".join([
f"{i}. <#{feed.discord_channel_id}> · [{feed.rss_source.nick}]({feed.rss_source.rss_url})" f"{i}. <#{feed.discord_channel_id}> · "
f"[{feed.rss_source.nick}]({feed.rss_source.rss_url})"
for i, feed in enumerate(feed_channels) for i, feed in enumerate(feed_channels)
]) ])
@ -519,13 +591,12 @@ class RssCog(commands.Cog):
await followup(inter, embed=embed) await followup(inter, embed=embed)
async def setup(bot): async def setup(bot):
""" """
Setup function for this extension. Setup function for this extension.
Adds `RssCog` to the bot. Adds `FeedCog` to the bot.
""" """
cog = RssCog(bot) cog = FeedCog(bot)
await bot.add_cog(cog) await bot.add_cog(cog)
log.info("Added %s cog", cog.__class__.__name__) log.info("Added %s cog", cog.__class__.__name__)

View File

@ -4,15 +4,24 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo
""" """
import logging import logging
import async_timeout
from discord.ext import commands import aiohttp
from feedparser import parse
from sqlalchemy import insert, select, and_
from discord import Interaction, app_commands, TextChannel
from discord.ext import commands, tasks
from discord.errors import Forbidden
from feed import Source, Article, get_unparsed_feed
from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class TaskCog(commands.Cog): class TaskCog(commands.Cog):
""" """
Command cog. Tasks cog.
""" """
def __init__(self, bot): def __init__(self, bot):
@ -21,7 +30,78 @@ class TaskCog(commands.Cog):
@commands.Cog.listener() @commands.Cog.listener()
async def on_ready(self): async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready") """Instructions to call when the cog is ready."""
self.rss_task.start() # pylint disable=E1101
log.info("%s cog is ready", self.__class__.__name__)
# @app_commands.command(name="test-trigger-task")
async def test_trigger_task(self, inter: Interaction):
await inter.response.defer()
await self.rss_task()
await inter.followup.send("done")
@tasks.loop(minutes=10)
async def rss_task(self):
log.info("Running rss task")
async with DatabaseManager() as database:
query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel)
result = await database.session.execute(query)
feeds = result.scalars().all()
for feed in feeds:
await self.process_feed(feed, database)
log.info("Finished rss task")
async def process_feed(self, feed: FeedChannelModel, database):
log.info("Processing feed: %s", feed.id)
channel = self.bot.get_channel(feed.discord_channel_id)
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
parsed_feed = parse(unparsed_content)
source = Source.from_parsed(parsed_feed)
articles = source.get_latest_articles(5)
if not articles:
log.info("No articles to process")
return
for article in articles:
await self.process_article(article, channel, database)
async def process_article(self, article: Article, channel: TextChannel, database):
log.info("Processing article: %s", article.url)
query = select(SentArticleModel).where(and_(
SentArticleModel.article_url == article.url,
SentArticleModel.discord_channel_id == channel.id
))
result = await database.session.execute(query)
if result.scalars().all():
log.info("Article already processed: %s", article.url)
return
embed = await article.to_embed()
try:
await channel.send(embed=embed)
except Forbidden:
log.error("Forbidden: %s · %s", channel.name, channel.id)
return
query = insert(SentArticleModel).values(
article_url = article.url,
discord_channel_id = channel.id,
discord_server_id = channel.guild.id,
discord_message_id = -1
)
await database.session.execute(query)
log.info("new Article processed: %s", article.url)
async def setup(bot): async def setup(bot):
@ -32,4 +112,4 @@ async def setup(bot):
cog = TaskCog(bot) cog = TaskCog(bot)
await bot.add_cog(cog) await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog") log.info("Added %s cog", cog.__class__.__name__)

View File

@ -4,6 +4,7 @@
import json import json
import logging import logging
import async_timeout
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
@ -179,6 +180,15 @@ class Source:
for i, entry in enumerate(self.feed.entries) for i, entry in enumerate(self.feed.entries)
if i < max if i < max
] ]
async def fetch(session, url: str) -> str:
async with async_timeout.timeout(20):
async with session.get(url) as response:
return await response.text()
async def get_unparsed_feed(url: str):
async with aiohttp.ClientSession() as session:
return await fetch(session, url)
def get_source(rss_url: str) -> Source: def get_source(rss_url: str) -> Source:

View File

@ -33,7 +33,7 @@ async def main():
# Setup logging settings and mute spammy loggers # Setup logging settings and mute spammy loggers
logsetup = LogSetup(BASE_DIR) logsetup = LogSetup(BASE_DIR)
logsetup.setup_logs() logsetup.setup_logs(logging.INFO)
logsetup.update_log_levels( logsetup.update_log_levels(
('discord', 'PIL', 'urllib3', 'aiosqlite', 'charset_normalizer'), ('discord', 'PIL', 'urllib3', 'aiosqlite', 'charset_normalizer'),
level=logging.WARNING level=logging.WARNING