379 lines
13 KiB
Python

"""
Extension for the `FeedCog`.
Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bot.
"""
import logging
from typing import Tuple
import aiohttp
import validators
from feedparser import FeedParserDict, parse
from discord.ext import commands
from discord import Interaction, TextChannel
from discord.app_commands import Choice, Group, autocomplete, rename, command
from api import API
from feed import Subscription, SubscriptionChannel, TrackedContent
from utils import (
Followup,
PaginationView,
get_rss_data,
)
log = logging.getLogger(__name__)
rss_list_sort_choices = [
Choice(name="Nickname", value=0),
Choice(name="Date Added", value=1)
]
channels_list_sort_choices=[
Choice(name="Feed Nickname", value=0),
Choice(name="Channel ID", value=1),
Choice(name="Date Added", value=2)
]
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the
# target resource. Run a period task to check this.
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]:
"""Validate a provided RSS source.
Parameters
----------
nickname : str
Nickname of the source. Must not contain URL.
url : str
URL of the source. Must be URL with valid status code and be an RSS feed.
Returns
-------
str or None
String invalid message if invalid, NoneType if valid.
FeedParserDict or None
The feed parsed from the given URL or None if invalid.
"""
# Ensure the URL is valid
if not validators.url(url):
return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
# Check the nickname is not a URL
if validators.url(nickname):
return "It looks like the nickname you have entered is a URL.\n" \
f"For security reasons, this is not allowed.\n`{nickname=}`", None
feed_data, status_code = await get_rss_data(url)
# Check the URL status code is valid
if status_code != 200:
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
# Check the contents is actually an RSS feed.
feed = parse(feed_data)
if not feed.version:
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
return None, feed
class FeedCog(commands.Cog):
"""
Command cog.
"""
def __init__(self, bot: commands.Bot):
super().__init__()
self.bot = bot
@commands.Cog.listener()
async def on_ready(self):
"""Instructions to call when the cog is ready."""
log.info("%s cog is ready", self.__class__.__name__)
# async def autocomplete_subscriptions(self, inter: Interaction, name: str) -> list[Choice]:
# """"""
# log.debug("autocompleting subscriptions '%s'", name)
# try:
# async with aiohttp.ClientSession() as session:
# api = API(self.bot.api_token, session)
# results, _ = await api.get_subscriptions(server=inter.guild_id, search=name)
# except Exception as exc:
# log.error(exc)
# return []
# subscriptions = Subscription.from_list(results)
# return [
# Choice(name=sub.name, value=sub.uuid)
# for sub in subscriptions
# ]
# async def autocomplete_subscription_channels(self, inter: Interaction, uuid: str):
# """"""
# log.debug("autocompleting subscription channels")
# try:
# async with aiohttp.ClientSession() as session:
# api = API(self.bot.api_token, session)
# results, _ = await api.get_subscription_channels()
# except Exception as exc:
# log.error(exc)
# return []
# subscription_channels = SubscriptionChannel.from_list(results)
# async def name(link):
# result = self.bot.get_channel(link.id) or await self.bot.fetch_channel(link.id)
# return f"{link.subscription.name} -> #{result.name}"
# return [
# Choice(name=await name(link), value=link.uuid)
# for link in subscription_channels
# ]
# subscription_group = Group(
# name="subscriptions",
# description="subscription commands",
# guild_only=True
# )
# @subscription_group.command(name="link")
# @autocomplete(sub_uuid=autocomplete_subscriptions)
# @rename(sub_uuid="subscription")
# async def link_subscription_channel(self, inter: Interaction, sub_uuid: str, channel: TextChannel):
# """
# Link Subscription to discord.TextChannel.
# """
# await inter.response.defer()
# try:
# async with aiohttp.ClientSession() as session:
# api = API(self.bot.api_token, session)
# data = await api.create_subscription_channel(str(channel.id), sub_uuid)
# except aiohttp.ClientResponseError as exc:
# return await (
# Followup(
# f"Error · {exc.message}",
# "Ensure you haven't: \n"
# "- Already linked this subscription to this channel\n"
# "- Already linked this subscription to the maximum of 4 channels"
# )
# .footer(f"HTTP {exc.code}")
# .error()
# .send(inter)
# )
# subscription = Subscription.from_dict(data.pop("subscription"))
# data["subscription"] = (
# f"{subscription.name}\n"
# f"[RSS]({subscription.rss_url}) · "
# f"[API Subscription]({API.SUBSCRIPTION_ENDPOINT}{subscription.uuid}) · "
# f"[API Link]({API.CHANNEL_ENDPOINT}{data['uuid']})"
# )
# channel_id = int(data.pop("id"))
# channel = self.bot.get_channel(channel_id) or await self.bot.fetch_channel(channel_id)
# data["channel"] = channel.mention
# data.pop("creation_datetime")
# data.pop("uuid")
# await (
# Followup("Linked!")
# .fields(**data)
# .added()
# .send(inter)
# )
# @subscription_group.command(name="unlink")
# @autocomplete(uuid=autocomplete_subscription_channels)
# @rename(uuid="link")
# async def unlink_subscription_channel(self, inter: Interaction, uuid: str):
# """
# Unlink subscription from discord.TextChannel.
# """
# await inter.response.defer()
# try:
# async with aiohttp.ClientSession() as session:
# api = API(self.bot.api_token, session)
# # data = await api.get_subscription(uuid=uuid)
# await api.delete_subscription_channel(uuid=uuid)
# # sub_channel = await SubscriptionChannel.from_dict(data)
# except Exception as exc:
# return await (
# Followup(exc.__class__.__name__, str(exc))
# .error()
# .send(inter)
# )
# await (
# Followup("Subscription unlinked!", uuid)
# .added()
# .send(inter)
# )
# @subscription_group.command(name="list-links")
# async def list_subscription(self, inter: Interaction):
# """List Subscriptions Channels in this server."""
# await inter.response.defer()
# async def formatdata(index: int, item: dict) -> tuple[str, str]:
# item = SubscriptionChannel.from_dict(item)
# next_emoji = self.bot.get_emoji(1204542366602502265)
# key = f"{index}. {item.subscription.name} {next_emoji} {item.mention}"
# return key, item.hyperlinks_string
# async def getdata(page: int, pagesize: int) -> dict:
# async with aiohttp.ClientSession() as session:
# api = API(self.bot.api_token, session)
# return await api.get_subscription_channels(
# subscription__server=inter.guild.id, page=page, page_size=pagesize
# )
# embed = Followup(f"Links in {inter.guild.name}").info()._embed
# pagination = PaginationView(
# self.bot,
# inter=inter,
# embed=embed,
# getdata=getdata,
# formatdata=formatdata,
# pagesize=10,
# initpage=1
# )
# await pagination.send()
# @subscription_group.command(name="add")
# async def new_subscription(self, inter: Interaction, name: str, rss_url: str):
# """Subscribe this server to a new RSS Feed."""
# await inter.response.defer()
# try:
# parsed_rssfeed = await self.bot.functions.validate_feed(name, rss_url)
# image_url = parsed_rssfeed.get("feed", {}).get("image", {}).get("href")
# async with aiohttp.ClientSession() as session:
# api = API(self.bot.api_token, session)
# data = await api.create_subscription(name, rss_url, image_url, str(inter.guild_id), [-1])
# except aiohttp.ClientResponseError as exc:
# return await (
# Followup(
# f"Error · {exc.message}",
# "Ensure you haven't: \n"
# "- Reused an identical name of an existing Subscription\n"
# "- Already created the maximum of 25 Subscriptions"
# )
# .footer(f"HTTP {exc.code}")
# .error()
# .send(inter)
# )
# # Omit data we dont want the user to see
# data.pop("uuid")
# data.pop("image")
# data.pop("server")
# data.pop("creation_datetime")
# # Update keys to be more human readable
# data["url"] = data.pop("rss_url")
# await (
# Followup("Subscription Added!")
# .fields(**data)
# .image(image_url)
# .added()
# .send(inter)
# )
# @subscription_group.command(name="remove")
# @autocomplete(uuid=autocomplete_subscriptions)
# @rename(uuid="choice")
# async def remove_subscriptions(self, inter: Interaction, uuid: str):
# """Unsubscribe this server from an existing RSS Feed."""
# await inter.response.defer()
# try:
# async with aiohttp.ClientSession() as session:
# api = API(self.bot.api_token, session)
# await api.delete_subscription(uuid)
# except Exception as exc:
# return await (
# Followup(exc.__class__.__name__, str(exc))
# .error()
# .send(inter)
# )
# await (
# Followup("Subscription Removed!", uuid)
# .trash()
# .send(inter)
# )
@command(name="subscriptions")
async def list_subscription(self, inter: Interaction):
"""List Subscriptions from this server."""
await inter.response.defer()
def formatdata(index, item):
item = Subscription.from_dict(item)
channels = f"{item.channels_count}{' channels' if item.channels_count != 1 else ' channel'}"
filters = f"{len(item.filters)}{' filters' if len(item.filters) != 1 else ' filter'}"
notes = item.extra_notes[:25] + "..." if len(item.extra_notes) > 28 else item.extra_notes
links = f"[RSS URL]({item.url}) · [API URL]({API.API_ENDPOINT}subscription/{item.id}/)"
description = f"{channels}, {filters}\n"
description += f"{notes}\n" if notes else ""
description += links
key = f"{index}. {item.name}"
return key, description # key, value pair
async def getdata(page: int, pagesize: int):
async with aiohttp.ClientSession() as session:
api = API(self.bot.api_token, session)
return await api.get_subscriptions(
guild_id=inter.guild.id, page=page, page_size=pagesize
)
embed = Followup(f"Subscriptions in {inter.guild.name}").info()._embed
pagination = PaginationView(
self.bot,
inter=inter,
embed=embed,
getdata=getdata,
formatdata=formatdata,
pagesize=10,
initpage=1
)
await pagination.send()
# await Followup("results", str(await getdata(1, 10))).send(inter)
async def setup(bot):
"""
Setup function for this extension.
Adds `FeedCog` to the bot.
"""
cog = FeedCog(bot)
await bot.add_cog(cog)
log.info("Added %s cog", cog.__class__.__name__)