Compare commits

...

36 Commits
master ... dev

Author SHA1 Message Date
7d395d2001 tests for other filter cases & a duplicate case
Some checks failed
Build and Push Docker Image / build (push) Failing after 7m2s
2024-11-15 23:47:31 +00:00
b0b02aae9f log before/after duplicates are removed 2024-11-15 23:47:07 +00:00
ca2037528d fix 'all' filter failing 2024-11-15 23:46:49 +00:00
a9c08dea4d write proper tests
Some checks failed
Build and Push Docker Image / build (push) Failing after 7m4s
using pytest
2024-11-15 23:08:38 +00:00
8fcccc7ac9 merge embeds into one
Some checks failed
Build and Push Docker Image / build (push) Failing after 6m57s
2024-11-13 23:19:55 +00:00
efbb4b18ff needed more log info 2024-11-13 21:40:41 +00:00
81795feb65 duplicate content enforcement
Some checks failed
Build and Push Docker Image / build (push) Failing after 7m4s
2024-11-03 22:37:52 +00:00
1f9075ce60 simple console logging
Some checks failed
Build and Push Docker Image / build (push) Has been cancelled
2024-11-03 22:37:20 +00:00
6e153782c2 check for duplicate content & loop 'do_task'
Some checks failed
Build and Push Docker Image / build (push) Failing after 6m52s
2024-11-02 00:29:30 +00:00
d86fc0eb71 save method for content
Some checks are pending
Build and Push Docker Image / build (push) Waiting to run
2024-11-02 00:29:06 +00:00
08295dfea6 efficiency and function for tasks
Some checks failed
Build and Push Docker Image / build (push) Failing after 7m7s
2024-10-31 23:52:29 +00:00
470f78c144 write small test
Some checks failed
Build and Push Docker Image / build (push) Failing after 7m9s
2024-10-31 13:11:14 +00:00
94b154742e remove unused requirements 2024-10-31 13:11:08 +00:00
1294909188 mute annoying httpx logger
Some checks are pending
Build and Push Docker Image / build (push) Waiting to run
2024-10-31 13:10:59 +00:00
cf8fb34a29 temp trim commands
Some checks are pending
Build and Push Docker Image / build (push) Waiting to run
2024-10-31 13:10:46 +00:00
eb97dca5c6 process models in task
Some checks failed
Build and Push Docker Image / build (push) Has been cancelled
2024-10-31 13:10:27 +00:00
b8f1ffb8d9 models for task & task work
Some checks failed
Build and Push Docker Image / build (push) Failing after 6m58s
2024-10-30 17:01:07 +00:00
ccfa35adda reworking task
Some checks failed
Build and Push Docker Image / build (push) Failing after 6m55s
2024-10-29 23:44:49 +00:00
cc06d3e09f Update README.md
All checks were successful
Build and Push Docker Image / build (push) Successful in 23s
2024-09-25 14:39:36 +00:00
f9de8ff085 Update feed.py
All checks were successful
Build and Push Docker Image / build (push) Successful in 12s
2024-09-22 21:35:47 +01:00
02a4917152 Update CHANGELOG.md
All checks were successful
Build and Push Docker Image / build (push) Successful in 12s
2024-09-22 21:33:37 +01:00
ff30e31cf1 fix error from yet-to-be implemented feature
Some checks failed
Build and Push Docker Image / build (push) Failing after 7m4s
2024-09-22 21:19:58 +01:00
50f7a62cd4 dates and correct a version mistake
All checks were successful
Build and Push Docker Image / build (push) Successful in 12s
2024-09-18 15:16:36 +01:00
4c31fe69e2 update changelog 2024-09-18 15:13:04 +01:00
bb3475b79d Support latest pyrss-website version
All checks were successful
Build and Push Docker Image / build (push) Successful in 13s
2024-09-18 15:04:11 +01:00
82fe6bea9a Bump version: 0.2.0 → 0.2.1
All checks were successful
Build and Push Docker Image / build (push) Successful in 12s
2024-09-10 11:03:34 +01:00
32b8092034 release notes
All checks were successful
Build and Push Docker Image / build (push) Successful in 12s
2024-09-10 11:02:46 +01:00
26f697cf78 update changelog
All checks were successful
Build and Push Docker Image / build (push) Successful in 11s
2024-09-10 11:01:30 +01:00
ab472e1979 Merge branch 'staging' into dev
All checks were successful
Build and Push Docker Image / build (push) Successful in 12s
2024-09-10 11:00:51 +01:00
7515d6b86e Update CHANGELOG.md
All checks were successful
Build and Push Docker Image / build (push) Successful in 12s
2024-08-26 20:27:19 +01:00
a5c593bb14 proper type hinting
All checks were successful
Build and Push Docker Image / build (push) Successful in 11s
2024-08-26 20:10:19 +01:00
48697c08a6 last_build_date fix
All checks were successful
Build and Push Docker Image / build (push) Successful in 13s
2024-08-26 20:06:53 +01:00
cbe15aceb8 whitespace remove
All checks were successful
Build and Push Docker Image / build (push) Successful in 12s
2024-08-23 18:01:29 +01:00
e8a9c270e4 enhancements for view commands
All checks were successful
Build and Push Docker Image / build (push) Successful in 13s
2024-08-23 17:47:01 +01:00
2a1aaa689e function get many filters 2024-08-23 17:45:43 +01:00
1a4f25ec97 ContentFilter model 2024-08-23 17:45:26 +01:00
14 changed files with 1658 additions and 667 deletions

View File

@ -1,4 +1,4 @@
[bumpversion] [bumpversion]
current_version = 0.2.0 current_version = 0.2.1
commit = True commit = True
tag = True tag = True

View File

@ -1,14 +1,54 @@
# Changelog
**v0.2.0** All notable changes to this project will be documented in this file.
- Fix: Fetch channels if not found in bot cache (error fix) The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
- Enhancement: command to test a channel's permissions allow for the Bot to function and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
- Enhancement: account for active state from a server's settings (`GuildSettings`)
- Enhancement: command to view tracked content from the server or a given subscription of the same server.
- Other: code optimisation & `GuildSettings` dataclass
- Other: Cleaned out many instances of unused code
**v0.1.1** ## [Unreleased]
- Docs: Start of changelog ### Added
- Enhancement: Versioning with tagged docker images
- Search and filter controls for the data viewing commands
### Fixed
- RSS feeds without a build date would break the subscription task
- TypeError when an RSS item lacks a title or description
- Fix an issue with a missing field on the Subscription model
### Changed
- Show whether a subscription is active or inactive when using a data view command
- Added `unique_content_rules` field to `Subscription` dataclass (support latest pyrss-website version)
- Update changelog to follow [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
## [0.2.0] - 2024-08-19
### Added
- Command to view tracked content from the relevant server
- Command to test the bot's permissions in a specified channel
- `GuildSettings` dataclass
### Fixed
- channels are `NoneType` because they didn't exist in the cache, fixed by fetching from API
### Changed
- Subscription task will ignore subscriptions flagged as 'inactive'
- Code optimisation
### Removed
- Unused and commented out code
## [0.1.1] - 2024-08-17
### Added
- Start of changelog
- Versioning with tagged docker images
## [0.1.0] - 2024-08-13

View File

@ -4,4 +4,4 @@ An RSS driven Discord bot written in Python.
Provides user commands for storing RSS feed URLs that can be assigned to any given discord channel. Provides user commands for storing RSS feed URLs that can be assigned to any given discord channel.
Depends on the [web application](https://gitea.corbz.dev/corbz/PYRSS-Website). Check the releases for compatible versions. Depends on the [web application](https://gitea.cor.bz/corbz/PYRSS-Website). Check the releases for compatible versions.

View File

@ -1,15 +1,15 @@
{ {
"version": 1, "version": 1,
"disable_existing_loggers": false, "disable_existing_loggers": true,
"formatters": { "formatters": {
"simple": { "simple": {
"format": "%(levelname)s %(message)s" "format": "[%(module)s|%(message)s]"
}, },
"detail": { "detail": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" "format": "[%(asctime)s] [%(levelname)s] [%(module)s]: %(message)s"
}, },
"complex": { "complex": {
"format": "[%(levelname)s|%(module)s|L%(lineno)d] %(asctime)s %(message)s", "format": "[%(levelname)s|%(name)s|L%(lineno)d] %(asctime)s %(message)s",
"datefmt": "%Y-%m-%dT%H:%M:%S%z" "datefmt": "%Y-%m-%dT%H:%M:%S%z"
} }
}, },
@ -17,7 +17,7 @@
"stdout": { "stdout": {
"class": "logging.StreamHandler", "class": "logging.StreamHandler",
"level": "DEBUG", "level": "DEBUG",
"formatter": "simple", "formatter": "detail",
"stream": "ext://sys.stdout" "stream": "ext://sys.stdout"
}, },
"file": { "file": {
@ -46,6 +46,14 @@
}, },
"discord": { "discord": {
"level": "INFO" "level": "INFO"
},
"httpx": {
"level": "WARNING",
"propagate": false
},
"httpcore": {
"level": "WARNING",
"propagate": false
} }
} }
} }

3
pytest.ini Normal file
View File

@ -0,0 +1,3 @@
[pytest]
filterwarnings =
ignore:'audioop' is deprecated and slated for removal in Python 3.13:DeprecationWarning

View File

