diff --git a/src/extensions/channels.py b/src/extensions/channels.py deleted file mode 100644 index 87f0298..0000000 --- a/src/extensions/channels.py +++ /dev/null @@ -1,222 +0,0 @@ -""" -Extension for the `ChannelCog`. -Loading this file via `commands.Bot.load_extension` will add `ChannelCog` to the bot. -""" - -import logging - -from sqlalchemy.orm import aliased -from sqlalchemy import select, insert, delete, and_ -from discord import Interaction, TextChannel, Embed, Colour -from discord.ext import commands -from discord.app_commands import Group, Choice, autocomplete, choices - -from db import DatabaseManager, FeedChannelModel, RssSourceModel -from utils import followup - -log = logging.getLogger(__name__) - - -class ChannelCog(commands.Cog): - """ - Command cog. - """ - - def __init__(self, bot): - super().__init__() - self.bot = bot - - @commands.Cog.listener() - async def on_ready(self): - log.info(f"{self.__class__.__name__} cog is ready") - - async def autocomplete_rss_sources(self, inter: Interaction, nickname: str): - """""" - - async with DatabaseManager() as database: - whereclause = and_( - RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.nick.ilike(f"%{nickname}%") - ) - query = select(RssSourceModel).where(whereclause) - result = await database.session.execute(query) - sources = [ - Choice(name=rss.nick, value=rss.id) - for rss in result.scalars().all() - ] - - log.debug(f"Autocomplete rss_sources returned {len(sources)} results") - - return sources - - async def autocomplete_existing_feeds(self, inter: Interaction, current: str): - """Returns a list of existing RSS + Channel feeds. - - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - current : str - The current text entered for the autocomplete. - """ - - async with DatabaseManager() as database: - whereclause = and_( - FeedChannelModel.discord_server_id == inter.guild_id, - FeedChannelModel.search_name.ilike(f"%{current}%"), # is this secure from SQL Injection atk ? - RssSourceModel.id == FeedChannelModel.rss_source_id - ) - query = ( - select(FeedChannelModel, RssSourceModel) - .where(whereclause) - .join(RssSourceModel) - .order_by(FeedChannelModel.discord_channel_id) - ) - result = await database.session.execute(query) - feeds = [] - for feed in result.scalars().all(): - channel = inter.guild.get_channel(feed.discord_channel_id) - feeds.append(Choice(name=f"# {channel.name} | {feed.rss_source.nick}", value=feed.id)) - - log.debug(f"Autocomplete existing_feeds returned {len(feeds)} results") - - return feeds - - # All RSS commands belong to this group. - channel_group = Group( - name="channels", - description="Commands for channel assignment.", - guild_only=True # These commands belong to channels of - ) - - @channel_group.command(name="include-feed") - @autocomplete(rss=autocomplete_rss_sources) - async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None): - """Include a feed within the specified channel. - - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - rss : str - The RSS feed to include. - channel : TextChannel - The channel to include the feed in. - """ - - await inter.response.defer() - - channel = channel or inter.channel - - async with DatabaseManager() as database: - select_query = select(RssSourceModel).where(and_( - RssSourceModel.id == rss, - RssSourceModel.discord_server_id == inter.guild_id - )) - - select_result = await database.session.execute(select_query) - rss_source = select_result.scalars().one() - nick, rss_url = rss_source.nick, rss_source.rss_url - - insert_query = insert(FeedChannelModel).values( - discord_server_id = inter.guild_id, - discord_channel_id = channel.id, - rss_source_id=rss, - search_name=f"{nick} #{channel.name}" - ) - - insert_result = await database.session.execute(insert_query) - - - await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}") - - @channel_group.command(name="exclude-feed") - @autocomplete(option=autocomplete_existing_feeds) - async def exclude_feed(self, inter: Interaction, option: int): - """Undo command for the `/channel include-feed` command. - - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - option : str - The RSS feed and channel to exclude. - """ - - await inter.response.defer() - - async with DatabaseManager() as database: - query = delete(FeedChannelModel).where(and_( - FeedChannelModel.id == option, - FeedChannelModel.discord_server_id == inter.guild_id - )) - - result = await database.session.execute(query) - - if not result.rowcount: - await followup(inter, "I couldn't find any items under that ID (placeholder response)") - return - - await followup(inter, "I've removed this item (placeholder response)") - - @channel_group.command(name="list") - # @choices(sort=[ - # Choice(name="RSS Nickname", value=0), - # Choice(name="Channel ID", value=1), - # Choice(name="Date Added", value=2) - # ]) - async def list_feeds(self, inter: Interaction): # sort: int - """List all of the channels and their respective included feeds. - - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - """ - - await inter.response.defer() - - async with DatabaseManager() as database: - whereclause = and_( - FeedChannelModel.discord_server_id == inter.guild_id, - RssSourceModel.id == FeedChannelModel.rss_source_id - ) - query = ( - select(FeedChannelModel, RssSourceModel) - .where(whereclause) - .join(RssSourceModel) - .order_by(FeedChannelModel.discord_channel_id) - ) - result = await database.session.execute(query) - - feed_channels = result.scalars().all() - rowcount = len(feed_channels) - - if not feed_channels: - await followup(inter, "It looks like there are no feed channels available.") - return - - output = "\n".join([ - f"{i}. <#{feed.discord_channel_id}> · [{feed.rss_source.nick}]({feed.rss_source.rss_url})" - for i, feed in enumerate(feed_channels) - ]) - - embed = Embed( - title="Saved Feed Channels", - description=f"{output}", - colour=Colour.blue() - ) - embed.set_footer(text=f"Showing {rowcount} results") - - await followup(inter, embed=embed) - - -async def setup(bot): - """ - Setup function for this extension. - Adds `ChannelCog` to the bot. - """ - - cog = ChannelCog(bot) - await bot.add_cog(cog) - log.info(f"Added {cog.__class__.__name__} cog") diff --git a/src/extensions/rss.py b/src/extensions/rss.py index 9023f16..56bb18e 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -4,20 +4,23 @@ Loading this file via `commands.Bot.load_extension` will add `RssCog` to the bot """ import logging -import validators from typing import Tuple -import textwrap -import feedparser -from markdownify import markdownify -from discord import Interaction, Embed, Colour +import validators +from feedparser import FeedParserDict, parse from discord.ext import commands -from discord.app_commands import Choice, Group, autocomplete, choices +from discord import Interaction, Embed, Colour, TextChannel +from discord.app_commands import Choice, Group, autocomplete, choices, rename from sqlalchemy import insert, select, and_, delete -from utils import get_rss_data, followup, audit -from feed import get_source, Source -from db import DatabaseManager, SentArticleModel, RssSourceModel +from utils import get_rss_data, followup, audit # pylint: disable=E0401 +from feed import get_source, Source # pylint: disable=E0401 +from db import ( # pylint: disable=E0401 + DatabaseManager, + SentArticleModel, + RssSourceModel, + FeedChannelModel +) log = logging.getLogger(__name__) @@ -26,9 +29,9 @@ rss_list_sort_choices = [ Choice(name="Date Added", value=1) ] -# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. -# Run a period task to check this. -async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]: +# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the +# target resource. Run a period task to check this. +async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]: """Validate a provided RSS source. Parameters @@ -63,7 +66,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feed return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None # Check the contents is actually an RSS feed. - feed = feedparser.parse(feed_data) + feed = parse(feed_data) if not feed.version: return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None @@ -81,7 +84,9 @@ class RssCog(commands.Cog): @commands.Cog.listener() async def on_ready(self): - log.info(f"{self.__class__.__name__} cog is ready") + """Instructions to call when the cog is ready.""" + + log.info("%s cog is ready", self.__class__.__name__) async def source_autocomplete(self, inter: Interaction, nickname: str): """Provides RSS source autocomplete functionality for commands. @@ -114,13 +119,13 @@ class RssCog(commands.Cog): return sources # All RSS commands belong to this group. - rss_group = Group( - name="rss", + feed_group = Group( + name="feed", description="Commands for rss sources.", guild_only=True # We store guild IDs in the database, so guild only = True ) - @rss_group.command(name="add") + @feed_group.command(name="add") async def add_rss_source(self, inter: Interaction, nickname: str, url: str): """Add a new RSS source. @@ -163,7 +168,7 @@ class RssCog(commands.Cog): await followup(inter, embed=embed) - @rss_group.command(name="remove") + @feed_group.command(name="remove") @autocomplete(url=source_autocomplete) async def remove_rss_source(self, inter: Interaction, url: str): """Delete an existing RSS source. @@ -178,7 +183,7 @@ class RssCog(commands.Cog): await inter.response.defer() - log.debug(f"Attempting to remove RSS source ({url=})") + log.debug("Attempting to remove RSS source (url=%s)", url) async with DatabaseManager() as database: select_result = await database.session.execute( @@ -219,7 +224,7 @@ class RssCog(commands.Cog): await followup(inter, embed=embed) - @rss_group.command(name="list") + @feed_group.command(name="list") @choices(sort=rss_list_sort_choices) async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False): """Provides a with a list of RSS sources available for the current server. @@ -278,9 +283,10 @@ class RssCog(commands.Cog): await followup(inter, embed=embed) - @rss_group.command(name="fetch") + @feed_group.command(name="fetch") + @rename(max_="max") @autocomplete(rss=source_autocomplete) - async def fetch_rss(self, inter: Interaction, rss: str, max: int=1): + async def fetch_rss(self, inter: Interaction, rss: str, max_: int=1): """Fetch an item from the specified RSS feed. Parameters @@ -289,13 +295,13 @@ class RssCog(commands.Cog): Represents an app command interaction. rss : str The RSS feed to fetch from. - max : int, optional + max_ : int, optional Maximum number of items to fetch, by default 1, limits at 5. """ await inter.response.defer() - if max > 5: + if max_ > 5: followup(inter, "It looks like you have requested too many articles.\nThe limit is 5") return @@ -305,7 +311,7 @@ class RssCog(commands.Cog): return source = Source.from_parsed(feed) - articles = source.get_latest_articles(max) + articles = source.get_latest_articles(max_) if not articles: await followup(inter, "Sorry, I couldn't find any articles from this feed.") @@ -324,11 +330,196 @@ class RssCog(commands.Cog): for article in articles ]) await database.session.execute(query) - await audit(self, f"User is requesting {max} articles from {source.name}", inter.user.id, database=database) + await audit(self, f"User is requesting {max_} articles from {source.name}", inter.user.id, database=database) await followup(inter, embeds=embeds) + # Channels ---- ---- ---- + + + async def autocomplete_rss_sources(self, inter: Interaction, nickname: str): + """""" + + async with DatabaseManager() as database: + whereclause = and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.nick.ilike(f"%{nickname}%") + ) + query = select(RssSourceModel).where(whereclause) + result = await database.session.execute(query) + sources = [ + Choice(name=rss.nick, value=rss.id) + for rss in result.scalars().all() + ] + + log.debug("Autocomplete rss_sources returned %s results", len(sources)) + + return sources + + async def autocomplete_existing_feeds(self, inter: Interaction, current: str): + """Returns a list of existing RSS + Channel feeds. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + current : str + The current text entered for the autocomplete. + """ + + async with DatabaseManager() as database: + whereclause = and_( + FeedChannelModel.discord_server_id == inter.guild_id, + FeedChannelModel.search_name.ilike(f"%{current}%"), # is this secure from SQL Injection atk ? + RssSourceModel.id == FeedChannelModel.rss_source_id + ) + query = ( + select(FeedChannelModel, RssSourceModel) + .where(whereclause) + .join(RssSourceModel) + .order_by(FeedChannelModel.discord_channel_id) + ) + result = await database.session.execute(query) + feeds = [] + for feed in result.scalars().all(): + channel = inter.guild.get_channel(feed.discord_channel_id) + feeds.append(Choice(name=f"# {channel.name} | {feed.rss_source.nick}", value=feed.id)) + + log.debug("Autocomplete existing_feeds returned %s results", len(feeds)) + + return feeds + + # # All RSS commands belong to this group. + # channel_group = Group( + # name="channels", + # description="Commands for channel assignment.", + # guild_only=True # These commands belong to channels of + # ) + + @feed_group.command(name="assign") + @autocomplete(rss=autocomplete_rss_sources) + async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None): + """Include a feed within the specified channel. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + rss : str + The RSS feed to include. + channel : TextChannel + The channel to include the feed in. + """ + + await inter.response.defer() + + channel = channel or inter.channel + + async with DatabaseManager() as database: + select_query = select(RssSourceModel).where(and_( + RssSourceModel.id == rss, + RssSourceModel.discord_server_id == inter.guild_id + )) + + select_result = await database.session.execute(select_query) + rss_source = select_result.scalars().one() + nick, rss_url = rss_source.nick, rss_source.rss_url + + insert_query = insert(FeedChannelModel).values( + discord_server_id = inter.guild_id, + discord_channel_id = channel.id, + rss_source_id=rss, + search_name=f"{nick} #{channel.name}" + ) + + await database.session.execute(insert_query) + + + await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}") + + @feed_group.command(name="unassign") + @autocomplete(option=autocomplete_existing_feeds) + async def exclude_feed(self, inter: Interaction, option: int): + """Undo command for the `/channel include-feed` command. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + option : str + The RSS feed and channel to exclude. + """ + + await inter.response.defer() + + async with DatabaseManager() as database: + query = delete(FeedChannelModel).where(and_( + FeedChannelModel.id == option, + FeedChannelModel.discord_server_id == inter.guild_id + )) + + result = await database.session.execute(query) + + if not result.rowcount: + await followup(inter, "I couldn't find any items under that ID (placeholder response)") + return + + await followup(inter, "I've removed this item (placeholder response)") + + @feed_group.command(name="channels") + # @choices(sort=[ + # Choice(name="RSS Nickname", value=0), + # Choice(name="Channel ID", value=1), + # Choice(name="Date Added", value=2) + # ]) + async def list_feeds(self, inter: Interaction): # sort: int + """List all of the channels and their respective included feeds. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + """ + + await inter.response.defer() + + async with DatabaseManager() as database: + whereclause = and_( + FeedChannelModel.discord_server_id == inter.guild_id, + RssSourceModel.id == FeedChannelModel.rss_source_id + ) + query = ( + select(FeedChannelModel, RssSourceModel) + .where(whereclause) + .join(RssSourceModel) + .order_by(FeedChannelModel.discord_channel_id) + ) + result = await database.session.execute(query) + + feed_channels = result.scalars().all() + rowcount = len(feed_channels) + + if not feed_channels: + await followup(inter, "It looks like there are no feed channels available.") + return + + output = "\n".join([ + f"{i}. <#{feed.discord_channel_id}> · [{feed.rss_source.nick}]({feed.rss_source.rss_url})" + for i, feed in enumerate(feed_channels) + ]) + + embed = Embed( + title="Saved Feed Channels", + description=f"{output}", + colour=Colour.blue() + ) + embed.set_footer(text=f"Showing {rowcount} results") + + await followup(inter, embed=embed) + + + async def setup(bot): """ Setup function for this extension. @@ -337,4 +528,4 @@ async def setup(bot): cog = RssCog(bot) await bot.add_cog(cog) - log.info(f"Added {cog.__class__.__name__} cog") + log.info("Added %s cog", cog.__class__.__name__)