diff --git a/src/bot.py b/src/bot.py index e509f59..69e3a1c 100644 --- a/src/bot.py +++ b/src/bot.py @@ -46,9 +46,16 @@ class DiscordBot(commands.Bot): if path.suffix == ".py": await self.load_extension(f"extensions.{path.stem}") - async def audit(self, message: str, user_id: int): + async def audit(self, message: str, user_id: int, database: DatabaseManager=None): + + message = f"Requesting latest article" + query = insert(AuditModel).values(discord_user_id=user_id, message=message) + + if database: + await database.session.execute(query) + return async with DatabaseManager() as database: - message = f"Requesting latest article" - query = insert(AuditModel).values(discord_user_id=user_id, message=message) - await database.session.execute(query) \ No newline at end of file + await database.session.execute(query) + + log.debug("Audit logged") diff --git a/src/db/models.py b/src/db/models.py index 8b5acbf..7f64668 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -5,7 +5,7 @@ All table classes should be suffixed with `Model`. from enum import Enum, auto -from sqlalchemy import Column, Integer, String, DateTime, BigInteger +from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint from sqlalchemy.sql import func from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import declarative_base @@ -24,7 +24,6 @@ class AuditModel(Base): discord_user_id = Column(BigInteger, nullable=False) message = Column(String, nullable=False) created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - active = Column(Integer, default=True, nullable=False) class SentArticleModel(Base): @@ -39,7 +38,6 @@ class SentArticleModel(Base): discord_channel_id = Column(BigInteger, nullable=False) discord_server_id = Column(BigInteger, nullable=False) article_url = Column(String, nullable=False) - active = Column(Integer, default=True, nullable=False) class RssSourceModel(Base): @@ -50,9 +48,14 @@ class RssSourceModel(Base): __tablename__ = "rss_source" id = Column(Integer, primary_key=True, autoincrement=True) + nick = Column(String, nullable=False) discord_server_id = Column(BigInteger, nullable=False) rss_url = Column(String, nullable=False) - active = Column(Integer, default=True, nullable=False) + + # the nickname must be unique, but only within the same discord server + __table_args__ = ( + UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'), + ) class FeedChannelModel(Base): @@ -64,4 +67,3 @@ class FeedChannelModel(Base): id = Column(Integer, primary_key=True, autoincrement=True) discord_channel_id = Column(BigInteger, nullable=False) - active = Column(Integer, default=True, nullable=False) diff --git a/src/extensions/cmd.py b/src/extensions/cmd.py index 7f54a3c..22515ad 100644 --- a/src/extensions/cmd.py +++ b/src/extensions/cmd.py @@ -12,7 +12,8 @@ import feedparser from markdownify import markdownify from discord import app_commands, Interaction, Embed from discord.ext import commands, tasks -from sqlalchemy import insert, select, update, and_, or_ +from discord.app_commands import Choice, Group, command, autocomplete +from sqlalchemy import insert, select, update, and_, or_, delete from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel from feed import Feeds, get_source @@ -26,6 +27,22 @@ async def get_rss_data(url: str): return items +async def followup(inter: Interaction, *args, **kwargs): + """Shorthand for following up on an interaction. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + """ + + await inter.followup.send(*args, **kwargs) + +async def audit(cog, *args, **kwargs): + """Shorthand for auditing an interaction.""" + + await cog.bot.audit(*args, **kwargs) + class CommandCog(commands.Cog): """ @@ -40,38 +57,87 @@ class CommandCog(commands.Cog): async def on_ready(self): log.info(f"{self.__class__.__name__} cog is ready") - rss_group = app_commands.Group( + async def source_autocomplete(self, inter: Interaction, nickname: str): + """Provides RSS source autocomplete functionality for commands. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + nickname : str + _description_ + + Returns + ------- + list of app_commands.Choice + _description_ + """ + + async with DatabaseManager() as database: + whereclause = and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.nick.ilike(f"%{nickname}%") + ) + query = select(RssSourceModel).where(whereclause) + result = await database.session.execute(query) + sources = [ + Choice(name=rss.nick, value=rss.rss_url) + for rss in result.scalars().all() + ] + + return sources + + rss_group = Group( name="rss", description="Commands for rss sources.", guild_only=True ) @rss_group.command(name="add") - async def add_rss_source(self, inter: Interaction, url: str): + async def add_rss_source(self, inter: Interaction, url: str, nickname: str): + """Add a new RSS source. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + url : str + The RSS feed URL. + nickname : str + A name used to identify the RSS source. + """ await inter.response.defer() - # validate the input + # Ensure the URL is valid if not validators.url(url): - await inter.followup.send( - "The URL you have entered is malformed or invalid:\n" - f"`{url=}`", + await followup(inter, + f"The URL you have entered is malformed or invalid:\n`{url=}`", suppress_embeds=True ) return - feed_data, status_code = await get_rss_data(url) + # Check the nickname is not a URL + if validators.url(nickname): + await followup(inter, + "It looks like the nickname you have entered is a URL.\n" + f"For security reasons, this is not allowed.\n`{nickname=}`", + suppress_embeds=True + ) + return + + # Check the URL points to an RSS feed. + feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this. if status_code != 200: - await inter.followup.send( - f"The URL provided returned an invalid status code:\n" - f"{url=}, {status_code=}", + await followup(inter, + f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", suppress_embeds=True ) return feed = feedparser.parse(feed_data) if not feed.version: - await inter.followup.send( + await followup(inter, f"The provided URL '{url}' does not seem to be a valid RSS feed.", suppress_embeds=True ) @@ -80,97 +146,147 @@ class CommandCog(commands.Cog): async with DatabaseManager() as database: query = insert(RssSourceModel).values( discord_server_id = inter.guild_id, - rss_url = url + rss_url = url, + nick=nickname ) await database.session.execute(query) - await inter.followup.send("RSS source added") + await audit(self, + f"Added RSS source ({nickname=}, {url=})", + inter.user.id, database=database + ) + + await followup(inter, f"RSS source added [{nickname}]({url})", suppress_embeds=True) @rss_group.command(name="remove") - async def remove_rss_source(self, inter: Interaction, number: int | None=None, url: str | None = None): + @autocomplete(source=source_autocomplete) + async def remove_rss_source(self, inter: Interaction, source: str): + """Delete an existing RSS source. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + source : str + The RSS source to be removed. Autocomplete or enter the URL. + """ await inter.response.defer() - def exists(item) -> bool: - """ - Shorthand for `is not None`. Cant just use `if not number` because 0 int will pass. - Ironically with this func & comment the code is longer, but at least I can read it ... - """ + log.debug(f"Attempting to remove RSS source ({source=})") - return item is not None + async with DatabaseManager() as database: + rss_source = (await database.session.execute( + select(RssSourceModel).filter( + and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.rss_url == source + ) + ) + )).fetchone() - url_exists = exists(url) - num_exists = exists(number) - - if (url_exists and num_exists) or (not url_exists and not num_exists): - await inter.followup.send( - "Please only specify either the existing rss number or url, " - "enter at least one of these, but don't enter both." + result = await database.session.execute( + delete(RssSourceModel).filter( + and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.rss_url == source + ) + ) ) - return - if url_exists and not validators.url(url): - await inter.followup.send( - "The URL you have entered is malformed or invalid:\n" - f"`{url=}`", + # TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works) + + if result.rowcount: + await followup(inter, + f"RSS source deleted successfully\n**[{rss_source.nick}]({rss_source.rss_url})**", suppress_embeds=True ) return - async with DatabaseManager() as database: - whereclause = and_( - RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.rss_url == url - ) - query = update(RssSourceModel).where(whereclause).values(active=False) - result = await database.session.execute(query) + await followup(inter, "Couldn't find any RSS sources with this name.") - await inter.followup.send(f"I've updated {result.rowcount} rows") + # potential_matches = await self.source_autocomplete(inter, source) @rss_group.command(name="list") - @app_commands.choices(filter=[ - app_commands.Choice(name="Active Only [default]", value=1), - app_commands.Choice(name="Inactive Only", value=0), - app_commands.Choice(name="All", value=2), - ]) - async def list_rss_sources(self, inter: Interaction, filter: app_commands.Choice[int]): + async def list_rss_sources(self, inter: Interaction): + """Provides a with a list of RSS sources available for the current server. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + """ await inter.response.defer() - if filter.value == 2: - whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) - else: - whereclause = and_( - RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.active == filter.value # should result to 0 or 1 - ) - async with DatabaseManager() as database: + whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) query = select(RssSourceModel).where(whereclause) result = await database.session.execute(query) rss_sources = result.scalars().all() - embed_fields = [{ - "name": f"[{i}]", - "value": f"{rss.rss_url} | {'inactive' if not rss.active else 'active'}" - } for i, rss in enumerate(rss_sources)] - if not embed_fields: - await inter.followup.send("It looks like you have no rss sources.") + if not rss_sources: + await followup(inter, "It looks like you have no rss sources.") + return + + output = "## Available RSS Sources\n" + output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources]) + + await followup(inter, output, suppress_embeds=True) + + @rss_group.command(name="fetch") + @autocomplete(rss=source_autocomplete) + async def fetch_rss(self, inter: Interaction, rss: str, max: int=1): + # """""" + + await inter.response.defer() + + if max > 5: + followup(inter, "It looks like you have requested too many articles.\nThe limit is 5") return - embed = Embed( - title="RSS Sources", - description="Here are your rss sources:" - ) + source = get_source(rss) + articles = source.get_latest_articles(max) + + embeds = [] + for article in articles: + md_description = markdownify(article.description, strip=("img",)) + article_description = textwrap.shorten(md_description, 4096) + + embed = Embed( + title=article.title, + description=article_description, + url=article.url, + timestamp=article.published, + ) + embed.set_thumbnail(url=source.icon_url) + embed.set_image(url=await article.get_thumbnail_url()) + embed.set_footer(text=article.author) + embed.set_author( + name=source.name, + url=source.url, + ) + embeds.append(embed) + + async with DatabaseManager() as database: + query = insert(SentArticleModel).values([ + { + "discord_server_id": inter.guild_id, + "discord_channel_id": inter.channel_id, + "discord_message_id": inter.id, + "article_url": article.url, + } + for article in articles + ]) + await database.session.execute(query) + await audit(self, f"User is requesting {max} articles", inter.user.id, database=database) + + await followup(inter, embeds=embeds) + - for field in embed_fields: - embed.add_field(**field, inline=False) - # output = "Your rss sources:\n\n" - # output += "\n".join([f"[{i+1}] {rss.rss_url=} {bool(rss.active)=}" for i, rss in enumerate(rss_sources)]) - await inter.followup.send(embed=embed) async def setup(bot): diff --git a/src/feed.py b/src/feed.py index 63184b0..09b377d 100644 --- a/src/feed.py +++ b/src/feed.py @@ -37,8 +37,18 @@ class Source: feed=feed ) - def get_latest_article(self): - return Article.from_parsed(self.feed) + def get_latest_articles(self, max: int) -> list: + """""" + + articles = [] + + for i, entry in enumerate(self.feed.entries): + if i >= max: + break + + articles.append(Article.from_entry(entry)) + + return articles @dataclass @@ -66,6 +76,21 @@ class Article: author = entry.get("author") ) + @classmethod + def from_entry(cls, entry:FeedParserDict): + + published_parsed = entry.get("published_parsed") + published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None + + return cls( + title=entry.get("title"), + description=entry.get("description"), + url=entry.get("link"), + published=published, + author = entry.get("author") + ) + + async def get_thumbnail_url(self): """