@ -1,29 +1,30 @@
aiocache==0.12.2 aiohappyeyeballs==2.4.3
aiohttp==3.9.3 aiohttp==3.10.10
aiosignal==1.3.1 aiosignal==1.3.1
aiosqlite==0.19.0 anyio==4.6.2.post1
async-timeout==4.0.3 attrs==24.2.0
asyncpg==0.29.0
attrs==23.2.0
beautifulsoup4==4.12.3 beautifulsoup4==4.12.3
bump2version==1.0.1 bump2version==1.0.1
click==8.1.7 certifi==2024.8.30
discord.py==2.3.2 discord.py==2.3.2
feedparser==6.0.11 feedparser==6.0.11
frozenlist==1.4.1 frozenlist==1.5.0
greenlet==3.0.3 h11==0.14.0
idna==3.6 httpcore==1.0.6
httpx==0.27.2
idna==3.10
iniconfig==2.0.0
markdownify==0.11.6 markdownify==0.11.6
multidict==6.0.5 multidict==6.1.0
pip-chill==1.0.3 packaging==24.2
psycopg2-binary==2.9.9 pluggy==1.5.0
propcache==0.2.0
pytest==8.3.3
python-dotenv==1.0.0 python-dotenv==1.0.0
rapidfuzz==3.9.4 rapidfuzz==3.9.4
sgmllib3k==1.0.0 sgmllib3k==1.0.0
six==1.16.0 six==1.16.0
soupsieve==2.5 sniffio==1.3.1
SQLAlchemy==2.0.23 soupsieve==2.6
typing_extensions==4.10.0
uwuipy==0.1.9 uwuipy==0.1.9
validators==0.22.0 yarl==1.17.0
yarl==1.9.4

View File

@ -157,7 +157,7 @@ class API:
log.debug("getting tracked content") log.debug("getting tracked content")
return await self._get_many(self.API_ENDPOINT + f"tracked-content/", filters) return await self._get_many(self.API_ENDPOINT + "tracked-content/", filters)
async def get_filter(self, filter_id: int) -> dict: async def get_filter(self, filter_id: int) -> dict:
""" """
@ -167,3 +167,10 @@ class API:
log.debug("getting a filter") log.debug("getting a filter")
return await self._get_one(f"{self.API_ENDPOINT}filter/{filter_id}") return await self._get_one(f"{self.API_ENDPOINT}filter/{filter_id}")
async def get_filters(self, **filters) -> tuple[list[dict], int]:
"""
Get many instances of Filter.
"""
return await self._get_many(self.API_ENDPOINT + "filter/", filters)

View File

