help commands, and tasks function

This commit is contained in:
Corban-Lee Jones 2023-12-20 08:22:37 +00:00
parent df60652da6
commit 8a1f623c6f
6 changed files with 201 additions and 37 deletions

View File

@ -24,7 +24,7 @@ class DatabaseManager:
"""
def __init__(self, no_commit: bool = False):
database_url = self.get_database_url() # TODO: This is called every time a connection is established, maybe make it once and reference it?
database_url = self.get_database_url() # This is called every time a connection is established, maybe make it once and reference it?
self.engine = create_async_engine(database_url, future=True)
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
self.session = None
@ -36,13 +36,10 @@ class DatabaseManager:
Returns a connection string for the database.
"""
# TODO finish support for mysql, mariadb, etc
url = f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}"
url_addon = ""
# This looks fucking ugly
#
match use_async, DB_TYPE:
case True, "sqlite":
url_addon = "aiosqlite"

View File

@ -3,12 +3,18 @@ Models and Enums for the database.
All table classes should be suffixed with `Model`.
"""
from enum import Enum, auto
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint, ForeignKey
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
BigInteger,
UniqueConstraint,
ForeignKey
)
Base = declarative_base()
@ -23,7 +29,7 @@ class AuditModel(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
discord_user_id = Column(BigInteger, nullable=False)
message = Column(String, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
class SentArticleModel(Base):
@ -38,7 +44,7 @@ class SentArticleModel(Base):
discord_channel_id = Column(BigInteger, nullable=False)
discord_server_id = Column(BigInteger, nullable=False)
article_url = Column(String, nullable=False)
when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
class RssSourceModel(Base):
@ -52,7 +58,7 @@ class RssSourceModel(Base):
nick = Column(String, nullable=False)
discord_server_id = Column(BigInteger, nullable=False)
rss_url = Column(String, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
feed_channels = relationship("FeedChannelModel", cascade="all, delete")

View File

@ -1,6 +1,6 @@
"""
Extension for the `RssCog`.
Loading this file via `commands.Bot.load_extension` will add `RssCog` to the bot.
Extension for the `FeedCog`.
Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bot.
"""
import logging
@ -73,7 +73,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed
return None, feed
class RssCog(commands.Cog):
class FeedCog(commands.Cog):
"""
Command cog.
"""
@ -226,7 +226,9 @@ class RssCog(commands.Cog):
@feed_group.command(name="list")
@choices(sort=rss_list_sort_choices)
async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False):
async def list_rss_sources(
self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False
):
"""Provides a with a list of RSS sources available for the current server.
Parameters
@ -238,7 +240,7 @@ class RssCog(commands.Cog):
await inter.response.defer()
# Default to the first choice if not specified.
if type(sort) is Choice:
if isinstance(sort, Choice):
description = "Sort by "
description += "Nickname " if sort.value == 0 else "Date Added "
description += '\U000025BC' if sort_reverse else '\U000025B2'
@ -246,19 +248,17 @@ class RssCog(commands.Cog):
sort = rss_list_sort_choices[0]
description = ""
sort = sort if type(sort) == Choice else rss_list_sort_choices[0]
match sort.value, sort_reverse:
case 0, False:
order_by = RssSourceModel.nick.asc()
case 0, True:
order_by = RssSourceModel.nick.desc()
case 1, False: # NOTE:
order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest
case 1, True: # date first, not the oldest as it would sort otherwise.
case 1, False:
order_by = RssSourceModel.created.desc()
case 1, True:
order_by = RssSourceModel.created.asc()
case _, _:
raise ValueError("Unknown sort: %s" % sort)
raise ValueError(f"Unknown sort: {sort}")
async with DatabaseManager() as database:
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
@ -272,7 +272,10 @@ class RssCog(commands.Cog):
await followup(inter, "It looks like you have no rss sources.")
return
output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)])
output = "\n".join([
f"{i}. **[{rss.nick}]({rss.rss_url})** "
for i, rss in enumerate(rss_sources)
])
embed = Embed(
title="Saved RSS Feeds",
@ -330,16 +333,79 @@ class RssCog(commands.Cog):
for article in articles
])
await database.session.execute(query)
await audit(self, f"User is requesting {max_} articles from {source.name}", inter.user.id, database=database)
await audit(self,
f"User is requesting {max_} articles from {source.name}",
inter.user.id, database=database
)
await followup(inter, embeds=embeds)
# Help ---- ---- ----
@feed_group.command(name="help")
async def get_help(self, inter: Interaction):
"""Get help on how to use my commands.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.response.defer()
description = (
"`/feed add <nickname> <url>` \n\n"
"Save a new RSS feed to the bot. This can be referred to later, when assigning "
"channels to receive content from these RSS feeds."
"\n\n\n`/feed remove <option>` \n\n"
"Remove a previously saved RSS feed. Select the nickname from the shown options "
"if any. You can re-add the RSS feed to the bot using the `/feeds add` command."
"\n\n\n`/feed list <sort> <sort_reverse>` \n\n"
"List all saved RSS feeds numerically. Use the `<sort>` option to order "
"the results by either nickname or date & time added. Use the `<sort_reverse>` "
"option to order by ascending or descending in conjunction with the `<sort>` option."
"\n\n\n`/feed assign <rss> <textchannel>` \n\n"
"Assign a channel to an RSS feed. Previously saved RSS feeds will be selectable "
"under the `<rss>` option. The channel will be assumed as the current channel, "
"unless specified otherwise using the `<textchannel` option."
"\n\n\n`/feed unassign <option>` \n\n"
"Unassigned channel from an RSS feed. Previously assigned channels will be shown "
"as an `<option>`, select one to remove it."
"\n\n\n`/feed channels` \n\n"
"List all channels assigned to an RSS feed numerically."
)
embed = Embed(
title="Help",
description=description,
colour=Colour.blue(),
)
await followup(inter, embed=embed)
# Channels ---- ---- ----
async def autocomplete_rss_sources(self, inter: Interaction, nickname: str):
""""""
"""_summary_
Parameters
----------
inter : Interaction
_description_
nickname : str
_description_
Returns
-------
_type_
_description_
"""
async with DatabaseManager() as database:
whereclause = and_(
@ -371,7 +437,7 @@ class RssCog(commands.Cog):
async with DatabaseManager() as database:
whereclause = and_(
FeedChannelModel.discord_server_id == inter.guild_id,
FeedChannelModel.search_name.ilike(f"%{current}%"), # is this secure from SQL Injection atk ?
FeedChannelModel.search_name.ilike(f"%{current}%"),
RssSourceModel.id == FeedChannelModel.rss_source_id
)
query = (
@ -384,7 +450,10 @@ class RssCog(commands.Cog):
feeds = []
for feed in result.scalars().all():
channel = inter.guild.get_channel(feed.discord_channel_id)
feeds.append(Choice(name=f"# {channel.name} | {feed.rss_source.nick}", value=feed.id))
feeds.append(Choice(
name=f"# {channel.name} | {feed.rss_source.nick}",
value=feed.id
))
log.debug("Autocomplete existing_feeds returned %s results", len(feeds))
@ -501,11 +570,14 @@ class RssCog(commands.Cog):
rowcount = len(feed_channels)
if not feed_channels:
await followup(inter, "It looks like there are no feed channels available.")
await followup(inter,
"It looks like there are no feed channels available."
)
return
output = "\n".join([
f"{i}. <#{feed.discord_channel_id}> · [{feed.rss_source.nick}]({feed.rss_source.rss_url})"
f"{i}. <#{feed.discord_channel_id}> · "
f"[{feed.rss_source.nick}]({feed.rss_source.rss_url})"
for i, feed in enumerate(feed_channels)
])
@ -519,13 +591,12 @@ class RssCog(commands.Cog):
await followup(inter, embed=embed)
async def setup(bot):
"""
Setup function for this extension.
Adds `RssCog` to the bot.
Adds `FeedCog` to the bot.
"""
cog = RssCog(bot)
cog = FeedCog(bot)
await bot.add_cog(cog)
log.info("Added %s cog", cog.__class__.__name__)

View File

@ -4,15 +4,24 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo
"""
import logging
import async_timeout
from discord.ext import commands
import aiohttp
from feedparser import parse
from sqlalchemy import insert, select, and_
from discord import Interaction, app_commands, TextChannel
from discord.ext import commands, tasks
from discord.errors import Forbidden
from feed import Source, Article, get_unparsed_feed
from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel
log = logging.getLogger(__name__)
class TaskCog(commands.Cog):
"""
Command cog.
Tasks cog.
"""
def __init__(self, bot):
@ -21,7 +30,78 @@ class TaskCog(commands.Cog):
@commands.Cog.listener()
async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready")
"""Instructions to call when the cog is ready."""
self.rss_task.start() # pylint disable=E1101
log.info("%s cog is ready", self.__class__.__name__)
# @app_commands.command(name="test-trigger-task")
async def test_trigger_task(self, inter: Interaction):
await inter.response.defer()
await self.rss_task()
await inter.followup.send("done")
@tasks.loop(minutes=10)
async def rss_task(self):
log.info("Running rss task")
async with DatabaseManager() as database:
query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel)
result = await database.session.execute(query)
feeds = result.scalars().all()
for feed in feeds:
await self.process_feed(feed, database)
log.info("Finished rss task")
async def process_feed(self, feed: FeedChannelModel, database):
log.info("Processing feed: %s", feed.id)
channel = self.bot.get_channel(feed.discord_channel_id)
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
parsed_feed = parse(unparsed_content)
source = Source.from_parsed(parsed_feed)
articles = source.get_latest_articles(5)
if not articles:
log.info("No articles to process")
return
for article in articles:
await self.process_article(article, channel, database)
async def process_article(self, article: Article, channel: TextChannel, database):
log.info("Processing article: %s", article.url)
query = select(SentArticleModel).where(and_(
SentArticleModel.article_url == article.url,
SentArticleModel.discord_channel_id == channel.id
))
result = await database.session.execute(query)
if result.scalars().all():
log.info("Article already processed: %s", article.url)
return
embed = await article.to_embed()
try:
await channel.send(embed=embed)
except Forbidden:
log.error("Forbidden: %s · %s", channel.name, channel.id)
return
query = insert(SentArticleModel).values(
article_url = article.url,
discord_channel_id = channel.id,
discord_server_id = channel.guild.id,
discord_message_id = -1
)
await database.session.execute(query)
log.info("new Article processed: %s", article.url)
async def setup(bot):
@ -32,4 +112,4 @@ async def setup(bot):
cog = TaskCog(bot)
await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog")
log.info("Added %s cog", cog.__class__.__name__)

View File

@ -4,6 +4,7 @@
import json
import logging
import async_timeout
from dataclasses import dataclass
from datetime import datetime
@ -179,6 +180,15 @@ class Source:
for i, entry in enumerate(self.feed.entries)
if i < max
]
async def fetch(session, url: str) -> str:
async with async_timeout.timeout(20):
async with session.get(url) as response:
return await response.text()
async def get_unparsed_feed(url: str):
async with aiohttp.ClientSession() as session:
return await fetch(session, url)
def get_source(rss_url: str) -> Source:

View File

@ -33,7 +33,7 @@ async def main():
# Setup logging settings and mute spammy loggers
logsetup = LogSetup(BASE_DIR)
logsetup.setup_logs()
logsetup.setup_logs(logging.INFO)
logsetup.update_log_levels(
('discord', 'PIL', 'urllib3', 'aiosqlite', 'charset_normalizer'),
level=logging.WARNING