From 08ba99552007e5f8e21a1f1813fbd8a220d6a0d9 Mon Sep 17 00:00:00 2001 From: corbz Date: Sun, 11 Feb 2024 23:52:32 +0000 Subject: [PATCH] Working on new db model integration --- src/api.py | 288 ++++++++-- src/extensions/rss.py | 1205 ++++++++++++++++++++++++----------------- src/feed.py | 72 +++ 3 files changed, 1025 insertions(+), 540 deletions(-) diff --git a/src/api.py b/src/api.py index b7996a2..5ced5ea 100644 --- a/src/api.py +++ b/src/api.py @@ -39,6 +39,10 @@ class API: RSS_FEED_ENDPOINT = API_ENDPOINT + "rssfeed/" FEED_CHANNEL_ENDPOINT = API_ENDPOINT + "feedchannel/" + SUBSCRIPTION_ENDPOINT = API_ENDPOINT + "subscription/" + CHANNEL_ENDPOINT = SUBSCRIPTION_ENDPOINT + "channel/" + TRACKED_ENDPOINT = API_ENDPOINT + "tracked/" + def __init__(self, api_token: str, session: aiohttp.ClientSession): log.debug("API session initialised") self.session = session @@ -69,72 +73,262 @@ class API: return {"json": json, "text": text, "status": status} - async def create_new_rssfeed(self, name: str, url: str, image_url: str, discord_server_id: int) -> dict: - """Create a new RSS Feed. + async def _post_data(self, url: str, data: dict | aiohttp.FormData) -> dict: + return await self.make_request( + method="POST", + url=url, + data=data + ) - Args: - name (str): Name of the RSS Feed. - url (str): URL for the RSS Feed. - image_url (str): URL of the image representation of the RSS Feed. - discord_server_id (int): ID of the discord server behind this item. + async def _put_data(self, url: str, data: dict | aiohttp.FormData) -> dict: + return await self.make_request( + method="PUT", + url=url, + data=data + ) - Returns: - dict: JSON representation of the newly created RSS Feed. + async def _get_one(self, url: str) -> dict: + return await self.make_request( + method="GET", + url=url + ) + + async def _get_many(self, url: str, filters: dict) -> tuple[list[dict], int]: + data = await self.make_request( + method="GET", + url=url, + params=filters + ) + content = data["json"] + return content["results"], content["count"] + + async def _delete(self, url: str) -> None: + await self.make_request( + method="DELETE", + url=url + ) + + async def create_subscription(self, name: str, rss_url: str, image_url: str, server_id: str) -> dict: + """ + Create a new Subscription. """ - log.debug("creating rssfeed: %s %s %s %s", name, url, image_url, discord_server_id) + log.debug("subscribing '%s' to '%s'", server_id, rss_url) async with self.session.get(image_url) as response: image_data = await response.read() - # Using formdata to make the image transfer easier. - form = aiohttp.FormData({ - "name": name, - "url": url, - "discord_server_id": str(discord_server_id) - }) + form = aiohttp.FormData() + form.add_field("name", name) + form.add_field("rss_url", rss_url) + form.add_field("server", server_id) form.add_field("image", image_data, filename="file.jpg") - data = (await self.make_request("POST", self.RSS_FEED_ENDPOINT, data=form))["json"] - return data + data = await self._post_data(self.SUBSCRIPTION_ENDPOINT, form) - async def get_rssfeed(self, uuid: str) -> dict: - """Get a particular RSS Feed given it's UUID. + return data["json"] - Args: - uuid (str): Identifier of the desired RSS Feed. - - Returns: - dict: A JSON representation of the RSS Feed. + async def get_subscription(self, uuid: str) -> dict: + """ + Retreive a Subscription. """ - log.debug("getting rssfeed: %s", uuid) - endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/" - data = (await self.make_request("GET", endpoint))["json"] - return data + log.debug("retreiving subscription '%s'", uuid) - async def get_rssfeed_list(self, **filters) -> tuple[list[dict], int]: - """Get all RSS Feeds with the associated filters. + url=f"{self.SUBSCRIPTION_ENDPOINT}{uuid}/" + data = await self._get_one(url) - Returns: - tuple[list[dict], int] list contains dictionaries of each item, int is total items. + return data["json"] + + async def get_subscriptions(self, **filters) -> tuple[list[dict], int]: + """ + Retreive multiple Subscriptions. """ - log.debug("getting list of rss feeds with filters: %s", filters) - data = (await self.make_request("GET", self.RSS_FEED_ENDPOINT, params=filters))["json"] - return data["results"], data["count"] + log.debug("retreiving multiple subscriptions") - async def delete_rssfeed(self, uuid: str) -> int: - """Delete a specified RSS Feed. + return await self._get_many(self.SUBSCRIPTION_ENDPOINT, filters) - Args: - uuid (str): Identifier of the RSS Feed to delete. - - Returns: - int: Status code of the response. + async def delete_subscription(self, uuid: str) -> None: + """ + Delete an existing Subscription. """ - log.debug("deleting rssfeed: %s", uuid) - endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/" - status = (await self.make_request("DELETE", endpoint))["status"] - return status + log.debug("deleting subscription '%s'", uuid) + + url=f"{self.SUBSCRIPTION_ENDPOINT}{uuid}/" + await self._delete(url) + + async def create_subscription_channel(self, channel_id: int | str, sub_uuid: str) -> dict: + """ + Create a new Channel. + """ + + log.debug("creating new subscription channel '%s', '%s'", channel_id, sub_uuid) + + form = aiohttp.FormData() + form.add_field("id", str(channel_id)) + form.add_field("subscription", sub_uuid) + + data = await self._post_data(self.CHANNEL_ENDPOINT, form) + + return data["json"] + + async def get_subscription_channel(self, uuid: str) -> dict: + """ + Retreive a Channel. + """ + + log.debug("retreiving a subscription channel '%s'", uuid) + + url = f"{self.CHANNEL_ENDPOINT}{uuid}/" + data = await self._get_one(url) + + return data["json"] + + async def get_subscription_channels(self, **filters) -> tuple[list[dict], int]: + """ + Retreive multiple Channels. + """ + + log.debug("retreiving multiple channels") + + return await self._get_many(self.CHANNEL_ENDPOINT, filters) + + async def delete_subscription_channel(self, uuid: str) -> None: + """ + Delete an existing Channel. + """ + + log.debug("deleting channel '%s'", uuid) + + url=f"{self.CHANNEL_ENDPOINT}{uuid}/" + await self._delete(url) + + async def create_tracked_content(self, sub_uuid: str, content_url: str) -> dict: + """ + Create a Tracked Content. + """ + + log.debug("creating tracked content '%s', '%s'", sub_uuid, content_url) + + form = aiohttp.FormData() + form.add_field("subscription", sub_uuid) + form.add_field("content_url", content_url) + + data = await self._post_data(self.TRACKED_ENDPOINT, form) + + return data["json"] + + async def get_tracked_content(self, uuid: str = None, content_url: str = None) -> dict: + """ + Retreive a Tracked Content. + """ + + log.debug("retreiving tracked content '%s', '%s'", uuid, content_url) + + if not (uuid or content_url) or (uuid and content_url): + raise ValueError( + "Must use only 'uuid' or 'content_url' arguments, cannot use " + "both arguments or none." + ) + + url = f"{self.TRACKED_ENDPOINT}{uuid or content_url}/" + data = await self._get_one(url) + + return data["json"] + + async def get_tracked_contents(self, **filters) -> tuple[list[dict], int]: + """ + Retreive multiple Tracked Content. + """ + + log.debug("retreiving multiple tracked content") + + return await self._get_many(self.TRACKED_ENDPOINT, filters) + + async def delete_tracked_content(self, uuid: str) -> None: + """ + Delete a Tracked Content. + """ + + log.debug("deleting tracked content '%s'", uuid) + + url = f"{self.TRACKED_ENDPOINT}{uuid}/" + await self._delete(url) + + async def is_tracked(self, content_url: str) -> bool: + """ + Returns boolean if an item with the given url exists. + """ + raise NotImplementedError + + # async def create_new_rssfeed(self, name: str, url: str, image_url: str, discord_server_id: int) -> dict: + # """Create a new RSS Feed. + + # Args: + # name (str): Name of the RSS Feed. + # url (str): URL for the RSS Feed. + # image_url (str): URL of the image representation of the RSS Feed. + # discord_server_id (int): ID of the discord server behind this item. + + # Returns: + # dict: JSON representation of the newly created RSS Feed. + # """ + + # log.debug("creating rssfeed: %s %s %s %s", name, url, image_url, discord_server_id) + + # async with self.session.get(image_url) as response: + # image_data = await response.read() + + # # Using formdata to make the image transfer easier. + # form = aiohttp.FormData({ + # "name": name, + # "url": url, + # "discord_server_id": str(discord_server_id) + # }) + # form.add_field("image", image_data, filename="file.jpg") + + # data = (await self.make_request("POST", self.RSS_FEED_ENDPOINT, data=form))["json"] + # return data + + # async def get_rssfeed(self, uuid: str) -> dict: + # """Get a particular RSS Feed given it's UUID. + + # Args: + # uuid (str): Identifier of the desired RSS Feed. + + # Returns: + # dict: A JSON representation of the RSS Feed. + # """ + + # log.debug("getting rssfeed: %s", uuid) + # endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/" + # data = (await self.make_request("GET", endpoint))["json"] + # return data + + # async def get_rssfeed_list(self, **filters) -> tuple[list[dict], int]: + # """Get all RSS Feeds with the associated filters. + + # Returns: + # tuple[list[dict], int] list contains dictionaries of each item, int is total items. + # """ + + # log.debug("getting list of rss feeds with filters: %s", filters) + # data = (await self.make_request("GET", self.RSS_FEED_ENDPOINT, params=filters))["json"] + # return data["results"], data["count"] + + # async def delete_rssfeed(self, uuid: str) -> int: + # """Delete a specified RSS Feed. + + # Args: + # uuid (str): Identifier of the RSS Feed to delete. + + # Returns: + # int: Status code of the response. + # """ + + # log.debug("deleting rssfeed: %s", uuid) + # endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/" + # status = (await self.make_request("DELETE", endpoint))["status"] + # return status diff --git a/src/extensions/rss.py b/src/extensions/rss.py index cbd0b68..7bec1ed 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -5,18 +5,19 @@ Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bo import logging from typing import Tuple +from dataclasses import asdict import aiohttp import validators from feedparser import FeedParserDict, parse from discord.ext import commands from discord import Interaction, Embed, Colour, TextChannel, Permissions -from discord.app_commands import Choice, Group, autocomplete, choices, rename +from discord.app_commands import Choice, Group, autocomplete, choices, rename, command from sqlalchemy import insert, select, and_, delete from sqlalchemy.exc import NoResultFound, IntegrityError from api import API -from feed import Source, RSSFeed +from feed import Source, RSSFeed, Subscription, SubscriptionChannel, TrackedContent from errors import IllegalFeed from db import ( DatabaseManager, @@ -124,580 +125,798 @@ class FeedCog(commands.Cog): log.info("%s cog is ready", self.__class__.__name__) - async def autocomplete_rssfeed(self, inter: Interaction, name: str) -> list[Choice]: + # async def autocomplete_channels(self, inter: Interaction, name: str) -> list[Choice]: + # """""" - async with aiohttp.ClientSession() as session: - data, _ = await API(self.bot.api_token, session).get_rssfeed_list( - discord_server_id=inter.guild_id - ) - rssfeeds = RSSFeed.from_list(data) + # log.debug("autocompleting channels '%s'", name) - choices = [ - Choice(name=item.name, value=item.uuid) - for item in rssfeeds + # try: + # async with aiohttp.ClientSession() as session: + # api = API(self.bot.api_token, session) + # results, _ = await api.get_channel(server=inter.guild_id) + + # except Exception as exc: + # log.error(exc) + # return [] + + # channels = Channel.from_list(results) + + # return [ + # Choice(name=channel.get_textchannel(self.bot).name, value=channel.id) + # for channel in channels + # ] + + async def autocomplete_subscriptions(self, inter: Interaction, name: str) -> list[Choice]: + """""" + + log.debug("autocompleting subscriptions '%s'", name) + + try: + async with aiohttp.ClientSession() as session: + api = API(self.bot.api_token, session) + results, _ = await api.get_subscriptions(server=inter.guild_id, search=name) + + except Exception as exc: + log.error(exc) + return [] + + subscriptions = Subscription.from_list(results) + + return [ + Choice(name=sub.name, value=sub.uuid) + for sub in subscriptions ] - return choices + # channel_group = Group( + # name="channels", + # description="channel commands", + # guild_only=True + # ) - async def source_autocomplete(self, inter: Interaction, nickname: str): - """Provides RSS source autocomplete functionality for commands. + # @channel_group.command(name="add") + # @autocomplete(sub_uuid=autocomplete_subscriptions) + # @rename(sub_uuid="subscription", textchannel="channel") + # async def new_channel(self, inter: Interaction, sub_uuid: str, textchannel: TextChannel): + # """Create a new channel.""" - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - nickname : str - _description_ + # await inter.response.defer() - Returns - ------- - list of app_commands.Choice - _description_ - """ + # try: + # async with aiohttp.ClientSession() as session: + # api = API(self.bot.api_token, session) + # await api.create_channel(textchannel.id) + # # sub = await api.get_subscription(sub_uuid) + # # sub["channels"].append(textchannel.id) + # # await api.put_subscription(sub_uuid, sub) - async with DatabaseManager() as database: - whereclause = and_( - RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.nick.ilike(f"%{nickname}%") - ) - query = select(RssSourceModel).where(whereclause).order_by(RssSourceModel.nick) - result = await database.session.execute(query) - sources = [ - Choice(name=rss.nick, value=rss.rss_url) - for rss in result.scalars().all() - ] + # except Exception as exc: + # return await ( + # Followup(exc.__class__.__name__, str(exc)) + # .error() + # .send(inter) + # ) - return sources + # await ( + # Followup("Channel Assigned!") + # .fields( + # subscription=sub_uuid, + # channel=textchannel.mention + # ) + # .added() + # .send(inter) + # ) - # All RSS commands belong to this group. - feed_group = Group( - name="feed", - description="Commands for RSS sources.", - default_permissions=Permissions.elevated(), - guild_only=True # We store guild IDs in the database, so guild only = True + # @channel_group.command(name="remove") + # @autocomplete(id=autocomplete_channels) + # @rename(id="choice") + # async def remove_channel(self, inter: Interaction, id: int): + # """Remove a channel.""" + + # await inter.response.defer() + + # try: + # async with aiohttp.ClientSession() as session: + # api = API(self.bot.api_token, session) + # await api.delete_channel(id) + + # except Exception as exc: + # return await ( + # Followup(exc.__class__.__name__, str(exc)) + # .error() + # .send(inter) + # ) + + # await ( + # Followup("Channel Removed!", str(id)) + # .trash() + # .send(inter) + # ) + + # @channel_group.command(name="list") + # async def list_channels(self, inter: Interaction): + + # log.debug("Listing all subscription channels with this server.") + + # await inter.response.defer() + + # page = 1 + # pagesize = 10 + + # def formatdata(index, item): + # item = Channel.from_dict(item) + # text_channel = item.get_textchannel(self.bot) + + # key = f"{index}. {text_channel.mention}" + # value = f"[RSS]({item.rss_url}) · [API]({API.CHANNEL_ENDPOINT}{item.uuid}/)" + # return key, value + + # async def getdata(page): + # async with aiohttp.ClientSession() as session: + # api = API(self.bot.api_token, session) + # return await api.get_subscriptions( + # server=inter.guild.id, page=page, page_size=pagesize + # ) + + # embed = Followup(f"Subscriptions in {inter.guild.name}").info()._embed + # pagination = PaginationView(self.bot, inter, embed, getdata, formatdata, pagesize, page) + # await pagination.send() + + subscription_group = Group( + name="subscriptions", + description="subscription commands", + guild_only=True ) - @feed_group.command(name="new") - async def add_rssfeed(self, inter: Interaction, name: str, url: str): - """Add a new RSS Feed for this server. - - Args: - inter (Interaction): Represents the discord command interaction. - name (str): A nickname used to refer to this RSS Feed. - url (str): The URL of the RSS Feed. + @subscription_group.command(name="link") + @autocomplete(sub_uuid=autocomplete_subscriptions) + @rename(sub_uuid="subscription") + async def link_subscription_to_channel(self, inter: Interaction, sub_uuid: str, channel: TextChannel): + """ + Link Subscription to discord.TextChannel. """ await inter.response.defer() try: - rssfeed = await self.bot.functions.create_new_rssfeed(name, url, inter.guild_id) + async with aiohttp.ClientSession() as session: + api = API(self.bot.api_token, session) + data = await api.create_subscription_channel(str(channel.id), sub_uuid) + except Exception as exc: - await ( + return await ( Followup(exc.__class__.__name__, str(exc)) .error() .send(inter) ) - else: - await ( - Followup("New RSS Feed") - .image(rssfeed.image) - .fields(uuid=rssfeed.uuid, name=name, url=url) - .added() - .send(inter) - ) - @feed_group.command(name="delete") - @autocomplete(uuid=autocomplete_rssfeed) - @rename(uuid="rssfeed") - async def delete_rssfeed(self, inter: Interaction, uuid: str): - """Delete an existing RSS Feed for this server. + await ( + Followup("Linked!") + .fields(**data) + .added() + .send(inter) + ) - Args: - inter (Interaction): Represents the discord command interaction. - uuid (str): The UUID of the - """ + + @subscription_group.command(name="add") + async def new_subscription(self, inter: Interaction, name: str, rss_url: str): + """Subscribe this server to a new RSS Feed.""" await inter.response.defer() try: - rssfeed = await self.bot.functions.delete_rssfeed(uuid) - except NoResultFound: - await ( - Followup( - "Feed Not Found Error", - "A Feed with these parameters could not be found." - ) + parsed_rssfeed = await self.bot.functions.validate_feed(name, rss_url) + image_url = parsed_rssfeed.get("feed", {}).get("image", {}).get("href") + + async with aiohttp.ClientSession() as session: + api = API(self.bot.api_token, session) + data = await api.create_subscription(name, rss_url, image_url, str(inter.guild_id)) + + except Exception as exc: + return await ( + Followup(exc.__class__.__name__, str(exc)) .error() .send(inter) ) - else: - await ( - Followup("Feed Deleted") - .image(rssfeed.image) - .fields(uuid=rssfeed.uuid, name=rssfeed.name, url=rssfeed.url) - .trash() + + # Omit data we dont want the user to see + data.pop("image") + data.pop("server") + data.pop("creation_datetime") + + # Update keys to be more human readable + data["UUID"] = data.pop("uuid") + data["url"] = data.pop("rss_url") + + await ( + Followup("Subscription Added!") + .fields(**data) + .image(image_url) + .added() + .send(inter) + ) + + @subscription_group.command(name="remove") + @autocomplete(uuid=autocomplete_subscriptions) + @rename(uuid="choice") + async def remove_subscriptions(self, inter: Interaction, uuid: str): + """Unsubscribe this server from an existing RSS Feed.""" + + await inter.response.defer() + + try: + async with aiohttp.ClientSession() as session: + api = API(self.bot.api_token, session) + await api.delete_subscription(uuid) + + except Exception as exc: + return await ( + Followup(exc.__class__.__name__, str(exc)) + .error() .send(inter) ) - @feed_group.command(name="list") - async def list_rssfeeds(self, inter: Interaction): - """Provides a list of all RSS Feeds + await ( + Followup("Subscription Removed!", uuid) + .trash() + .send(inter) + ) - Args: - inter (Interaction): Represents the discord command interaction. - """ + @subscription_group.command(name="list") + async def list_subscription(self, inter: Interaction): + """List Subscriptions from this server.""" await inter.response.defer() page = 1 pagesize = 10 - try: - def formatdata(index, item): - key = f"{index}. {item.name}" - value = f"[RSS]({item.url}) · [API]({API.RSS_FEED_ENDPOINT}{item.uuid}/)" - return key, value + def formatdata(index, item): + item = Subscription.from_dict(item) + key = f"{index}. {item.name}" + value = f"[RSS]({item.rss_url}) · [API]({API.SUBSCRIPTION_ENDPOINT}{item.uuid}/)" + return key, value - async def getdata(page): - data, count = await self.bot.functions.get_rssfeeds(inter.guild_id, page, pagesize) - return data, count + async def getdata(page): + async with aiohttp.ClientSession() as session: + api = API(self.bot.api_token, session) + return await api.get_subscriptions( + server=inter.guild.id, page=page, page_size=pagesize + ) - embed = Followup(f"Available RSS Feeds in {inter.guild.name}").info()._embed - pagination = PaginationView(self.bot, inter, embed, getdata, formatdata, pagesize, 1) - await pagination.send() + embed = Followup(f"Subscriptions in {inter.guild.name}").info()._embed + pagination = PaginationView(self.bot, inter, embed, getdata, formatdata, pagesize, page) + await pagination.send() - except Exception as exc: - await ( - Followup(exc.__class__.__name__, str(exc)) - .error() - .send(inter) - ) - - # @feed_group.command(name="fetch") - # @rename(max_="max") - # @autocomplete(rss=source_autocomplete) - async def fetch_rss(self, inter: Interaction, rss: str, max_: int=1): - """Fetch an item from the specified RSS feed. - - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - rss : str - The RSS feed to fetch from. - max_ : int, optional - Maximum number of items to fetch, by default 1, limits at 5. - """ - - await inter.response.defer() - - if max_ > 5: - followup(inter, "It looks like you have requested too many articles.\nThe limit is 5") - return - - invalid_message, feed = await validate_rss_source("", rss) - if invalid_message: - await followup(inter, invalid_message) - return - - source = Source.from_parsed(feed) - articles = source.get_latest_articles(max_) - - if not articles: - await followup(inter, "Sorry, I couldn't find any articles from this feed.") - return - - async with aiohttp.ClientSession() as session: - embeds = [await article.to_embed(session) for article in articles] - - async with DatabaseManager() as database: - query = insert(SentArticleModel).values([ - { - "discord_server_id": inter.guild_id, - "discord_channel_id": inter.channel_id, - "discord_message_id": inter.id, - "article_url": article.url, - } - for article in articles - ]) - await database.session.execute(query) - await audit(self, - f"User is requesting {max_} articles from {source.name}", - inter.user.id, database=database - ) - - await followup(inter, embeds=embeds) - - # Help ---- ---- ---- - - @feed_group.command(name="help") - async def get_help(self, inter: Interaction): - """Get help on how to use my commands. - - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - """ - - await inter.response.defer() - - description = ( - "`/feed add ` \n\n" - "Save a new RSS feed to the bot. This can be referred to later, when assigning " - "channels to receive content from these RSS feeds." - - "\n\n\n`/feed remove