@ -7,76 +7,82 @@ import logging
from typing import Tuple from typing import Tuple
from datetime import datetime from datetime import datetime
import aiohttp # import aiohttp
import validators # import validators
from feedparser import FeedParserDict, parse from feedparser import FeedParserDict, parse
from discord.ext import commands from discord.ext import commands
from discord import Interaction, TextChannel, Embed, Colour from discord import Interaction, TextChannel, Embed, Colour
from discord.app_commands import Choice, Group, autocomplete, rename, command from discord.app_commands import Choice, choices, Group, autocomplete, rename, command
from discord.errors import Forbidden from discord.errors import Forbidden
from api import API # from api import API
from feed import Subscription, TrackedContent # from feed import Subscription, TrackedContent, ContentFilter
from utils import ( # from utils import (
Followup, # Followup,
PaginationView, # PaginationView,
get_rss_data, # get_rss_data,
) # )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
rss_list_sort_choices = [ # rss_list_sort_choices = [
Choice(name="Nickname", value=0), # Choice(name="Nickname", value=0),
Choice(name="Date Added", value=1) # Choice(name="Date Added", value=1)
] # ]
channels_list_sort_choices=[ # channels_list_sort_choices=[
Choice(name="Feed Nickname", value=0), # Choice(name="Feed Nickname", value=0),
Choice(name="Channel ID", value=1), # Choice(name="Channel ID", value=1),
Choice(name="Date Added", value=2) # Choice(name="Date Added", value=2)
] # ]
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the # # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the
# target resource. Run a period task to check this. # # target resource. Run a period task to check this.
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]: # async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]:
"""Validate a provided RSS source. # """Validate a provided RSS source.
Parameters # Parameters
---------- # ----------
nickname : str # nickname : str
Nickname of the source. Must not contain URL. # Nickname of the source. Must not contain URL.
url : str # url : str
URL of the source. Must be URL with valid status code and be an RSS feed. # URL of the source. Must be URL with valid status code and be an RSS feed.
Returns # Returns
------- # -------
str or None # str or None
String invalid message if invalid, NoneType if valid. # String invalid message if invalid, NoneType if valid.
FeedParserDict or None # FeedParserDict or None
The feed parsed from the given URL or None if invalid. # The feed parsed from the given URL or None if invalid.
""" # """
# Ensure the URL is valid # # Ensure the URL is valid
if not validators.url(url): # if not validators.url(url):
return f"The URL you have entered is malformed or invalid:\n`{url=}`", None # return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
# Check the nickname is not a URL # # Check the nickname is not a URL
if validators.url(nickname): # if validators.url(nickname):
return "It looks like the nickname you have entered is a URL.\n" \ # return "It looks like the nickname you have entered is a URL.\n" \
f"For security reasons, this is not allowed.\n`{nickname=}`", None # f"For security reasons, this is not allowed.\n`{nickname=}`", None
feed_data, status_code = await get_rss_data(url) # feed_data, status_code = await get_rss_data(url)
# Check the URL status code is valid # # Check the URL status code is valid
if status_code != 200: # if status_code != 200:
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None # return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
# Check the contents is actually an RSS feed. # # Check the contents is actually an RSS feed.
feed = parse(feed_data) # feed = parse(feed_data)
if not feed.version: # if not feed.version:
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None # return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
return None, feed # return None, feed
# tri_choices = [
# Choice(name="Yes", value=2),
# Choice(name="No (default)", value=1),
# Choice(name="All", value=0),
# ]
class CommandsCog(commands.Cog): class CommandsCog(commands.Cog):
@ -94,135 +100,206 @@ class CommandsCog(commands.Cog):
log.info("%s cog is ready", self.__class__.__name__) log.info("%s cog is ready", self.__class__.__name__)
# Group for commands about viewing data # # Group for commands about viewing data
view_group = Group( # view_group = Group(
name="view", # name="view",
description="View data.", # description="View data.",
guild_only=True # guild_only=True
) # )
@view_group.command(name="subscriptions") # @view_group.command(name="subscriptions")
async def cmd_list_subs(self, inter: Interaction, search: str = ""): # async def cmd_list_subs(self, inter: Interaction, search: str = ""):
"""List Subscriptions from this server.""" # """List Subscriptions from this server."""
await inter.response.defer() # await inter.response.defer()
def formatdata(index, item): # def formatdata(index, item):
item = Subscription.from_dict(item) # item = Subscription.from_dict(item)
channels = f"{item.channels_count}{' channels' if item.channels_count != 1 else ' channel'}"
filters = f"{len(item.filters)}{' filters' if len(item.filters) != 1 else ' filter'}"
notes = item.extra_notes[:25] + "..." if len(item.extra_notes) > 28 else item.extra_notes
links = f"[RSS Link]({item.url}) · [API Link]({API.API_EXTERNAL_ENDPOINT}subscription/{item.id}/)"
description = f"{channels}, {filters}\n" # notes = item.extra_notes[:25] + "..." if len(item.extra_notes) > 28 else item.extra_notes
description += f"{notes}\n" if notes else "" # links = f"[RSS Link]({item.url}) • [API Link]({API.API_EXTERNAL_ENDPOINT}subscription/{item.id}/)"
description += links # activeness = "✅ `enabled`" if item.active else "🚫 `disabled`"
key = f"{index}. {item.name}" # description = f"🆔 `{item.id}`\n{activeness}\n#️⃣ `{item.channels_count}` 🔽 `{len(item.filters)}`\n"
return key, description # key, value pair # description = f"{notes}\n" + description if notes else description
# description += links
async def getdata(page: int, pagesize: int): # key = f"{index}. {item.name}"
async with aiohttp.ClientSession() as session: # return key, description # key, value pair
api = API(self.bot.api_token, session)
return await api.get_subscriptions(
guild_id=inter.guild.id,
page=page,
page_size=pagesize,
search=search
)
embed = Followup(f"Subscriptions in {inter.guild.name}").info()._embed # async def getdata(page: int, pagesize: int):
pagination = PaginationView( # async with aiohttp.ClientSession() as session:
self.bot, # api = API(self.bot.api_token, session)
inter=inter, # return await api.get_subscriptions(
embed=embed, # guild_id=inter.guild.id,
getdata=getdata, # page=page,
formatdata=formatdata, # page_size=pagesize,
pagesize=10, # search=search
initpage=1 # )
)
await pagination.send()
@view_group.command(name="tracked-content") # embed = Followup(f"Subscriptions in {inter.guild.name}").info()._embed
async def cmd_list_tracked(self, inter: Interaction, search: str = ""): # pagination = PaginationView(
"""List Tracked Content from this server, or a given sub""" # self.bot,
# inter=inter,
# embed=embed,
# getdata=getdata,
# formatdata=formatdata,
# pagesize=10,
# initpage=1
# )
# await pagination.send()
await inter.response.defer() # @view_group.command(name="tracked-content")
# @choices(blocked=tri_choices)
# async def cmd_list_tracked(self, inter: Interaction, search: str = "", blocked: Choice[int] = 1):
# """List Tracked Content from this server""" # TODO: , or a given sub
def formatdata(index, item): # await inter.response.defer()
item = TrackedContent.from_dict(item)
sub = Subscription.from_dict(item.subscription)
links = f"[Content Link]({item.url}) · [Message Link](https://discord.com/channels/{sub.guild_id}/{item.channel_id}/{item.message_id}/)" # # If the user picks an option it's an instance of `Choice` otherwise `str`
description = f"Subscription: {sub.name}\n{links}" # # Can't figure a way to select a default choices, so blame discordpy for this mess.
# if isinstance(blocked, Choice):
# blocked = blocked.value
key = f"{item.id}. {item.title}" # def formatdata(index, item):
return key, description # item = TrackedContent.from_dict(item)
# sub = Subscription.from_dict(item.subscription)
async def getdata(page: int, pagesize: int): # links = f"[Content Link]({item.url}) · [Message Link](https://discord.com/channels/{sub.guild_id}/{item.channel_id}/{item.message_id}/) · [API Link]({API.API_EXTERNAL_ENDPOINT}tracked-content/{item.id}/)"
async with aiohttp.ClientSession() as session: # delivery_state = "✅ Delivered" if not item.blocked else "🚫 Blocked"
api = API(self.bot.api_token, session)
return await api.get_tracked_content(
subscription__guild_id=inter.guild_id,
page=page,
page_size=pagesize,
search=search
)
embed = Followup(f"Tracked Content in {inter.guild.name}").info()._embed # description = f"🆔 `{item.id}`\n"
pagination = PaginationView( # description += f"{delivery_state}\n" if blocked == 0 else ""
self.bot, # description += f"➡️ *{sub.name}*\n{links}"
inter=inter,
embed=embed,
getdata=getdata,
formatdata=formatdata,
pagesize=10,
initpage=1
)
await pagination.send()
# Group for test related commands # key = f"{index}. {item.title}"
test_group = Group( # return key, description
name="test",
description="Commands to test Bot functionality.",
guild_only=True
)
@test_group.command(name="channel-permissions") # def determine_blocked():
async def cmd_test_channel_perms(self, inter: Interaction): # match blocked:
"""Test that the current channel's permissions allow for PYRSS to operate in it.""" # case 0: return ""
# case 1: return "false"
# case 2: return "true"
# case _: return ""
try: # async def getdata(page: int, pagesize: int):
test_message = await inter.channel.send(content="... testing permissions ...") # async with aiohttp.ClientSession() as session:
await self.test_channel_perms(inter.channel) # api = API(self.bot.api_token, session)
except Exception as error: # is_blocked = determine_blocked()
await inter.response.send_message(content=f"Failed: {error}") # return await api.get_tracked_content(
return # subscription__guild_id=inter.guild_id,
# blocked=is_blocked,
# page=page,
# page_size=pagesize,
# search=search,
# )
await test_message.delete() # embed = Followup(f"Tracked Content in {inter.guild.name}").info()._embed
await inter.response.send_message(content="Success") # pagination = PaginationView(
# self.bot,
# inter=inter,
# embed=embed,
# getdata=getdata,
# formatdata=formatdata,
# pagesize=10,
# initpage=1
# )
# await pagination.send()
async def test_channel_perms(self, channel: TextChannel): # @view_group.command(name="filters")
# async def cmd_list_filters(self, inter: Interaction, search: str = ""):
# """List Filters from this server."""
# Test generic message and delete # await inter.response.defer()
msg = await channel.send(content="test message")
await msg.delete()
# Test detailed embed # def formatdata(index, item):
embed = Embed( # item = ContentFilter.from_dict(item)
title="test title",
description="test description", # matching_algorithm = get_algorithm_name(item.matching_algorithm)
colour=Colour.random(), # whitelist = "Whitelist" if item.is_whitelist else "Blacklist"
timestamp=datetime.now(), # sensitivity = "Case insensitive" if item.is_insensitive else "Case sensitive"
url="https://google.com"
) # description = f"🆔 `{item.id}`\n"
embed.set_author(name="test author") # description += f"🔄 `{matching_algorithm}`\n🟰 `{item.match}`\n"
embed.set_footer(text="test footer") # description += f"✅ `{whitelist}` 🔠 `{sensitivity}`\n"
embed.set_thumbnail(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png") # description += f"[API Link]({API.API_EXTERNAL_ENDPOINT}filter/{item.id}/)"
embed.set_image(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
embed_msg = await channel.send(embed=embed) # key = f"{index}. {item.name}"
await embed_msg.delete() # return key, description
# def get_algorithm_name(matching_algorithm: int):
# match matching_algorithm:
# case 0: return "None"
# case 1: return "Any word"
# case 2: return "All words"
# case 3: return "Exact match"
# case 4: return "Regex match"
# case 5: return "Fuzzy match"
# case _: return "unknown"
# async def getdata(page, pagesize):
# async with aiohttp.ClientSession() as session:
# api = API(self.bot.api_token, session)
# return await api.get_filters(
# guild_id=inter.guild_id,
# page=page,
# page_size=pagesize,
# search=search
# )
# embed = Followup(f"Filters in {inter.guild.name}").info()._embed
# pagination = PaginationView(
# self.bot,
# inter=inter,
# embed=embed,
# getdata=getdata,
# formatdata=formatdata,
# pagesize=10,
# initpage=1
# )
# await pagination.send()
# # Group for test related commands
# test_group = Group(
# name="test",
# description="Commands to test Bot functionality.",
# guild_only=True
# )
# @test_group.command(name="channel-permissions")
# async def cmd_test_channel_perms(self, inter: Interaction):
# """Test that the current channel's permissions allow for PYRSS to operate in it."""
# try:
# test_message = await inter.channel.send(content="... testing permissions ...")
# await self.test_channel_perms(inter.channel)
# except Exception as error:
# await inter.response.send_message(content=f"Failed: {error}")
# return
# await test_message.delete()
# await inter.response.send_message(content="Success")
# async def test_channel_perms(self, channel: TextChannel):
# # Test generic message and delete
# msg = await channel.send(content="test message")
# await msg.delete()
# # Test detailed embed
# embed = Embed(
# title="test title",
# description="test description",
# colour=Colour.random(),
# timestamp=datetime.now(),
# url="https://google.com"
# )
# embed.set_author(name="test author")
# embed.set_footer(text="test footer")
# embed.set_thumbnail(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
# embed.set_image(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
# embed_msg = await channel.send(embed=embed)
# await embed_msg.delete()
async def setup(bot): async def setup(bot):

View File

@ -3,29 +3,37 @@ Extension for the `TaskCog`.
Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot. Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot.
""" """
import json
import asyncio import asyncio
import logging import logging
import datetime import datetime
import traceback
from os import getenv from os import getenv
from time import perf_counter from time import perf_counter
from collections import deque from collections import deque
from textwrap import shorten
import aiohttp # import aiohttp
from aiocache import Cache import httpx
import feedparser
import discord
# from aiocache import Cache
from discord import TextChannel from discord import TextChannel
from discord import app_commands from discord import app_commands, Interaction
from discord.ext import commands, tasks from discord.ext import commands, tasks
from discord.errors import Forbidden from discord.errors import Forbidden
from feedparser import parse from markdownify import markdownify
from feed import RSSFeed, Subscription, RSSItem, GuildSettings import models
from utils import get_unparsed_feed from utils import do_batch_job
from filters import match_text # from feed import RSSFeed, Subscription, RSSItem, GuildSettings
# from utils import get_unparsed_feed
# from filters import match_text
from api import API from api import API
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
cache = Cache(Cache.MEMORY) # cache = Cache(Cache.MEMORY)
BATCH_SIZE = 100 BATCH_SIZE = 100
@ -35,7 +43,7 @@ subscription_task_times = [
for hour in range(24) for hour in range(24)
for minute in range(0, 60, int(TASK_INTERVAL_MINUTES)) for minute in range(0, 60, int(TASK_INTERVAL_MINUTES))
] ]
log.debug("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES) log.info("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES)
class TaskCog(commands.Cog): class TaskCog(commands.Cog):
@ -46,16 +54,23 @@ class TaskCog(commands.Cog):
api: API | None = None api: API | None = None
content_queue = deque() content_queue = deque()
api_base_url: str
api_headers: dict
client: httpx.AsyncClient | None
def __init__(self, bot): def __init__(self, bot):
super().__init__() super().__init__()
self.bot = bot self.bot = bot
self.api_base_url = "http://localhost:8000/api/"
self.api_headers = {"Authorization": f"Token {self.bot.api_token}"}
@commands.Cog.listener() @commands.Cog.listener()
async def on_ready(self): async def on_ready(self):
""" """
Instructions to execute when the cog is ready. Instructions to execute when the cog is ready.
""" """
self.subscription_task.start() # self.subscription_task.start()
self.do_task.start()
log.info("%s cog is ready", self.__class__.__name__) log.info("%s cog is ready", self.__class__.__name__)
@commands.Cog.listener(name="cog_unload") @commands.Cog.listener(name="cog_unload")
@ -72,219 +87,434 @@ class TaskCog(commands.Cog):
) )
@group.command(name="trigger") @group.command(name="trigger")
async def cmd_trigger_task(self, inter): async def cmd_trigger_task(self, inter: Interaction):
await inter.response.defer() await inter.response.defer()
start_time = perf_counter() start_time = perf_counter()
try: try:
await self.subscription_task() await self.do_task()
except Exception as error: except Exception as exc:
await inter.followup.send(str(error)) log.exception(exc)
await inter.followup.send(str(exc) or "unknown error")
finally: finally:
end_time = perf_counter() end_time = perf_counter()
await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds") await inter.followup.send(f"completed command in {end_time - start_time:.4f} seconds")
@tasks.loop(time=subscription_task_times) @tasks.loop(time=subscription_task_times)
async def subscription_task(self): async def do_task(self):
""" log.info("Running task")
Task for fetching and processing subscriptions.
"""
log.info("Running subscription task")
start_time = perf_counter() start_time = perf_counter()
async with aiohttp.ClientSession() as session: async with httpx.AsyncClient() as client:
self.api = API(self.bot.api_token, session) self.client = client
await self.execute_task() servers = await self.get_servers()
await do_batch_job(servers, self.process_server, 10)
end_time = perf_counter() end_time = perf_counter()
log.debug(f"task completed in {end_time - start_time:.4f} seconds") log.info(f"completed task in {end_time - start_time:.4f} seconds")
async def execute_task(self): async def iterate_pages(self, url: str, params: dict={}):
"""Execute the task directly."""
# Filter out inactive guild IDs using related settings for page_number, _ in enumerate(iterable=iter(int, 1), start=1):
guild_ids = [guild.id for guild in self.bot.guilds] params.update({"page": page_number})
guild_settings = await self.get_guild_settings(guild_ids) response = await self.client.get(
active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active] self.api_base_url + url,
headers=self.api_headers,
subscriptions = await self.get_subscriptions(active_guild_ids) params=params
await self.process_subscriptions(subscriptions)
async def get_guild_settings(self, guild_ids: list[int]) -> list[int]:
"""Returns a list of guild settings from the Bot's guilds, if they exist."""
guild_settings = []
# Iterate infinitely taking the iter no. as `page`
# data will be empty after last page reached.
for page, _ in enumerate(iter(int, 1)):
data = await self.get_guild_settings_page(guild_ids, page)
if not data:
break
guild_settings.extend(data[0])
# Only return active guild IDs
return GuildSettings.from_list(guild_settings)
async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]:
"""Returns an individual page of guild settings."""
try:
return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1)
except aiohttp.ClientResponseError as error:
self.handle_pagination_error(error)
return []
def handle_pagination_error(self, error: aiohttp.ClientResponseError):
"""Handle the error cases from pagination attempts."""
match error.status:
case 404:
log.debug("final page reached")
case 403:
log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True)
self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently
case _:
log.debug(error)
async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]:
"""Get a list of `Subscription`s matching the given `guild_ids`."""
subscriptions = []
# Iterate infinitely taking the iter no. as `page`
# data will be empty after last page reached.
for page, _ in enumerate(iter(int, 1)):
data = await self.get_subs_page(guild_ids, page)
if not data:
break
subscriptions.extend(data[0])
return Subscription.from_list(subscriptions)
async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]:
"""Returns an individual page of subscriptions."""
try:
return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1)
except aiohttp.ClientResponseError as error:
self.handle_pagination_error(error)
return []
async def process_subscriptions(self, subscriptions: list[Subscription]):
"""Process a given list of `Subscription`s."""
async def process_single_subscription(sub: Subscription):
log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id)
if not sub.active or not sub.channels_count:
return
unparsed_feed = await get_unparsed_feed(sub.url)
parsed_feed = parse(unparsed_feed)
rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
await self.process_items(sub, rss_feed)
semaphore = asyncio.Semaphore(10)
async def semaphore_process(sub: Subscription):
async with semaphore:
await process_single_subscription(sub)
await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions))
async def process_items(self, sub: Subscription, feed: RSSFeed):
log.debug("processing items")
channels = await self.fetch_or_get_channels(await sub.get_channels(self.api))
filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters]
for item in feed.items:
log.debug("processing item '%s'", item.guid)
if item.pub_date < sub.published_threshold:
log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold)
continue
blocked = any(self.filter_item(_filter, item) for _filter in filters)
mutated_item = item.create_mutated_copy(sub.mutators) if sub.mutators else None
for channel in channels:
await self.track_and_send(sub, feed, item, mutated_item, channel, blocked)
async def fetch_or_get_channels(self, channels_data: list[dict]):
channels = []
for data in channels_data:
try:
channel = self.bot.get_channel(data.channel_id)
channels.append(channel or await self.bot.fetch_channel(data.channel_id))
except Forbidden:
log.error(f"Forbidden Channel '{data.channel_id}'")
return channels
def filter_item(self, _filter: dict, item: RSSItem) -> bool:
"""
Returns `True` if item should be ignored due to filters.
"""
match_found = match_text(_filter, item.title) or match_text(_filter, item.description)
log.debug("filter match found? '%s'", match_found)
return match_found
async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem | None, channel: TextChannel, blocked: bool):
message_id = -1
log.debug("track and send func %s, %s", item.guid, item.title)
result = await self.api.get_tracked_content(guid=item.guid)
if result[1]:
log.debug(f"This item is already tracked, skipping '{item.guid}'")
return
result = await self.api.get_tracked_content(url=item.link)
if result[1]:
log.debug(f"This item is already tracked, skipping '{item.guid}'")
return
if not blocked:
try:
log.debug("sending '%s', exists '%s'", item.guid, result[1])
sendable_item = mutated_item or item
message = await channel.send(embed=await sendable_item.to_embed(sub, feed, self.api.session))
message_id = message.id
except Forbidden:
log.error(f"Forbidden to send to channel {channel.id}")
await self.mark_tracked_item(sub, item, channel.id, message_id, blocked)
async def process_batch(self):
pass
async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool):
try:
log.debug("marking as tracked")
await self.api.create_tracked_content(
guid=item.guid,
title=item.title,
url=item.link,
subscription=sub.id,
channel_id=channel_id,
message_id=message_id,
blocked=blocked
) )
return True response.raise_for_status()
except aiohttp.ClientResponseError as error: content = response.json()
if error.status == 409:
log.debug(error)
else:
log.error(error)
return False yield content.get("results", [])
if not content.get("next"):
break
async def get_servers(self) -> list[models.Server]:
servers = []
async for servers_batch in self.iterate_pages("servers/"):
if servers_batch:
servers.extend(servers_batch)
return models.Server.from_list(servers)
async def get_subscriptions(self, server: models.Server) -> list[models.Subscription]:
subscriptions = []
params = {"server": server.id, "active": True}
async for subscriptions_batch in self.iterate_pages("subscriptions/", params):
if subscriptions_batch:
subscriptions.extend(subscriptions_batch)
return models.Subscription.from_list(subscriptions)
async def get_contents(self, subscription: models.Subscription, raw_rss_content: dict):
contents = await models.Content.from_raw_rss(raw_rss_content, subscription, self.client)
duplicate_contents = []
async def check_duplicate_content(content: models.Content):
exists = await content.exists_via_api(
url=self.api_base_url + "content/",
headers=self.api_headers,
client=self.client
)
if exists:
log.debug(f"Removing duplicate {content}")
duplicate_contents.append(content)
await do_batch_job(contents, check_duplicate_content, 15)
log.debug(f"before removing duplicates: {len(contents)}")
for duplicate in duplicate_contents:
contents.remove(duplicate)
log.debug(f"after removing duplicates: {len(contents)}")
return contents
async def process_server(self, server: models.Server):
log.debug(f"processing server: {server.name}")
start_time = perf_counter()
subscriptions = await self.get_subscriptions(server)
for subscription in subscriptions:
subscription.server = server
await do_batch_job(subscriptions, self.process_subscription, 10)
end_time = perf_counter()
log.debug(f"Finished processing server: {server.name} in {end_time - start_time:.4f} seconds")
async def process_subscription(self, subscription: models.Subscription):
log.debug(f"processing subscription {subscription.name}")
start_time = perf_counter()
raw_rss_content = await subscription.get_rss_content(self.client)
if not raw_rss_content:
return
contents = await self.get_contents(subscription, raw_rss_content)
if not contents:
log.debug("no contents to process")
return
channels = await subscription.get_discord_channels(self.bot)
valid_contents, invalid_contents = subscription.filter_entries(contents)
async def send_content(channel: discord.TextChannel):
# BUG: I believe there are duplicate embeds here
# discord only shows 1 when urls are matching, but merges images from both into the 1
# embeds = [content.embed for content in valid_contents]
# batch_size = 10
# for i in range(0, len(embeds), batch_size):
# batch = embeds[i:i + batch_size]
# await channel.send(embeds=batch)
batch_size = 10
total_batches = (len(valid_contents) + batch_size - 1) // batch_size
for batch_number, i in enumerate(range(0, len(valid_contents), batch_size)):
contents_batch = valid_contents[i:i + batch_size]
embeds = await self.create_embeds(
contents=contents_batch,
subscription_name=subscription.name,
colour=subscription.message_style.colour,
batch_number=batch_number,
total_batches=total_batches
)
await channel.send(embeds=embeds)
await do_batch_job(channels, send_content, 5)
combined = valid_contents.copy()
combined.extend(invalid_contents)
tasks = [content.save(self.client, self.api_base_url, self.api_headers) for content in combined]
await asyncio.gather(*tasks)
# TODO: mark invalid contents as blocked
end_time = perf_counter()
log.debug(f"Finished processing subscription: {subscription.name} in {end_time - start_time:.4f}")
async def create_embeds(self, contents: list[models.Content], subscription_name: str, colour: str, batch_number: int, total_batches: int):
discord_colour = discord.Colour.from_str(
colour if colour.startswith("#")
else f"#{colour}"
)
url = "https://pyrss.cor.bz"
title = subscription_name
if total_batches > 1:
title += f" [{batch_number+1}/{total_batches}]"
embed = discord.Embed(title=title, colour=discord_colour, url=url)
embeds = [embed]
for content in contents:
description = shorten(markdownify(content.item_description, strip=("img",)), 256)
description += f"\n[View Article]({content.item_url})"
embed.add_field(
name=shorten(markdownify(content.item_title, strip=("img", "a")), 256),
value=description,
inline=False
)
# If there is only one content, set the main embed's image and return it.
# Otherwise progressing the normal way will create a wonky looking embed
# where the lone image is aligned left to an invisible non-existant right
# image.
if len(contents) == 1:
embed.set_image(url=content.item_image_url)
return embeds
if len(embeds) <= 5:
image_embed = discord.Embed(title="dummy", url=url)
image_embed.set_image(url=content.item_image_url)
embeds.append(image_embed)
return embeds
# async def process_valid_contents(
# self,
# contents: list[models.Content],
# channels: list[discord.TextChannel],
# client: httpx.AsyncClient
# ):
# semaphore = asyncio.Semaphore(5)
# async def batch_process(
# content: models.Content,
# channels: list[discord.TextChannel],
# client: httpx.AsyncClient
# ):
# async with semaphore: await self.process_valid_content(content, channels, client)
# tasks = [
# batch_process()
# ]
# @group.command(name="trigger")
# async def cmd_trigger_task(self, inter):
# await inter.response.defer()
# start_time = perf_counter()
# try:
# await self.subscription_task()
# except Exception as error:
# await inter.followup.send(str(error))
# finally:
# end_time = perf_counter()
# await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds")
# @tasks.loop(time=subscription_task_times)
# async def subscription_task(self):
# """
# Task for fetching and processing subscriptions.
# """
# log.info("Running subscription task")
# start_time = perf_counter()
# async with aiohttp.ClientSession() as session:
# self.api = API(self.bot.api_token, session)
# await self.execute_task()
# end_time = perf_counter()
# log.debug(f"task completed in {end_time - start_time:.4f} seconds")
# async def execute_task(self):
# """Execute the task directly."""
# # Filter out inactive guild IDs using related settings
# guild_ids = [guild.id for guild in self.bot.guilds]
# guild_settings = await self.get_guild_settings(guild_ids)
# active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active]
# subscriptions = await self.get_subscriptions(active_guild_ids)
# await self.process_subscriptions(subscriptions)
# async def get_guild_settings(self, guild_ids: list[int]) -> list[int]:
# """Returns a list of guild settings from the Bot's guilds, if they exist."""
# guild_settings = []
# # Iterate infinitely taking the iter no. as `page`
# # data will be empty after last page reached.
# for page, _ in enumerate(iter(int, 1)):
# data = await self.get_guild_settings_page(guild_ids, page)
# if not data:
# break
# guild_settings.extend(data[0])
# # Only return active guild IDs
# return GuildSettings.from_list(guild_settings)
# async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]:
# """Returns an individual page of guild settings."""
# try:
# return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1)
# except aiohttp.ClientResponseError as error:
# self.handle_pagination_error(error)
# return []
# def handle_pagination_error(self, error: aiohttp.ClientResponseError):
# """Handle the error cases from pagination attempts."""
# match error.status:
# case 404:
# log.debug("final page reached")
# case 403:
# log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True)
# self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently
# case _:
# log.debug(error)
# async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]:
# """Get a list of `Subscription`s matching the given `guild_ids`."""
# subscriptions = []
# # Iterate infinitely taking the iter no. as `page`
# # data will be empty after last page reached.
# for page, _ in enumerate(iter(int, 1)):
# data = await self.get_subs_page(guild_ids, page)
# if not data:
# break
# subscriptions.extend(data[0])
# return Subscription.from_list(subscriptions)
# async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]:
# """Returns an individual page of subscriptions."""
# try:
# return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1)
# except aiohttp.ClientResponseError as error:
# self.handle_pagination_error(error)
# return []
# async def process_subscriptions(self, subscriptions: list[Subscription]):
# """Process a given list of `Subscription`s."""
# async def process_single_subscription(sub: Subscription):
# log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id)
# if not sub.active or not sub.channels_count:
# return
# unparsed_feed = await get_unparsed_feed(sub.url)
# parsed_feed = parse(unparsed_feed)
# rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
# await self.process_items(sub, rss_feed)
# semaphore = asyncio.Semaphore(10)
# async def semaphore_process(sub: Subscription):
# async with semaphore:
# await process_single_subscription(sub)
# await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions))
# async def process_items(self, sub: Subscription, feed: RSSFeed):
# log.debug("processing items")
# channels = await self.fetch_or_get_channels(await sub.get_channels(self.api))
# filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters]
# for item in feed.items:
# log.debug("processing item '%s'", item.guid)
# if item.pub_date < sub.published_threshold:
# log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold)
# continue
# blocked = any(self.filter_item(_filter, item) for _filter in filters)
# mutated_item = item.create_mutated_copy(sub.mutators) if sub.mutators else None
# for channel in channels:
# await self.track_and_send(sub, feed, item, mutated_item, channel, blocked)
# async def fetch_or_get_channels(self, channels_data: list[dict]):
# channels = []
# for data in channels_data:
# try:
# channel = self.bot.get_channel(data.channel_id)
# channels.append(channel or await self.bot.fetch_channel(data.channel_id))
# except Forbidden:
# log.error(f"Forbidden Channel '{data.channel_id}'")
# return channels
# def filter_item(self, _filter: dict, item: RSSItem) -> bool:
# """
# Returns `True` if item should be ignored due to filters.
# """
# match_found = match_text(_filter, item.title) or match_text(_filter, item.description)
# log.debug("filter match found? '%s'", match_found)
# return match_found
# async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem | None, channel: TextChannel, blocked: bool):
# message_id = -1
# log.debug("track and send func %s, %s", item.guid, item.title)
# result = await self.api.get_tracked_content(guid=item.guid)
# if result[1]:
# log.debug(f"This item is already tracked, skipping '{item.guid}'")
# return
# result = await self.api.get_tracked_content(url=item.link)
# if result[1]:
# log.debug(f"This item is already tracked, skipping '{item.guid}'")
# return
# if not blocked:
# try:
# log.debug("sending '%s', exists '%s'", item.guid, result[1])
# sendable_item = mutated_item or item
# message = await channel.send(embed=await sendable_item.to_embed(sub, feed, self.api.session))
# message_id = message.id
# except Forbidden:
# log.error(f"Forbidden to send to channel {channel.id}")
# await self.mark_tracked_item(sub, item, channel.id, message_id, blocked)
# async def process_batch(self):
# pass
# async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool):
# try:
# log.debug("marking as tracked")
# await self.api.create_tracked_content(
# guid=item.guid,
# title=item.title,
# url=item.link,
# subscription=sub.id,
# channel_id=channel_id,
# message_id=message_id,
# blocked=blocked
# )
# return True
# except aiohttp.ClientResponseError as error:
# if error.status == 409:
# log.debug(error)
# else:
# log.error(error)
# return False
async def setup(bot): async def setup(bot):

