diff --git a/src/db/models.py b/src/db/models.py index c462038..97554af 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -4,8 +4,7 @@ All table classes should be suffixed with `Model`. """ from sqlalchemy.sql import func -from sqlalchemy.orm import relationship -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship, declarative_base from sqlalchemy import ( Column, Integer, diff --git a/src/errors.py b/src/errors.py index 3f528af..d21d74b 100644 --- a/src/errors.py +++ b/src/errors.py @@ -1,4 +1,5 @@ class IllegalFeed(Exception): - pass - + def __init__(self, message: str, **items): + super().__init__(message) + self.items = items diff --git a/src/extensions/rss.py b/src/extensions/rss.py index 8d91f13..b9d4b03 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -12,18 +12,25 @@ from discord.ext import commands from discord import Interaction, Embed, Colour, TextChannel, Permissions from discord.app_commands import Choice, Group, autocomplete, choices, rename from sqlalchemy import insert, select, and_, delete -from sqlalchemy.exc import NoResultFound +from sqlalchemy.exc import NoResultFound, IntegrityError -from utils import get_rss_data, followup, audit, followup_error, extract_error_info # pylint: disable=E0401 -from feed import Source # pylint: disable=E0401 -from db import ( # pylint: disable=E0401 +from feed import Source +from errors import IllegalFeed +from db import ( DatabaseManager, SentArticleModel, RssSourceModel, FeedChannelModel, AuditModel ) -from errors import IllegalFeed +from utils import ( + Followup, + get_rss_data, + followup, + audit, + extract_error_info, + get_unparsed_feed +) log = logging.getLogger(__name__) @@ -81,7 +88,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed return None, feed async def set_all_articles_as_sent(inter, channel: TextChannel, feed_id: int, rss_url: str): - unparsed_feed = await self.bot.functions.get_unparsed_feed(rss_url) + unparsed_feed = await get_unparsed_feed(rss_url) source = Source.from_parsed(parse(unparsed_feed)) articles = source.get_latest_articles() @@ -148,101 +155,89 @@ class FeedCog(commands.Cog): feed_group = Group( name="feed", description="Commands for rss sources.", - guild_only=True, # We store guild IDs in the database, so guild only = True - default_permissions=Permissions.elevated() + default_permissions=Permissions.elevated(), + guild_only=True # We store guild IDs in the database, so guild only = True ) @feed_group.command(name="add") async def add_rss_source(self, inter: Interaction, nickname: str, url: str): - """Add a new RSS source. + """Add a new Feed for this server. Parameters ---------- inter : Interaction Represents an app command interaction. nickname : str - A name used to identify the RSS source. + A name used to identify the Feed. url : str - The RSS feed URL. + The Feed URL. """ await inter.response.defer() try: - source = self.bot.functions.create_new_feed(nickname, url) + source = await self.bot.functions.create_new_feed(nickname, url, inter.guild_id) except IllegalFeed as error: title, desc = extract_error_info(error) - followup_error(inter, title=title, description=desc) - - embed = Embed(title="RSS Feed Added", colour=Colour.dark_green()) - embed.add_field(name="Nickname", value=nickname) - embed.add_field(name="URL", value=url) - embed.set_thumbnail(url=source.thumb_url) - - await followup(inter, embed=embed) + await Followup(title, desc).fields(**error.items).error().send(inter) + except IntegrityError as error: + await ( + Followup( + "Duplicate Feed Error", + "A Feed with the same nickname already exist." + ) + .fields(nickname=nickname) + .error() + .send(inter) + ) + else: + await ( + Followup("Feed Added") + .image(source.icon_url) + .fields(nickname=nickname, url=url) + .added() + .send(inter) + ) @feed_group.command(name="remove") @rename(url="option") @autocomplete(url=source_autocomplete) async def remove_rss_source(self, inter: Interaction, url: str): - """Delete an existing RSS source. + """Delete an existing Feed from this server. Parameters ---------- inter : Interaction Represents an app command interaction. url : str - The RSS source to be removed. Autocomplete or enter the URL. + The Feed to be removed. Autocomplete or enter the URL. """ await inter.response.defer() - log.debug("Attempting to remove RSS source (url=%s)", url) - - async with DatabaseManager() as database: - whereclause = and_( - RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.rss_url == url - ) - - # We will select the item first, so we can reference it's nickname later. - select_query = select(RssSourceModel).filter(whereclause) - select_result = await database.session.execute(select_query) - - try: - rss_source = select_result.scalars().one() - except NoResultFound: - await followup_error(inter, - title="Error Deleting Feed", - message=f"I couldn't find anything for `{url}`" + try: + source = await self.bot.functions.delete_feed(url, inter.guild_id) + except NoResultFound: + await ( + Followup( + "Feed Not Found Error", + "A Feed with these parameters could not be found." ) - return - - nickname = rss_source.nick - - delete_query = delete(RssSourceModel).filter(whereclause) - delete_result = await database.session.execute(delete_query) - - await audit(self, - f"Deleted RSS source ({nickname=}, {url=})", - inter.user.id, database=database + .error() + .send(inter) + ) + else: + await ( + Followup("Feed Deleted") + .image(source.icon_url) + .fields(url=url) + .trash() + .send(inter) ) - - source = await Source.from_url(url) - - embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red()) - embed.add_field(name="Nickname", value=nickname) - embed.add_field(name="URL", value=url) - embed.set_thumbnail(url=source.icon_url) - - await followup(inter, embed=embed) @feed_group.command(name="list") - @choices(sort=rss_list_sort_choices) - async def list_rss_sources( - self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False - ): - """Provides a with a list of RSS sources available for the current server. + async def list_rss_sources(self, inter: Interaction): + """Provides a with a list of Feeds available for this server. Parameters ---------- @@ -252,55 +247,32 @@ class FeedCog(commands.Cog): await inter.response.defer() - # Default to the first choice if not specified. - if isinstance(sort, Choice): - description = "Sort by " - description += "Nickname " if sort.value == 0 else "Date Added " - description += '\U000025BC' if sort_reverse else '\U000025B2' - else: - sort = rss_list_sort_choices[0] - description = "" - - match sort.value, sort_reverse: - case 0, False: - order_by = RssSourceModel.nick.asc() - case 0, True: - order_by = RssSourceModel.nick.desc() - case 1, False: - order_by = RssSourceModel.created.desc() - case 1, True: - order_by = RssSourceModel.created.asc() - case _, _: - raise ValueError(f"Unknown sort: {sort}") - - async with DatabaseManager() as database: - whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) - query = select(RssSourceModel).where(whereclause).order_by(order_by) - result = await database.session.execute(query) - - rss_sources = result.scalars().all() - rowcount = len(rss_sources) - - if not rss_sources: - await followup_error(inter, - title="No Feeds Found", - message="I couldn't find any Feeds for this server." + try: + sources = await self.bot.functions.get_feeds(inter.guild_id) + except NoResultFound: + await ( + Followup( + "Feeds Not Found Error", + "There are no available Feeds for this server.\n" + "Add a new feed with `/feed add`." ) - return - - output = "\n".join([ - f"{i}. **[{rss.nick}]({rss.rss_url})** " - for i, rss in enumerate(rss_sources) + .error() + .send() + ) + else: + description = "\n".join([ + f"{i}. **[{source.name}]({source.url})**" + for i, source in enumerate(sources) ]) + await ( + Followup( + f"Available Feeds in {inter.guild.name}", + description + ) + .info() + .send(inter) + ) - embed = Embed( - title="Saved RSS Feeds", - description=f"{description}\n\n{output}", - colour=Colour.blue() - ) - embed.set_footer(text=f"Showing {rowcount} results") - - await followup(inter, embed=embed) # @feed_group.command(name="fetch") # @rename(max_="max") @@ -431,7 +403,7 @@ class FeedCog(commands.Cog): query = select(RssSourceModel).where(whereclause) result = await database.session.execute(query) sources = [ - Choice(name=rss.nick, value=rss.id) + Choice(name=rss.nick, value=rss.rss_url) for rss in result.scalars().all() ] @@ -483,10 +455,10 @@ class FeedCog(commands.Cog): # ) @feed_group.command(name="assign") - @rename(rss="feed") - @autocomplete(rss=autocomplete_rss_sources) + @rename(url="feed") + @autocomplete(url=autocomplete_rss_sources) async def include_feed( - self, inter: Interaction, rss: int, channel: TextChannel = None, prevent_spam: bool = True + self, inter: Interaction, url: str, channel: TextChannel = None, prevent_spam: bool = True ): """Include a feed within the specified channel. @@ -494,7 +466,7 @@ class FeedCog(commands.Cog): ---------- inter : Interaction Represents an app command interaction. - rss : int + url : int The RSS feed to include. channel : TextChannel The channel to include the feed in. @@ -504,30 +476,41 @@ class FeedCog(commands.Cog): channel = channel or inter.channel - async with DatabaseManager() as database: - select_query = select(RssSourceModel).where(and_( - RssSourceModel.id == rss, - RssSourceModel.discord_server_id == inter.guild_id - )) - - select_result = await database.session.execute(select_query) - rss_source = select_result.scalars().one() - nick, rss_url = rss_source.nick, rss_source.rss_url - - insert_query = insert(FeedChannelModel).values( - discord_server_id = inter.guild_id, - discord_channel_id = channel.id, - rss_source_id=rss, - search_name=f"{nick} #{channel.name}" + try: + feed_id, source = await self.bot.functions.assign_feed( + url, channel.name, channel.id, inter.guild_id + ) + except IntegrityError: + await ( + Followup( + "Duplicate Assigned Feed Error", + f"This Feed has already been assigned to {channel.mention}" + ) + .error() + .send(inter) + ) + except NoResultFound: + await ( + Followup( + "Feed Not Found Error", + "A Feed with these parameters could not be found." + ) + .error() + .send(inter) + ) + else: + await ( + Followup( + "Feed Assigned", + f"I've assigned {channel.mention} to receive content from " + f"[{source.name}]({source.url})." + ) + .assign() + .send(inter) ) - insert_result = await database.session.execute(insert_query) - feed_id = insert_result.inserted_primary_key.id - if prevent_spam: - await set_all_articles_as_sent(inter, channel, feed_id, rss_url) - - await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}") + await set_all_articles_as_sent(inter, channel, feed_id, url) @feed_group.command(name="unassign") @autocomplete(option=autocomplete_existing_feeds) @@ -544,20 +527,41 @@ class FeedCog(commands.Cog): await inter.response.defer() - async with DatabaseManager() as database: - query = delete(FeedChannelModel).where(and_( - FeedChannelModel.id == option, - FeedChannelModel.discord_server_id == inter.guild_id - )) - - result = await database.session.execute(query) - - if not result.rowcount: - await followup_error(inter, - title="Assigned Feed Not Found", - message=f"I couldn't find any assigned feeds for the option: {option}" + try: + await self.bot.functions.unassign_feed(option, inter.guild_id) + except NoResultFound: + await ( + Followup( + "Assigned Feed Not Found", + "The assigned Feed doesn't exist." + ) + .error() + .send(inter) ) - return + else: + await ( + Followup( + "Unassigned Feed", + "Feed has been unassigned." + ) + .trash() + .send(inter) + ) + + # async with DatabaseManager() as database: + # query = delete(FeedChannelModel).where(and_( + # FeedChannelModel.id == option, + # FeedChannelModel.discord_server_id == inter.guild_id + # )) + + # result = await database.session.execute(query) + + # if not result.rowcount: + # await followup_error(inter, + # title="Assigned Feed Not Found", + # message=f"I couldn't find any assigned feeds for the option: {option}" + # ) + # return await followup(inter, "I've removed this item (placeholder response)") diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py index 5e25966..8b5599a 100644 --- a/src/extensions/tasks.py +++ b/src/extensions/tasks.py @@ -6,14 +6,20 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo import logging from time import process_time -from feedparser import parse -from sqlalchemy import insert, select, and_ -from discord import Interaction, TextChannel +from discord import TextChannel from discord.ext import commands, tasks from discord.errors import Forbidden +from sqlalchemy import insert, select, and_ +from feedparser import parse -from feed import Source, Article # pylint disable=E0401 -from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel # pylint disable=E0401 +from feed import Source, Article +from db import ( + DatabaseManager, + FeedChannelModel, + RssSourceModel, + SentArticleModel +) +from utils import get_unparsed_feed log = logging.getLogger(__name__) @@ -68,7 +74,7 @@ class TaskCog(commands.Cog): channel = self.bot.get_channel(feed.discord_channel_id) - unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url) + unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url) parsed_feed = parse(unparsed_content) source = Source.from_parsed(parsed_feed) articles = source.get_latest_articles(5) diff --git a/src/feed.py b/src/feed.py index 1bdcc37..7472c6d 100644 --- a/src/feed.py +++ b/src/feed.py @@ -1,22 +1,23 @@ import json import logging -import async_timeout from dataclasses import dataclass from datetime import datetime -from typing import Tuple -import aiohttp +import aiohttp import validators -from textwrap import shorten -from markdownify import markdownify from discord import Embed, Colour from bs4 import BeautifulSoup as bs4 from feedparser import FeedParserDict, parse +from markdownify import markdownify +from sqlalchemy import select, insert, delete, and_ +from sqlalchemy.exc import NoResultFound +from textwrap import shorten -from utils import audit from errors import IllegalFeed +from db import DatabaseManager, RssSourceModel, FeedChannelModel +from utils import get_rss_data, get_unparsed_feed log = logging.getLogger(__name__) dumps = lambda _dict: json.dumps(_dict, indent=8) @@ -162,8 +163,8 @@ class Source: @classmethod async def from_url(cls, url: str): - unparsed_content = await Functions.get_unparsed_feed(url) - return + unparsed_content = await get_unparsed_feed(url) + return cls.from_parsed(parse(unparsed_content)) def get_latest_articles(self, max: int = 999) -> list[Article]: """Returns a list of Article objects. @@ -193,19 +194,26 @@ class Functions: def __init__(self, bot): self.bot = bot - @staticmethod - async def fetch(session, url: str) -> str: - async with async_timeout.timeout(20): - async with session.get(url) as response: - return await response.text() - - @staticmethod - async def get_unparsed_feed(url: str): - async with aiohttp.ClientSession() as session: - return await self.fetch(session, url) # TODO: work from here - async def validate_feed(self, nickname: str, url: str) -> FeedParserDict: - """""" + """Validates a feed based on the given nickname and url. + + Parameters + ---------- + nickname : str + Human readable nickname used to refer to the feed. + url : str + URL to fetch content from the feed. + + Returns + ------- + FeedParserDict + A Parsed Dictionary of the feed. + + Raises + ------ + IllegalFeed + If the feed is invalid. + """ # Ensure the URL is valid if not validators.url(url): @@ -215,28 +223,47 @@ class Functions: if validators.url(nickname): raise IllegalFeed( "It looks like the nickname you have entered is a URL.\n" \ - f"For security reasons, this is not allowed.\n`{nickname=}`" + "For security reasons, this is not allowed.", + nickname=nickname ) feed_data, status_code = await get_rss_data(url) if status_code != 200: raise IllegalFeed( - f"The URL provided returned an invalid status code:\n{url=}, {status_code=}" + "The URL provided returned an invalid status code:", + url=url, status_code=status_code ) # Check the contents is actually an RSS feed. feed = parse(feed_data) if not feed.version: raise IllegalFeed( - f"The provided URL '{url}' does not seem to be a valid RSS feed." + "The provided URL does not seem to be a valid RSS feed.", + url=url ) return feed - async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source: - """""" + """Create a new Feed, and return it as a Source object. + + Parameters + ---------- + nickname : str + Human readable nickname used to refer to the feed. + url : str + URL to fetch content from the feed. + guild_id : int + Discord Server ID associated with the feed. + + Returns + ------- + Source + Dataclass containing attributes of the feed. + """ + + log.info("Creating new Feed: %s - %s", nickname, guild_id) parsed_feed = await self.validate_feed(nickname, url) @@ -248,4 +275,110 @@ class Functions: ) await database.session.execute(query) + log.info("Created Feed: %s - %s", nickname, guild_id) + return Source.from_parsed(parsed_feed) + + async def delete_feed(self, url: str, guild_id: int) -> Source: + """Delete an existing Feed, then return it as a Source object. + + Parameters + ---------- + url : str + URL of the feed, used in the whereclause. + guild_id : int + Discord Server ID of the feed, used in the whereclause. + + Returns + ------- + Source + Dataclass containing attributes of the feed. + """ + + log.info("Deleting Feed: %s - %s", url, guild_id) + + async with DatabaseManager() as database: + whereclause = and_( + RssSourceModel.discord_server_id == guild_id, + RssSourceModel.rss_url == url + ) + + # Select the Feed entry, because an exception is raised if not found. + select_query = select(RssSourceModel).filter(whereclause) + select_result = await database.session.execute(select_query) + select_result.scalars().one() + + delete_query = delete(RssSourceModel).filter(whereclause) + await database.session.execute(delete_query) + + log.info("Deleted Feed: %s - %s", url, guild_id) + + return await Source.from_url(url) + + async def get_feeds(self, guild_id: int) -> list[Source]: + """Returns a list of fetched Feed objects from the database. + Note: a request will be made too all found Feed URLs. + + Parameters + ---------- + guild_id : int + The Discord Server ID, used to filter down the Feed query. + + Returns + ------- + list[Source] + List of Source objects, resulting from the query. + + Raises + ------ + NoResultFound + Raised if no results are found. + """ + + async with DatabaseManager() as database: + whereclause = and_(RssSourceModel.discord_server_id == guild_id) + query = select(RssSourceModel).where(whereclause) + result = await database.session.execute(query) + rss_sources = result.scalars().all() + + if not rss_sources: + raise NoResultFound + + return [await Source.from_url(feed.rss_url) for feed in rss_sources] + + async def assign_feed( + self, url: str, channel_name: str, channel_id: int, guild_id: int + ) -> tuple[int, Source]: + """""" + + async with DatabaseManager() as database: + select_query = select(RssSourceModel).where(and_( + RssSourceModel.rss_url == url, + RssSourceModel.discord_server_id == guild_id + )) + select_result = await database.session.execute(select_query) + + rss_source = select_result.scalars().one() + + insert_query = insert(FeedChannelModel).values( + discord_server_id = guild_id, + discord_channel_id = channel_id, + rss_source_id=rss_source.id, + search_name=f"{rss_source.nick} #{channel_name}" + ) + + insert_result = await database.session.execute(insert_query) + return insert_result.inserted_primary_key.id, await Source.from_url(url) + + async def unassign_feed( self, assigned_feed_id: int, guild_id: int): + """""" + + async with DatabaseManager() as database: + query = delete(FeedChannelModel).where(and_( + FeedChannelModel.id == assigned_feed_id, + FeedChannelModel.discord_server_id == guild_id + )) + + result = await database.session.execute(query) + if not result.rowcount: + raise NoResultFound diff --git a/src/utils.py b/src/utils.py index 8bc0ed9..d4a03f5 100644 --- a/src/utils.py +++ b/src/utils.py @@ -2,11 +2,21 @@ import aiohttp import logging +import async_timeout from discord import Interaction, Embed, Colour log = logging.getLogger(__name__) +async def fetch(session, url: str) -> str: + async with async_timeout.timeout(20): + async with session.get(url) as response: + return await response.text() + +async def get_unparsed_feed(url: str): + async with aiohttp.ClientSession() as session: + return await fetch(session, url) + async def get_rss_data(url: str): async with aiohttp.ClientSession() as session: async with session.get(url) as response: @@ -30,25 +40,94 @@ async def audit(cog, *args, **kwargs): await cog.bot.audit(*args, **kwargs) -async def followup_error(inter: Interaction, title: str, message: str, *args, **kwargs): - """Shorthand for following up on an interaction, except returns an embed styled in - error colours. - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - """ +# https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png - await inter.followup.send( - *args, - embed=Embed( + +class FollowupIcons: + error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png" + success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png" + trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png" + info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png" + added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png" + assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png" + + +class Followup: + """Wrapper for a discord embed to follow up an interaction.""" + + def __init__( + self, + title: str = None, + description: str = None, + ): + self._embed = Embed( title=title, - description=message, - colour=Colour.red() - ), - **kwargs - ) + description=description + ) + + async def send(self, inter: Interaction, message: str = None): + """""" + + await inter.followup.send(content=message, embed=self._embed) + + def fields(self, inline: bool = False, **fields: dict): + """""" + + for key, value in fields.items(): + self._embed.add_field(name=key, value=value, inline=inline) + + return self + + def image(self, url: str): + """""" + + self._embed.set_image(url=url) + + return self + + def error(self): + """""" + + self._embed.colour = Colour.red() + self._embed.set_thumbnail(url=FollowupIcons.error) + return self + + def success(self): + """""" + + self._embed.colour = Colour.green() + self._embed.set_thumbnail(url=FollowupIcons.success) + return self + + def info(self): + """""" + + self._embed.colour = Colour.blue() + self._embed.set_thumbnail(url=FollowupIcons.info) + return self + + def added(self): + """""" + + self._embed.colour = Colour.blue() + self._embed.set_thumbnail(url=FollowupIcons.added) + return self + + def assign(self): + """""" + + self._embed.colour = Colour.blue() + self._embed.set_thumbnail(url=FollowupIcons.assigned) + return self + + def trash(self): + """""" + + self._embed.colour = Colour.red() + self._embed.set_thumbnail(url=FollowupIcons.trash) + return self + def extract_error_info(error: Exception) -> str: class_name = error.__class__.__name__