From 56224d8d2064e67e108db6efc54a8008ff3e9aba Mon Sep 17 00:00:00 2001 From: corbz Date: Fri, 15 Dec 2023 17:48:43 +0000 Subject: [PATCH 01/11] autocomplete --- src/extensions/test.py | 43 +++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/src/extensions/test.py b/src/extensions/test.py index 18c2722..7651254 100644 --- a/src/extensions/test.py +++ b/src/extensions/test.py @@ -9,9 +9,9 @@ import textwrap from markdownify import markdownify from discord import app_commands, Interaction, Embed from discord.ext import commands, tasks -from sqlalchemy import insert, select +from sqlalchemy import insert, select, and_ -from db import DatabaseManager, AuditModel, SentArticleModel +from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel from feed import Feeds, get_source log = logging.getLogger(__name__) @@ -31,19 +31,39 @@ class Test(commands.Cog): async def on_ready(self): log.info(f"{self.__class__.__name__} cog is ready") + async def source_autocomplete(self, inter: Interaction, current: str): + """ + + """ + + async with DatabaseManager() as database: + whereclause = and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.rss_url.ilike(f"%{current}%") + ) + query = select(RssSourceModel).where(whereclause) + result = await database.session.execute(query) + sources = [ + app_commands.Choice(name=rss.rss_url, value=rss.rss_url) + for rss in result.scalars().all() + ] + + return sources + @app_commands.command(name="test-latest-article") - # @app_commands.choices(source=[ - # app_commands.Choice(name="The Babylon Bee", value=Feeds.THE_BABYLON_BEE), - # app_commands.Choice(name="The Upper Lip", value=Feeds.THE_UPPER_LIP), - # app_commands.Choice(name="BBC News", value=Feeds.BBC_NEWS), - # ]) - async def test_bee(self, inter: Interaction, source: Feeds): + @app_commands.autocomplete(source=source_autocomplete) + async def test_news(self, inter: Interaction, source: str): await inter.response.defer() await self.bot.audit("Requesting latest article.", inter.user.id) - source = get_source(source) - article = source.get_latest_article() + try: + source = get_source(source) + article = source.get_latest_article() + except IndexError as e: + log.error(e) + await inter.followup.send("An error occured, it's possible that the source provided was bad.") + return md_description = markdownify(article.description, strip=("img",)) article_description = textwrap.shorten(md_description, 4096) @@ -74,9 +94,6 @@ class Test(commands.Cog): await inter.followup.send(embed=embed) - - - async def setup(bot): """ Setup function for this extension. From 85b6f118bc07f815714b74d429567502e2d8d079 Mon Sep 17 00:00:00 2001 From: corbz Date: Fri, 15 Dec 2023 17:49:02 +0000 Subject: [PATCH 02/11] fix attribute error on Source and Article class --- src/feed.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/feed.py b/src/feed.py index 4183187..63184b0 100644 --- a/src/feed.py +++ b/src/feed.py @@ -21,19 +21,19 @@ class Feeds(Enum): @dataclass class Source: - name: str - url: str - icon_url: str + name: str | None + url: str | None + icon_url: str | None feed: FeedParserDict @classmethod def from_parsed(cls, feed:FeedParserDict): - # print(json.dumps(feed, indent=8)) + return cls( - name=feed.channel.title, - url=feed.channel.link, - icon_url=feed.feed.image.href, + name=feed.get("channel", {}).get("title"), + url=feed.get("channel", {}).get("link"), + icon_url=feed.get("feed", {}).get("image", {}).get("href"), feed=feed ) @@ -44,10 +44,10 @@ class Source: @dataclass class Article: - title: str - description: str - url: str - published: datetime + title: str | None + description: str | None + url: str | None + published: datetime | None author: str | None @classmethod @@ -55,12 +55,15 @@ class Article: entry = feed.entries[0] # log.debug(json.dumps(entry, indent=8)) + published_parsed = entry.get("published_parsed") + published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None + return cls( - title=entry.title, - description=entry.description, - url=entry.link, - published=datetime(*entry.published_parsed[0:-2]), - author = entry.get("author", None) + title=entry.get("title"), + description=entry.get("description"), + url=entry.get("link"), + published=published, + author = entry.get("author") ) async def get_thumbnail_url(self): @@ -78,12 +81,12 @@ class Article: return image_element.get("content") if image_element else None -def get_source(feed: Feeds) -> Source: +def get_source(rss_url: str) -> Source: """ """ - parsed_feed = parse("https://gitea.corbz.dev/corbz/BBC-News-Bot/rss/branch/main/src/extensions/news.py") + parsed_feed = parse(rss_url) return Source.from_parsed(parsed_feed) From 0770fb3f6f136ca3b58ad51ceeac6aa8ad226c79 Mon Sep 17 00:00:00 2001 From: corbz Date: Fri, 15 Dec 2023 23:13:39 +0000 Subject: [PATCH 03/11] Working on commands --- src/bot.py | 15 ++- src/db/models.py | 12 +- src/extensions/cmd.py | 256 ++++++++++++++++++++++++++++++------------ src/feed.py | 29 ++++- 4 files changed, 231 insertions(+), 81 deletions(-) diff --git a/src/bot.py b/src/bot.py index e509f59..69e3a1c 100644 --- a/src/bot.py +++ b/src/bot.py @@ -46,9 +46,16 @@ class DiscordBot(commands.Bot): if path.suffix == ".py": await self.load_extension(f"extensions.{path.stem}") - async def audit(self, message: str, user_id: int): + async def audit(self, message: str, user_id: int, database: DatabaseManager=None): + + message = f"Requesting latest article" + query = insert(AuditModel).values(discord_user_id=user_id, message=message) + + if database: + await database.session.execute(query) + return async with DatabaseManager() as database: - message = f"Requesting latest article" - query = insert(AuditModel).values(discord_user_id=user_id, message=message) - await database.session.execute(query) \ No newline at end of file + await database.session.execute(query) + + log.debug("Audit logged") diff --git a/src/db/models.py b/src/db/models.py index 8b5acbf..7f64668 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -5,7 +5,7 @@ All table classes should be suffixed with `Model`. from enum import Enum, auto -from sqlalchemy import Column, Integer, String, DateTime, BigInteger +from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint from sqlalchemy.sql import func from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import declarative_base @@ -24,7 +24,6 @@ class AuditModel(Base): discord_user_id = Column(BigInteger, nullable=False) message = Column(String, nullable=False) created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - active = Column(Integer, default=True, nullable=False) class SentArticleModel(Base): @@ -39,7 +38,6 @@ class SentArticleModel(Base): discord_channel_id = Column(BigInteger, nullable=False) discord_server_id = Column(BigInteger, nullable=False) article_url = Column(String, nullable=False) - active = Column(Integer, default=True, nullable=False) class RssSourceModel(Base): @@ -50,9 +48,14 @@ class RssSourceModel(Base): __tablename__ = "rss_source" id = Column(Integer, primary_key=True, autoincrement=True) + nick = Column(String, nullable=False) discord_server_id = Column(BigInteger, nullable=False) rss_url = Column(String, nullable=False) - active = Column(Integer, default=True, nullable=False) + + # the nickname must be unique, but only within the same discord server + __table_args__ = ( + UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'), + ) class FeedChannelModel(Base): @@ -64,4 +67,3 @@ class FeedChannelModel(Base): id = Column(Integer, primary_key=True, autoincrement=True) discord_channel_id = Column(BigInteger, nullable=False) - active = Column(Integer, default=True, nullable=False) diff --git a/src/extensions/cmd.py b/src/extensions/cmd.py index 7f54a3c..22515ad 100644 --- a/src/extensions/cmd.py +++ b/src/extensions/cmd.py @@ -12,7 +12,8 @@ import feedparser from markdownify import markdownify from discord import app_commands, Interaction, Embed from discord.ext import commands, tasks -from sqlalchemy import insert, select, update, and_, or_ +from discord.app_commands import Choice, Group, command, autocomplete +from sqlalchemy import insert, select, update, and_, or_, delete from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel from feed import Feeds, get_source @@ -26,6 +27,22 @@ async def get_rss_data(url: str): return items +async def followup(inter: Interaction, *args, **kwargs): + """Shorthand for following up on an interaction. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + """ + + await inter.followup.send(*args, **kwargs) + +async def audit(cog, *args, **kwargs): + """Shorthand for auditing an interaction.""" + + await cog.bot.audit(*args, **kwargs) + class CommandCog(commands.Cog): """ @@ -40,38 +57,87 @@ class CommandCog(commands.Cog): async def on_ready(self): log.info(f"{self.__class__.__name__} cog is ready") - rss_group = app_commands.Group( + async def source_autocomplete(self, inter: Interaction, nickname: str): + """Provides RSS source autocomplete functionality for commands. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + nickname : str + _description_ + + Returns + ------- + list of app_commands.Choice + _description_ + """ + + async with DatabaseManager() as database: + whereclause = and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.nick.ilike(f"%{nickname}%") + ) + query = select(RssSourceModel).where(whereclause) + result = await database.session.execute(query) + sources = [ + Choice(name=rss.nick, value=rss.rss_url) + for rss in result.scalars().all() + ] + + return sources + + rss_group = Group( name="rss", description="Commands for rss sources.", guild_only=True ) @rss_group.command(name="add") - async def add_rss_source(self, inter: Interaction, url: str): + async def add_rss_source(self, inter: Interaction, url: str, nickname: str): + """Add a new RSS source. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + url : str + The RSS feed URL. + nickname : str + A name used to identify the RSS source. + """ await inter.response.defer() - # validate the input + # Ensure the URL is valid if not validators.url(url): - await inter.followup.send( - "The URL you have entered is malformed or invalid:\n" - f"`{url=}`", + await followup(inter, + f"The URL you have entered is malformed or invalid:\n`{url=}`", suppress_embeds=True ) return - feed_data, status_code = await get_rss_data(url) + # Check the nickname is not a URL + if validators.url(nickname): + await followup(inter, + "It looks like the nickname you have entered is a URL.\n" + f"For security reasons, this is not allowed.\n`{nickname=}`", + suppress_embeds=True + ) + return + + # Check the URL points to an RSS feed. + feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this. if status_code != 200: - await inter.followup.send( - f"The URL provided returned an invalid status code:\n" - f"{url=}, {status_code=}", + await followup(inter, + f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", suppress_embeds=True ) return feed = feedparser.parse(feed_data) if not feed.version: - await inter.followup.send( + await followup(inter, f"The provided URL '{url}' does not seem to be a valid RSS feed.", suppress_embeds=True ) @@ -80,97 +146,147 @@ class CommandCog(commands.Cog): async with DatabaseManager() as database: query = insert(RssSourceModel).values( discord_server_id = inter.guild_id, - rss_url = url + rss_url = url, + nick=nickname ) await database.session.execute(query) - await inter.followup.send("RSS source added") + await audit(self, + f"Added RSS source ({nickname=}, {url=})", + inter.user.id, database=database + ) + + await followup(inter, f"RSS source added [{nickname}]({url})", suppress_embeds=True) @rss_group.command(name="remove") - async def remove_rss_source(self, inter: Interaction, number: int | None=None, url: str | None = None): + @autocomplete(source=source_autocomplete) + async def remove_rss_source(self, inter: Interaction, source: str): + """Delete an existing RSS source. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + source : str + The RSS source to be removed. Autocomplete or enter the URL. + """ await inter.response.defer() - def exists(item) -> bool: - """ - Shorthand for `is not None`. Cant just use `if not number` because 0 int will pass. - Ironically with this func & comment the code is longer, but at least I can read it ... - """ + log.debug(f"Attempting to remove RSS source ({source=})") - return item is not None + async with DatabaseManager() as database: + rss_source = (await database.session.execute( + select(RssSourceModel).filter( + and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.rss_url == source + ) + ) + )).fetchone() - url_exists = exists(url) - num_exists = exists(number) - - if (url_exists and num_exists) or (not url_exists and not num_exists): - await inter.followup.send( - "Please only specify either the existing rss number or url, " - "enter at least one of these, but don't enter both." + result = await database.session.execute( + delete(RssSourceModel).filter( + and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.rss_url == source + ) + ) ) - return - if url_exists and not validators.url(url): - await inter.followup.send( - "The URL you have entered is malformed or invalid:\n" - f"`{url=}`", + # TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works) + + if result.rowcount: + await followup(inter, + f"RSS source deleted successfully\n**[{rss_source.nick}]({rss_source.rss_url})**", suppress_embeds=True ) return - async with DatabaseManager() as database: - whereclause = and_( - RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.rss_url == url - ) - query = update(RssSourceModel).where(whereclause).values(active=False) - result = await database.session.execute(query) + await followup(inter, "Couldn't find any RSS sources with this name.") - await inter.followup.send(f"I've updated {result.rowcount} rows") + # potential_matches = await self.source_autocomplete(inter, source) @rss_group.command(name="list") - @app_commands.choices(filter=[ - app_commands.Choice(name="Active Only [default]", value=1), - app_commands.Choice(name="Inactive Only", value=0), - app_commands.Choice(name="All", value=2), - ]) - async def list_rss_sources(self, inter: Interaction, filter: app_commands.Choice[int]): + async def list_rss_sources(self, inter: Interaction): + """Provides a with a list of RSS sources available for the current server. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + """ await inter.response.defer() - if filter.value == 2: - whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) - else: - whereclause = and_( - RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.active == filter.value # should result to 0 or 1 - ) - async with DatabaseManager() as database: + whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) query = select(RssSourceModel).where(whereclause) result = await database.session.execute(query) rss_sources = result.scalars().all() - embed_fields = [{ - "name": f"[{i}]", - "value": f"{rss.rss_url} | {'inactive' if not rss.active else 'active'}" - } for i, rss in enumerate(rss_sources)] - if not embed_fields: - await inter.followup.send("It looks like you have no rss sources.") + if not rss_sources: + await followup(inter, "It looks like you have no rss sources.") + return + + output = "## Available RSS Sources\n" + output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources]) + + await followup(inter, output, suppress_embeds=True) + + @rss_group.command(name="fetch") + @autocomplete(rss=source_autocomplete) + async def fetch_rss(self, inter: Interaction, rss: str, max: int=1): + # """""" + + await inter.response.defer() + + if max > 5: + followup(inter, "It looks like you have requested too many articles.\nThe limit is 5") return - embed = Embed( - title="RSS Sources", - description="Here are your rss sources:" - ) + source = get_source(rss) + articles = source.get_latest_articles(max) + + embeds = [] + for article in articles: + md_description = markdownify(article.description, strip=("img",)) + article_description = textwrap.shorten(md_description, 4096) + + embed = Embed( + title=article.title, + description=article_description, + url=article.url, + timestamp=article.published, + ) + embed.set_thumbnail(url=source.icon_url) + embed.set_image(url=await article.get_thumbnail_url()) + embed.set_footer(text=article.author) + embed.set_author( + name=source.name, + url=source.url, + ) + embeds.append(embed) + + async with DatabaseManager() as database: + query = insert(SentArticleModel).values([ + { + "discord_server_id": inter.guild_id, + "discord_channel_id": inter.channel_id, + "discord_message_id": inter.id, + "article_url": article.url, + } + for article in articles + ]) + await database.session.execute(query) + await audit(self, f"User is requesting {max} articles", inter.user.id, database=database) + + await followup(inter, embeds=embeds) + - for field in embed_fields: - embed.add_field(**field, inline=False) - # output = "Your rss sources:\n\n" - # output += "\n".join([f"[{i+1}] {rss.rss_url=} {bool(rss.active)=}" for i, rss in enumerate(rss_sources)]) - await inter.followup.send(embed=embed) async def setup(bot): diff --git a/src/feed.py b/src/feed.py index 63184b0..09b377d 100644 --- a/src/feed.py +++ b/src/feed.py @@ -37,8 +37,18 @@ class Source: feed=feed ) - def get_latest_article(self): - return Article.from_parsed(self.feed) + def get_latest_articles(self, max: int) -> list: + """""" + + articles = [] + + for i, entry in enumerate(self.feed.entries): + if i >= max: + break + + articles.append(Article.from_entry(entry)) + + return articles @dataclass @@ -66,6 +76,21 @@ class Article: author = entry.get("author") ) + @classmethod + def from_entry(cls, entry:FeedParserDict): + + published_parsed = entry.get("published_parsed") + published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None + + return cls( + title=entry.get("title"), + description=entry.get("description"), + url=entry.get("link"), + published=published, + author = entry.get("author") + ) + + async def get_thumbnail_url(self): """ From 41887472dface005e70e9c847b8f29e362794336 Mon Sep 17 00:00:00 2001 From: corbz Date: Sat, 16 Dec 2023 14:21:30 +0000 Subject: [PATCH 04/11] Embeds for interaction responses --- src/extensions/cmd.py | 32 +++++--- src/extensions/test.py | 105 ------------------------ src/feed.py | 177 +++++++++++++++++++++++------------------ 3 files changed, 121 insertions(+), 193 deletions(-) delete mode 100644 src/extensions/test.py diff --git a/src/extensions/cmd.py b/src/extensions/cmd.py index 22515ad..01ac0e3 100644 --- a/src/extensions/cmd.py +++ b/src/extensions/cmd.py @@ -3,6 +3,7 @@ Extension for the `CommandCog`. Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot. """ +import json import logging import validators @@ -10,13 +11,13 @@ import aiohttp import textwrap import feedparser from markdownify import markdownify -from discord import app_commands, Interaction, Embed +from discord import app_commands, Interaction, Embed, Colour from discord.ext import commands, tasks from discord.app_commands import Choice, Group, command, autocomplete from sqlalchemy import insert, select, update, and_, or_, delete from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel -from feed import Feeds, get_source +from feed import get_source, Source log = logging.getLogger(__name__) @@ -156,7 +157,15 @@ class CommandCog(commands.Cog): inter.user.id, database=database ) - await followup(inter, f"RSS source added [{nickname}]({url})", suppress_embeds=True) + embed = Embed( + title=f"New RSS Source: **{nickname}**", + url=url, + colour=Colour.from_str("#59ff00") + ) + embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href")) + + # , f"RSS source added [{nickname}]({url})", suppress_embeds=True + await followup(inter, embed=embed) @rss_group.command(name="remove") @autocomplete(source=source_autocomplete) @@ -176,16 +185,17 @@ class CommandCog(commands.Cog): log.debug(f"Attempting to remove RSS source ({source=})") async with DatabaseManager() as database: - rss_source = (await database.session.execute( + select_result = await database.session.execute( select(RssSourceModel).filter( and_( RssSourceModel.discord_server_id == inter.guild_id, RssSourceModel.rss_url == source ) ) - )).fetchone() + ) + rss_source = select_result.fetchone() - result = await database.session.execute( + delete_result = await database.session.execute( delete(RssSourceModel).filter( and_( RssSourceModel.discord_server_id == inter.guild_id, @@ -194,11 +204,13 @@ class CommandCog(commands.Cog): ) ) + nickname, rss_url = rss_source.nick, rss_source.rss_url + # TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works) - if result.rowcount: + if delete_result.rowcount: await followup(inter, - f"RSS source deleted successfully\n**[{rss_source.nick}]({rss_source.rss_url})**", + f"RSS source deleted successfully\n**[{nickname}]({rss_url})**", suppress_embeds=True ) return @@ -285,10 +297,6 @@ class CommandCog(commands.Cog): await followup(inter, embeds=embeds) - - - - async def setup(bot): """ Setup function for this extension. diff --git a/src/extensions/test.py b/src/extensions/test.py deleted file mode 100644 index 7651254..0000000 --- a/src/extensions/test.py +++ /dev/null @@ -1,105 +0,0 @@ -""" -Extension for the `test` cog. -Loading this file via `commands.Bot.load_extension` will add the `test` cog to the bot. -""" - -import logging - -import textwrap -from markdownify import markdownify -from discord import app_commands, Interaction, Embed -from discord.ext import commands, tasks -from sqlalchemy import insert, select, and_ - -from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel -from feed import Feeds, get_source - -log = logging.getLogger(__name__) - - -class Test(commands.Cog): - """ - News cog. - Delivers embeds of news articles to discord channels. - """ - - def __init__(self, bot): - super().__init__() - self.bot = bot - - @commands.Cog.listener() - async def on_ready(self): - log.info(f"{self.__class__.__name__} cog is ready") - - async def source_autocomplete(self, inter: Interaction, current: str): - """ - - """ - - async with DatabaseManager() as database: - whereclause = and_( - RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.rss_url.ilike(f"%{current}%") - ) - query = select(RssSourceModel).where(whereclause) - result = await database.session.execute(query) - sources = [ - app_commands.Choice(name=rss.rss_url, value=rss.rss_url) - for rss in result.scalars().all() - ] - - return sources - - @app_commands.command(name="test-latest-article") - @app_commands.autocomplete(source=source_autocomplete) - async def test_news(self, inter: Interaction, source: str): - - await inter.response.defer() - await self.bot.audit("Requesting latest article.", inter.user.id) - - try: - source = get_source(source) - article = source.get_latest_article() - except IndexError as e: - log.error(e) - await inter.followup.send("An error occured, it's possible that the source provided was bad.") - return - - md_description = markdownify(article.description, strip=("img",)) - article_description = textwrap.shorten(md_description, 4096) - - embed = Embed( - title=article.title, - description=article_description, - url=article.url, - timestamp=article.published, - ) - embed.set_thumbnail(url=source.icon_url) - embed.set_image(url=await article.get_thumbnail_url()) - embed.set_footer(text=article.author) - embed.set_author( - name=source.name, - url=source.url, - ) - - async with DatabaseManager() as database: - query = insert(SentArticleModel).values( - discord_server_id=inter.guild_id, - discord_channel_id=inter.channel_id, - discord_message_id=inter.id, - article_url=article.url - ) - await database.session.execute(query) - - await inter.followup.send(embed=embed) - - -async def setup(bot): - """ - Setup function for this extension. - Adds the `ErrorCog` cog to the bot. - """ - - cog = Test(bot) - await bot.add_cog(cog) - log.info(f"Added {cog.__class__.__name__} cog") diff --git a/src/feed.py b/src/feed.py index 09b377d..d4a8474 100644 --- a/src/feed.py +++ b/src/feed.py @@ -1,7 +1,9 @@ +""" + +""" import json import logging -from enum import Enum from dataclasses import dataclass from datetime import datetime @@ -10,16 +12,70 @@ from bs4 import BeautifulSoup as bs4 from feedparser import FeedParserDict, parse log = logging.getLogger(__name__) +dumps = lambda _dict: json.dumps(_dict, indent=8) -class Feeds(Enum): - THE_UPPER_LIP = "https://theupperlip.co.uk/rss" - THE_BABYLON_BEE= "https://babylonbee.com/feed" - BBC_NEWS = "https://feeds.bbci.co.uk/news/rss.xml" +@dataclass +class Article: + """Represents a news article, or entry from an RSS feed.""" + + title: str | None + description: str | None + url: str | None + published: datetime | None + author: str | None + + @classmethod + def from_entry(cls, entry:FeedParserDict): + """Create an Article from an RSS feed entry. + + Parameters + ---------- + entry : FeedParserDict + An entry pulled from a complete FeedParserDict object. + + Returns + ------- + Article + The Article created from the feed entry. + """ + + log.debug("Creating Article from entry: %s", dumps(entry)) + + published_parsed = entry.get("published_parsed") + published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None + + return cls( + title=entry.get("title"), + description=entry.get("description"), + url=entry.get("link"), + published=published, + author = entry.get("author") + ) + + async def get_thumbnail_url(self) -> str | None: + """Returns the thumbnail URL for an article. + + Returns + ------- + str or None + The thumbnail URL, or None if not found. + """ + + log.debug("Fetching thumbnail for article: %s", self) + + async with aiohttp.ClientSession() as session: + async with session.get(self.url) as response: + html = await response.text() + + soup = bs4(html, "html.parser") + image_element = soup.select_one("meta[property='og:image']") + return image_element.get("content") if image_element else None @dataclass class Source: + """Represents an RSS source.""" name: str | None url: str | None @@ -28,7 +84,20 @@ class Source: @classmethod def from_parsed(cls, feed:FeedParserDict): - # print(json.dumps(feed, indent=8)) + """Returns a Source object from a parsed feed. + + Parameters + ---------- + feed : FeedParserDict + The feed used to create the Source. + + Returns + ------- + Source + The Source object + """ + + log.debug("Creating Source from feed: %s", dumps(feed)) return cls( name=feed.get("channel", {}).get("title"), @@ -37,86 +106,42 @@ class Source: feed=feed ) - def get_latest_articles(self, max: int) -> list: - """""" + def get_latest_articles(self, max: int) -> list[Article]: + """Returns a list of Article objects. - articles = [] + Parameters + ---------- + max : int + The maximum number of articles to return. - for i, entry in enumerate(self.feed.entries): - if i >= max: - break - - articles.append(Article.from_entry(entry)) - - return articles - - -@dataclass -class Article: - - title: str | None - description: str | None - url: str | None - published: datetime | None - author: str | None - - @classmethod - def from_parsed(cls, feed:FeedParserDict): - entry = feed.entries[0] - # log.debug(json.dumps(entry, indent=8)) - - published_parsed = entry.get("published_parsed") - published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None - - return cls( - title=entry.get("title"), - description=entry.get("description"), - url=entry.get("link"), - published=published, - author = entry.get("author") - ) - - @classmethod - def from_entry(cls, entry:FeedParserDict): - - published_parsed = entry.get("published_parsed") - published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None - - return cls( - title=entry.get("title"), - description=entry.get("description"), - url=entry.get("link"), - published=published, - author = entry.get("author") - ) - - - async def get_thumbnail_url(self): + Returns + ------- + list of Article + A list of Article objects. """ - """ + log.debug("Fetching latest articles from %s, max=%s", self, max) - async with aiohttp.ClientSession() as session: - async with session.get(self.url) as response: - html = await response.text() - - # Parse the thumbnail for the news story - soup = bs4(html, "html.parser") - image_element = soup.select_one("meta[property='og:image']") - return image_element.get("content") if image_element else None + return [ + Article.from_entry(entry) + for i, entry in enumerate(self.feed.entries) + if i < max + ] def get_source(rss_url: str) -> Source: - """ + """_summary_ + Parameters + ---------- + rss_url : str + _description_ + + Returns + ------- + Source + _description_ """ parsed_feed = parse(rss_url) return Source.from_parsed(parsed_feed) - - -def get_test(): - - parsed = parse(Feeds.THE_UPPER_LIP.value) - print(json.dumps(parsed, indent=4)) - return parsed From 1320ab5e26fbf1b9a46d6d9272d249f2b4369145 Mon Sep 17 00:00:00 2001 From: corbz Date: Sat, 16 Dec 2023 17:55:21 +0000 Subject: [PATCH 05/11] =?UTF-8?q?Add=20RSS=20Source=20=C2=B7=20Embed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/extensions/cmd.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/extensions/cmd.py b/src/extensions/cmd.py index 01ac0e3..0b4724e 100644 --- a/src/extensions/cmd.py +++ b/src/extensions/cmd.py @@ -157,11 +157,9 @@ class CommandCog(commands.Cog): inter.user.id, database=database ) - embed = Embed( - title=f"New RSS Source: **{nickname}**", - url=url, - colour=Colour.from_str("#59ff00") - ) + embed = Embed(title="RSS Feed Added", colour=Colour.from_str("#59ff00")) + embed.add_field(name="Nickname", value=nickname) + embed.add_field(name="URL", value=url) embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href")) # , f"RSS source added [{nickname}]({url})", suppress_embeds=True From 1d12a4b482724c5e2f98abd2d501595d272e6f32 Mon Sep 17 00:00:00 2001 From: corbz Date: Sat, 16 Dec 2023 22:43:56 +0000 Subject: [PATCH 06/11] Create launch.json --- .vscode/launch.json | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..e298949 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,17 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: NewsBot", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/src/main.py", + "python": "${workspaceFolder}/venv/bin/python", + "console": "integratedTerminal", + "justMyCode": true + } + ] +} From 866660ef8bc2ce24a2a771a1603f410975c62b8e Mon Sep 17 00:00:00 2001 From: corbz Date: Sat, 16 Dec 2023 23:53:50 +0000 Subject: [PATCH 07/11] =?UTF-8?q?Added=20Datetime=20Columns=20=C2=B7=20Sen?= =?UTF-8?q?tArticleModel,=20RssSourceModel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/db/db.py | 2 +- src/db/models.py | 2 ++ src/feed.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/db/db.py b/src/db/db.py index 1c61cfd..81c5684 100644 --- a/src/db/db.py +++ b/src/db/db.py @@ -24,7 +24,7 @@ class DatabaseManager: """ def __init__(self, no_commit: bool = False): - database_url = self.get_database_url() + database_url = self.get_database_url() # TODO: This is called every time a connection is established, maybe make it once and reference it? self.engine = create_async_engine(database_url, future=True) self.session_maker = sessionmaker(self.engine, class_=AsyncSession) self.session = None diff --git a/src/db/models.py b/src/db/models.py index 7f64668..e44731a 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -38,6 +38,7 @@ class SentArticleModel(Base): discord_channel_id = Column(BigInteger, nullable=False) discord_server_id = Column(BigInteger, nullable=False) article_url = Column(String, nullable=False) + when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) class RssSourceModel(Base): @@ -51,6 +52,7 @@ class RssSourceModel(Base): nick = Column(String, nullable=False) discord_server_id = Column(BigInteger, nullable=False) rss_url = Column(String, nullable=False) + created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # the nickname must be unique, but only within the same discord server __table_args__ = ( diff --git a/src/feed.py b/src/feed.py index d4a8474..c4f12db 100644 --- a/src/feed.py +++ b/src/feed.py @@ -143,5 +143,5 @@ def get_source(rss_url: str) -> Source: _description_ """ - parsed_feed = parse(rss_url) + parsed_feed = parse(rss_url) # TODO: make asyncronous return Source.from_parsed(parsed_feed) From 7ddfe09e4d0c2eb6d635646850babfb8f6810111 Mon Sep 17 00:00:00 2001 From: corbz Date: Sat, 16 Dec 2023 23:54:12 +0000 Subject: [PATCH 08/11] Updated implementation of interaction commands --- src/extensions/cmd.py | 194 +++++++++++++++++++++++++++--------------- 1 file changed, 127 insertions(+), 67 deletions(-) diff --git a/src/extensions/cmd.py b/src/extensions/cmd.py index 0b4724e..220e449 100644 --- a/src/extensions/cmd.py +++ b/src/extensions/cmd.py @@ -3,24 +3,29 @@ Extension for the `CommandCog`. Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot. """ -import json import logging import validators +from typing import Tuple import aiohttp import textwrap import feedparser from markdownify import markdownify -from discord import app_commands, Interaction, Embed, Colour -from discord.ext import commands, tasks -from discord.app_commands import Choice, Group, command, autocomplete -from sqlalchemy import insert, select, update, and_, or_, delete +from discord import Interaction, Embed, Colour +from discord.ext import commands +from discord.app_commands import Choice, Group, autocomplete, choices +from sqlalchemy import insert, select, and_, delete -from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel from feed import get_source, Source +from db import DatabaseManager, SentArticleModel, RssSourceModel log = logging.getLogger(__name__) +rss_list_sort_choices = [ + Choice(name="Nickname", value=0), + Choice(name="Date Added", value=1) +] + async def get_rss_data(url: str): async with aiohttp.ClientSession() as session: async with session.get(url) as response: @@ -44,6 +49,49 @@ async def audit(cog, *args, **kwargs): await cog.bot.audit(*args, **kwargs) +# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. +# Run a period task to check this. +async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]: + """Validate a provided RSS source. + + Parameters + ---------- + nickname : str + Nickname of the source. Must not contain URL. + url : str + URL of the source. Must be URL with valid status code and be an RSS feed. + + Returns + ------- + str or None + String invalid message if invalid, NoneType if valid. + FeedParserDict or None + The feed parsed from the given URL or None if invalid. + """ + + # Ensure the URL is valid + if not validators.url(url): + return f"The URL you have entered is malformed or invalid:\n`{url=}`", None + + # Check the nickname is not a URL + if validators.url(nickname): + return "It looks like the nickname you have entered is a URL.\n" \ + f"For security reasons, this is not allowed.\n`{nickname=}`", None + + + feed_data, status_code = await get_rss_data(url) + + # Check the URL status code is valid + if status_code != 200: + return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None + + # Check the contents is actually an RSS feed. + feed = feedparser.parse(feed_data) + if not feed.version: + return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None + + return None, feed + class CommandCog(commands.Cog): """ @@ -88,61 +136,35 @@ class CommandCog(commands.Cog): return sources + # All RSS commands belong to this group. rss_group = Group( name="rss", description="Commands for rss sources.", - guild_only=True + guild_only=True # We store guild IDs in the database, so guild only = True ) @rss_group.command(name="add") - async def add_rss_source(self, inter: Interaction, url: str, nickname: str): + async def add_rss_source(self, inter: Interaction, nickname: str, url: str): """Add a new RSS source. Parameters ---------- inter : Interaction Represents an app command interaction. - url : str - The RSS feed URL. nickname : str A name used to identify the RSS source. + url : str + The RSS feed URL. """ await inter.response.defer() - # Ensure the URL is valid - if not validators.url(url): - await followup(inter, - f"The URL you have entered is malformed or invalid:\n`{url=}`", - suppress_embeds=True - ) + illegal_message, feed = await validate_rss_source(nickname, url) + if illegal_message: + await followup(inter, illegal_message, suppress_embeds=True) return - # Check the nickname is not a URL - if validators.url(nickname): - await followup(inter, - "It looks like the nickname you have entered is a URL.\n" - f"For security reasons, this is not allowed.\n`{nickname=}`", - suppress_embeds=True - ) - return - - # Check the URL points to an RSS feed. - feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this. - if status_code != 200: - await followup(inter, - f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", - suppress_embeds=True - ) - return - - feed = feedparser.parse(feed_data) - if not feed.version: - await followup(inter, - f"The provided URL '{url}' does not seem to be a valid RSS feed.", - suppress_embeds=True - ) - return + log.debug("RSS feed added") async with DatabaseManager() as database: query = insert(RssSourceModel).values( @@ -157,68 +179,72 @@ class CommandCog(commands.Cog): inter.user.id, database=database ) - embed = Embed(title="RSS Feed Added", colour=Colour.from_str("#59ff00")) + embed = Embed(title="RSS Feed Added", colour=Colour.dark_green()) embed.add_field(name="Nickname", value=nickname) embed.add_field(name="URL", value=url) embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href")) - # , f"RSS source added [{nickname}]({url})", suppress_embeds=True await followup(inter, embed=embed) @rss_group.command(name="remove") - @autocomplete(source=source_autocomplete) - async def remove_rss_source(self, inter: Interaction, source: str): + @autocomplete(url=source_autocomplete) + async def remove_rss_source(self, inter: Interaction, url: str): """Delete an existing RSS source. Parameters ---------- inter : Interaction Represents an app command interaction. - source : str + url : str The RSS source to be removed. Autocomplete or enter the URL. """ await inter.response.defer() - log.debug(f"Attempting to remove RSS source ({source=})") + log.debug(f"Attempting to remove RSS source ({url=})") async with DatabaseManager() as database: select_result = await database.session.execute( select(RssSourceModel).filter( and_( RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.rss_url == source + RssSourceModel.rss_url == url ) ) ) - rss_source = select_result.fetchone() + rss_source = select_result.scalars().one() + nickname = rss_source.nick delete_result = await database.session.execute( delete(RssSourceModel).filter( and_( RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.rss_url == source + RssSourceModel.rss_url == url ) ) ) - nickname, rss_url = rss_source.nick, rss_source.rss_url - - # TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works) - - if delete_result.rowcount: - await followup(inter, - f"RSS source deleted successfully\n**[{nickname}]({rss_url})**", - suppress_embeds=True + await audit(self, + f"Added RSS source ({nickname=}, {url=})", + inter.user.id, database=database ) + + if not delete_result.rowcount: + await followup(inter, "Couldn't find any RSS sources with this name.") return - await followup(inter, "Couldn't find any RSS sources with this name.") + source = get_source(url) - # potential_matches = await self.source_autocomplete(inter, source) + embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red()) + embed.add_field(name="Nickname", value=nickname) + embed.add_field(name="URL", value=url) + embed.set_thumbnail(url=source.icon_url) + + await followup(inter, embed=embed) @rss_group.command(name="list") - async def list_rss_sources(self, inter: Interaction): + @choices(sort=rss_list_sort_choices) + async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False): """Provides a with a list of RSS sources available for the current server. Parameters @@ -229,9 +255,32 @@ class CommandCog(commands.Cog): await inter.response.defer() + # Default to the first choice if not specified. + if type(sort) is Choice: + description = "Sort by " + description += "Nickname " if sort.value == 0 else "Date Added " + description += '\U000025BC' if sort_reverse else '\U000025B2' + else: + sort = rss_list_sort_choices[0] + description = "" + + sort = sort if type(sort) == Choice else rss_list_sort_choices[0] + + match sort.value, sort_reverse: + case 0, False: + order_by = RssSourceModel.nick.asc() + case 0, True: + order_by = RssSourceModel.nick.desc() + case 1, False: # NOTE: + order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest + case 1, True: # date first, not the oldest as it would sort otherwise. + order_by = RssSourceModel.created.asc() + case _, _: + raise ValueError("Unknown sort: %s" % sort) + async with DatabaseManager() as database: whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) - query = select(RssSourceModel).where(whereclause) + query = select(RssSourceModel).where(whereclause).order_by(order_by) result = await database.session.execute(query) rss_sources = result.scalars().all() @@ -240,10 +289,15 @@ class CommandCog(commands.Cog): await followup(inter, "It looks like you have no rss sources.") return - output = "## Available RSS Sources\n" - output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources]) + output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)]) - await followup(inter, output, suppress_embeds=True) + embed = Embed( + title="Saved RSS Feeds", + description=f"{description}\n\n{output}", + colour=Colour.lighter_grey() + ) + + await followup(inter, embed=embed) @rss_group.command(name="fetch") @autocomplete(rss=source_autocomplete) @@ -256,7 +310,12 @@ class CommandCog(commands.Cog): followup(inter, "It looks like you have requested too many articles.\nThe limit is 5") return - source = get_source(rss) + invalid_message, feed = await validate_rss_source("", rss) + if invalid_message: + await followup(inter, invalid_message) + return + + source = Source.from_parsed(feed) articles = source.get_latest_articles(max) embeds = [] @@ -269,6 +328,7 @@ class CommandCog(commands.Cog): description=article_description, url=article.url, timestamp=article.published, + colour=Colour.brand_red() ) embed.set_thumbnail(url=source.icon_url) embed.set_image(url=await article.get_thumbnail_url()) @@ -290,7 +350,7 @@ class CommandCog(commands.Cog): for article in articles ]) await database.session.execute(query) - await audit(self, f"User is requesting {max} articles", inter.user.id, database=database) + await audit(self, f"User is requesting {max} articles from {source.name}", inter.user.id, database=database) await followup(inter, embeds=embeds) From 994b940e0fc00393952a16f1a19f3938a7fc9113 Mon Sep 17 00:00:00 2001 From: corbz Date: Sun, 17 Dec 2023 00:49:12 +0000 Subject: [PATCH 09/11] Create tasks.py --- src/extensions/tasks.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 src/extensions/tasks.py diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py new file mode 100644 index 0000000..9c3207b --- /dev/null +++ b/src/extensions/tasks.py @@ -0,0 +1,35 @@ +""" +Extension for the `TaskCog`. +Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot. +""" + +import logging + +from discord.ext import commands + +log = logging.getLogger(__name__) + + +class TaskCog(commands.Cog): + """ + Command cog. + """ + + def __init__(self, bot): + super().__init__() + self.bot = bot + + @commands.Cog.listener() + async def on_ready(self): + log.info(f"{self.__class__.__name__} cog is ready") + + +async def setup(bot): + """ + Setup function for this extension. + Adds `TaskCog` to the bot. + """ + + cog = TaskCog(bot) + await bot.add_cog(cog) + log.info(f"Added {cog.__class__.__name__} cog") From 428f459e98e25373dc9a136bbed51b5e570d253c Mon Sep 17 00:00:00 2001 From: corbz Date: Mon, 18 Dec 2023 01:06:45 +0000 Subject: [PATCH 10/11] cmd.py to rss.py + utils --- src/extensions/{cmd.py => rss.py} | 43 ++++++++++--------------------- src/utils.py | 31 ++++++++++++++++++++++ 2 files changed, 44 insertions(+), 30 deletions(-) rename src/extensions/{cmd.py => rss.py} (92%) create mode 100644 src/utils.py diff --git a/src/extensions/cmd.py b/src/extensions/rss.py similarity index 92% rename from src/extensions/cmd.py rename to src/extensions/rss.py index 220e449..2ab7bb7 100644 --- a/src/extensions/cmd.py +++ b/src/extensions/rss.py @@ -1,13 +1,12 @@ """ -Extension for the `CommandCog`. -Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot. +Extension for the `RssCog`. +Loading this file via `commands.Bot.load_extension` will add `RssCog` to the bot. """ import logging import validators from typing import Tuple -import aiohttp import textwrap import feedparser from markdownify import markdownify @@ -16,6 +15,7 @@ from discord.ext import commands from discord.app_commands import Choice, Group, autocomplete, choices from sqlalchemy import insert, select, and_, delete +from utils import get_rss_data, followup, audit from feed import get_source, Source from db import DatabaseManager, SentArticleModel, RssSourceModel @@ -26,29 +26,6 @@ rss_list_sort_choices = [ Choice(name="Date Added", value=1) ] -async def get_rss_data(url: str): - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - items = await response.text(), response.status - - return items - -async def followup(inter: Interaction, *args, **kwargs): - """Shorthand for following up on an interaction. - - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - """ - - await inter.followup.send(*args, **kwargs) - -async def audit(cog, *args, **kwargs): - """Shorthand for auditing an interaction.""" - - await cog.bot.audit(*args, **kwargs) - # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. # Run a period task to check this. async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]: @@ -93,7 +70,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feed return None, feed -class CommandCog(commands.Cog): +class RssCog(commands.Cog): """ Command cog. """ @@ -318,6 +295,10 @@ class CommandCog(commands.Cog): source = Source.from_parsed(feed) articles = source.get_latest_articles(max) + if not articles: + await followup(inter, "Sorry, I couldn't find any articles from this feed.") + return + embeds = [] for article in articles: md_description = markdownify(article.description, strip=("img",)) @@ -330,8 +311,10 @@ class CommandCog(commands.Cog): timestamp=article.published, colour=Colour.brand_red() ) + thumbail_url = await article.get_thumbnail_url() + thumbail_url = thumbail_url if validators.url(thumbail_url) else None embed.set_thumbnail(url=source.icon_url) - embed.set_image(url=await article.get_thumbnail_url()) + embed.set_image(url=thumbail_url) embed.set_footer(text=article.author) embed.set_author( name=source.name, @@ -358,9 +341,9 @@ class CommandCog(commands.Cog): async def setup(bot): """ Setup function for this extension. - Adds `CommandCog` to the bot. + Adds `RssCog` to the bot. """ - cog = CommandCog(bot) + cog = RssCog(bot) await bot.add_cog(cog) log.info(f"Added {cog.__class__.__name__} cog") diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..d0c14bb --- /dev/null +++ b/src/utils.py @@ -0,0 +1,31 @@ +"""A collection of utility functions that can be used in various places.""" + +import aiohttp +import logging + +from discord import Interaction + +log = logging.getLogger(__name__) + +async def get_rss_data(url: str): + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + items = await response.text(), response.status + + return items + +async def followup(inter: Interaction, *args, **kwargs): + """Shorthand for following up on an interaction. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + """ + + await inter.followup.send(*args, **kwargs) + +async def audit(cog, *args, **kwargs): + """Shorthand for auditing an interaction.""" + + await cog.bot.audit(*args, **kwargs) From be720285cccc3543fafc77e4049b54101dbf691f Mon Sep 17 00:00:00 2001 From: corbz Date: Mon, 18 Dec 2023 01:07:16 +0000 Subject: [PATCH 11/11] New channels cog + FeedChannelModel changes --- src/db/models.py | 7 ++- src/extensions/channels.py | 106 +++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 src/extensions/channels.py diff --git a/src/db/models.py b/src/db/models.py index e44731a..d4477a0 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -5,7 +5,7 @@ All table classes should be suffixed with `Model`. from enum import Enum, auto -from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint +from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint, ForeignKey from sqlalchemy.sql import func from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import declarative_base @@ -54,6 +54,8 @@ class RssSourceModel(Base): rss_url = Column(String, nullable=False) created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + feed_channels = relationship("FeedChannelModel", cascade="all, delete") + # the nickname must be unique, but only within the same discord server __table_args__ = ( UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'), @@ -69,3 +71,6 @@ class FeedChannelModel(Base): id = Column(Integer, primary_key=True, autoincrement=True) discord_channel_id = Column(BigInteger, nullable=False) + discord_server_id = Column(BigInteger, nullable=False) + search_name = Column(String, nullable=False) + rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False) diff --git a/src/extensions/channels.py b/src/extensions/channels.py new file mode 100644 index 0000000..5f0876b --- /dev/null +++ b/src/extensions/channels.py @@ -0,0 +1,106 @@ +""" +Extension for the `ChannelCog`. +Loading this file via `commands.Bot.load_extension` will add `ChannelCog` to the bot. +""" + +import logging + +from sqlalchemy import select, and_ +from discord import Interaction, TextChannel +from discord.ext import commands +from discord.app_commands import Group, Choice, autocomplete + +from db import DatabaseManager, FeedChannelModel +from utils import followup + +log = logging.getLogger(__name__) + + +class ChannelCog(commands.Cog): + """ + Command cog. + """ + + def __init__(self, bot): + super().__init__() + self.bot = bot + + @commands.Cog.listener() + async def on_ready(self): + log.info(f"{self.__class__.__name__} cog is ready") + + async def autocomplete_existing_feeds(self, inter: Interaction, current: str): + """Returns a list of existing RSS + Channel feeds. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + current : str + The current text entered for the autocomplete. + """ + + async with DatabaseManager() as database: + whereclause = and_( + FeedChannelModel.discord_server_id == inter.guild_id, + FeedChannelModel.search_name.ilike(f"%{current}%") # is this secure from SQL Injection atk ? + ) + query = select(FeedChannelModel).where(whereclause) + result = await database.session.execute(query) + feeds = [ + Choice(name=feed.search_name, value=feed.id) + for feed in result.scalars().all() + ] + + return feeds + + # All RSS commands belong to this group. + channel_group = Group( + name="channel", + description="Commands for channel assignment.", + guild_only=True # We store guild IDs in the database, so guild only = True + ) + + channel_group.command(name="include-feed") + async def include_feed(self, inter: Interaction, rss: str, channel: TextChannel): + """Include a feed within the specified channel. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + rss : str + The RSS feed to include. + channel : TextChannel + The channel to include the feed in. + """ + + await inter.response.defer() + await followup(inter, "Ping") + + channel_group.command(name="exclude-feed") + @autocomplete(option=autocomplete_existing_feeds) + async def exclude_feed(self, inter: Interaction, option: int): + """Undo command for the `/channel include-feed` command. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + option : str + The RSS feed and channel to exclude. + """ + + await inter.response.defer() + await followup(inter, "Pong") + + +async def setup(bot): + """ + Setup function for this extension. + Adds `ChannelCog` to the bot. + """ + + cog = ChannelCog(bot) + await bot.add_cog(cog) + log.info(f"Added {cog.__class__.__name__} cog")