View File

@ -199,7 +199,7 @@ class RSSFeed:
description: str description: str
link: str link: str
lang: str lang: str
last_build_date: datetime last_build_date: datetime | None
image_href: str image_href: str
items: list[RSSItem] = None items: list[RSSItem] = None
@ -240,7 +240,8 @@ class RSSFeed:
language = pf.feed.get('language', None) language = pf.feed.get('language', None)
last_build_date = pf.feed.get('updated_parsed', None) last_build_date = pf.feed.get('updated_parsed', None)
last_build_date = datetime(*last_build_date[0:-2] if last_build_date else None) if last_build_date:
last_build_date = datetime(*last_build_date[0:-2])
image_href = pf.feed.get("image", {}).get("href") image_href = pf.feed.get("image", {}).get("href")
@ -250,7 +251,7 @@ class RSSFeed:
item = RSSItem.from_parsed_entry(entry) item = RSSItem.from_parsed_entry(entry)
feed.add_item(item) feed.add_item(item)
feed.items.reverse() feed.items.reverse() # order so that older items are processed first
return feed return feed
@ -301,6 +302,7 @@ class Subscription(DjangoDataModel):
published_threshold: datetime published_threshold: datetime
active: bool active: bool
channels_count: int channels_count: int
unique_content_rules: list
@staticmethod @staticmethod
def parser(item: dict) -> dict: def parser(item: dict) -> dict:
@ -311,6 +313,7 @@ class Subscription(DjangoDataModel):
"description": item.pop("article_desc_mutators") "description": item.pop("article_desc_mutators")
} }
item["published_threshold"] = datetime.strptime(item["published_threshold"], "%Y-%m-%dT%H:%M:%S%z") item["published_threshold"] = datetime.strptime(item["published_threshold"], "%Y-%m-%dT%H:%M:%S%z")
item["unique_content_rules"] = item.get("unique_content_rules", [])
return item return item
@ -357,3 +360,21 @@ class TrackedContent(DjangoDataModel):
item["creation_datetime"] = datetime.strptime(item["creation_datetime"], "%Y-%m-%dT%H:%M:%S.%f%z") item["creation_datetime"] = datetime.strptime(item["creation_datetime"], "%Y-%m-%dT%H:%M:%S.%f%z")
return item return item
@dataclass(slots=True)
class ContentFilter(DjangoDataModel):
id: int
name: str
matching_algorithm: int
match: str
is_insensitive: bool
is_whitelist: bool
guild_id: int
@staticmethod
def parser(item: dict) -> dict:
item["guild_id"] = int(item["guild_id"]) # stored as str due to a django/sqlite bug, convert back to int
return item

