From 94b7972a38f4863dc494399261055ffe9a0fd791 Mon Sep 17 00:00:00 2001 From: Corban-Lee Date: Sun, 24 Dec 2023 16:40:04 +0000 Subject: [PATCH] incomplete commit, used to work on other machine --- src/bot.py | 2 + src/errors.py | 4 ++ src/extensions/rss.py | 35 +++++----------- src/extensions/tasks.py | 4 +- src/feed.py | 89 +++++++++++++++++++++++++++++++---------- src/main.py | 3 +- src/utils.py | 5 +++ 7 files changed, 92 insertions(+), 50 deletions(-) create mode 100644 src/errors.py diff --git a/src/bot.py b/src/bot.py index e78422d..079eda0 100644 --- a/src/bot.py +++ b/src/bot.py @@ -9,6 +9,7 @@ from discord import Intents from discord.ext import commands from sqlalchemy import insert +from feed import Functions from db import DatabaseManager, AuditModel log = logging.getLogger(__name__) @@ -18,6 +19,7 @@ class DiscordBot(commands.Bot): def __init__(self, BASE_DIR: Path): super().__init__(command_prefix="-", intents=Intents.all()) + self.functions = Functions(self) self.BASE_DIR = BASE_DIR async def sync_app_commands(self): diff --git a/src/errors.py b/src/errors.py new file mode 100644 index 0000000..3f528af --- /dev/null +++ b/src/errors.py @@ -0,0 +1,4 @@ + +class IllegalFeed(Exception): + pass + diff --git a/src/extensions/rss.py b/src/extensions/rss.py index 60c4b72..8d91f13 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -14,8 +14,8 @@ from discord.app_commands import Choice, Group, autocomplete, choices, rename from sqlalchemy import insert, select, and_, delete from sqlalchemy.exc import NoResultFound -from utils import get_rss_data, followup, audit, followup_error # pylint: disable=E0401 -from feed import get_source, get_unparsed_feed, Source # pylint: disable=E0401 +from utils import get_rss_data, followup, audit, followup_error, extract_error_info # pylint: disable=E0401 +from feed import Source # pylint: disable=E0401 from db import ( # pylint: disable=E0401 DatabaseManager, SentArticleModel, @@ -23,6 +23,7 @@ from db import ( # pylint: disable=E0401 FeedChannelModel, AuditModel ) +from errors import IllegalFeed log = logging.getLogger(__name__) @@ -80,7 +81,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed return None, feed async def set_all_articles_as_sent(inter, channel: TextChannel, feed_id: int, rss_url: str): - unparsed_feed = await get_unparsed_feed(rss_url) + unparsed_feed = await self.bot.functions.get_unparsed_feed(rss_url) source = Source.from_parsed(parse(unparsed_feed)) articles = source.get_latest_articles() @@ -167,30 +168,16 @@ class FeedCog(commands.Cog): await inter.response.defer() - illegal_message, feed = await validate_rss_source(nickname, url) - if illegal_message: - await followup(inter, illegal_message, suppress_embeds=True) - return - - log.debug("RSS feed added") - - async with DatabaseManager() as database: - query = insert(RssSourceModel).values( - discord_server_id = inter.guild_id, - rss_url = url, - nick=nickname - ) - await database.session.execute(query) - - await audit(self, - f"Added RSS source ({nickname=}, {url=})", - inter.user.id, database=database - ) + try: + source = self.bot.functions.create_new_feed(nickname, url) + except IllegalFeed as error: + title, desc = extract_error_info(error) + followup_error(inter, title=title, description=desc) embed = Embed(title="RSS Feed Added", colour=Colour.dark_green()) embed.add_field(name="Nickname", value=nickname) embed.add_field(name="URL", value=url) - embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href")) + embed.set_thumbnail(url=source.thumb_url) await followup(inter, embed=embed) @@ -241,7 +228,7 @@ class FeedCog(commands.Cog): inter.user.id, database=database ) - source = get_source(url) # TODO: replace with async function + source = await Source.from_url(url) embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red()) embed.add_field(name="Nickname", value=nickname) diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py index bb77533..5e25966 100644 --- a/src/extensions/tasks.py +++ b/src/extensions/tasks.py @@ -12,7 +12,7 @@ from discord import Interaction, TextChannel from discord.ext import commands, tasks from discord.errors import Forbidden -from feed import Source, Article, get_unparsed_feed # pylint disable=E0401 +from feed import Source, Article # pylint disable=E0401 from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel # pylint disable=E0401 log = logging.getLogger(__name__) @@ -68,7 +68,7 @@ class TaskCog(commands.Cog): channel = self.bot.get_channel(feed.discord_channel_id) - unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url) + unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url) parsed_feed = parse(unparsed_content) source = Source.from_parsed(parsed_feed) articles = source.get_latest_articles(5) diff --git a/src/feed.py b/src/feed.py index dacb0de..1bdcc37 100644 --- a/src/feed.py +++ b/src/feed.py @@ -4,6 +4,7 @@ import logging import async_timeout from dataclasses import dataclass from datetime import datetime +from typing import Tuple import aiohttp @@ -14,6 +15,9 @@ from discord import Embed, Colour from bs4 import BeautifulSoup as bs4 from feedparser import FeedParserDict, parse +from utils import audit +from errors import IllegalFeed + log = logging.getLogger(__name__) dumps = lambda _dict: json.dumps(_dict, indent=8) @@ -156,6 +160,11 @@ class Source: feed=feed ) + @classmethod + async def from_url(cls, url: str): + unparsed_content = await Functions.get_unparsed_feed(url) + return + def get_latest_articles(self, max: int = 999) -> list[Article]: """Returns a list of Article objects. @@ -177,30 +186,66 @@ class Source: for i, entry in enumerate(self.feed.entries) if i < max ] - -async def fetch(session, url: str) -> str: - async with async_timeout.timeout(20): - async with session.get(url) as response: - return await response.text() - -async def get_unparsed_feed(url: str): - async with aiohttp.ClientSession() as session: - return await fetch(session, url) -def get_source(rss_url: str) -> Source: - """_summary_ +class Functions: - Parameters - ---------- - rss_url : str - _description_ + def __init__(self, bot): + self.bot = bot - Returns - ------- - Source - _description_ - """ + @staticmethod + async def fetch(session, url: str) -> str: + async with async_timeout.timeout(20): + async with session.get(url) as response: + return await response.text() - parsed_feed = parse(rss_url) # TODO: make asyncronous - return Source.from_parsed(parsed_feed) + @staticmethod + async def get_unparsed_feed(url: str): + async with aiohttp.ClientSession() as session: + return await self.fetch(session, url) # TODO: work from here + + async def validate_feed(self, nickname: str, url: str) -> FeedParserDict: + """""" + + # Ensure the URL is valid + if not validators.url(url): + raise IllegalFeed(f"The URL you have entered is malformed or invalid:\n`{url=}`") + + # Check the nickname is not a URL + if validators.url(nickname): + raise IllegalFeed( + "It looks like the nickname you have entered is a URL.\n" \ + f"For security reasons, this is not allowed.\n`{nickname=}`" + ) + + feed_data, status_code = await get_rss_data(url) + + if status_code != 200: + raise IllegalFeed( + f"The URL provided returned an invalid status code:\n{url=}, {status_code=}" + ) + + # Check the contents is actually an RSS feed. + feed = parse(feed_data) + if not feed.version: + raise IllegalFeed( + f"The provided URL '{url}' does not seem to be a valid RSS feed." + ) + + return feed + + + async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source: + """""" + + parsed_feed = await self.validate_feed(nickname, url) + + async with DatabaseManager() as database: + query = insert(RssSourceModel).values( + discord_server_id=guild_id, + rss_url=url, + nick=nickname + ) + await database.session.execute(query) + + return Source.from_parsed(parsed_feed) diff --git a/src/main.py b/src/main.py index 1194781..7117f6a 100644 --- a/src/main.py +++ b/src/main.py @@ -44,5 +44,4 @@ async def main(): await bot.start(token, reconnect=True) if __name__ == "__main__": - - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/src/utils.py b/src/utils.py index bdeaf26..8bc0ed9 100644 --- a/src/utils.py +++ b/src/utils.py @@ -49,3 +49,8 @@ async def followup_error(inter: Interaction, title: str, message: str, *args, ** ), **kwargs ) + +def extract_error_info(error: Exception) -> str: + class_name = error.__class__.__name__ + desc = str(error) + return class_name, desc