""" Extension for the `FeedCog`. Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bot. """ import logging from typing import Tuple import aiohttp import validators from feedparser import FeedParserDict, parse from discord.ext import commands from discord import Interaction, TextChannel from discord.app_commands import Choice, Group, autocomplete, rename, command from api import API from feed import Subscription, SubscriptionChannel, TrackedContent from utils import ( Followup, PaginationView, get_rss_data, ) log = logging.getLogger(__name__) rss_list_sort_choices = [ Choice(name="Nickname", value=0), Choice(name="Date Added", value=1) ] channels_list_sort_choices=[ Choice(name="Feed Nickname", value=0), Choice(name="Channel ID", value=1), Choice(name="Date Added", value=2) ] # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the # target resource. Run a period task to check this. async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]: """Validate a provided RSS source. Parameters ---------- nickname : str Nickname of the source. Must not contain URL. url : str URL of the source. Must be URL with valid status code and be an RSS feed. Returns ------- str or None String invalid message if invalid, NoneType if valid. FeedParserDict or None The feed parsed from the given URL or None if invalid. """ # Ensure the URL is valid if not validators.url(url): return f"The URL you have entered is malformed or invalid:\n`{url=}`", None # Check the nickname is not a URL if validators.url(nickname): return "It looks like the nickname you have entered is a URL.\n" \ f"For security reasons, this is not allowed.\n`{nickname=}`", None feed_data, status_code = await get_rss_data(url) # Check the URL status code is valid if status_code != 200: return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None # Check the contents is actually an RSS feed. feed = parse(feed_data) if not feed.version: return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None return None, feed class FeedCog(commands.Cog): """ Command cog. """ def __init__(self, bot: commands.Bot): super().__init__() self.bot = bot @commands.Cog.listener() async def on_ready(self): """Instructions to call when the cog is ready.""" log.info("%s cog is ready", self.__class__.__name__) # async def autocomplete_subscriptions(self, inter: Interaction, name: str) -> list[Choice]: # """""" # log.debug("autocompleting subscriptions '%s'", name) # try: # async with aiohttp.ClientSession() as session: # api = API(self.bot.api_token, session) # results, _ = await api.get_subscriptions(server=inter.guild_id, search=name) # except Exception as exc: # log.error(exc) # return [] # subscriptions = Subscription.from_list(results) # return [ # Choice(name=sub.name, value=sub.uuid) # for sub in subscriptions # ] # async def autocomplete_subscription_channels(self, inter: Interaction, uuid: str): # """""" # log.debug("autocompleting subscription channels") # try: # async with aiohttp.ClientSession() as session: # api = API(self.bot.api_token, session) # results, _ = await api.get_subscription_channels() # except Exception as exc: # log.error(exc) # return [] # subscription_channels = SubscriptionChannel.from_list(results) # async def name(link): # result = self.bot.get_channel(link.id) or await self.bot.fetch_channel(link.id) # return f"{link.subscription.name} -> #{result.name}" # return [ # Choice(name=await name(link), value=link.uuid) # for link in subscription_channels # ] # subscription_group = Group( # name="subscriptions", # description="subscription commands", # guild_only=True # ) # @subscription_group.command(name="link") # @autocomplete(sub_uuid=autocomplete_subscriptions) # @rename(sub_uuid="subscription") # async def link_subscription_channel(self, inter: Interaction, sub_uuid: str, channel: TextChannel): # """ # Link Subscription to discord.TextChannel. # """ # await inter.response.defer() # try: # async with aiohttp.ClientSession() as session: # api = API(self.bot.api_token, session) # data = await api.create_subscription_channel(str(channel.id), sub_uuid) # except aiohttp.ClientResponseError as exc: # return await ( # Followup( # f"Error · {exc.message}", # "Ensure you haven't: \n" # "- Already linked this subscription to this channel\n" # "- Already linked this subscription to the maximum of 4 channels" # ) # .footer(f"HTTP {exc.code}") # .error() # .send(inter) # ) # subscription = Subscription.from_dict(data.pop("subscription")) # data["subscription"] = ( # f"{subscription.name}\n" # f"[RSS]({subscription.rss_url}) · " # f"[API Subscription]({API.SUBSCRIPTION_ENDPOINT}{subscription.uuid}) · " # f"[API Link]({API.CHANNEL_ENDPOINT}{data['uuid']})" # ) # channel_id = int(data.pop("id")) # channel = self.bot.get_channel(channel_id) or await self.bot.fetch_channel(channel_id) # data["channel"] = channel.mention # data.pop("creation_datetime") # data.pop("uuid") # await ( # Followup("Linked!") # .fields(**data) # .added() # .send(inter) # ) # @subscription_group.command(name="unlink") # @autocomplete(uuid=autocomplete_subscription_channels) # @rename(uuid="link") # async def unlink_subscription_channel(self, inter: Interaction, uuid: str): # """ # Unlink subscription from discord.TextChannel. # """ # await inter.response.defer() # try: # async with aiohttp.ClientSession() as session: # api = API(self.bot.api_token, session) # # data = await api.get_subscription(uuid=uuid) # await api.delete_subscription_channel(uuid=uuid) # # sub_channel = await SubscriptionChannel.from_dict(data) # except Exception as exc: # return await ( # Followup(exc.__class__.__name__, str(exc)) # .error() # .send(inter) # ) # await ( # Followup("Subscription unlinked!", uuid) # .added() # .send(inter) # ) # @subscription_group.command(name="list-links") # async def list_subscription(self, inter: Interaction): # """List Subscriptions Channels in this server.""" # await inter.response.defer() # async def formatdata(index: int, item: dict) -> tuple[str, str]: # item = SubscriptionChannel.from_dict(item) # next_emoji = self.bot.get_emoji(1204542366602502265) # key = f"{index}. {item.subscription.name} {next_emoji} {item.mention}" # return key, item.hyperlinks_string # async def getdata(page: int, pagesize: int) -> dict: # async with aiohttp.ClientSession() as session: # api = API(self.bot.api_token, session) # return await api.get_subscription_channels( # subscription__server=inter.guild.id, page=page, page_size=pagesize # ) # embed = Followup(f"Links in {inter.guild.name}").info()._embed # pagination = PaginationView( # self.bot, # inter=inter, # embed=embed, # getdata=getdata, # formatdata=formatdata, # pagesize=10, # initpage=1 # ) # await pagination.send() # @subscription_group.command(name="add") # async def new_subscription(self, inter: Interaction, name: str, rss_url: str): # """Subscribe this server to a new RSS Feed.""" # await inter.response.defer() # try: # parsed_rssfeed = await self.bot.functions.validate_feed(name, rss_url) # image_url = parsed_rssfeed.get("feed", {}).get("image", {}).get("href") # async with aiohttp.ClientSession() as session: # api = API(self.bot.api_token, session) # data = await api.create_subscription(name, rss_url, image_url, str(inter.guild_id), [-1]) # except aiohttp.ClientResponseError as exc: # return await ( # Followup( # f"Error · {exc.message}", # "Ensure you haven't: \n" # "- Reused an identical name of an existing Subscription\n" # "- Already created the maximum of 25 Subscriptions" # ) # .footer(f"HTTP {exc.code}") # .error() # .send(inter) # ) # # Omit data we dont want the user to see # data.pop("uuid") # data.pop("image") # data.pop("server") # data.pop("creation_datetime") # # Update keys to be more human readable # data["url"] = data.pop("rss_url") # await ( # Followup("Subscription Added!") # .fields(**data) # .image(image_url) # .added() # .send(inter) # ) # @subscription_group.command(name="remove") # @autocomplete(uuid=autocomplete_subscriptions) # @rename(uuid="choice") # async def remove_subscriptions(self, inter: Interaction, uuid: str): # """Unsubscribe this server from an existing RSS Feed.""" # await inter.response.defer() # try: # async with aiohttp.ClientSession() as session: # api = API(self.bot.api_token, session) # await api.delete_subscription(uuid) # except Exception as exc: # return await ( # Followup(exc.__class__.__name__, str(exc)) # .error() # .send(inter) # ) # await ( # Followup("Subscription Removed!", uuid) # .trash() # .send(inter) # ) @command(name="subscriptions") async def list_subscription(self, inter: Interaction): """List Subscriptions from this server.""" await inter.response.defer() def formatdata(index, item): item = Subscription.from_dict(item) channels = f"{item.channels_count}{' channels' if item.channels_count != 1 else ' channel'}" filters = f"{len(item.filters)}{' filters' if len(item.filters) != 1 else ' filter'}" notes = item.extra_notes[:25] + "..." if len(item.extra_notes) > 28 else item.extra_notes links = f"[RSS URL]({item.url}) · [API URL]({API.API_EXTERNAL_ENDPOINT}subscription/{item.id}/)" description = f"{channels}, {filters}\n" description += f"{notes}\n" if notes else "" description += links key = f"{index}. {item.name}" return key, description # key, value pair async def getdata(page: int, pagesize: int): async with aiohttp.ClientSession() as session: api = API(self.bot.api_token, session) return await api.get_subscriptions( guild_id=inter.guild.id, page=page, page_size=pagesize ) embed = Followup(f"Subscriptions in {inter.guild.name}").info()._embed pagination = PaginationView( self.bot, inter=inter, embed=embed, getdata=getdata, formatdata=formatdata, pagesize=10, initpage=1 ) await pagination.send() # await Followup("results", str(await getdata(1, 10))).send(inter) async def setup(bot): """ Setup function for this extension. Adds `FeedCog` to the bot. """ cog = FeedCog(bot) await bot.add_cog(cog) log.info("Added %s cog", cog.__class__.__name__)