527
src/models.py Normal file
View File

@ -0,0 +1,527 @@
import re
import logging
import hashlib
import asyncio
from enum import Enum
from time import perf_counter
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from datetime import datetime, timezone
from textwrap import shorten
import feedparser.parsers
import httpx
import discord
import rapidfuzz
import feedparser
from bs4 import BeautifulSoup
from markdownify import markdownify
from utils import do_batch_job
log = logging.getLogger(__name__)
@dataclass
class DjangoDataModel(ABC):
@staticmethod
@abstractmethod
def parser(item: dict) -> dict:
return item
@classmethod
def from_list(cls, data: list[dict]) -> list:
return [cls(**cls.parser(item)) for item in data]
@classmethod
def from_dict(cls, data: dict):
return cls(**cls.parser(data))
@dataclass(slots=True)
class Server(DjangoDataModel):
id: int
name: str
icon_hash: str
is_bot_operational: bool
active: bool
@staticmethod
def parser(item: dict) -> dict:
item["id"] = int(item.pop("id"))
return item
class MatchingAlgorithm(Enum):
NONE = 0
ANY = 1
ALL = 2
LITERAL = 3
REGEX = 4
FUZZY = 5
AUTO = 6
@classmethod
def from_value(cls, value: int):
for member in cls:
if member.value == value:
return member
raise ValueError(f"No {cls.__class__.__name__} for value: {value}")
@dataclass(slots=True)
class ContentFilter(DjangoDataModel):
id: int
server_id: int
name: str
matching_pattern: str
matching_algorithm: MatchingAlgorithm
is_insensitive: bool
is_whitelist: bool
@staticmethod
def parser(item: dict) -> dict:
item["id"] = item.pop("id")
item["server_id"] = item.pop("server")
item["matching_pattern"] = item.pop("match")
item["matching_algorithm"] = MatchingAlgorithm.from_value(item.pop("matching_algorithm"))
return item
@property
def _regex_flags(self):
return re.IGNORECASE if self.is_insensitive else 0
@property
def cleaned_matching_pattern(self):
"""
Splits the pattern to individual keywords, getting rid of unnecessary
spaces and grouping quoted words together.
"""
findterms = re.compile(r'"([^"]+)"|(\S+)').findall
normspace = re.compile(r"\s+").sub
return [
re.escape(normspace(" ", (t[0] or t[1]).strip())).replace(r"\ ", r"\s+")
for t in findterms(self.matching_pattern)
]
def _match_any(self, matching_against: str):
for word in self.cleaned_matching_pattern:
if re.search(rf"\b{word}\b", matching_against, self._regex_flags):
return True
return False
def _match_all(self, matching_against: str):
for word in self.cleaned_matching_pattern:
if not re.search(rf"\b{word}\b", matching_against, self._regex_flags):
return False
return True
def _match_literal(self, matching_against: str):
return bool(
re.search(
rf"\b{re.escape(self.matching_pattern)}\b",
matching_against,
self._regex_flags
)
)
def _match_regex(self, matching_against: str):
try:
return bool(re.search(
re.compile(self.matching_pattern, self._regex_flags),
matching_against
))
except re.error as exc:
log.error(f"Filter regex error: {exc}")
return False
def _match_fuzzy(self, matching_against: str):
matching_against = re.sub(r"[^\w\s]", "", matching_against)
matching_pattern = re.sub(r"[^\w\s]", "", self.matching_pattern)
if self.is_insensitive:
matching_against = matching_against.lower()
matching_pattern = matching_pattern.lower()
return rapidfuzz.fuzz.partial_ratio(
matching_against,
matching_pattern,
score_cutoff=90
)
def _get_algorithm_func(self):
match self.matching_algorithm:
case MatchingAlgorithm.NONE: return
case MatchingAlgorithm.ANY: return self._match_any
case MatchingAlgorithm.ALL: return self._match_all
case MatchingAlgorithm.LITERAL: return self._match_literal
case MatchingAlgorithm.REGEX: return self._match_regex
case MatchingAlgorithm.FUZZY: return self._match_fuzzy
case _: return
def matches(self, content) -> bool:
log.debug(f"applying filter: {self}")
if not self.matching_pattern.strip():
return False
if self.matching_algorithm == MatchingAlgorithm.ALL:
match_found = self._match_all(content.item_title + " " + content.item_description)
else:
algorithm_func = self._get_algorithm_func()
if not algorithm_func:
log.error(f"Bad algorithm function: {self.matching_algorithm}")
return False
match_found = algorithm_func(content.item_title) or algorithm_func(content.item_description)
log.debug(f"filter match found: {match_found}")
return not match_found if self.is_whitelist else match_found
@dataclass(slots=True)
class MessageMutator(DjangoDataModel):
id: int
name: str
value: str
@staticmethod
def parser(item: dict) -> dict:
item["id"] = item.pop("id")
return item
@dataclass(slots=True)
class MessageStyle(DjangoDataModel):
id: int
server_id: int
name: str
colour: str
is_embed: bool
is_hyperlinked: bool
show_author: bool
show_timestamp: bool
show_images: bool
fetch_images: bool
title_mutator: dict | None
description_mutator: dict | None
auto_created: bool
@staticmethod
def parser(item: dict) -> dict:
item["id"] = int(item.pop("id"))
item["server_id"] = int(item.pop("server"))
item["title_mutator"] = item.pop("title_mutator_detail")
item["description_mutator"] = item.pop("description_mutator_detail")
return item
@dataclass(slots=True)
class DiscordChannel(DjangoDataModel):
id: int
name: str
is_nsfw: bool
@staticmethod
def parser(item: dict) -> dict:
item["id"] = int(item.pop("id"))
return item
@dataclass(slots=True)
class Subscription(DjangoDataModel):
id: int
server_id: int
name: str
url: str
created_at: datetime
updated_at: datetime
extra_notes: str
active: bool
publish_threshold: datetime
channels: list[DiscordChannel]
filters: list[ContentFilter]
message_style: MessageStyle
_server: Server | None = None
@staticmethod
def parser(item: dict) -> dict:
item["id"] = int(item.pop("id"))
item["server_id"] = int(item.pop("server"))
item["created_at"] = datetime.strptime(item.pop("created_at"), "%Y-%m-%dT%H:%M:%S.%f%z")
item["updated_at"] = datetime.strptime(item.pop("updated_at"), "%Y-%m-%dT%H:%M:%S.%f%z")
item["publish_threshold"] = datetime.strptime(item.pop("publish_threshold"), "%Y-%m-%dT%H:%M:%S%z")
item["channels"] = DiscordChannel.from_list(item.pop("channels_detail"))
item["filters"] = ContentFilter.from_list(item.pop("filters_detail"))
item["message_style"] = MessageStyle.from_dict(item.pop("message_style_detail"))
return item
@property
def server(self) -> Server:
return self._server
@server.setter
def server(self, server: server):
self._server = server
async def get_rss_content(self, client: httpx.AsyncClient) -> str:
start_time = perf_counter()
try:
response = await client.get(self.url)
response.raise_for_status()
except httpx.HTTPError as exc:
log.error("(%s) HTTP Exception for %s - %s", type(exc), exc.request.url, exc)
return
finally:
log.debug(f"Got rss content in {perf_counter() - start_time:.4f} seconds")
content_type = response.headers.get("Content-Type")
if not "text/xml" in content_type:
log.warning("Invalid 'Content-Type' header: %s (must contain 'text/xml')", content_type)
return
return response.text
async def get_discord_channels(self, bot) -> list[discord.TextChannel]:
start_time = perf_counter()
channels = []
for channel_detail in self.channels:
try:
channel = bot.get_channel(channel_detail.id)
channels.append(channel or await bot.fetch_channel(channel_detail.id))
except Exception as exc:
channel_reference = f"({channel_detail.name}, {channel_detail.id})"
server_reference = f"({self.server.name}, {self.server.id})"
log.debug(f"Failed to get channel {channel_reference} from {server_reference}: {exc}")
log.debug(f"Got channels in {perf_counter() - start_time:.4f} seconds")
return channels
def filter_entries(self, contents: list) -> tuple[list, list]:
log.debug(f"filtering entries for {self.name} in {self.server.name}")
valid_contents = []
invalid_contents = []
for content in contents:
log.debug(f"filtering: '{content.item_title}'")
if any(content_filter.matches(content) for content_filter in self.filters):
content.blocked = True
invalid_contents.append(content)
else:
valid_contents.append(content)
log.debug(f"filtered content: valid:{len(valid_contents)}, invalid:{len(invalid_contents)}")
return valid_contents, invalid_contents
@dataclass(slots=True)
class Content(DjangoDataModel):
id: int
subscription_id: int
item_id: str
item_guid: str
item_url: str
item_title: str
item_description: str
item_content_hash: str
item_image_url: str | None
item_thumbnail_url: str | None
item_published: datetime | None
item_author: str
item_author_url: str | None
item_feed_title: str
item_feed_url: str
_subscription: Subscription | None = None
blocked: bool = False
@staticmethod
def parser(item: dict) -> dict:
item["id"] = item.pop("id")
item["subscription_id"] = item.pop("subscription")
return item
async def exists_via_api(self, url: str, headers: dict, client: httpx.AsyncClient):
log.debug(f"checking if {self.item_content_hash} exists via API")
params = {
"match_any": True, # allows any param to match, instead of needing all
"item_id": self.item_id,
"item_guid": self.item_guid,
"item_url": self.item_url,
"item_title": self.item_title,
"item_content_hash": self.item_content_hash,
"subscription": self.subscription_id
}
log.debug(f"params: {params}")
try:
response = await client.get(
url=url,
headers=headers,
params=params
)
response.raise_for_status()
except httpx.HTTPError as exc:
log.error(f"assuming not duplicate due to error: {exc}")
return False
return response.json().get("results", [])
def is_duplicate(self, other):
if not isinstance(other, Content):
raise ValueError(f"Expected Content, received {type(other)}")
other_details = other.duplicate_details
return any(
other_details.get(key) == value
for key, value in self.duplicate_details.items()
)
@property
def duplicate_details(self):
keys = [
"item_id",
"item_guid",
"item_url",
"item_title",
"item_content_hash"
]
data = asdict(self)
return { key: data[key] for key in keys }
async def save(self, client: httpx.AsyncClient, base_url: str, headers: dict):
log.debug(f"saving content {self.item_content_hash}")
data = asdict(self)
data.pop("id")
data["subscription"] = data.pop("subscription_id")
item_published = data.pop("item_published")
data["item_published"] = item_published.strftime("%Y-%m-%d") if item_published else None
data.pop("_subscription")
response = await client.post(
url=base_url + "content/",
headers=headers,
data=data
)
response.raise_for_status()
log.debug(f"save success for {self.item_content_hash}")
@classmethod
async def from_raw_rss(cls, rss: str, subscription: Subscription, client: httpx.AsyncClient):
style = subscription.message_style
parsed_rss = feedparser.parse(rss)
contents = []
async def create_content(entry: feedparser.FeedParserDict):
published = entry.get("published_parsed")
published = datetime(*published[0:6] if published else None, tzinfo=timezone.utc)
if published < subscription.publish_threshold:
log.debug("skipping due to publish threshold")
return
content_hash = hashlib.new("sha256")
content_hash.update(entry.get("description", "").encode())
item_url = entry.get("link", "")
item_image_url = entry.get("media_thumbnail", [{}])[0].get("url")
if style.fetch_images:
item_image_url = await cls.get_image_url(item_url, client)
content = Content.from_dict({
"id": -1,
"subscription": subscription.id,
"item_id": entry.get("id", ""),
"item_guid": entry.get("guid", ""),
"item_url": item_url,
"item_title": entry.get("title", ""),
"item_description": entry.get("description", ""),
"item_content_hash": content_hash.hexdigest(),
"item_image_url": item_image_url,
"item_thumbnail_url": parsed_rss.feed.image.href or None,
"item_published": published,
"item_author": entry.get("author", ""),
"item_author_url": entry.get("author_detail", {}).get("href"),
"item_feed_title": parsed_rss.get("feed", {}).get("title"),
"item_feed_url": parsed_rss.get("feed", {}).get("link")
})
# Weed out duplicates
log.debug("weeding out duplicates")
if any(content.is_duplicate(other) for other in contents):
log.debug("found duplicate while loading rss data")
return
content.subscription = subscription
contents.append(content)
await do_batch_job(parsed_rss.entries, create_content, 15)
contents.sort(key=lambda k: k.item_published)
return contents
@staticmethod
async def get_image_url(url: str, client: httpx.AsyncClient) -> str | None:
log.debug("Fetching image url")
try:
response = await client.get(url, timeout=15)
except httpx.HTTPError:
return None
soup = BeautifulSoup(response.text, "html.parser")
image_element = soup.select_one("meta[property='og:image']")
if not image_element:
return None
return image_element.get("content")
@property
def subscription(self) -> Subscription:
return self._subscription
@subscription.setter
def subscription(self, subscription: Subscription):
self._subscription = subscription
@property
def embed(self):
colour=discord.Colour.from_str(
f"#{self.subscription.message_style.colour}"
)
# ensure content fits within character limits
title = shorten(markdownify(self.item_title, strip=("img", "a")), 256)
description = shorten(markdownify(self.item_description, strip=("img",)), 4096)
author = self.item_author or self.item_feed_title
combined_length = len(title) + len(description) + (len(author) * 2)
cutoff = combined_length - 6000
description = shorten(description, cutoff) if cutoff > 0 else description
embed = discord.Embed(
title=title,
description=description,
url=self.item_url,
colour=colour,
timestamp=self.item_published
)
embed.set_image(url=self.item_image_url)
embed.set_thumbnail(url=self.item_thumbnail_url)
embed.set_author(
name=author,
url=self.item_author_url or self.item_feed_url
)
embed.set_footer(text=self.subscription.name)
log.debug(f"created embed: {embed.to_dict()}")
return embed

