From 8d02605019464b9cb099fd3e84a16a644772dd59 Mon Sep 17 00:00:00 2001 From: Corban-Lee Jones Date: Fri, 26 Jan 2024 17:52:33 +0000 Subject: [PATCH 01/12] Update README.md --- README.md | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index bdd0618..07e5819 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,7 @@ -# NewsBot +# PYRSS -Bot delivering news articles to discord servers. +An RSS driven Discord bot written in Python. -Plans +Provides user commands for storing RSS feed URLs that can be assigned to any given discord channel. -- Multiple news providers -- Choose how much of each provider should be delivered -- Check for duplicate articles between providers, and only deliver preferred provider article - - -## Dev Notes: - -For the sake of development, the following defintions apply: - -- Feed - An RSS feed stored within the database, submitted by a user. -- Assigned Feed - A discord channel set to receive content from a Feed. \ No newline at end of file +Content is shared every 10 minutes as an Embed. \ No newline at end of file From d8e6e5ba0635e3878da3f3abfd58f23c11fdb645 Mon Sep 17 00:00:00 2001 From: corbz Date: Tue, 30 Jan 2024 13:51:13 +0000 Subject: [PATCH 02/12] development activity status indicates when the bot is in development mode --- src/bot.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/bot.py b/src/bot.py index 079eda0..1e9c8d7 100644 --- a/src/bot.py +++ b/src/bot.py @@ -5,7 +5,7 @@ The discord bot for the application. import logging from pathlib import Path -from discord import Intents +from discord import Intents, Game from discord.ext import commands from sqlalchemy import insert @@ -18,7 +18,8 @@ log = logging.getLogger(__name__) class DiscordBot(commands.Bot): def __init__(self, BASE_DIR: Path): - super().__init__(command_prefix="-", intents=Intents.all()) + activity = Game("Indev") + super().__init__(command_prefix="-", intents=Intents.all(), activity=activity) self.functions = Functions(self) self.BASE_DIR = BASE_DIR From a3850a26468f9a932558103cd3bfad919a0e0a0d Mon Sep 17 00:00:00 2001 From: corbz Date: Tue, 30 Jan 2024 13:51:32 +0000 Subject: [PATCH 03/12] API Interactions --- src/api.py | 109 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 src/api.py diff --git a/src/api.py b/src/api.py new file mode 100644 index 0000000..7211bda --- /dev/null +++ b/src/api.py @@ -0,0 +1,109 @@ + +import logging + +import aiohttp + +log = logging.getLogger(__name__) + + +class APIException(Exception): + pass + + +class NotCreatedException(APIException): + pass + +class BadStatusException(APIException): + pass + + +class API: + """Interactions with the API.""" + + API_HOST = "http://localhost:8000/" + API_ENDPOINT = API_HOST + "api/" + + RSS_FEED_ENDPOINT = API_ENDPOINT + "rssfeed/" + FEED_CHANNEL_ENDPOINT = API_ENDPOINT + "feedchannel/" + + def __init__(self, session: aiohttp.ClientSession): + log.debug("API session initialised") + self.session = session + self.token_headers = {"Authorization": f"Token 12bccad74fb8575b3242902014f8f3807016f4fe"} + + async def fetch_data(self, url: str, params=None): + log.debug("api fetching from url: %s", url) + async with self.session.get(url, params=params, headers=self.token_headers) as response: + return await response.json(), await response.text(), response.status + + async def send_data(self, url: str, data: dict): + log.debug("api posting to url: %s", url) + async with self.session.post(url, data=data, headers=self.token_headers) as response: + return await response.json(), await response.text(), response.status + + async def delete_data(self, url: str): + log.debug("api deleting to url %s", url) + async with self.session.delete(url, headers=self.token_headers) as response: + return await response.text(), response.status + + async def create_new_rssfeed(self, name: str, url: str, image_url: str, discord_server_id: int): + + log.debug("api creating rss feed: %s %s %s", name, url, image_url) + + async with self.session.get(image_url) as response: + image_data = await response.read() + + form = aiohttp.FormData() + form.add_field("name", name) + form.add_field("url", url) + form.add_field("image", image_data, filename="file.jpg") + form.add_field("discord_server_id", str(discord_server_id)) + + resp_json, resp_text, resp_status = await self.send_data(self.RSS_FEED_ENDPOINT, form) + + if resp_status != 201: + log.error(resp_text) + raise NotCreatedException(f"Expected HTTP 201, not HTTP {response.status} - {resp_text}") + + log.debug(resp_text) + + return resp_json + + async def get_rssfeed(self, uuid: str) -> dict: + + log.debug("api getting rss feed") + + endpoint = f"{self.RSS_FEED_ENDPOINT}{uuid}" + resp_json, resp_text, resp_status = await self.fetch_data(endpoint) + + if resp_status != 200: + log.error(resp_text) + raise BadStatusException(f"Expected HTTP 200, not HTTP {resp_status} - {resp_text}") + + return resp_json + + async def get_rssfeed_list(self, **filters) -> dict: + + log.debug("api getting list of rss feed") + + resp_json, resp_text, resp_status = await self.fetch_data(self.RSS_FEED_ENDPOINT, params=filters) + + if resp_status != 200: + log.error(resp_text) + raise BadStatusException(f"Expected HTTP 200, not HTTP {resp_status} - {resp_text}") + + log.debug(resp_text) + + return resp_json["results"], resp_json["count"] + + async def delete_rssfeed(self, uuid: str): + + log.debug("api deleting rss feed") + + resp_text, resp_status = await self.delete_data(f"{self.RSS_FEED_ENDPOINT}{uuid}/") + + if resp_status != 204: + log.error(resp_text) + raise BadStatusException(f"Expected HTTP 204, not HTTP {resp_status} - {resp_text}") + + log.debug(resp_text) From 93fe4ebfcef12c8443f1a26dc1d1ca276b1a65f0 Mon Sep 17 00:00:00 2001 From: corbz Date: Tue, 30 Jan 2024 13:51:54 +0000 Subject: [PATCH 04/12] Pagination View --- src/utils.py | 120 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 117 insertions(+), 3 deletions(-) diff --git a/src/utils.py b/src/utils.py index d4a03f5..add5dc5 100644 --- a/src/utils.py +++ b/src/utils.py @@ -3,8 +3,10 @@ import aiohttp import logging import async_timeout +from typing import Callable -from discord import Interaction, Embed, Colour +from discord import Interaction, Embed, Colour, ButtonStyle, Button +from discord.ui import View, button log = logging.getLogger(__name__) @@ -53,6 +55,111 @@ class FollowupIcons: assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png" +class PaginationView(View): + def __init__( + self, inter: Interaction, embed: Embed, getdata: Callable, + formatdata: Callable, maxpage: int, initpage: int=1 + ): + self.inter = inter + self.embed = embed + self.getdata = getdata + self.formatdata = formatdata + self.maxpage = maxpage + self.index = initpage + super().__init__(timeout=100) + + async def check_user_is_author(self, inter: Interaction) -> bool: + """Ensure the user is the author of the original command.""" + + if inter.user == self.inter.user: + return True + + await inter.response.defer() + await ( + Followup(None, "Only the author can interact with this.") + .error() + .send(inter, ephemeral=True) + ) + return False + + async def on_timeout(self): + """Erase the controls on timeout.""" + + message = await self.inter.original_response() + await message.edit(view=None) + + @staticmethod + def calc_total_pages(results: int, max_pagesize: int) -> int: + result = ((results - 1) // max_pagesize) + 1 + log.debug("total pages calculated: %s", result) + return result + + @button(emoji="◀️", style=ButtonStyle.blurple) + async def backward(self, inter: Interaction, button: Button): + self.index -= 1 + await inter.response.defer() + self.inter = inter + await self.navigate() + + @button(emoji="▶️", style=ButtonStyle.blurple) + async def forward(self, inter: Interaction, button: Button): + self.index += 1 + await inter.response.defer() + self.inter = inter + await self.navigate() + + @button(emoji="⏭️", style=ButtonStyle.blurple) + async def start_or_end(self, inter: Interaction, button: Button): + if self.index <= self.maxpage // 2: + self.index = self.maxpage + else: + self.index = 1 + + await inter.response.defer() + self.inter = inter + await self.navigate() + + async def navigate(self): + log.debug("navigating to page: %s", self.index) + + self.update_buttons() + paged_embed = await self.create_paged_embed() + await self.inter.edit_original_response(embed=paged_embed, view=self) + + async def create_paged_embed(self) -> Embed: + embed = self.embed.copy() + data = await self.getdata(self.index) + + for item in data: + key, value = self.formatdata(item) + embed.add_field(name=key, value=value, inline=False) + + if self.maxpage != 1: + embed.set_footer(text=f"Page {self.index}/{self.maxpage}") + + return embed + + def update_buttons(self): + if self.index >= self.maxpage: + self.children[2].emoji = "⏮️" + else: + self.children[2].emoji = "⏭️" + + self.children[0].disabled = self.index == 1 + self.children[1].disabled = self.index == self.maxpage + + async def send(self): + embed = await self.create_paged_embed() + + if self.maxpage == 1: + await self.inter.edit_original_response(embed=embed) + return + + self.update_buttons() + await self.inter.edit_original_response(embed=embed, view=self) + + + class Followup: """Wrapper for a discord embed to follow up an interaction.""" @@ -66,10 +173,10 @@ class Followup: description=description ) - async def send(self, inter: Interaction, message: str = None): + async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False): """""" - await inter.followup.send(content=message, embed=self._embed) + await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral) def fields(self, inline: bool = False, **fields: dict): """""" @@ -86,6 +193,13 @@ class Followup: return self + def footer(self, text: str, icon_url: str = None): + """""" + + self._embed.set_footer(text=text, icon_url=icon_url) + + return self + def error(self): """""" From 32df589ed216aef90f7d51f0170f3dd19afe516e Mon Sep 17 00:00:00 2001 From: corbz Date: Tue, 30 Jan 2024 13:53:04 +0000 Subject: [PATCH 05/12] API Integration --- src/extensions/rss.py | 173 ++++++++++++++++++++++++++---------------- src/feed.py | 162 ++++++++++++++++----------------------- 2 files changed, 172 insertions(+), 163 deletions(-) diff --git a/src/extensions/rss.py b/src/extensions/rss.py index b9d4b03..9ac1308 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -6,6 +6,7 @@ Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bo import logging from typing import Tuple +import aiohttp import validators from feedparser import FeedParserDict, parse from discord.ext import commands @@ -14,7 +15,8 @@ from discord.app_commands import Choice, Group, autocomplete, choices, rename from sqlalchemy import insert, select, and_, delete from sqlalchemy.exc import NoResultFound, IntegrityError -from feed import Source +from api import API +from feed import Source, RSSFeed from errors import IllegalFeed from db import ( DatabaseManager, @@ -25,6 +27,7 @@ from db import ( ) from utils import ( Followup, + PaginationView, get_rss_data, followup, audit, @@ -121,6 +124,19 @@ class FeedCog(commands.Cog): log.info("%s cog is ready", self.__class__.__name__) + async def autocomplete_rssfeed(self, inter: Interaction, name: str) -> list[Choice]: + + async with aiohttp.ClientSession() as session: + data = await API(session).get_rssfeed_list() + rssfeeds = RSSFeed.from_list(data) + + choices = [ + Choice(name=item.name, value=item.uuid) + for item in rssfeeds + ] + + return choices + async def source_autocomplete(self, inter: Interaction, nickname: str): """Provides RSS source autocomplete functionality for commands. @@ -154,69 +170,55 @@ class FeedCog(commands.Cog): # All RSS commands belong to this group. feed_group = Group( name="feed", - description="Commands for rss sources.", + description="Commands for RSS sources.", default_permissions=Permissions.elevated(), guild_only=True # We store guild IDs in the database, so guild only = True ) - @feed_group.command(name="add") - async def add_rss_source(self, inter: Interaction, nickname: str, url: str): - """Add a new Feed for this server. + @feed_group.command(name="new") + async def add_rssfeed(self, inter: Interaction, name: str, url: str): + """Add a new RSS Feed for this server. - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - nickname : str - A name used to identify the Feed. - url : str - The Feed URL. + Args: + inter (Interaction): Represents the discord command interaction. + name (str): A nickname used to refer to this RSS Feed. + url (str): The URL of the RSS Feed. """ await inter.response.defer() try: - source = await self.bot.functions.create_new_feed(nickname, url, inter.guild_id) - except IllegalFeed as error: - title, desc = extract_error_info(error) - await Followup(title, desc).fields(**error.items).error().send(inter) - except IntegrityError as error: + rssfeed = await self.bot.functions.create_new_rssfeed(name, url, inter.guild_id) + except Exception as exc: await ( - Followup( - "Duplicate Feed Error", - "A Feed with the same nickname already exist." - ) - .fields(nickname=nickname) + Followup(exc.__class__.__name__, str(exc)) .error() .send(inter) ) else: await ( - Followup("Feed Added") - .image(source.icon_url) - .fields(nickname=nickname, url=url) + Followup("New RSS Feed") + .image(rssfeed.image) + .fields(uuid=rssfeed.uuid, name=name, url=url) .added() .send(inter) ) - @feed_group.command(name="remove") - @rename(url="option") - @autocomplete(url=source_autocomplete) - async def remove_rss_source(self, inter: Interaction, url: str): - """Delete an existing Feed from this server. + @feed_group.command(name="delete") + @autocomplete(uuid=autocomplete_rssfeed) + @rename(uuid="rssfeed") + async def delete_rssfeed(self, inter: Interaction, uuid: str): + """Delete an existing RSS Feed for this server. - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - url : str - The Feed to be removed. Autocomplete or enter the URL. + Args: + inter (Interaction): Represents the discord command interaction. + uuid (str): The UUID of the """ await inter.response.defer() try: - source = await self.bot.functions.delete_feed(url, inter.guild_id) + rssfeed = await self.bot.functions.delete_rssfeed(uuid) except NoResultFound: await ( Followup( @@ -229,49 +231,88 @@ class FeedCog(commands.Cog): else: await ( Followup("Feed Deleted") - .image(source.icon_url) - .fields(url=url) + .image(rssfeed.image) + .fields(uuid=rssfeed.uuid, name=rssfeed.name, url=rssfeed.url) .trash() .send(inter) ) @feed_group.command(name="list") - async def list_rss_sources(self, inter: Interaction): - """Provides a with a list of Feeds available for this server. + async def list_rssfeeds(self, inter: Interaction): + """Provides a list of all RSS Feeds - Parameters - ---------- - inter : Interaction - Represents an app command interaction. + Args: + inter (Interaction): Represents the discord command interaction. """ await inter.response.defer() + page = 1 + try: - sources = await self.bot.functions.get_feeds(inter.guild_id) - except NoResultFound: + rssfeeds, total_results = await self.bot.functions.get_rssfeeds(inter.guild_id, page) + except Exception as exc: await ( - Followup( - "Feeds Not Found Error", - "There are no available Feeds for this server.\n" - "Add a new feed with `/feed add`." - ) + Followup(exc.__class__.__name__, str(exc)) .error() - .send() - ) - else: - description = "\n".join([ - f"{i}. **[{source.name}]({source.url})**" - for i, source in enumerate(sources) - ]) - await ( - Followup( - f"Available Feeds in {inter.guild.name}", - description - ) - .info() .send(inter) ) + else: + # description = "\n\n".join( + # f"{item.name}\n{item.url}\n{item.uuid}" + # for item in rssfeeds + # ) + + # fields = { + # f"{i+1}.": f"{item.name}\n{item.url}\n{item.uuid}" + # for i, item in enumerate(rssfeeds) + # } + + def formatdata(item): + return item.name, f"{item.url}\n{item.uuid}" + + async def getdata(page): + data, count = await self.bot.functions.get_rssfeeds(inter.guild_id, page) + return data + + embed = Followup(f"Available RSS Feeds in {inter.guild.name}").info()._embed + maxpage = PaginationView.calc_total_pages(total_results, 10) + pagination = PaginationView(inter, embed, getdata, formatdata, maxpage, 1) + await pagination.send() + + # await ( + # Followup(f"Available RSS Feeds in {inter.guild.name}") + # .info() + # .fields(**fields) + # .footer(f"Page {page}") + # .send(inter) + # ) + + # try: + # sources = await self.bot.functions.get_feeds(inter.guild_id) + # except NoResultFound: + # await ( + # Followup( + # "Feeds Not Found Error", + # "There are no available Feeds for this server.\n" + # "Add a new feed with `/feed add`." + # ) + # .error() + # .send() + # ) + # else: + # description = "\n".join([ + # f"{i}. **[{source.name}]({source.url})**" + # for i, source in enumerate(sources) + # ]) + # await ( + # Followup( + # f"Available Feeds in {inter.guild.name}", + # description + # ) + # .info() + # .send(inter) + # ) # @feed_group.command(name="fetch") diff --git a/src/feed.py b/src/feed.py index 7472c6d..bf200d9 100644 --- a/src/feed.py +++ b/src/feed.py @@ -1,4 +1,5 @@ +import ssl import json import logging from dataclasses import dataclass @@ -18,6 +19,7 @@ from textwrap import shorten from errors import IllegalFeed from db import DatabaseManager, RssSourceModel, FeedChannelModel from utils import get_rss_data, get_unparsed_feed +from api import API log = logging.getLogger(__name__) dumps = lambda _dict: json.dumps(_dict, indent=8) @@ -131,7 +133,7 @@ class Article: @dataclass class Source: """Represents an RSS source.""" - + name: str | None url: str | None icon_url: str | None @@ -164,7 +166,9 @@ class Source: @classmethod async def from_url(cls, url: str): unparsed_content = await get_unparsed_feed(url) - return cls.from_parsed(parse(unparsed_content)) + source = cls.from_parsed(parse(unparsed_content)) + source.url = url + return source def get_latest_articles(self, max: int = 999) -> list[Article]: """Returns a list of Article objects. @@ -189,6 +193,32 @@ class Source: ] +@dataclass +class RSSFeed: + + uuid: str + name: str + url: str + image: str + discord_server_id: id + created_at: str + + @classmethod + def from_list(cls, data: list) -> list: + result = [] + + for item in data: + key = "discord_server_id" + item[key] = int(item.get(key)) + result.append(cls(**item)) + + return result + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + class Functions: def __init__(self, bot): @@ -245,106 +275,44 @@ class Functions: return feed - async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source: - """Create a new Feed, and return it as a Source object. + async def create_new_rssfeed(self, name: str, url: str, guild_id: int) -> RSSFeed: - Parameters - ---------- - nickname : str - Human readable nickname used to refer to the feed. - url : str - URL to fetch content from the feed. - guild_id : int - Discord Server ID associated with the feed. - Returns - ------- - Source - Dataclass containing attributes of the feed. + log.info("Creating new Feed: %s", name) + + parsed_feed = await self.validate_feed(name, url) + source = Source.from_parsed(parsed_feed) + + async with aiohttp.ClientSession() as session: + data = await API(session).create_new_rssfeed(name, url, source.icon_url, guild_id) + + return RSSFeed.from_dict(data) + + async def delete_rssfeed(self, uuid: str) -> RSSFeed: + + log.info("Deleting Feed '%s'", uuid) + + async with aiohttp.ClientSession() as session: + api = API(session) + data = await api.get_rssfeed(uuid) + await api.delete_rssfeed(uuid) + + return RSSFeed.from_dict(data) + + async def get_rssfeeds(self, guild_id: int, page: int) -> list[RSSFeed]: + """Get a list of RSS Feeds. + + Args: + guild_id (int): The guild_id to filter by. + + Returns: + list[RSSFeed]: Resulting list of RSS Feeds """ - log.info("Creating new Feed: %s - %s", nickname, guild_id) + async with aiohttp.ClientSession() as session: + data, count = await API(session).get_rssfeed_list(discord_server_id=guild_id, page=page) - parsed_feed = await self.validate_feed(nickname, url) - - async with DatabaseManager() as database: - query = insert(RssSourceModel).values( - discord_server_id=guild_id, - rss_url=url, - nick=nickname - ) - await database.session.execute(query) - - log.info("Created Feed: %s - %s", nickname, guild_id) - - return Source.from_parsed(parsed_feed) - - async def delete_feed(self, url: str, guild_id: int) -> Source: - """Delete an existing Feed, then return it as a Source object. - - Parameters - ---------- - url : str - URL of the feed, used in the whereclause. - guild_id : int - Discord Server ID of the feed, used in the whereclause. - - Returns - ------- - Source - Dataclass containing attributes of the feed. - """ - - log.info("Deleting Feed: %s - %s", url, guild_id) - - async with DatabaseManager() as database: - whereclause = and_( - RssSourceModel.discord_server_id == guild_id, - RssSourceModel.rss_url == url - ) - - # Select the Feed entry, because an exception is raised if not found. - select_query = select(RssSourceModel).filter(whereclause) - select_result = await database.session.execute(select_query) - select_result.scalars().one() - - delete_query = delete(RssSourceModel).filter(whereclause) - await database.session.execute(delete_query) - - log.info("Deleted Feed: %s - %s", url, guild_id) - - return await Source.from_url(url) - - async def get_feeds(self, guild_id: int) -> list[Source]: - """Returns a list of fetched Feed objects from the database. - Note: a request will be made too all found Feed URLs. - - Parameters - ---------- - guild_id : int - The Discord Server ID, used to filter down the Feed query. - - Returns - ------- - list[Source] - List of Source objects, resulting from the query. - - Raises - ------ - NoResultFound - Raised if no results are found. - """ - - async with DatabaseManager() as database: - whereclause = and_(RssSourceModel.discord_server_id == guild_id) - query = select(RssSourceModel).where(whereclause) - result = await database.session.execute(query) - rss_sources = result.scalars().all() - - if not rss_sources: - raise NoResultFound - - return [await Source.from_url(feed.rss_url) for feed in rss_sources] + return RSSFeed.from_list(data), count async def assign_feed( self, url: str, channel_name: str, channel_id: int, guild_id: int From 8c35f42a0e4e28aa3d5cabfdee99daa4ca8f2e38 Mon Sep 17 00:00:00 2001 From: corbz Date: Tue, 30 Jan 2024 19:29:42 +0000 Subject: [PATCH 06/12] Improved pagination view Removed double API call added pagesize to API call added calc_dataitem_index method for properly calculating the index of each data item, given the current page. --- src/extensions/rss.py | 72 +++++++++---------------------------------- src/feed.py | 8 +++-- src/utils.py | 33 +++++++++++++++++--- 3 files changed, 48 insertions(+), 65 deletions(-) diff --git a/src/extensions/rss.py b/src/extensions/rss.py index 9ac1308..378f3bc 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -248,72 +248,28 @@ class FeedCog(commands.Cog): await inter.response.defer() page = 1 + pagesize = 10 try: - rssfeeds, total_results = await self.bot.functions.get_rssfeeds(inter.guild_id, page) + def formatdata(index, item): + key = f"{index}. {item.name}" + value = f"[RSS]({item.url}) · [API](http://localhost:8000/api/rssfeed/{item.uuid}/)" + return key, value + + async def getdata(page): + data, count = await self.bot.functions.get_rssfeeds(inter.guild_id, page, pagesize) + return data, count + + embed = Followup(f"Available RSS Feeds in {inter.guild.name}").info()._embed + pagination = PaginationView(inter, embed, getdata, formatdata, pagesize, 1) + await pagination.send() + except Exception as exc: await ( Followup(exc.__class__.__name__, str(exc)) .error() .send(inter) ) - else: - # description = "\n\n".join( - # f"{item.name}\n{item.url}\n{item.uuid}" - # for item in rssfeeds - # ) - - # fields = { - # f"{i+1}.": f"{item.name}\n{item.url}\n{item.uuid}" - # for i, item in enumerate(rssfeeds) - # } - - def formatdata(item): - return item.name, f"{item.url}\n{item.uuid}" - - async def getdata(page): - data, count = await self.bot.functions.get_rssfeeds(inter.guild_id, page) - return data - - embed = Followup(f"Available RSS Feeds in {inter.guild.name}").info()._embed - maxpage = PaginationView.calc_total_pages(total_results, 10) - pagination = PaginationView(inter, embed, getdata, formatdata, maxpage, 1) - await pagination.send() - - # await ( - # Followup(f"Available RSS Feeds in {inter.guild.name}") - # .info() - # .fields(**fields) - # .footer(f"Page {page}") - # .send(inter) - # ) - - # try: - # sources = await self.bot.functions.get_feeds(inter.guild_id) - # except NoResultFound: - # await ( - # Followup( - # "Feeds Not Found Error", - # "There are no available Feeds for this server.\n" - # "Add a new feed with `/feed add`." - # ) - # .error() - # .send() - # ) - # else: - # description = "\n".join([ - # f"{i}. **[{source.name}]({source.url})**" - # for i, source in enumerate(sources) - # ]) - # await ( - # Followup( - # f"Available Feeds in {inter.guild.name}", - # description - # ) - # .info() - # .send(inter) - # ) - # @feed_group.command(name="fetch") # @rename(max_="max") diff --git a/src/feed.py b/src/feed.py index bf200d9..5af67b8 100644 --- a/src/feed.py +++ b/src/feed.py @@ -299,7 +299,7 @@ class Functions: return RSSFeed.from_dict(data) - async def get_rssfeeds(self, guild_id: int, page: int) -> list[RSSFeed]: + async def get_rssfeeds(self, guild_id: int, page: int, pagesize: int) -> list[RSSFeed]: """Get a list of RSS Feeds. Args: @@ -310,7 +310,11 @@ class Functions: """ async with aiohttp.ClientSession() as session: - data, count = await API(session).get_rssfeed_list(discord_server_id=guild_id, page=page) + data, count = await API(session).get_rssfeed_list( + discord_server_id=guild_id, + page=page, + page_size=pagesize + ) return RSSFeed.from_list(data), count diff --git a/src/utils.py b/src/utils.py index add5dc5..58a2b6c 100644 --- a/src/utils.py +++ b/src/utils.py @@ -56,15 +56,29 @@ class FollowupIcons: class PaginationView(View): + """A Discord UI View that adds pagination to an embed.""" + def __init__( self, inter: Interaction, embed: Embed, getdata: Callable, - formatdata: Callable, maxpage: int, initpage: int=1 + formatdata: Callable, pagesize: int, initpage: int=1 ): + """_summary_ + + Args: + inter (Interaction): Represents a discord command interaction. + embed (Embed): The base embed to paginate. + getdata (Callable): A function that provides data, must return Tuple[List[Any], int]. + formatdata (Callable): A formatter function that determines how the data is displayed. + pagesize (int): The size of each page. + initpage (int, optional): The inital page. Defaults to 1. + """ + self.inter = inter self.embed = embed self.getdata = getdata self.formatdata = formatdata - self.maxpage = maxpage + self.maxpage = None + self.pagesize = pagesize self.index = initpage super().__init__(timeout=100) @@ -94,6 +108,13 @@ class PaginationView(View): log.debug("total pages calculated: %s", result) return result + def calc_dataitem_index(self, dataitem_index: int): + if self.index > 1: + dataitem_index += self.pagesize * (self.index - 1) + + dataitem_index += 1 + return dataitem_index + @button(emoji="◀️", style=ButtonStyle.blurple) async def backward(self, inter: Interaction, button: Button): self.index -= 1 @@ -128,10 +149,12 @@ class PaginationView(View): async def create_paged_embed(self) -> Embed: embed = self.embed.copy() - data = await self.getdata(self.index) + data, total_results = await self.getdata(self.index) + self.maxpage = self.calc_total_pages(total_results, self.pagesize) - for item in data: - key, value = self.formatdata(item) + for i, item in enumerate(data): + i = self.calc_dataitem_index(i) + key, value = self.formatdata(i, item) embed.add_field(name=key, value=value, inline=False) if self.maxpage != 1: From c785e5eeed1bf26a4964f133d172baa41489cfd3 Mon Sep 17 00:00:00 2001 From: corbz Date: Wed, 31 Jan 2024 11:49:49 +0000 Subject: [PATCH 07/12] API token as environment variable can be included in .env as API_TOKEN --- src/api.py | 4 ++-- src/bot.py | 5 +++-- src/feed.py | 11 +++++++---- src/main.py | 20 +++++++++++++------- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/src/api.py b/src/api.py index 7211bda..0cfce3b 100644 --- a/src/api.py +++ b/src/api.py @@ -26,10 +26,10 @@ class API: RSS_FEED_ENDPOINT = API_ENDPOINT + "rssfeed/" FEED_CHANNEL_ENDPOINT = API_ENDPOINT + "feedchannel/" - def __init__(self, session: aiohttp.ClientSession): + def __init__(self, api_token: str, session: aiohttp.ClientSession): log.debug("API session initialised") self.session = session - self.token_headers = {"Authorization": f"Token 12bccad74fb8575b3242902014f8f3807016f4fe"} + self.token_headers = {"Authorization": f"Token {api_token}"} async def fetch_data(self, url: str, params=None): log.debug("api fetching from url: %s", url) diff --git a/src/bot.py b/src/bot.py index 0c24713..d3697d3 100644 --- a/src/bot.py +++ b/src/bot.py @@ -17,12 +17,13 @@ log = logging.getLogger(__name__) class DiscordBot(commands.Bot): - def __init__(self, BASE_DIR: Path, developing: bool): + def __init__(self, BASE_DIR: Path, developing: bool, api_token: str): activity = Game("Indev") if developing else None super().__init__(command_prefix="-", intents=Intents.all(), activity=activity) - self.functions = Functions(self) + self.functions = Functions(self, api_token) self.BASE_DIR = BASE_DIR self.developing = developing + self.api_token = api_token log.info("developing=%s", developing) diff --git a/src/feed.py b/src/feed.py index 5af67b8..cda05dd 100644 --- a/src/feed.py +++ b/src/feed.py @@ -221,8 +221,9 @@ class RSSFeed: class Functions: - def __init__(self, bot): + def __init__(self, bot, api_token: str): self.bot = bot + self.api_token = api_token async def validate_feed(self, nickname: str, url: str) -> FeedParserDict: """Validates a feed based on the given nickname and url. @@ -284,7 +285,9 @@ class Functions: source = Source.from_parsed(parsed_feed) async with aiohttp.ClientSession() as session: - data = await API(session).create_new_rssfeed(name, url, source.icon_url, guild_id) + data = await API(self.api_token, session).create_new_rssfeed( + name, url, source.icon_url, guild_id + ) return RSSFeed.from_dict(data) @@ -293,7 +296,7 @@ class Functions: log.info("Deleting Feed '%s'", uuid) async with aiohttp.ClientSession() as session: - api = API(session) + api = API(self.api_token, session) data = await api.get_rssfeed(uuid) await api.delete_rssfeed(uuid) @@ -310,7 +313,7 @@ class Functions: """ async with aiohttp.ClientSession() as session: - data, count = await API(session).get_rssfeed_list( + data, count = await API(self.api_token, session).get_rssfeed_list( discord_server_id=guild_id, page=page, page_size=pagesize diff --git a/src/main.py b/src/main.py index bfb6807..6cfc7ab 100644 --- a/src/main.py +++ b/src/main.py @@ -9,7 +9,7 @@ from os import getenv from pathlib import Path # it's important to load environment variables before -# importing the packages that depend on them. +# importing the modules that depend on them. from dotenv import load_dotenv load_dotenv() @@ -26,12 +26,17 @@ async def main(): # Grab the token before anything else, because if there is no token # available then the bot cannot be started anyways. - token = getenv("BOT_TOKEN") + bot_token = getenv("BOT_TOKEN") + if not bot_token: + raise ValueError("Bot Token is empty") - if not token: - raise ValueError("Token is empty") + # ^ same story for the API token. Without it the API cannot be + # interacted with, so grab it first. + api_token = getenv("API_TOKEN") + if not api_token: + raise ValueError("API Token is empty") - developing = bool(getenv("DEVELOPING")) + developing = bool(getenv("DEVELOPING", False)) # Setup logging settings and mute spammy loggers logsetup = LogSetup(BASE_DIR) @@ -41,9 +46,10 @@ async def main(): level=logging.WARNING ) - async with DiscordBot(BASE_DIR, developing=developing) as bot: + + async with DiscordBot(BASE_DIR, developing=developing, api_token=api_token) as bot: await bot.load_extensions() - await bot.start(token, reconnect=True) + await bot.start(bot_token, reconnect=True) if __name__ == "__main__": asyncio.run(main()) From a5eb297f1b8446686eb09ae051efef6b23079927 Mon Sep 17 00:00:00 2001 From: corbz Date: Wed, 31 Jan 2024 11:50:17 +0000 Subject: [PATCH 08/12] fixed runtime error regarding unpacking dataset --- src/extensions/rss.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/extensions/rss.py b/src/extensions/rss.py index 378f3bc..0ce51a8 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -127,7 +127,9 @@ class FeedCog(commands.Cog): async def autocomplete_rssfeed(self, inter: Interaction, name: str) -> list[Choice]: async with aiohttp.ClientSession() as session: - data = await API(session).get_rssfeed_list() + data, _ = await API(self.bot.api_token, session).get_rssfeed_list( + discord_server_id=inter.guild_id + ) rssfeeds = RSSFeed.from_list(data) choices = [ From 9031cc90c9ce99450da7e3d1bc4c2afdda747708 Mon Sep 17 00:00:00 2001 From: corbz Date: Mon, 5 Feb 2024 22:25:28 +0000 Subject: [PATCH 09/12] fixed issue with reading boolean env var --- src/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main.py b/src/main.py index 6cfc7ab..1bc664e 100644 --- a/src/main.py +++ b/src/main.py @@ -11,7 +11,7 @@ from pathlib import Path # it's important to load environment variables before # importing the modules that depend on them. from dotenv import load_dotenv -load_dotenv() +load_dotenv(override=True) from bot import DiscordBot from logs import LogSetup @@ -36,7 +36,7 @@ async def main(): if not api_token: raise ValueError("API Token is empty") - developing = bool(getenv("DEVELOPING", False)) + developing = getenv("DEVELOPING", "False") == "True" # Setup logging settings and mute spammy loggers logsetup = LogSetup(BASE_DIR) From fb432631d3e00178f2ae5fbb840c1e5a1a8f57da Mon Sep 17 00:00:00 2001 From: corbz Date: Tue, 6 Feb 2024 23:49:31 +0000 Subject: [PATCH 10/12] API code improvements + docstrings --- src/api.py | 126 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 72 insertions(+), 54 deletions(-) diff --git a/src/api.py b/src/api.py index 0cfce3b..3b22e43 100644 --- a/src/api.py +++ b/src/api.py @@ -31,79 +31,97 @@ class API: self.session = session self.token_headers = {"Authorization": f"Token {api_token}"} - async def fetch_data(self, url: str, params=None): - log.debug("api fetching from url: %s", url) - async with self.session.get(url, params=params, headers=self.token_headers) as response: - return await response.json(), await response.text(), response.status + async def make_request(self, method: str, url: str, **kwargs) -> dict: + """Make a request to the given API endpoint. - async def send_data(self, url: str, data: dict): - log.debug("api posting to url: %s", url) - async with self.session.post(url, data=data, headers=self.token_headers) as response: - return await response.json(), await response.text(), response.status + Args: + method (str): The request method to use, examples: GET, POST, DELETE... + url (str): The API endpoint to request to. + **kwargs: Passed into self.session.request. - async def delete_data(self, url: str): - log.debug("api deleting to url %s", url) - async with self.session.delete(url, headers=self.token_headers) as response: - return await response.text(), response.status + Returns: + dict: Dictionary containing status code, json or text. + """ - async def create_new_rssfeed(self, name: str, url: str, image_url: str, discord_server_id: int): + async with self.session.request(method, url, headers=self.token_headers, **kwargs) as response: + response.raise_for_status() + try: + json = await response.json() + text = None + except aiohttp.ContentTypeError: + json = None + text = await response.text() - log.debug("api creating rss feed: %s %s %s", name, url, image_url) + status = response.status + + return {"json": json, "text": text, "status": status} + + async def create_new_rssfeed(self, name: str, url: str, image_url: str, discord_server_id: int) -> dict: + """Create a new RSS Feed. + + Args: + name (str): Name of the RSS Feed. + url (str): URL for the RSS Feed. + image_url (str): URL of the image representation of the RSS Feed. + discord_server_id (int): ID of the discord server behind this item. + + Returns: + dict: JSON representation of the newly created RSS Feed. + """ + + log.debug("creating rssfeed: %s %s %s %s", name, url, image_url, discord_server_id) async with self.session.get(image_url) as response: image_data = await response.read() - form = aiohttp.FormData() - form.add_field("name", name) - form.add_field("url", url) + # Using formdata to make the image transfer easier. + form = aiohttp.FormData({ + "name": name, + "url": url, + "discord_server_id": discord_server_id + }) form.add_field("image", image_data, filename="file.jpg") - form.add_field("discord_server_id", str(discord_server_id)) - resp_json, resp_text, resp_status = await self.send_data(self.RSS_FEED_ENDPOINT, form) - - if resp_status != 201: - log.error(resp_text) - raise NotCreatedException(f"Expected HTTP 201, not HTTP {response.status} - {resp_text}") - - log.debug(resp_text) - - return resp_json + data = (await self.make_response("POST", self.RSS_FEED_ENDPOINT, data=form))["json"] + return data async def get_rssfeed(self, uuid: str) -> dict: + """Get a particular RSS Feed given it's UUID. - log.debug("api getting rss feed") + Args: + uuid (str): Identifier of the desired RSS Feed. - endpoint = f"{self.RSS_FEED_ENDPOINT}{uuid}" - resp_json, resp_text, resp_status = await self.fetch_data(endpoint) + Returns: + dict: A JSON representation of the RSS Feed. + """ - if resp_status != 200: - log.error(resp_text) - raise BadStatusException(f"Expected HTTP 200, not HTTP {resp_status} - {resp_text}") + log.debug("getting rssfeed: %s", uuid) + endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/" + data = (await self.make_request("GET", endpoint))["json"] + return data - return resp_json + async def get_rssfeed_list(self, **filters) -> tuple[list[dict], int]: + """Get all RSS Feeds with the associated filters. - async def get_rssfeed_list(self, **filters) -> dict: + Returns: + tuple[list[dict], int] list contains dictionaries of each item, int is total items. + """ - log.debug("api getting list of rss feed") + log.debug("getting list of rss feeds with filters: %s", filters) + data = (await self.make_request("GET", self.RSS_FEED_ENDPOINT, params=filters))["json"] + return data["results"], data["count"] - resp_json, resp_text, resp_status = await self.fetch_data(self.RSS_FEED_ENDPOINT, params=filters) + async def delete_rssfeed(self, uuid: str) -> int: + """Delete a specified RSS Feed. - if resp_status != 200: - log.error(resp_text) - raise BadStatusException(f"Expected HTTP 200, not HTTP {resp_status} - {resp_text}") + Args: + uuid (str): Identifier of the RSS Feed to delete. - log.debug(resp_text) + Returns: + int: Status code of the response. + """ - return resp_json["results"], resp_json["count"] - - async def delete_rssfeed(self, uuid: str): - - log.debug("api deleting rss feed") - - resp_text, resp_status = await self.delete_data(f"{self.RSS_FEED_ENDPOINT}{uuid}/") - - if resp_status != 204: - log.error(resp_text) - raise BadStatusException(f"Expected HTTP 204, not HTTP {resp_status} - {resp_text}") - - log.debug(resp_text) + log.debug("deleting rssfeed: %s", uuid) + endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/" + status = (await self.make_request("DELETE", endpoint))["status"] + return status From 343767c755848a6f804c003f945814f772c1a6ee Mon Sep 17 00:00:00 2001 From: corbz Date: Wed, 7 Feb 2024 01:02:42 +0000 Subject: [PATCH 11/12] pagination emoji integration 1/? --- src/extensions/rss.py | 2 +- src/utils.py | 32 ++++++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/extensions/rss.py b/src/extensions/rss.py index 0ce51a8..030c87a 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -263,7 +263,7 @@ class FeedCog(commands.Cog): return data, count embed = Followup(f"Available RSS Feeds in {inter.guild.name}").info()._embed - pagination = PaginationView(inter, embed, getdata, formatdata, pagesize, 1) + pagination = PaginationView(self.bot, inter, embed, getdata, formatdata, pagesize, 1) await pagination.send() except Exception as exc: diff --git a/src/utils.py b/src/utils.py index 58a2b6c..b0b5169 100644 --- a/src/utils.py +++ b/src/utils.py @@ -7,6 +7,7 @@ from typing import Callable from discord import Interaction, Embed, Colour, ButtonStyle, Button from discord.ui import View, button +from discord.ext.commands import Bot log = logging.getLogger(__name__) @@ -59,12 +60,13 @@ class PaginationView(View): """A Discord UI View that adds pagination to an embed.""" def __init__( - self, inter: Interaction, embed: Embed, getdata: Callable, + self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable, formatdata: Callable, pagesize: int, initpage: int=1 ): """_summary_ Args: + bot (commands.Bot) The discord bot inter (Interaction): Represents a discord command interaction. embed (Embed): The base embed to paginate. getdata (Callable): A function that provides data, must return Tuple[List[Any], int]. @@ -72,7 +74,8 @@ class PaginationView(View): pagesize (int): The size of each page. initpage (int, optional): The inital page. Defaults to 1. """ - + + self.bot = bot self.inter = inter self.embed = embed self.getdata = getdata @@ -80,6 +83,13 @@ class PaginationView(View): self.maxpage = None self.pagesize = pagesize self.index = initpage + + # emoji reference + next_emoji = bot.get_emoji(1204542366602502265) + prev_emoji = bot.get_emoji(1204542365432422470) + self.start_emoji = bot.get_emoji(1204542364073463818) + self.end_emoji = bot.get_emoji(1204542367752003624) + super().__init__(timeout=100) async def check_user_is_author(self, inter: Interaction) -> bool: @@ -109,6 +119,13 @@ class PaginationView(View): return result def calc_dataitem_index(self, dataitem_index: int): + """Calculates a given index to be relative to the sum of all pages items + + Example: dataitem_index = 6 + pagesize = 10 + if page == 1 then return 6 + else return 6 + 10 * (page - 1)""" + if self.index > 1: dataitem_index += self.pagesize * (self.index - 1) @@ -157,16 +174,19 @@ class PaginationView(View): key, value = self.formatdata(i, item) embed.add_field(name=key, value=value, inline=False) - if self.maxpage != 1: + if not total_results: + embed.description = "There are no results" + + if self.maxpage > 1: embed.set_footer(text=f"Page {self.index}/{self.maxpage}") return embed def update_buttons(self): if self.index >= self.maxpage: - self.children[2].emoji = "⏮️" + self.children[2].emoji = self.start_emoji else: - self.children[2].emoji = "⏭️" + self.children[2].emoji = self.end_emoji self.children[0].disabled = self.index == 1 self.children[1].disabled = self.index == self.maxpage @@ -174,7 +194,7 @@ class PaginationView(View): async def send(self): embed = await self.create_paged_embed() - if self.maxpage == 1: + if self.maxpage <= 1: await self.inter.edit_original_response(embed=embed) return From 5b8ef98ecaea45304d7093180258a19d225a20bb Mon Sep 17 00:00:00 2001 From: corbz Date: Wed, 7 Feb 2024 01:03:00 +0000 Subject: [PATCH 12/12] placeholder api integration --- src/extensions/tasks.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py index 9f5160a..19ad9d0 100644 --- a/src/extensions/tasks.py +++ b/src/extensions/tasks.py @@ -8,13 +8,15 @@ import datetime from os import getenv from time import process_time +import aiohttp from discord import TextChannel +from discord import app_commands from discord.ext import commands, tasks from discord.errors import Forbidden from sqlalchemy import insert, select, and_ from feedparser import parse -from feed import Source, Article +from feed import Source, Article, RSSFeed from db import ( DatabaseManager, FeedChannelModel, @@ -22,11 +24,13 @@ from db import ( SentArticleModel ) from utils import get_unparsed_feed +from api import API log = logging.getLogger(__name__) TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES") +# task trigger times : must be of type list times = [ datetime.time(hour, minute, tzinfo=datetime.timezone.utc) for hour in range(24) @@ -61,20 +65,37 @@ class TaskCog(commands.Cog): self.rss_task.cancel() - @tasks.loop(minutes=10) + @app_commands.command(name="debug-trigger-task") + async def debug_trigger_task(self, inter): + await inter.response.defer() + await self.rss_task() + await inter.followup.send("done") + + @tasks.loop(time=times) async def rss_task(self): """Automated task responsible for processing rss feeds.""" log.info("Running rss task") time = process_time() - async with DatabaseManager() as database: - query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel) - result = await database.session.execute(query) - feeds = result.scalars().all() + # async with DatabaseManager() as database: + # query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel) + # result = await database.session.execute(query) + # feeds = result.scalars().all() + + # for feed in feeds: + # await self.process_feed(feed, database) + + guild_ids = [guild.id for guild in self.bot.guilds] + + async with aiohttp.ClientSession() as session: + api = API(self.bot.api_token, session) + data, count = await api.get_rssfeed_list(discord_server_id__in=guild_ids) + rssfeeds = RSSFeed.from_list(data) + for item in rssfeeds: + log.info(item.name) + - for feed in feeds: - await self.process_feed(feed, database) log.info("Finished rss task, time elapsed: %s", process_time() - time) @@ -95,7 +116,7 @@ class TaskCog(commands.Cog): # TODO: integrate the `validate_feed` code into here, also do on list command and show errors. - unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url) + unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url) parsed_feed = parse(unparsed_content) source = Source.from_parsed(parsed_feed) articles = source.get_latest_articles(5)