diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..e298949 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,17 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: NewsBot", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/src/main.py", + "python": "${workspaceFolder}/venv/bin/python", + "console": "integratedTerminal", + "justMyCode": true + } + ] +} diff --git a/src/bot.py b/src/bot.py index e509f59..69e3a1c 100644 --- a/src/bot.py +++ b/src/bot.py @@ -46,9 +46,16 @@ class DiscordBot(commands.Bot): if path.suffix == ".py": await self.load_extension(f"extensions.{path.stem}") - async def audit(self, message: str, user_id: int): + async def audit(self, message: str, user_id: int, database: DatabaseManager=None): + + message = f"Requesting latest article" + query = insert(AuditModel).values(discord_user_id=user_id, message=message) + + if database: + await database.session.execute(query) + return async with DatabaseManager() as database: - message = f"Requesting latest article" - query = insert(AuditModel).values(discord_user_id=user_id, message=message) - await database.session.execute(query) \ No newline at end of file + await database.session.execute(query) + + log.debug("Audit logged") diff --git a/src/db/db.py b/src/db/db.py index 1c61cfd..81c5684 100644 --- a/src/db/db.py +++ b/src/db/db.py @@ -24,7 +24,7 @@ class DatabaseManager: """ def __init__(self, no_commit: bool = False): - database_url = self.get_database_url() + database_url = self.get_database_url() # TODO: This is called every time a connection is established, maybe make it once and reference it? self.engine = create_async_engine(database_url, future=True) self.session_maker = sessionmaker(self.engine, class_=AsyncSession) self.session = None diff --git a/src/db/models.py b/src/db/models.py index 8b5acbf..d4477a0 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -5,7 +5,7 @@ All table classes should be suffixed with `Model`. from enum import Enum, auto -from sqlalchemy import Column, Integer, String, DateTime, BigInteger +from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint, ForeignKey from sqlalchemy.sql import func from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import declarative_base @@ -24,7 +24,6 @@ class AuditModel(Base): discord_user_id = Column(BigInteger, nullable=False) message = Column(String, nullable=False) created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - active = Column(Integer, default=True, nullable=False) class SentArticleModel(Base): @@ -39,7 +38,7 @@ class SentArticleModel(Base): discord_channel_id = Column(BigInteger, nullable=False) discord_server_id = Column(BigInteger, nullable=False) article_url = Column(String, nullable=False) - active = Column(Integer, default=True, nullable=False) + when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) class RssSourceModel(Base): @@ -50,9 +49,17 @@ class RssSourceModel(Base): __tablename__ = "rss_source" id = Column(Integer, primary_key=True, autoincrement=True) + nick = Column(String, nullable=False) discord_server_id = Column(BigInteger, nullable=False) rss_url = Column(String, nullable=False) - active = Column(Integer, default=True, nullable=False) + created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + + feed_channels = relationship("FeedChannelModel", cascade="all, delete") + + # the nickname must be unique, but only within the same discord server + __table_args__ = ( + UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'), + ) class FeedChannelModel(Base): @@ -64,4 +71,6 @@ class FeedChannelModel(Base): id = Column(Integer, primary_key=True, autoincrement=True) discord_channel_id = Column(BigInteger, nullable=False) - active = Column(Integer, default=True, nullable=False) + discord_server_id = Column(BigInteger, nullable=False) + search_name = Column(String, nullable=False) + rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False) diff --git a/src/extensions/channels.py b/src/extensions/channels.py new file mode 100644 index 0000000..5f0876b --- /dev/null +++ b/src/extensions/channels.py @@ -0,0 +1,106 @@ +""" +Extension for the `ChannelCog`. +Loading this file via `commands.Bot.load_extension` will add `ChannelCog` to the bot. +""" + +import logging + +from sqlalchemy import select, and_ +from discord import Interaction, TextChannel +from discord.ext import commands +from discord.app_commands import Group, Choice, autocomplete + +from db import DatabaseManager, FeedChannelModel +from utils import followup + +log = logging.getLogger(__name__) + + +class ChannelCog(commands.Cog): + """ + Command cog. + """ + + def __init__(self, bot): + super().__init__() + self.bot = bot + + @commands.Cog.listener() + async def on_ready(self): + log.info(f"{self.__class__.__name__} cog is ready") + + async def autocomplete_existing_feeds(self, inter: Interaction, current: str): + """Returns a list of existing RSS + Channel feeds. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + current : str + The current text entered for the autocomplete. + """ + + async with DatabaseManager() as database: + whereclause = and_( + FeedChannelModel.discord_server_id == inter.guild_id, + FeedChannelModel.search_name.ilike(f"%{current}%") # is this secure from SQL Injection atk ? + ) + query = select(FeedChannelModel).where(whereclause) + result = await database.session.execute(query) + feeds = [ + Choice(name=feed.search_name, value=feed.id) + for feed in result.scalars().all() + ] + + return feeds + + # All RSS commands belong to this group. + channel_group = Group( + name="channel", + description="Commands for channel assignment.", + guild_only=True # We store guild IDs in the database, so guild only = True + ) + + channel_group.command(name="include-feed") + async def include_feed(self, inter: Interaction, rss: str, channel: TextChannel): + """Include a feed within the specified channel. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + rss : str + The RSS feed to include. + channel : TextChannel + The channel to include the feed in. + """ + + await inter.response.defer() + await followup(inter, "Ping") + + channel_group.command(name="exclude-feed") + @autocomplete(option=autocomplete_existing_feeds) + async def exclude_feed(self, inter: Interaction, option: int): + """Undo command for the `/channel include-feed` command. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + option : str + The RSS feed and channel to exclude. + """ + + await inter.response.defer() + await followup(inter, "Pong") + + +async def setup(bot): + """ + Setup function for this extension. + Adds `ChannelCog` to the bot. + """ + + cog = ChannelCog(bot) + await bot.add_cog(cog) + log.info(f"Added {cog.__class__.__name__} cog") diff --git a/src/extensions/cmd.py b/src/extensions/cmd.py deleted file mode 100644 index 7f54a3c..0000000 --- a/src/extensions/cmd.py +++ /dev/null @@ -1,184 +0,0 @@ -""" -Extension for the `CommandCog`. -Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot. -""" - -import logging -import validators - -import aiohttp -import textwrap -import feedparser -from markdownify import markdownify -from discord import app_commands, Interaction, Embed -from discord.ext import commands, tasks -from sqlalchemy import insert, select, update, and_, or_ - -from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel -from feed import Feeds, get_source - -log = logging.getLogger(__name__) - -async def get_rss_data(url: str): - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - items = await response.text(), response.status - - return items - - -class CommandCog(commands.Cog): - """ - Command cog. - """ - - def __init__(self, bot): - super().__init__() - self.bot = bot - - @commands.Cog.listener() - async def on_ready(self): - log.info(f"{self.__class__.__name__} cog is ready") - - rss_group = app_commands.Group( - name="rss", - description="Commands for rss sources.", - guild_only=True - ) - - @rss_group.command(name="add") - async def add_rss_source(self, inter: Interaction, url: str): - - await inter.response.defer() - - # validate the input - if not validators.url(url): - await inter.followup.send( - "The URL you have entered is malformed or invalid:\n" - f"`{url=}`", - suppress_embeds=True - ) - return - - feed_data, status_code = await get_rss_data(url) - if status_code != 200: - await inter.followup.send( - f"The URL provided returned an invalid status code:\n" - f"{url=}, {status_code=}", - suppress_embeds=True - ) - return - - feed = feedparser.parse(feed_data) - if not feed.version: - await inter.followup.send( - f"The provided URL '{url}' does not seem to be a valid RSS feed.", - suppress_embeds=True - ) - return - - async with DatabaseManager() as database: - query = insert(RssSourceModel).values( - discord_server_id = inter.guild_id, - rss_url = url - ) - await database.session.execute(query) - - await inter.followup.send("RSS source added") - - @rss_group.command(name="remove") - async def remove_rss_source(self, inter: Interaction, number: int | None=None, url: str | None = None): - - await inter.response.defer() - - def exists(item) -> bool: - """ - Shorthand for `is not None`. Cant just use `if not number` because 0 int will pass. - Ironically with this func & comment the code is longer, but at least I can read it ... - """ - - return item is not None - - url_exists = exists(url) - num_exists = exists(number) - - if (url_exists and num_exists) or (not url_exists and not num_exists): - await inter.followup.send( - "Please only specify either the existing rss number or url, " - "enter at least one of these, but don't enter both." - ) - return - - if url_exists and not validators.url(url): - await inter.followup.send( - "The URL you have entered is malformed or invalid:\n" - f"`{url=}`", - suppress_embeds=True - ) - return - - async with DatabaseManager() as database: - whereclause = and_( - RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.rss_url == url - ) - query = update(RssSourceModel).where(whereclause).values(active=False) - result = await database.session.execute(query) - - await inter.followup.send(f"I've updated {result.rowcount} rows") - - @rss_group.command(name="list") - @app_commands.choices(filter=[ - app_commands.Choice(name="Active Only [default]", value=1), - app_commands.Choice(name="Inactive Only", value=0), - app_commands.Choice(name="All", value=2), - ]) - async def list_rss_sources(self, inter: Interaction, filter: app_commands.Choice[int]): - - await inter.response.defer() - - if filter.value == 2: - whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) - else: - whereclause = and_( - RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.active == filter.value # should result to 0 or 1 - ) - - async with DatabaseManager() as database: - query = select(RssSourceModel).where(whereclause) - result = await database.session.execute(query) - - rss_sources = result.scalars().all() - embed_fields = [{ - "name": f"[{i}]", - "value": f"{rss.rss_url} | {'inactive' if not rss.active else 'active'}" - } for i, rss in enumerate(rss_sources)] - - if not embed_fields: - await inter.followup.send("It looks like you have no rss sources.") - return - - embed = Embed( - title="RSS Sources", - description="Here are your rss sources:" - ) - - for field in embed_fields: - embed.add_field(**field, inline=False) - - # output = "Your rss sources:\n\n" - # output += "\n".join([f"[{i+1}] {rss.rss_url=} {bool(rss.active)=}" for i, rss in enumerate(rss_sources)]) - - await inter.followup.send(embed=embed) - - -async def setup(bot): - """ - Setup function for this extension. - Adds `CommandCog` to the bot. - """ - - cog = CommandCog(bot) - await bot.add_cog(cog) - log.info(f"Added {cog.__class__.__name__} cog") diff --git a/src/extensions/rss.py b/src/extensions/rss.py new file mode 100644 index 0000000..2ab7bb7 --- /dev/null +++ b/src/extensions/rss.py @@ -0,0 +1,349 @@ +""" +Extension for the `RssCog`. +Loading this file via `commands.Bot.load_extension` will add `RssCog` to the bot. +""" + +import logging +import validators +from typing import Tuple + +import textwrap +import feedparser +from markdownify import markdownify +from discord import Interaction, Embed, Colour +from discord.ext import commands +from discord.app_commands import Choice, Group, autocomplete, choices +from sqlalchemy import insert, select, and_, delete + +from utils import get_rss_data, followup, audit +from feed import get_source, Source +from db import DatabaseManager, SentArticleModel, RssSourceModel + +log = logging.getLogger(__name__) + +rss_list_sort_choices = [ + Choice(name="Nickname", value=0), + Choice(name="Date Added", value=1) +] + +# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. +# Run a period task to check this. +async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]: + """Validate a provided RSS source. + + Parameters + ---------- + nickname : str + Nickname of the source. Must not contain URL. + url : str + URL of the source. Must be URL with valid status code and be an RSS feed. + + Returns + ------- + str or None + String invalid message if invalid, NoneType if valid. + FeedParserDict or None + The feed parsed from the given URL or None if invalid. + """ + + # Ensure the URL is valid + if not validators.url(url): + return f"The URL you have entered is malformed or invalid:\n`{url=}`", None + + # Check the nickname is not a URL + if validators.url(nickname): + return "It looks like the nickname you have entered is a URL.\n" \ + f"For security reasons, this is not allowed.\n`{nickname=}`", None + + + feed_data, status_code = await get_rss_data(url) + + # Check the URL status code is valid + if status_code != 200: + return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None + + # Check the contents is actually an RSS feed. + feed = feedparser.parse(feed_data) + if not feed.version: + return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None + + return None, feed + + +class RssCog(commands.Cog): + """ + Command cog. + """ + + def __init__(self, bot): + super().__init__() + self.bot = bot + + @commands.Cog.listener() + async def on_ready(self): + log.info(f"{self.__class__.__name__} cog is ready") + + async def source_autocomplete(self, inter: Interaction, nickname: str): + """Provides RSS source autocomplete functionality for commands. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + nickname : str + _description_ + + Returns + ------- + list of app_commands.Choice + _description_ + """ + + async with DatabaseManager() as database: + whereclause = and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.nick.ilike(f"%{nickname}%") + ) + query = select(RssSourceModel).where(whereclause) + result = await database.session.execute(query) + sources = [ + Choice(name=rss.nick, value=rss.rss_url) + for rss in result.scalars().all() + ] + + return sources + + # All RSS commands belong to this group. + rss_group = Group( + name="rss", + description="Commands for rss sources.", + guild_only=True # We store guild IDs in the database, so guild only = True + ) + + @rss_group.command(name="add") + async def add_rss_source(self, inter: Interaction, nickname: str, url: str): + """Add a new RSS source. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + nickname : str + A name used to identify the RSS source. + url : str + The RSS feed URL. + """ + + await inter.response.defer() + + illegal_message, feed = await validate_rss_source(nickname, url) + if illegal_message: + await followup(inter, illegal_message, suppress_embeds=True) + return + + log.debug("RSS feed added") + + async with DatabaseManager() as database: + query = insert(RssSourceModel).values( + discord_server_id = inter.guild_id, + rss_url = url, + nick=nickname + ) + await database.session.execute(query) + + await audit(self, + f"Added RSS source ({nickname=}, {url=})", + inter.user.id, database=database + ) + + embed = Embed(title="RSS Feed Added", colour=Colour.dark_green()) + embed.add_field(name="Nickname", value=nickname) + embed.add_field(name="URL", value=url) + embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href")) + + await followup(inter, embed=embed) + + @rss_group.command(name="remove") + @autocomplete(url=source_autocomplete) + async def remove_rss_source(self, inter: Interaction, url: str): + """Delete an existing RSS source. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + url : str + The RSS source to be removed. Autocomplete or enter the URL. + """ + + await inter.response.defer() + + log.debug(f"Attempting to remove RSS source ({url=})") + + async with DatabaseManager() as database: + select_result = await database.session.execute( + select(RssSourceModel).filter( + and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.rss_url == url + ) + ) + ) + rss_source = select_result.scalars().one() + nickname = rss_source.nick + + delete_result = await database.session.execute( + delete(RssSourceModel).filter( + and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.rss_url == url + ) + ) + ) + + await audit(self, + f"Added RSS source ({nickname=}, {url=})", + inter.user.id, database=database + ) + + if not delete_result.rowcount: + await followup(inter, "Couldn't find any RSS sources with this name.") + return + + source = get_source(url) + + embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red()) + embed.add_field(name="Nickname", value=nickname) + embed.add_field(name="URL", value=url) + embed.set_thumbnail(url=source.icon_url) + + await followup(inter, embed=embed) + + @rss_group.command(name="list") + @choices(sort=rss_list_sort_choices) + async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False): + """Provides a with a list of RSS sources available for the current server. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + """ + + await inter.response.defer() + + # Default to the first choice if not specified. + if type(sort) is Choice: + description = "Sort by " + description += "Nickname " if sort.value == 0 else "Date Added " + description += '\U000025BC' if sort_reverse else '\U000025B2' + else: + sort = rss_list_sort_choices[0] + description = "" + + sort = sort if type(sort) == Choice else rss_list_sort_choices[0] + + match sort.value, sort_reverse: + case 0, False: + order_by = RssSourceModel.nick.asc() + case 0, True: + order_by = RssSourceModel.nick.desc() + case 1, False: # NOTE: + order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest + case 1, True: # date first, not the oldest as it would sort otherwise. + order_by = RssSourceModel.created.asc() + case _, _: + raise ValueError("Unknown sort: %s" % sort) + + async with DatabaseManager() as database: + whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) + query = select(RssSourceModel).where(whereclause).order_by(order_by) + result = await database.session.execute(query) + + rss_sources = result.scalars().all() + + if not rss_sources: + await followup(inter, "It looks like you have no rss sources.") + return + + output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)]) + + embed = Embed( + title="Saved RSS Feeds", + description=f"{description}\n\n{output}", + colour=Colour.lighter_grey() + ) + + await followup(inter, embed=embed) + + @rss_group.command(name="fetch") + @autocomplete(rss=source_autocomplete) + async def fetch_rss(self, inter: Interaction, rss: str, max: int=1): + # """""" + + await inter.response.defer() + + if max > 5: + followup(inter, "It looks like you have requested too many articles.\nThe limit is 5") + return + + invalid_message, feed = await validate_rss_source("", rss) + if invalid_message: + await followup(inter, invalid_message) + return + + source = Source.from_parsed(feed) + articles = source.get_latest_articles(max) + + if not articles: + await followup(inter, "Sorry, I couldn't find any articles from this feed.") + return + + embeds = [] + for article in articles: + md_description = markdownify(article.description, strip=("img",)) + article_description = textwrap.shorten(md_description, 4096) + + embed = Embed( + title=article.title, + description=article_description, + url=article.url, + timestamp=article.published, + colour=Colour.brand_red() + ) + thumbail_url = await article.get_thumbnail_url() + thumbail_url = thumbail_url if validators.url(thumbail_url) else None + embed.set_thumbnail(url=source.icon_url) + embed.set_image(url=thumbail_url) + embed.set_footer(text=article.author) + embed.set_author( + name=source.name, + url=source.url, + ) + embeds.append(embed) + + async with DatabaseManager() as database: + query = insert(SentArticleModel).values([ + { + "discord_server_id": inter.guild_id, + "discord_channel_id": inter.channel_id, + "discord_message_id": inter.id, + "article_url": article.url, + } + for article in articles + ]) + await database.session.execute(query) + await audit(self, f"User is requesting {max} articles from {source.name}", inter.user.id, database=database) + + await followup(inter, embeds=embeds) + + +async def setup(bot): + """ + Setup function for this extension. + Adds `RssCog` to the bot. + """ + + cog = RssCog(bot) + await bot.add_cog(cog) + log.info(f"Added {cog.__class__.__name__} cog") diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py new file mode 100644 index 0000000..9c3207b --- /dev/null +++ b/src/extensions/tasks.py @@ -0,0 +1,35 @@ +""" +Extension for the `TaskCog`. +Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot. +""" + +import logging + +from discord.ext import commands + +log = logging.getLogger(__name__) + + +class TaskCog(commands.Cog): + """ + Command cog. + """ + + def __init__(self, bot): + super().__init__() + self.bot = bot + + @commands.Cog.listener() + async def on_ready(self): + log.info(f"{self.__class__.__name__} cog is ready") + + +async def setup(bot): + """ + Setup function for this extension. + Adds `TaskCog` to the bot. + """ + + cog = TaskCog(bot) + await bot.add_cog(cog) + log.info(f"Added {cog.__class__.__name__} cog") diff --git a/src/extensions/test.py b/src/extensions/test.py deleted file mode 100644 index 18c2722..0000000 --- a/src/extensions/test.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Extension for the `test` cog. -Loading this file via `commands.Bot.load_extension` will add the `test` cog to the bot. -""" - -import logging - -import textwrap -from markdownify import markdownify -from discord import app_commands, Interaction, Embed -from discord.ext import commands, tasks -from sqlalchemy import insert, select - -from db import DatabaseManager, AuditModel, SentArticleModel -from feed import Feeds, get_source - -log = logging.getLogger(__name__) - - -class Test(commands.Cog): - """ - News cog. - Delivers embeds of news articles to discord channels. - """ - - def __init__(self, bot): - super().__init__() - self.bot = bot - - @commands.Cog.listener() - async def on_ready(self): - log.info(f"{self.__class__.__name__} cog is ready") - - @app_commands.command(name="test-latest-article") - # @app_commands.choices(source=[ - # app_commands.Choice(name="The Babylon Bee", value=Feeds.THE_BABYLON_BEE), - # app_commands.Choice(name="The Upper Lip", value=Feeds.THE_UPPER_LIP), - # app_commands.Choice(name="BBC News", value=Feeds.BBC_NEWS), - # ]) - async def test_bee(self, inter: Interaction, source: Feeds): - - await inter.response.defer() - await self.bot.audit("Requesting latest article.", inter.user.id) - - source = get_source(source) - article = source.get_latest_article() - - md_description = markdownify(article.description, strip=("img",)) - article_description = textwrap.shorten(md_description, 4096) - - embed = Embed( - title=article.title, - description=article_description, - url=article.url, - timestamp=article.published, - ) - embed.set_thumbnail(url=source.icon_url) - embed.set_image(url=await article.get_thumbnail_url()) - embed.set_footer(text=article.author) - embed.set_author( - name=source.name, - url=source.url, - ) - - async with DatabaseManager() as database: - query = insert(SentArticleModel).values( - discord_server_id=inter.guild_id, - discord_channel_id=inter.channel_id, - discord_message_id=inter.id, - article_url=article.url - ) - await database.session.execute(query) - - await inter.followup.send(embed=embed) - - - - - -async def setup(bot): - """ - Setup function for this extension. - Adds the `ErrorCog` cog to the bot. - """ - - cog = Test(bot) - await bot.add_cog(cog) - log.info(f"Added {cog.__class__.__name__} cog") diff --git a/src/feed.py b/src/feed.py index 4183187..c4f12db 100644 --- a/src/feed.py +++ b/src/feed.py @@ -1,7 +1,9 @@ +""" + +""" import json import logging -from enum import Enum from dataclasses import dataclass from datetime import datetime @@ -10,85 +12,136 @@ from bs4 import BeautifulSoup as bs4 from feedparser import FeedParserDict, parse log = logging.getLogger(__name__) - - -class Feeds(Enum): - THE_UPPER_LIP = "https://theupperlip.co.uk/rss" - THE_BABYLON_BEE= "https://babylonbee.com/feed" - BBC_NEWS = "https://feeds.bbci.co.uk/news/rss.xml" - - -@dataclass -class Source: - - name: str - url: str - icon_url: str - feed: FeedParserDict - - @classmethod - def from_parsed(cls, feed:FeedParserDict): - - # print(json.dumps(feed, indent=8)) - return cls( - name=feed.channel.title, - url=feed.channel.link, - icon_url=feed.feed.image.href, - feed=feed - ) - - def get_latest_article(self): - return Article.from_parsed(self.feed) +dumps = lambda _dict: json.dumps(_dict, indent=8) @dataclass class Article: - - title: str - description: str - url: str - published: datetime + """Represents a news article, or entry from an RSS feed.""" + + title: str | None + description: str | None + url: str | None + published: datetime | None author: str | None @classmethod - def from_parsed(cls, feed:FeedParserDict): - entry = feed.entries[0] - # log.debug(json.dumps(entry, indent=8)) + def from_entry(cls, entry:FeedParserDict): + """Create an Article from an RSS feed entry. + + Parameters + ---------- + entry : FeedParserDict + An entry pulled from a complete FeedParserDict object. + + Returns + ------- + Article + The Article created from the feed entry. + """ + + log.debug("Creating Article from entry: %s", dumps(entry)) + + published_parsed = entry.get("published_parsed") + published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None return cls( - title=entry.title, - description=entry.description, - url=entry.link, - published=datetime(*entry.published_parsed[0:-2]), - author = entry.get("author", None) + title=entry.get("title"), + description=entry.get("description"), + url=entry.get("link"), + published=published, + author = entry.get("author") ) - async def get_thumbnail_url(self): + async def get_thumbnail_url(self) -> str | None: + """Returns the thumbnail URL for an article. + + Returns + ------- + str or None + The thumbnail URL, or None if not found. """ - """ + log.debug("Fetching thumbnail for article: %s", self) async with aiohttp.ClientSession() as session: async with session.get(self.url) as response: html = await response.text() - # Parse the thumbnail for the news story soup = bs4(html, "html.parser") image_element = soup.select_one("meta[property='og:image']") return image_element.get("content") if image_element else None -def get_source(feed: Feeds) -> Source: +@dataclass +class Source: + """Represents an RSS source.""" + + name: str | None + url: str | None + icon_url: str | None + feed: FeedParserDict + + @classmethod + def from_parsed(cls, feed:FeedParserDict): + """Returns a Source object from a parsed feed. + + Parameters + ---------- + feed : FeedParserDict + The feed used to create the Source. + + Returns + ------- + Source + The Source object + """ + + log.debug("Creating Source from feed: %s", dumps(feed)) + + return cls( + name=feed.get("channel", {}).get("title"), + url=feed.get("channel", {}).get("link"), + icon_url=feed.get("feed", {}).get("image", {}).get("href"), + feed=feed + ) + + def get_latest_articles(self, max: int) -> list[Article]: + """Returns a list of Article objects. + + Parameters + ---------- + max : int + The maximum number of articles to return. + + Returns + ------- + list of Article + A list of Article objects. + """ + + log.debug("Fetching latest articles from %s, max=%s", self, max) + + return [ + Article.from_entry(entry) + for i, entry in enumerate(self.feed.entries) + if i < max + ] + + +def get_source(rss_url: str) -> Source: + """_summary_ + + Parameters + ---------- + rss_url : str + _description_ + + Returns + ------- + Source + _description_ """ - """ - - parsed_feed = parse("https://gitea.corbz.dev/corbz/BBC-News-Bot/rss/branch/main/src/extensions/news.py") + parsed_feed = parse(rss_url) # TODO: make asyncronous return Source.from_parsed(parsed_feed) - - -def get_test(): - - parsed = parse(Feeds.THE_UPPER_LIP.value) - print(json.dumps(parsed, indent=4)) - return parsed diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..d0c14bb --- /dev/null +++ b/src/utils.py @@ -0,0 +1,31 @@ +"""A collection of utility functions that can be used in various places.""" + +import aiohttp +import logging + +from discord import Interaction + +log = logging.getLogger(__name__) + +async def get_rss_data(url: str): + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + items = await response.text(), response.status + + return items + +async def followup(inter: Interaction, *args, **kwargs): + """Shorthand for following up on an interaction. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + """ + + await inter.followup.send(*args, **kwargs) + +async def audit(cog, *args, **kwargs): + """Shorthand for auditing an interaction.""" + + await cog.bot.audit(*args, **kwargs)