0
src/tests/__init__.py Normal file
View File

67
src/tests/test_content.py Normal file
View File

@ -0,0 +1,67 @@
from dataclasses import replace
import pytest
from models import MatchingAlgorithm, ContentFilter, Content
@pytest.fixture
def content() -> Content:
return Content(
id=0,
subscription_id=0,
item_id="",
item_guid="",
item_url="",
item_title="This week in the papers:",
item_description="The price of petrol has risen by over 2% since the previous financial report. Read the full article here.",
item_content_hash="",
item_image_url=None,
item_thumbnail_url=None,
item_published=None,
item_author="",
item_author_url=None,
item_feed_title="",
item_feed_url=""
)
@pytest.fixture
def content_filter() -> ContentFilter:
return ContentFilter(
id=0,
server_id=0,
name="Test Content Filter",
matching_pattern="",
matching_algorithm=MatchingAlgorithm.NONE,
is_insensitive=True,
is_whitelist=False
)
def test_content_filter_any(content: Content, content_filter: ContentFilter):
content_filter.matching_pattern = "france twenty report grass lately"
content_filter.matching_algorithm = MatchingAlgorithm.ANY
assert content_filter.matches(content)
def test_content_filter_all(content: Content, content_filter: ContentFilter):
content_filter.matching_pattern = "week petrol risen since"
content_filter.matching_algorithm = MatchingAlgorithm.ALL
assert content_filter.matches(content)
def test_content_filter_literal(content: Content, content_filter: ContentFilter):
content_filter.matching_pattern = "this week in the papers"
content_filter.matching_algorithm = MatchingAlgorithm.LITERAL
assert content_filter.matches(content)
def test_content_filter_regex(content: Content, content_filter: ContentFilter):
content_filter.matching_pattern = r"\b(The Papers|weekly quiz)\b"
content_filter.matching_algorithm = MatchingAlgorithm.REGEX
assert content_filter.matches(content)
# def test_content_filter_fuzzy(content: Content, content_filter: ContentFilter):
# content_filter.matching_algorithm = "this week in the papers"
# content_filter.matching_algorithm = MatchingAlgorithm.FUZZY
# assert content_filter.matches(content)
def test_content_duplicates(content: Content):
copy_of_content = replace(content)
assert content.is_duplicate(copy_of_content)

