Merge pull request 'v0.2.0' (#20) from staging into master
All checks were successful
Build and Push Docker Image / build (push) Successful in 10s
All checks were successful
Build and Push Docker Image / build (push) Successful in 10s
Reviewed-on: https://gitea.corbz.dev/corbz/PYRSS-Bot/pulls/20
This commit is contained in:
commit
7b1a293891
@ -1,4 +1,4 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.1.1
|
current_version = 0.2.0
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
|
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
@ -5,7 +5,7 @@
|
|||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
{
|
{
|
||||||
"name": "Python: NewsBot",
|
"name": "Python: PYRSS Bot",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "${workspaceFolder}/src/main.py",
|
"program": "${workspaceFolder}/src/main.py",
|
||||||
|
11
CHANGELOG.md
11
CHANGELOG.md
@ -1,5 +1,14 @@
|
|||||||
|
|
||||||
|
**v0.2.0**
|
||||||
|
|
||||||
|
- Fix: Fetch channels if not found in bot cache (error fix)
|
||||||
|
- Enhancement: command to test a channel's permissions allow for the Bot to function
|
||||||
|
- Enhancement: account for active state from a server's settings (`GuildSettings`)
|
||||||
|
- Enhancement: command to view tracked content from the server or a given subscription of the same server.
|
||||||
|
- Other: code optimisation & `GuildSettings` dataclass
|
||||||
|
- Other: Cleaned out many instances of unused code
|
||||||
|
|
||||||
**v0.1.1**
|
**v0.1.1**
|
||||||
|
|
||||||
- Docs: Start of changelog
|
- Docs: Start of changelog
|
||||||
- Enhancement: Versioning with tagged docker images
|
- Enhancement: Versioning with tagged docker images
|
||||||
|
@ -132,6 +132,15 @@ class API:
|
|||||||
|
|
||||||
return await self._get_many(self.API_ENDPOINT + "subchannel/", filters)
|
return await self._get_many(self.API_ENDPOINT + "subchannel/", filters)
|
||||||
|
|
||||||
|
async def get_guild_settings(self, **filters) -> tuple[list[dict], int]:
|
||||||
|
"""
|
||||||
|
Get many guild settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
log.debug("getting multiple guild settings")
|
||||||
|
|
||||||
|
return await self._get_many(self.API_ENDPOINT + "guild-settings/", filters)
|
||||||
|
|
||||||
async def create_tracked_content(self, **data) -> dict:
|
async def create_tracked_content(self, **data) -> dict:
|
||||||
"""
|
"""
|
||||||
Create an instance of tracked content.
|
Create an instance of tracked content.
|
||||||
|
230
src/extensions/cmds.py
Normal file
230
src/extensions/cmds.py
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
"""
|
||||||
|
Extension for the `FeedCog`.
|
||||||
|
Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import validators
|
||||||
|
from feedparser import FeedParserDict, parse
|
||||||
|
from discord.ext import commands
|
||||||
|
from discord import Interaction, TextChannel, Embed, Colour
|
||||||
|
from discord.app_commands import Choice, Group, autocomplete, rename, command
|
||||||
|
from discord.errors import Forbidden
|
||||||
|
|
||||||
|
from api import API
|
||||||
|
from feed import Subscription, TrackedContent
|
||||||
|
from utils import (
|
||||||
|
Followup,
|
||||||
|
PaginationView,
|
||||||
|
get_rss_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
rss_list_sort_choices = [
|
||||||
|
Choice(name="Nickname", value=0),
|
||||||
|
Choice(name="Date Added", value=1)
|
||||||
|
]
|
||||||
|
channels_list_sort_choices=[
|
||||||
|
Choice(name="Feed Nickname", value=0),
|
||||||
|
Choice(name="Channel ID", value=1),
|
||||||
|
Choice(name="Date Added", value=2)
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the
|
||||||
|
# target resource. Run a period task to check this.
|
||||||
|
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]:
|
||||||
|
"""Validate a provided RSS source.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
nickname : str
|
||||||
|
Nickname of the source. Must not contain URL.
|
||||||
|
url : str
|
||||||
|
URL of the source. Must be URL with valid status code and be an RSS feed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str or None
|
||||||
|
String invalid message if invalid, NoneType if valid.
|
||||||
|
FeedParserDict or None
|
||||||
|
The feed parsed from the given URL or None if invalid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Ensure the URL is valid
|
||||||
|
if not validators.url(url):
|
||||||
|
return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
|
||||||
|
|
||||||
|
# Check the nickname is not a URL
|
||||||
|
if validators.url(nickname):
|
||||||
|
return "It looks like the nickname you have entered is a URL.\n" \
|
||||||
|
f"For security reasons, this is not allowed.\n`{nickname=}`", None
|
||||||
|
|
||||||
|
|
||||||
|
feed_data, status_code = await get_rss_data(url)
|
||||||
|
|
||||||
|
# Check the URL status code is valid
|
||||||
|
if status_code != 200:
|
||||||
|
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
|
||||||
|
|
||||||
|
# Check the contents is actually an RSS feed.
|
||||||
|
feed = parse(feed_data)
|
||||||
|
if not feed.version:
|
||||||
|
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
|
||||||
|
|
||||||
|
return None, feed
|
||||||
|
|
||||||
|
|
||||||
|
class CommandsCog(commands.Cog):
|
||||||
|
"""
|
||||||
|
Command cog.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bot: commands.Bot):
|
||||||
|
super().__init__()
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_ready(self):
|
||||||
|
"""Instructions to call when the cog is ready."""
|
||||||
|
|
||||||
|
log.info("%s cog is ready", self.__class__.__name__)
|
||||||
|
|
||||||
|
# Group for commands about viewing data
|
||||||
|
view_group = Group(
|
||||||
|
name="view",
|
||||||
|
description="View data.",
|
||||||
|
guild_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@view_group.command(name="subscriptions")
|
||||||
|
async def cmd_list_subs(self, inter: Interaction):
|
||||||
|
"""List Subscriptions from this server."""
|
||||||
|
|
||||||
|
await inter.response.defer()
|
||||||
|
|
||||||
|
def formatdata(index, item):
|
||||||
|
item = Subscription.from_dict(item)
|
||||||
|
|
||||||
|
channels = f"{item.channels_count}{' channels' if item.channels_count != 1 else ' channel'}"
|
||||||
|
filters = f"{len(item.filters)}{' filters' if len(item.filters) != 1 else ' filter'}"
|
||||||
|
notes = item.extra_notes[:25] + "..." if len(item.extra_notes) > 28 else item.extra_notes
|
||||||
|
links = f"[RSS Link]({item.url}) · [API Link]({API.API_EXTERNAL_ENDPOINT}subscription/{item.id}/)"
|
||||||
|
|
||||||
|
description = f"{channels}, {filters}\n"
|
||||||
|
description += f"{notes}\n" if notes else ""
|
||||||
|
description += links
|
||||||
|
|
||||||
|
key = f"{index}. {item.name}"
|
||||||
|
return key, description # key, value pair
|
||||||
|
|
||||||
|
async def getdata(page: int, pagesize: int):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
api = API(self.bot.api_token, session)
|
||||||
|
return await api.get_subscriptions(
|
||||||
|
guild_id=inter.guild.id, page=page, page_size=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
embed = Followup(f"Subscriptions in {inter.guild.name}").info()._embed
|
||||||
|
pagination = PaginationView(
|
||||||
|
self.bot,
|
||||||
|
inter=inter,
|
||||||
|
embed=embed,
|
||||||
|
getdata=getdata,
|
||||||
|
formatdata=formatdata,
|
||||||
|
pagesize=10,
|
||||||
|
initpage=1
|
||||||
|
)
|
||||||
|
await pagination.send()
|
||||||
|
|
||||||
|
@view_group.command(name="tracked-content")
|
||||||
|
async def cmd_list_tracked(self, inter: Interaction):
|
||||||
|
"""List Tracked Content from this server, or a given sub"""
|
||||||
|
|
||||||
|
await inter.response.defer()
|
||||||
|
|
||||||
|
def formatdata(index, item):
|
||||||
|
item = TrackedContent.from_dict(item)
|
||||||
|
sub = Subscription.from_dict(item.subscription)
|
||||||
|
|
||||||
|
links = f"[Content Link]({item.url}) · [Message Link](https://discord.com/channels/{sub.guild_id}/{item.channel_id}/{item.message_id}/)"
|
||||||
|
description = f"Subscription: {sub.name}\n{links}"
|
||||||
|
|
||||||
|
key = f"{item.id}. {item.title}"
|
||||||
|
return key, description
|
||||||
|
|
||||||
|
async def getdata(page: int, pagesize: int):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
api = API(self.bot.api_token, session)
|
||||||
|
return await api.get_tracked_content(
|
||||||
|
subscription__guild_id=inter.guild_id, page=page, page_size=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
embed = Followup(f"Tracked Content in {inter.guild.name}").info()._embed
|
||||||
|
pagination = PaginationView(
|
||||||
|
self.bot,
|
||||||
|
inter=inter,
|
||||||
|
embed=embed,
|
||||||
|
getdata=getdata,
|
||||||
|
formatdata=formatdata,
|
||||||
|
pagesize=10,
|
||||||
|
initpage=1
|
||||||
|
)
|
||||||
|
await pagination.send()
|
||||||
|
|
||||||
|
# Group for test related commands
|
||||||
|
test_group = Group(
|
||||||
|
name="test",
|
||||||
|
description="Commands to test Bot functionality.",
|
||||||
|
guild_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_group.command(name="channel-permissions")
|
||||||
|
async def cmd_test_channel_perms(self, inter: Interaction):
|
||||||
|
"""Test that the current channel's permissions allow for PYRSS to operate in it."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
test_message = await inter.channel.send(content="... testing permissions ...")
|
||||||
|
await self.test_channel_perms(inter.channel)
|
||||||
|
except Exception as error:
|
||||||
|
await inter.response.send_message(content=f"Failed: {error}")
|
||||||
|
return
|
||||||
|
|
||||||
|
await test_message.delete()
|
||||||
|
await inter.response.send_message(content="Success")
|
||||||
|
|
||||||
|
async def test_channel_perms(self, channel: TextChannel):
|
||||||
|
|
||||||
|
# Test generic message and delete
|
||||||
|
msg = await channel.send(content="test message")
|
||||||
|
await msg.delete()
|
||||||
|
|
||||||
|
# Test detailed embed
|
||||||
|
embed = Embed(
|
||||||
|
title="test title",
|
||||||
|
description="test description",
|
||||||
|
colour=Colour.random(),
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
url="https://google.com"
|
||||||
|
)
|
||||||
|
embed.set_author(name="test author")
|
||||||
|
embed.set_footer(text="test footer")
|
||||||
|
embed.set_thumbnail(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
|
||||||
|
embed.set_image(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
|
||||||
|
embed_msg = await channel.send(embed=embed)
|
||||||
|
await embed_msg.delete()
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
"""
|
||||||
|
Setup function for this extension.
|
||||||
|
Adds `CommandsCog` to the bot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cog = CommandsCog(bot)
|
||||||
|
await bot.add_cog(cog)
|
||||||
|
log.info("Added %s cog", cog.__class__.__name__)
|
@ -1,378 +0,0 @@
|
|||||||
"""
|
|
||||||
Extension for the `FeedCog`.
|
|
||||||
Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bot.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import validators
|
|
||||||
from feedparser import FeedParserDict, parse
|
|
||||||
from discord.ext import commands
|
|
||||||
from discord import Interaction, TextChannel
|
|
||||||
from discord.app_commands import Choice, Group, autocomplete, rename, command
|
|
||||||
|
|
||||||
from api import API
|
|
||||||
from feed import Subscription, SubscriptionChannel, TrackedContent
|
|
||||||
from utils import (
|
|
||||||
Followup,
|
|
||||||
PaginationView,
|
|
||||||
get_rss_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
rss_list_sort_choices = [
|
|
||||||
Choice(name="Nickname", value=0),
|
|
||||||
Choice(name="Date Added", value=1)
|
|
||||||
]
|
|
||||||
channels_list_sort_choices=[
|
|
||||||
Choice(name="Feed Nickname", value=0),
|
|
||||||
Choice(name="Channel ID", value=1),
|
|
||||||
Choice(name="Date Added", value=2)
|
|
||||||
]
|
|
||||||
|
|
||||||
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the
|
|
||||||
# target resource. Run a period task to check this.
|
|
||||||
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]:
|
|
||||||
"""Validate a provided RSS source.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
nickname : str
|
|
||||||
Nickname of the source. Must not contain URL.
|
|
||||||
url : str
|
|
||||||
URL of the source. Must be URL with valid status code and be an RSS feed.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
str or None
|
|
||||||
String invalid message if invalid, NoneType if valid.
|
|
||||||
FeedParserDict or None
|
|
||||||
The feed parsed from the given URL or None if invalid.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Ensure the URL is valid
|
|
||||||
if not validators.url(url):
|
|
||||||
return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
|
|
||||||
|
|
||||||
# Check the nickname is not a URL
|
|
||||||
if validators.url(nickname):
|
|
||||||
return "It looks like the nickname you have entered is a URL.\n" \
|
|
||||||
f"For security reasons, this is not allowed.\n`{nickname=}`", None
|
|
||||||
|
|
||||||
|
|
||||||
feed_data, status_code = await get_rss_data(url)
|
|
||||||
|
|
||||||
# Check the URL status code is valid
|
|
||||||
if status_code != 200:
|
|
||||||
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
|
|
||||||
|
|
||||||
# Check the contents is actually an RSS feed.
|
|
||||||
feed = parse(feed_data)
|
|
||||||
if not feed.version:
|
|
||||||
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
|
|
||||||
|
|
||||||
return None, feed
|
|
||||||
|
|
||||||
|
|
||||||
class FeedCog(commands.Cog):
|
|
||||||
"""
|
|
||||||
Command cog.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, bot: commands.Bot):
|
|
||||||
super().__init__()
|
|
||||||
self.bot = bot
|
|
||||||
|
|
||||||
@commands.Cog.listener()
|
|
||||||
async def on_ready(self):
|
|
||||||
"""Instructions to call when the cog is ready."""
|
|
||||||
|
|
||||||
log.info("%s cog is ready", self.__class__.__name__)
|
|
||||||
|
|
||||||
# async def autocomplete_subscriptions(self, inter: Interaction, name: str) -> list[Choice]:
|
|
||||||
# """"""
|
|
||||||
|
|
||||||
# log.debug("autocompleting subscriptions '%s'", name)
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# async with aiohttp.ClientSession() as session:
|
|
||||||
# api = API(self.bot.api_token, session)
|
|
||||||
# results, _ = await api.get_subscriptions(server=inter.guild_id, search=name)
|
|
||||||
|
|
||||||
# except Exception as exc:
|
|
||||||
# log.error(exc)
|
|
||||||
# return []
|
|
||||||
|
|
||||||
# subscriptions = Subscription.from_list(results)
|
|
||||||
|
|
||||||
# return [
|
|
||||||
# Choice(name=sub.name, value=sub.uuid)
|
|
||||||
# for sub in subscriptions
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# async def autocomplete_subscription_channels(self, inter: Interaction, uuid: str):
|
|
||||||
# """"""
|
|
||||||
|
|
||||||
# log.debug("autocompleting subscription channels")
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# async with aiohttp.ClientSession() as session:
|
|
||||||
# api = API(self.bot.api_token, session)
|
|
||||||
# results, _ = await api.get_subscription_channels()
|
|
||||||
|
|
||||||
# except Exception as exc:
|
|
||||||
# log.error(exc)
|
|
||||||
# return []
|
|
||||||
|
|
||||||
# subscription_channels = SubscriptionChannel.from_list(results)
|
|
||||||
|
|
||||||
# async def name(link):
|
|
||||||
# result = self.bot.get_channel(link.id) or await self.bot.fetch_channel(link.id)
|
|
||||||
# return f"{link.subscription.name} -> #{result.name}"
|
|
||||||
|
|
||||||
# return [
|
|
||||||
# Choice(name=await name(link), value=link.uuid)
|
|
||||||
# for link in subscription_channels
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# subscription_group = Group(
|
|
||||||
# name="subscriptions",
|
|
||||||
# description="subscription commands",
|
|
||||||
# guild_only=True
|
|
||||||
# )
|
|
||||||
|
|
||||||
# @subscription_group.command(name="link")
|
|
||||||
# @autocomplete(sub_uuid=autocomplete_subscriptions)
|
|
||||||
# @rename(sub_uuid="subscription")
|
|
||||||
# async def link_subscription_channel(self, inter: Interaction, sub_uuid: str, channel: TextChannel):
|
|
||||||
# """
|
|
||||||
# Link Subscription to discord.TextChannel.
|
|
||||||
# """
|
|
||||||
|
|
||||||
# await inter.response.defer()
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# async with aiohttp.ClientSession() as session:
|
|
||||||
# api = API(self.bot.api_token, session)
|
|
||||||
# data = await api.create_subscription_channel(str(channel.id), sub_uuid)
|
|
||||||
|
|
||||||
# except aiohttp.ClientResponseError as exc:
|
|
||||||
# return await (
|
|
||||||
# Followup(
|
|
||||||
# f"Error · {exc.message}",
|
|
||||||
# "Ensure you haven't: \n"
|
|
||||||
# "- Already linked this subscription to this channel\n"
|
|
||||||
# "- Already linked this subscription to the maximum of 4 channels"
|
|
||||||
# )
|
|
||||||
# .footer(f"HTTP {exc.code}")
|
|
||||||
# .error()
|
|
||||||
# .send(inter)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# subscription = Subscription.from_dict(data.pop("subscription"))
|
|
||||||
# data["subscription"] = (
|
|
||||||
# f"{subscription.name}\n"
|
|
||||||
# f"[RSS]({subscription.rss_url}) · "
|
|
||||||
# f"[API Subscription]({API.SUBSCRIPTION_ENDPOINT}{subscription.uuid}) · "
|
|
||||||
# f"[API Link]({API.CHANNEL_ENDPOINT}{data['uuid']})"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# channel_id = int(data.pop("id"))
|
|
||||||
# channel = self.bot.get_channel(channel_id) or await self.bot.fetch_channel(channel_id)
|
|
||||||
# data["channel"] = channel.mention
|
|
||||||
|
|
||||||
# data.pop("creation_datetime")
|
|
||||||
# data.pop("uuid")
|
|
||||||
|
|
||||||
# await (
|
|
||||||
# Followup("Linked!")
|
|
||||||
# .fields(**data)
|
|
||||||
# .added()
|
|
||||||
# .send(inter)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# @subscription_group.command(name="unlink")
|
|
||||||
# @autocomplete(uuid=autocomplete_subscription_channels)
|
|
||||||
# @rename(uuid="link")
|
|
||||||
# async def unlink_subscription_channel(self, inter: Interaction, uuid: str):
|
|
||||||
# """
|
|
||||||
# Unlink subscription from discord.TextChannel.
|
|
||||||
# """
|
|
||||||
|
|
||||||
# await inter.response.defer()
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# async with aiohttp.ClientSession() as session:
|
|
||||||
# api = API(self.bot.api_token, session)
|
|
||||||
# # data = await api.get_subscription(uuid=uuid)
|
|
||||||
# await api.delete_subscription_channel(uuid=uuid)
|
|
||||||
# # sub_channel = await SubscriptionChannel.from_dict(data)
|
|
||||||
|
|
||||||
# except Exception as exc:
|
|
||||||
# return await (
|
|
||||||
# Followup(exc.__class__.__name__, str(exc))
|
|
||||||
# .error()
|
|
||||||
# .send(inter)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# await (
|
|
||||||
# Followup("Subscription unlinked!", uuid)
|
|
||||||
# .added()
|
|
||||||
# .send(inter)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# @subscription_group.command(name="list-links")
|
|
||||||
# async def list_subscription(self, inter: Interaction):
|
|
||||||
# """List Subscriptions Channels in this server."""
|
|
||||||
|
|
||||||
# await inter.response.defer()
|
|
||||||
|
|
||||||
# async def formatdata(index: int, item: dict) -> tuple[str, str]:
|
|
||||||
# item = SubscriptionChannel.from_dict(item)
|
|
||||||
# next_emoji = self.bot.get_emoji(1204542366602502265)
|
|
||||||
# key = f"{index}. {item.subscription.name} {next_emoji} {item.mention}"
|
|
||||||
# return key, item.hyperlinks_string
|
|
||||||
|
|
||||||
# async def getdata(page: int, pagesize: int) -> dict:
|
|
||||||
# async with aiohttp.ClientSession() as session:
|
|
||||||
# api = API(self.bot.api_token, session)
|
|
||||||
# return await api.get_subscription_channels(
|
|
||||||
# subscription__server=inter.guild.id, page=page, page_size=pagesize
|
|
||||||
# )
|
|
||||||
|
|
||||||
# embed = Followup(f"Links in {inter.guild.name}").info()._embed
|
|
||||||
# pagination = PaginationView(
|
|
||||||
# self.bot,
|
|
||||||
# inter=inter,
|
|
||||||
# embed=embed,
|
|
||||||
# getdata=getdata,
|
|
||||||
# formatdata=formatdata,
|
|
||||||
# pagesize=10,
|
|
||||||
# initpage=1
|
|
||||||
# )
|
|
||||||
# await pagination.send()
|
|
||||||
|
|
||||||
# @subscription_group.command(name="add")
|
|
||||||
# async def new_subscription(self, inter: Interaction, name: str, rss_url: str):
|
|
||||||
# """Subscribe this server to a new RSS Feed."""
|
|
||||||
|
|
||||||
# await inter.response.defer()
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# parsed_rssfeed = await self.bot.functions.validate_feed(name, rss_url)
|
|
||||||
# image_url = parsed_rssfeed.get("feed", {}).get("image", {}).get("href")
|
|
||||||
|
|
||||||
# async with aiohttp.ClientSession() as session:
|
|
||||||
# api = API(self.bot.api_token, session)
|
|
||||||
# data = await api.create_subscription(name, rss_url, image_url, str(inter.guild_id), [-1])
|
|
||||||
|
|
||||||
# except aiohttp.ClientResponseError as exc:
|
|
||||||
# return await (
|
|
||||||
# Followup(
|
|
||||||
# f"Error · {exc.message}",
|
|
||||||
# "Ensure you haven't: \n"
|
|
||||||
# "- Reused an identical name of an existing Subscription\n"
|
|
||||||
# "- Already created the maximum of 25 Subscriptions"
|
|
||||||
# )
|
|
||||||
# .footer(f"HTTP {exc.code}")
|
|
||||||
# .error()
|
|
||||||
# .send(inter)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # Omit data we dont want the user to see
|
|
||||||
# data.pop("uuid")
|
|
||||||
# data.pop("image")
|
|
||||||
# data.pop("server")
|
|
||||||
# data.pop("creation_datetime")
|
|
||||||
|
|
||||||
# # Update keys to be more human readable
|
|
||||||
# data["url"] = data.pop("rss_url")
|
|
||||||
|
|
||||||
# await (
|
|
||||||
# Followup("Subscription Added!")
|
|
||||||
# .fields(**data)
|
|
||||||
# .image(image_url)
|
|
||||||
# .added()
|
|
||||||
# .send(inter)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# @subscription_group.command(name="remove")
|
|
||||||
# @autocomplete(uuid=autocomplete_subscriptions)
|
|
||||||
# @rename(uuid="choice")
|
|
||||||
# async def remove_subscriptions(self, inter: Interaction, uuid: str):
|
|
||||||
# """Unsubscribe this server from an existing RSS Feed."""
|
|
||||||
|
|
||||||
# await inter.response.defer()
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# async with aiohttp.ClientSession() as session:
|
|
||||||
# api = API(self.bot.api_token, session)
|
|
||||||
# await api.delete_subscription(uuid)
|
|
||||||
|
|
||||||
# except Exception as exc:
|
|
||||||
# return await (
|
|
||||||
# Followup(exc.__class__.__name__, str(exc))
|
|
||||||
# .error()
|
|
||||||
# .send(inter)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# await (
|
|
||||||
# Followup("Subscription Removed!", uuid)
|
|
||||||
# .trash()
|
|
||||||
# .send(inter)
|
|
||||||
# )
|
|
||||||
|
|
||||||
@command(name="subscriptions")
|
|
||||||
async def list_subscription(self, inter: Interaction):
|
|
||||||
"""List Subscriptions from this server."""
|
|
||||||
|
|
||||||
await inter.response.defer()
|
|
||||||
|
|
||||||
def formatdata(index, item):
|
|
||||||
item = Subscription.from_dict(item)
|
|
||||||
|
|
||||||
channels = f"{item.channels_count}{' channels' if item.channels_count != 1 else ' channel'}"
|
|
||||||
filters = f"{len(item.filters)}{' filters' if len(item.filters) != 1 else ' filter'}"
|
|
||||||
notes = item.extra_notes[:25] + "..." if len(item.extra_notes) > 28 else item.extra_notes
|
|
||||||
links = f"[RSS URL]({item.url}) · [API URL]({API.API_EXTERNAL_ENDPOINT}subscription/{item.id}/)"
|
|
||||||
|
|
||||||
description = f"{channels}, {filters}\n"
|
|
||||||
description += f"{notes}\n" if notes else ""
|
|
||||||
description += links
|
|
||||||
|
|
||||||
key = f"{index}. {item.name}"
|
|
||||||
return key, description # key, value pair
|
|
||||||
|
|
||||||
async def getdata(page: int, pagesize: int):
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
api = API(self.bot.api_token, session)
|
|
||||||
return await api.get_subscriptions(
|
|
||||||
guild_id=inter.guild.id, page=page, page_size=pagesize
|
|
||||||
)
|
|
||||||
|
|
||||||
embed = Followup(f"Subscriptions in {inter.guild.name}").info()._embed
|
|
||||||
pagination = PaginationView(
|
|
||||||
self.bot,
|
|
||||||
inter=inter,
|
|
||||||
embed=embed,
|
|
||||||
getdata=getdata,
|
|
||||||
formatdata=formatdata,
|
|
||||||
pagesize=10,
|
|
||||||
initpage=1
|
|
||||||
)
|
|
||||||
await pagination.send()
|
|
||||||
# await Followup("results", str(await getdata(1, 10))).send(inter)
|
|
||||||
|
|
||||||
|
|
||||||
async def setup(bot):
|
|
||||||
"""
|
|
||||||
Setup function for this extension.
|
|
||||||
Adds `FeedCog` to the bot.
|
|
||||||
"""
|
|
||||||
|
|
||||||
cog = FeedCog(bot)
|
|
||||||
await bot.add_cog(cog)
|
|
||||||
log.info("Added %s cog", cog.__class__.__name__)
|
|
@ -18,7 +18,7 @@ from discord.ext import commands, tasks
|
|||||||
from discord.errors import Forbidden
|
from discord.errors import Forbidden
|
||||||
from feedparser import parse
|
from feedparser import parse
|
||||||
|
|
||||||
from feed import RSSFeed, Subscription, RSSItem
|
from feed import RSSFeed, Subscription, RSSItem, GuildSettings
|
||||||
from utils import get_unparsed_feed
|
from utils import get_unparsed_feed
|
||||||
from filters import match_text
|
from filters import match_text
|
||||||
from api import API
|
from api import API
|
||||||
@ -83,43 +83,88 @@ class TaskCog(commands.Cog):
|
|||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
self.api = API(self.bot.api_token, session)
|
self.api = API(self.bot.api_token, session)
|
||||||
subscriptions = await self.get_subscriptions()
|
await self.execute_task()
|
||||||
await self.process_subscriptions(subscriptions)
|
|
||||||
|
|
||||||
end_time = perf_counter()
|
end_time = perf_counter()
|
||||||
log.debug(f"task completed in {end_time - start_time:.4f} seconds")
|
log.debug(f"task completed in {end_time - start_time:.4f} seconds")
|
||||||
|
|
||||||
async def get_subscriptions(self) -> list[Subscription]:
|
async def execute_task(self):
|
||||||
|
"""Execute the task directly."""
|
||||||
|
|
||||||
|
# Filter out inactive guild IDs using related settings
|
||||||
guild_ids = [guild.id for guild in self.bot.guilds]
|
guild_ids = [guild.id for guild in self.bot.guilds]
|
||||||
sub_data = []
|
guild_settings = await self.get_guild_settings(guild_ids)
|
||||||
|
active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active]
|
||||||
|
|
||||||
|
subscriptions = await self.get_subscriptions(active_guild_ids)
|
||||||
|
await self.process_subscriptions(subscriptions)
|
||||||
|
|
||||||
|
async def get_guild_settings(self, guild_ids: list[int]) -> list[int]:
|
||||||
|
"""Returns a list of guild settings from the Bot's guilds, if they exist."""
|
||||||
|
|
||||||
|
guild_settings = []
|
||||||
|
|
||||||
|
# Iterate infinitely taking the iter no. as `page`
|
||||||
|
# data will be empty after last page reached.
|
||||||
for page, _ in enumerate(iter(int, 1)):
|
for page, _ in enumerate(iter(int, 1)):
|
||||||
try:
|
data = await self.get_guild_settings_page(guild_ids, page)
|
||||||
log.debug("fetching page '%s'", page + 1)
|
if not data:
|
||||||
sub_data.extend(
|
|
||||||
(await self.api.get_subscriptions(server__in=guild_ids, page=page+1))[0]
|
|
||||||
)
|
|
||||||
except aiohttp.ClientResponseError as error:
|
|
||||||
match error.status:
|
|
||||||
case 404:
|
|
||||||
log.debug("final page reached '%s'", page)
|
|
||||||
break
|
|
||||||
case 403:
|
|
||||||
log.critical(error)
|
|
||||||
self.subscription_task.cancel()
|
|
||||||
return [] # returning an empty list should gracefully end the task
|
|
||||||
case _:
|
|
||||||
log.error(error)
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as error:
|
|
||||||
log.error("Exception while gathering page data %s", error)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
guild_settings.extend(data[0])
|
||||||
|
|
||||||
return Subscription.from_list(sub_data)
|
# Only return active guild IDs
|
||||||
|
return GuildSettings.from_list(guild_settings)
|
||||||
|
|
||||||
|
async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]:
|
||||||
|
"""Returns an individual page of guild settings."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1)
|
||||||
|
except aiohttp.ClientResponseError as error:
|
||||||
|
self.handle_pagination_error(error)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def handle_pagination_error(self, error: aiohttp.ClientResponseError):
|
||||||
|
"""Handle the error cases from pagination attempts."""
|
||||||
|
|
||||||
|
match error.status:
|
||||||
|
case 404:
|
||||||
|
log.debug("final page reached")
|
||||||
|
case 403:
|
||||||
|
log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True)
|
||||||
|
self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently
|
||||||
|
case _:
|
||||||
|
log.debug(error)
|
||||||
|
|
||||||
|
async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]:
|
||||||
|
"""Get a list of `Subscription`s matching the given `guild_ids`."""
|
||||||
|
|
||||||
|
subscriptions = []
|
||||||
|
|
||||||
|
# Iterate infinitely taking the iter no. as `page`
|
||||||
|
# data will be empty after last page reached.
|
||||||
|
for page, _ in enumerate(iter(int, 1)):
|
||||||
|
data = await self.get_subs_page(guild_ids, page)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
|
||||||
|
subscriptions.extend(data[0])
|
||||||
|
|
||||||
|
return Subscription.from_list(subscriptions)
|
||||||
|
|
||||||
|
async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]:
|
||||||
|
"""Returns an individual page of subscriptions."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1)
|
||||||
|
except aiohttp.ClientResponseError as error:
|
||||||
|
self.handle_pagination_error(error)
|
||||||
|
return []
|
||||||
|
|
||||||
async def process_subscriptions(self, subscriptions: list[Subscription]):
|
async def process_subscriptions(self, subscriptions: list[Subscription]):
|
||||||
|
"""Process a given list of `Subscription`s."""
|
||||||
|
|
||||||
async def process_single_subscription(sub: Subscription):
|
async def process_single_subscription(sub: Subscription):
|
||||||
log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id)
|
log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id)
|
||||||
|
|
||||||
@ -143,7 +188,7 @@ class TaskCog(commands.Cog):
|
|||||||
async def process_items(self, sub: Subscription, feed: RSSFeed):
|
async def process_items(self, sub: Subscription, feed: RSSFeed):
|
||||||
log.debug("processing items")
|
log.debug("processing items")
|
||||||
|
|
||||||
channels = [self.bot.get_channel(channel.channel_id) for channel in await sub.get_channels(self.api)]
|
channels = await self.fetch_or_get_channels(await sub.get_channels(self.api))
|
||||||
filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters]
|
filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters]
|
||||||
|
|
||||||
for item in feed.items:
|
for item in feed.items:
|
||||||
@ -159,6 +204,18 @@ class TaskCog(commands.Cog):
|
|||||||
for channel in channels:
|
for channel in channels:
|
||||||
await self.track_and_send(sub, feed, item, mutated_item, channel, blocked)
|
await self.track_and_send(sub, feed, item, mutated_item, channel, blocked)
|
||||||
|
|
||||||
|
async def fetch_or_get_channels(self, channels_data: list[dict]):
|
||||||
|
channels = []
|
||||||
|
|
||||||
|
for data in channels_data:
|
||||||
|
try:
|
||||||
|
channel = self.bot.get_channel(data.channel_id)
|
||||||
|
channels.append(channel or await self.bot.fetch_channel(data.channel_id))
|
||||||
|
except Forbidden:
|
||||||
|
log.error(f"Forbidden Channel '{data.channel_id}'")
|
||||||
|
|
||||||
|
return channels
|
||||||
|
|
||||||
def filter_item(self, _filter: dict, item: RSSItem) -> bool:
|
def filter_item(self, _filter: dict, item: RSSItem) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns `True` if item should be ignored due to filters.
|
Returns `True` if item should be ignored due to filters.
|
||||||
@ -188,8 +245,8 @@ class TaskCog(commands.Cog):
|
|||||||
log.debug("sending '%s', exists '%s'", item.guid, result[1])
|
log.debug("sending '%s', exists '%s'", item.guid, result[1])
|
||||||
message = await channel.send(embed=await mutated_item.to_embed(sub, feed, self.api.session))
|
message = await channel.send(embed=await mutated_item.to_embed(sub, feed, self.api.session))
|
||||||
message_id = message.id
|
message_id = message.id
|
||||||
except Forbidden as error:
|
except Forbidden:
|
||||||
log.error(error)
|
log.error(f"Forbidden to send to channel {channel.id}")
|
||||||
|
|
||||||
await self.mark_tracked_item(sub, item, channel.id, message_id, blocked)
|
await self.mark_tracked_item(sub, item, channel.id, message_id, blocked)
|
||||||
|
|
||||||
|
59
src/feed.py
59
src/feed.py
@ -270,6 +270,19 @@ class DjangoDataModel(ABC):
|
|||||||
return cls(**cls.parser(data))
|
return cls(**cls.parser(data))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class GuildSettings(DjangoDataModel):
|
||||||
|
|
||||||
|
id: int
|
||||||
|
guild_id: int
|
||||||
|
default_embed_colour: str
|
||||||
|
active: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parser(item: dict) -> dict:
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class Subscription(DjangoDataModel):
|
class Subscription(DjangoDataModel):
|
||||||
|
|
||||||
@ -324,51 +337,21 @@ class SubChannel(DjangoDataModel):
|
|||||||
return f"<#{self.channel_id}>"
|
return f"<#{self.channel_id}>"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
|
||||||
class SubscriptionChannel(DjangoDataModel):
|
|
||||||
|
|
||||||
uuid: str
|
|
||||||
id: int
|
|
||||||
subscription: Subscription
|
|
||||||
creation_datetime: datetime
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parser(item: dict) -> dict:
|
|
||||||
|
|
||||||
item["id"] = int(item["id"])
|
|
||||||
item["subscription"] = Subscription.from_dict(item.pop("subscription"))
|
|
||||||
item["creation_datetime"] = datetime.strptime(item["creation_datetime"], DATETIME_FORMAT)
|
|
||||||
return item
|
|
||||||
|
|
||||||
@property
|
|
||||||
def mention(self) -> str:
|
|
||||||
"""
|
|
||||||
Returns the `id` as a string in the discord mention format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return f"<#{self.id}>"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hyperlinks_string(self) -> str:
|
|
||||||
""""""
|
|
||||||
|
|
||||||
api_hyperlink = f"[API]({API.CHANNEL_ENDPOINT}{self.uuid}/)"
|
|
||||||
rss_hyperlink = f"[RSS]({self.subscription.rss_url})"
|
|
||||||
value = f"{rss_hyperlink} · {api_hyperlink}"
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class TrackedContent(DjangoDataModel):
|
class TrackedContent(DjangoDataModel):
|
||||||
|
|
||||||
uuid: str
|
id: int
|
||||||
|
guid: str
|
||||||
|
title: str
|
||||||
|
url: str
|
||||||
subscription: str
|
subscription: str
|
||||||
content_url: str
|
channel_id: int
|
||||||
|
message_id: int
|
||||||
|
blocked: bool
|
||||||
creation_datetime: datetime
|
creation_datetime: datetime
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parser(item: dict) -> dict:
|
def parser(item: dict) -> dict:
|
||||||
|
|
||||||
item["creation_datetime"] = datetime.strptime(item["creation_datetime"], DATETIME_FORMAT)
|
item["creation_datetime"] = datetime.strptime(item["creation_datetime"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||||
return item
|
return item
|
||||||
|
@ -83,8 +83,6 @@ class PaginationView(View):
|
|||||||
self.index = initpage
|
self.index = initpage
|
||||||
|
|
||||||
# emoji reference
|
# emoji reference
|
||||||
next_emoji = bot.get_emoji(1204542366602502265)
|
|
||||||
prev_emoji = bot.get_emoji(1204542365432422470)
|
|
||||||
self.start_emoji = bot.get_emoji(1204542364073463818)
|
self.start_emoji = bot.get_emoji(1204542364073463818)
|
||||||
self.end_emoji = bot.get_emoji(1204542367752003624)
|
self.end_emoji = bot.get_emoji(1204542367752003624)
|
||||||
|
|
||||||
@ -113,7 +111,6 @@ class PaginationView(View):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def calc_total_pages(results: int, max_pagesize: int) -> int:
|
def calc_total_pages(results: int, max_pagesize: int) -> int:
|
||||||
result = ((results - 1) // max_pagesize) + 1
|
result = ((results - 1) // max_pagesize) + 1
|
||||||
log.debug("total pages calculated: %s", result)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def calc_dataitem_index(self, dataitem_index: int):
|
def calc_dataitem_index(self, dataitem_index: int):
|
||||||
@ -200,6 +197,7 @@ class PaginationView(View):
|
|||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
self.maxpage = self.calc_total_pages(total_results, self.pagesize)
|
self.maxpage = self.calc_total_pages(total_results, self.pagesize)
|
||||||
|
log.debug(f"{self.maxpage=!r}")
|
||||||
|
|
||||||
for i, item in enumerate(data):
|
for i, item in enumerate(data):
|
||||||
i = self.calc_dataitem_index(i)
|
i = self.calc_dataitem_index(i)
|
||||||
@ -228,6 +226,9 @@ class PaginationView(View):
|
|||||||
self.children[1].disabled = self.index == self.maxpage
|
self.children[1].disabled = self.index == self.maxpage
|
||||||
|
|
||||||
async def send(self):
|
async def send(self):
|
||||||
|
"""Send the pagination view. It may be important to defer before invoking this method."""
|
||||||
|
|
||||||
|
log.debug("sending pagination view")
|
||||||
embed = await self.create_paged_embed()
|
embed = await self.create_paged_embed()
|
||||||
|
|
||||||
if self.maxpage <= 1:
|
if self.maxpage <= 1:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user