View File

@ -1,10 +1,10 @@
"""A collection of utility functions that can be used in various places.""" """A collection of utility functions that can be used in various places."""
import asyncio import asyncio
import aiohttp # import aiohttp
import logging import logging
import async_timeout # import async_timeout
from typing import Callable # from typing import Callable
from discord import Interaction, Embed, Colour, ButtonStyle, Button from discord import Interaction, Embed, Colour, ButtonStyle, Button
from discord.ui import View, button from discord.ui import View, button
@ -12,325 +12,335 @@ from discord.ext.commands import Bot
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
async def fetch(session, url: str) -> str: async def do_batch_job(iterable: list, func, batch_size: int):
async with async_timeout.timeout(20): semaphore = asyncio.Semaphore(batch_size)
async with session.get(url) as response:
return await response.text()
async def get_unparsed_feed(url: str, session: aiohttp.ClientSession=None): async def batch_job(item):
if session is not None: async with semaphore:
return await fetch(session, url) await func(item)
async with aiohttp.ClientSession() as session: tasks = [batch_job(item) for item in iterable]
return await fetch(session, url) await asyncio.gather(*tasks)
async def get_rss_data(url: str): # async def fetch(session, url: str) -> str:
async with aiohttp.ClientSession() as session: # async with async_timeout.timeout(20):
async with session.get(url) as response: # async with session.get(url) as response:
items = await response.text(), response.status # return await response.text()
return items # async def get_unparsed_feed(url: str, session: aiohttp.ClientSession=None):
# if session is not None:
# return await fetch(session, url)
async def followup(inter: Interaction, *args, **kwargs): # async with aiohttp.ClientSession() as session:
"""Shorthand for following up on an interaction. # return await fetch(session, url)
Parameters # async def get_rss_data(url: str):
---------- # async with aiohttp.ClientSession() as session:
inter : Interaction # async with session.get(url) as response:
Represents an app command interaction. # items = await response.text(), response.status
"""
await inter.followup.send(*args, **kwargs) # return items
# async def followup(inter: Interaction, *args, **kwargs):
# """Shorthand for following up on an interaction.
# Parameters
# ----------
# inter : Interaction
# Represents an app command interaction.
# """
# await inter.followup.send(*args, **kwargs)
# https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png # # https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
class FollowupIcons: # class FollowupIcons:
error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png" # error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png" # success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"
trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png" # trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png"
info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png" # info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png"
added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png" # added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png"
assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png" # assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
class PaginationView(View): # class PaginationView(View):
"""A Discord UI View that adds pagination to an embed.""" # """A Discord UI View that adds pagination to an embed."""
def __init__( # def __init__(
self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable, # self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable,
formatdata: Callable, pagesize: int, initpage: int=1 # formatdata: Callable, pagesize: int, initpage: int=1
): # ):
"""_summary_ # """_summary_
Args: # Args:
bot (commands.Bot) The discord bot # bot (commands.Bot) The discord bot
inter (Interaction): Represents a discord command interaction. # inter (Interaction): Represents a discord command interaction.
embed (Embed): The base embed to paginate. # embed (Embed): The base embed to paginate.
getdata (Callable): A function that provides data, must return Tuple[List[Any], int]. # getdata (Callable): A function that provides data, must return Tuple[List[Any], int].
formatdata (Callable): A formatter function that determines how the data is displayed. # formatdata (Callable): A formatter function that determines how the data is displayed.
pagesize (int): The size of each page. # pagesize (int): The size of each page.
initpage (int, optional): The inital page. Defaults to 1. # initpage (int, optional): The inital page. Defaults to 1.
""" # """
self.bot = bot # self.bot = bot
self.inter = inter # self.inter = inter
self.embed = embed # self.embed = embed
self.getdata = getdata # self.getdata = getdata
self.formatdata = formatdata # self.formatdata = formatdata
self.maxpage = None # self.maxpage = None
self.pagesize = pagesize # self.pagesize = pagesize
self.index = initpage # self.index = initpage
# emoji reference # # emoji reference
self.start_emoji = bot.get_emoji(1204542364073463818) # self.start_emoji = bot.get_emoji(1204542364073463818)
self.end_emoji = bot.get_emoji(1204542367752003624) # self.end_emoji = bot.get_emoji(1204542367752003624)
super().__init__(timeout=100) # super().__init__(timeout=100)
async def check_user_is_author(self, inter: Interaction) -> bool: # async def check_user_is_author(self, inter: Interaction) -> bool:
"""Ensure the user is the author of the original command.""" # """Ensure the user is the author of the original command."""
if inter.user == self.inter.user: # if inter.user == self.inter.user:
return True # return True
await inter.response.defer() # await inter.response.defer()
await ( # await (
Followup(None, "Only the author can interact with this.") # Followup(None, "Only the author can interact with this.")
.error() # .error()
.send(inter, ephemeral=True) # .send(inter, ephemeral=True)
) # )
return False # return False
async def on_timeout(self): # async def on_timeout(self):
"""Erase the controls on timeout.""" # """Erase the controls on timeout."""
message = await self.inter.original_response() # message = await self.inter.original_response()
await message.edit(view=None) # await message.edit(view=None)
@staticmethod # @staticmethod
def calc_total_pages(results: int, max_pagesize: int) -> int: # def calc_total_pages(results: int, max_pagesize: int) -> int:
result = ((results - 1) // max_pagesize) + 1 # result = ((results - 1) // max_pagesize) + 1
return result # return result
def calc_dataitem_index(self, dataitem_index: int): # def calc_dataitem_index(self, dataitem_index: int):
"""Calculates a given index to be relative to the sum of all pages items # """Calculates a given index to be relative to the sum of all pages items
Example: dataitem_index = 6 # Example: dataitem_index = 6
pagesize = 10 # pagesize = 10
if page == 1 then return 6 # if page == 1 then return 6
else return 6 + 10 * (page - 1)""" # else return 6 + 10 * (page - 1)"""
if self.index > 1: # if self.index > 1:
dataitem_index += self.pagesize * (self.index - 1) # dataitem_index += self.pagesize * (self.index - 1)
dataitem_index += 1 # dataitem_index += 1
return dataitem_index # return dataitem_index
@button(emoji="◀️", style=ButtonStyle.blurple) # @button(emoji="◀️", style=ButtonStyle.blurple)
async def backward(self, inter: Interaction, button: Button): # async def backward(self, inter: Interaction, button: Button):
""" # """
Action the backwards button. # Action the backwards button.
""" # """
self.index -= 1 # self.index -= 1
await inter.response.defer() # await inter.response.defer()
self.inter = inter # self.inter = inter
await self.navigate() # await self.navigate()
@button(emoji="▶️", style=ButtonStyle.blurple) # @button(emoji="▶️", style=ButtonStyle.blurple)
async def forward(self, inter: Interaction, button: Button): # async def forward(self, inter: Interaction, button: Button):
""" # """
Action the forwards button. # Action the forwards button.
""" # """
self.index += 1 # self.index += 1
await inter.response.defer() # await inter.response.defer()
self.inter = inter # self.inter = inter
await self.navigate() # await self.navigate()
@button(emoji="⏭️", style=ButtonStyle.blurple) # @button(emoji="⏭️", style=ButtonStyle.blurple)
async def start_or_end(self, inter: Interaction, button: Button): # async def start_or_end(self, inter: Interaction, button: Button):
""" # """
Action the start and end button. # Action the start and end button.
This button becomes return to start if at end, otherwise skip to end. # This button becomes return to start if at end, otherwise skip to end.
""" # """
# Determine if should skip to start or end # # Determine if should skip to start or end
if self.index <= self.maxpage // 2: # if self.index <= self.maxpage // 2:
self.index = self.maxpage # self.index = self.maxpage
else: # else:
self.index = 1 # self.index = 1
await inter.response.defer() # await inter.response.defer()
self.inter = inter # self.inter = inter
await self.navigate() # await self.navigate()
async def navigate(self): # async def navigate(self):
""" # """
Acts as an update method for the entire instance. # Acts as an update method for the entire instance.
""" # """
log.debug("navigating to page: %s", self.index) # log.debug("navigating to page: %s", self.index)
self.update_buttons() # self.update_buttons()
paged_embed = await self.create_paged_embed() # paged_embed = await self.create_paged_embed()
await self.inter.edit_original_response(embed=paged_embed, view=self) # await self.inter.edit_original_response(embed=paged_embed, view=self)
async def create_paged_embed(self) -> Embed: # async def create_paged_embed(self) -> Embed:
""" # """
Returns a copy of the known embed, but with data from the current page. # Returns a copy of the known embed, but with data from the current page.
""" # """
embed = self.embed.copy() # embed = self.embed.copy()
try: # try:
data, total_results = await self.getdata(self.index, self.pagesize) # data, total_results = await self.getdata(self.index, self.pagesize)
except aiohttp.ClientResponseError as exc: # except aiohttp.ClientResponseError as exc:
log.error(exc) # log.error(exc)
await ( # await (
Followup(f"Error · {exc.message}",) # Followup(f"Error · {exc.message}",)
.footer(f"HTTP {exc.code}") # .footer(f"HTTP {exc.code}")
.error() # .error()
.send(self.inter) # .send(self.inter)
) # )
raise exc # raise exc
self.maxpage = self.calc_total_pages(total_results, self.pagesize) # self.maxpage = self.calc_total_pages(total_results, self.pagesize)
log.debug(f"{self.maxpage=!r}") # log.debug(f"{self.maxpage=!r}")
for i, item in enumerate(data): # for i, item in enumerate(data):
i = self.calc_dataitem_index(i) # i = self.calc_dataitem_index(i)
if asyncio.iscoroutinefunction(self.formatdata): # if asyncio.iscoroutinefunction(self.formatdata):
key, value = await self.formatdata(i, item) # key, value = await self.formatdata(i, item)
else: # else:
key, value = self.formatdata(i, item) # key, value = self.formatdata(i, item)
embed.add_field(name=key, value=value, inline=False) # embed.add_field(name=key, value=value, inline=False)
if not total_results: # if not total_results:
embed.description = "There are no results" # embed.description = "There are no results"
if self.maxpage > 1: # if self.maxpage > 1:
embed.set_footer(text=f"Page {self.index}/{self.maxpage}") # embed.set_footer(text=f"Page {self.index}/{self.maxpage}")
return embed # return embed
def update_buttons(self): # def update_buttons(self):
if self.index >= self.maxpage: # if self.index >= self.maxpage:
self.children[2].emoji = self.start_emoji # self.children[2].emoji = self.start_emoji
else: # else:
self.children[2].emoji = self.end_emoji # self.children[2].emoji = self.end_emoji
self.children[0].disabled = self.index == 1 # self.children[0].disabled = self.index == 1
self.children[1].disabled = self.index == self.maxpage # self.children[1].disabled = self.index == self.maxpage
async def send(self): # async def send(self):
"""Send the pagination view. It may be important to defer before invoking this method.""" # """Send the pagination view. It may be important to defer before invoking this method."""
log.debug("sending pagination view") # log.debug("sending pagination view")
embed = await self.create_paged_embed() # embed = await self.create_paged_embed()
if self.maxpage <= 1: # if self.maxpage <= 1:
await self.inter.edit_original_response(embed=embed) # await self.inter.edit_original_response(embed=embed)
return # return
self.update_buttons() # self.update_buttons()
await self.inter.edit_original_response(embed=embed, view=self) # await self.inter.edit_original_response(embed=embed, view=self)
class Followup: # class Followup:
"""Wrapper for a discord embed to follow up an interaction.""" # """Wrapper for a discord embed to follow up an interaction."""
def __init__( # def __init__(
self, # self,
title: str = None, # title: str = None,
description: str = None, # description: str = None,
): # ):
self._embed = Embed( # self._embed = Embed(
title=title, # title=title,
description=description # description=description
) # )
async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False): # async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False):
"""""" # """"""
await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral) # await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral)
def fields(self, inline: bool = False, **fields: dict): # def fields(self, inline: bool = False, **fields: dict):
"""""" # """"""
for key, value in fields.items(): # for key, value in fields.items():
self._embed.add_field(name=key, value=value, inline=inline) # self._embed.add_field(name=key, value=value, inline=inline)
return self # return self
def image(self, url: str): # def image(self, url: str):
"""""" # """"""
self._embed.set_image(url=url) # self._embed.set_image(url=url)
return self # return self
def author(self, name: str, url: str=None, icon_url: str=None): # def author(self, name: str, url: str=None, icon_url: str=None):
"""""" # """"""
self._embed.set_author(name=name, url=url, icon_url=icon_url) # self._embed.set_author(name=name, url=url, icon_url=icon_url)
return self # return self
def footer(self, text: str, icon_url: str = None): # def footer(self, text: str, icon_url: str = None):
"""""" # """"""
self._embed.set_footer(text=text, icon_url=icon_url) # self._embed.set_footer(text=text, icon_url=icon_url)
return self # return self
def error(self): # def error(self):
"""""" # """"""
self._embed.colour = Colour.red() # self._embed.colour = Colour.red()
self._embed.set_thumbnail(url=FollowupIcons.error) # self._embed.set_thumbnail(url=FollowupIcons.error)
return self # return self
def success(self): # def success(self):
"""""" # """"""
self._embed.colour = Colour.green() # self._embed.colour = Colour.green()
self._embed.set_thumbnail(url=FollowupIcons.success) # self._embed.set_thumbnail(url=FollowupIcons.success)
return self # return self
def info(self): # def info(self):
"""""" # """"""
self._embed.colour = Colour.blue() # self._embed.colour = Colour.blue()
self._embed.set_thumbnail(url=FollowupIcons.info) # self._embed.set_thumbnail(url=FollowupIcons.info)
return self # return self
def added(self): # def added(self):
"""""" # """"""
self._embed.colour = Colour.blue() # self._embed.colour = Colour.blue()
self._embed.set_thumbnail(url=FollowupIcons.added) # self._embed.set_thumbnail(url=FollowupIcons.added)
return self # return self
def assign(self): # def assign(self):
"""""" # """"""
self._embed.colour = Colour.blue() # self._embed.colour = Colour.blue()
self._embed.set_thumbnail(url=FollowupIcons.assigned) # self._embed.set_thumbnail(url=FollowupIcons.assigned)
return self # return self
def trash(self): # def trash(self):
"""""" # """"""
self._embed.colour = Colour.red() # self._embed.colour = Colour.red()
self._embed.set_thumbnail(url=FollowupIcons.trash) # self._embed.set_thumbnail(url=FollowupIcons.trash)
return self # return self
def extract_error_info(error: Exception) -> str: # def extract_error_info(error: Exception) -> str:
class_name = error.__class__.__name__ # class_name = error.__class__.__name__
desc = str(error) # desc = str(error)
return class_name, desc # return class_name, desc