Compare commits
36 Commits
Author | SHA1 | Date | |
---|---|---|---|
7d395d2001 | |||
b0b02aae9f | |||
ca2037528d | |||
a9c08dea4d | |||
8fcccc7ac9 | |||
efbb4b18ff | |||
81795feb65 | |||
1f9075ce60 | |||
6e153782c2 | |||
d86fc0eb71 | |||
08295dfea6 | |||
470f78c144 | |||
94b154742e | |||
1294909188 | |||
cf8fb34a29 | |||
eb97dca5c6 | |||
b8f1ffb8d9 | |||
ccfa35adda | |||
cc06d3e09f | |||
f9de8ff085 | |||
02a4917152 | |||
ff30e31cf1 | |||
50f7a62cd4 | |||
4c31fe69e2 | |||
bb3475b79d | |||
82fe6bea9a | |||
32b8092034 | |||
26f697cf78 | |||
ab472e1979 | |||
7515d6b86e | |||
a5c593bb14 | |||
48697c08a6 | |||
cbe15aceb8 | |||
e8a9c270e4 | |||
2a1aaa689e | |||
1a4f25ec97 |
@ -1,4 +1,4 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.2.0
|
current_version = 0.2.1
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
|
60
CHANGELOG.md
60
CHANGELOG.md
@ -1,14 +1,54 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
**v0.2.0**
|
All notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
- Fix: Fetch channels if not found in bot cache (error fix)
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||||
- Enhancement: command to test a channel's permissions allow for the Bot to function
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
- Enhancement: account for active state from a server's settings (`GuildSettings`)
|
|
||||||
- Enhancement: command to view tracked content from the server or a given subscription of the same server.
|
|
||||||
- Other: code optimisation & `GuildSettings` dataclass
|
|
||||||
- Other: Cleaned out many instances of unused code
|
|
||||||
|
|
||||||
**v0.1.1**
|
## [Unreleased]
|
||||||
|
|
||||||
- Docs: Start of changelog
|
### Added
|
||||||
- Enhancement: Versioning with tagged docker images
|
|
||||||
|
- Search and filter controls for the data viewing commands
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- RSS feeds without a build date would break the subscription task
|
||||||
|
- TypeError when an RSS item lacks a title or description
|
||||||
|
- Fix an issue with a missing field on the Subscription model
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Show whether a subscription is active or inactive when using a data view command
|
||||||
|
- Added `unique_content_rules` field to `Subscription` dataclass (support latest pyrss-website version)
|
||||||
|
- Update changelog to follow [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
|
||||||
|
|
||||||
|
## [0.2.0] - 2024-08-19
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Command to view tracked content from the relevant server
|
||||||
|
- Command to test the bot's permissions in a specified channel
|
||||||
|
- `GuildSettings` dataclass
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- channels are `NoneType` because they didn't exist in the cache, fixed by fetching from API
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Subscription task will ignore subscriptions flagged as 'inactive'
|
||||||
|
- Code optimisation
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
- Unused and commented out code
|
||||||
|
|
||||||
|
## [0.1.1] - 2024-08-17
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Start of changelog
|
||||||
|
- Versioning with tagged docker images
|
||||||
|
|
||||||
|
## [0.1.0] - 2024-08-13
|
||||||
|
@ -4,4 +4,4 @@ An RSS driven Discord bot written in Python.
|
|||||||
|
|
||||||
Provides user commands for storing RSS feed URLs that can be assigned to any given discord channel.
|
Provides user commands for storing RSS feed URLs that can be assigned to any given discord channel.
|
||||||
|
|
||||||
Depends on the [web application](https://gitea.corbz.dev/corbz/PYRSS-Website). Check the releases for compatible versions.
|
Depends on the [web application](https://gitea.cor.bz/corbz/PYRSS-Website). Check the releases for compatible versions.
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
{
|
{
|
||||||
"version": 1,
|
"version": 1,
|
||||||
"disable_existing_loggers": false,
|
"disable_existing_loggers": true,
|
||||||
"formatters": {
|
"formatters": {
|
||||||
"simple": {
|
"simple": {
|
||||||
"format": "%(levelname)s %(message)s"
|
"format": "[%(module)s|%(message)s]"
|
||||||
},
|
},
|
||||||
"detail": {
|
"detail": {
|
||||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s"
|
"format": "[%(asctime)s] [%(levelname)s] [%(module)s]: %(message)s"
|
||||||
},
|
},
|
||||||
"complex": {
|
"complex": {
|
||||||
"format": "[%(levelname)s|%(module)s|L%(lineno)d] %(asctime)s %(message)s",
|
"format": "[%(levelname)s|%(name)s|L%(lineno)d] %(asctime)s %(message)s",
|
||||||
"datefmt": "%Y-%m-%dT%H:%M:%S%z"
|
"datefmt": "%Y-%m-%dT%H:%M:%S%z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -17,7 +17,7 @@
|
|||||||
"stdout": {
|
"stdout": {
|
||||||
"class": "logging.StreamHandler",
|
"class": "logging.StreamHandler",
|
||||||
"level": "DEBUG",
|
"level": "DEBUG",
|
||||||
"formatter": "simple",
|
"formatter": "detail",
|
||||||
"stream": "ext://sys.stdout"
|
"stream": "ext://sys.stdout"
|
||||||
},
|
},
|
||||||
"file": {
|
"file": {
|
||||||
@ -46,6 +46,14 @@
|
|||||||
},
|
},
|
||||||
"discord": {
|
"discord": {
|
||||||
"level": "INFO"
|
"level": "INFO"
|
||||||
|
},
|
||||||
|
"httpx": {
|
||||||
|
"level": "WARNING",
|
||||||
|
"propagate": false
|
||||||
|
},
|
||||||
|
"httpcore": {
|
||||||
|
"level": "WARNING",
|
||||||
|
"propagate": false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
filterwarnings =
|
||||||
|
ignore:'audioop' is deprecated and slated for removal in Python 3.13:DeprecationWarning
|
@ -1,29 +1,30 @@
|
|||||||
aiocache==0.12.2
|
aiohappyeyeballs==2.4.3
|
||||||
aiohttp==3.9.3
|
aiohttp==3.10.10
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
aiosqlite==0.19.0
|
anyio==4.6.2.post1
|
||||||
async-timeout==4.0.3
|
attrs==24.2.0
|
||||||
asyncpg==0.29.0
|
|
||||||
attrs==23.2.0
|
|
||||||
beautifulsoup4==4.12.3
|
beautifulsoup4==4.12.3
|
||||||
bump2version==1.0.1
|
bump2version==1.0.1
|
||||||
click==8.1.7
|
certifi==2024.8.30
|
||||||
discord.py==2.3.2
|
discord.py==2.3.2
|
||||||
feedparser==6.0.11
|
feedparser==6.0.11
|
||||||
frozenlist==1.4.1
|
frozenlist==1.5.0
|
||||||
greenlet==3.0.3
|
h11==0.14.0
|
||||||
idna==3.6
|
httpcore==1.0.6
|
||||||
|
httpx==0.27.2
|
||||||
|
idna==3.10
|
||||||
|
iniconfig==2.0.0
|
||||||
markdownify==0.11.6
|
markdownify==0.11.6
|
||||||
multidict==6.0.5
|
multidict==6.1.0
|
||||||
pip-chill==1.0.3
|
packaging==24.2
|
||||||
psycopg2-binary==2.9.9
|
pluggy==1.5.0
|
||||||
|
propcache==0.2.0
|
||||||
|
pytest==8.3.3
|
||||||
python-dotenv==1.0.0
|
python-dotenv==1.0.0
|
||||||
rapidfuzz==3.9.4
|
rapidfuzz==3.9.4
|
||||||
sgmllib3k==1.0.0
|
sgmllib3k==1.0.0
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
soupsieve==2.5
|
sniffio==1.3.1
|
||||||
SQLAlchemy==2.0.23
|
soupsieve==2.6
|
||||||
typing_extensions==4.10.0
|
|
||||||
uwuipy==0.1.9
|
uwuipy==0.1.9
|
||||||
validators==0.22.0
|
yarl==1.17.0
|
||||||
yarl==1.9.4
|
|
||||||
|
@ -157,7 +157,7 @@ class API:
|
|||||||
|
|
||||||
log.debug("getting tracked content")
|
log.debug("getting tracked content")
|
||||||
|
|
||||||
return await self._get_many(self.API_ENDPOINT + f"tracked-content/", filters)
|
return await self._get_many(self.API_ENDPOINT + "tracked-content/", filters)
|
||||||
|
|
||||||
async def get_filter(self, filter_id: int) -> dict:
|
async def get_filter(self, filter_id: int) -> dict:
|
||||||
"""
|
"""
|
||||||
@ -167,3 +167,10 @@ class API:
|
|||||||
log.debug("getting a filter")
|
log.debug("getting a filter")
|
||||||
|
|
||||||
return await self._get_one(f"{self.API_ENDPOINT}filter/{filter_id}")
|
return await self._get_one(f"{self.API_ENDPOINT}filter/{filter_id}")
|
||||||
|
|
||||||
|
async def get_filters(self, **filters) -> tuple[list[dict], int]:
|
||||||
|
"""
|
||||||
|
Get many instances of Filter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await self._get_many(self.API_ENDPOINT + "filter/", filters)
|
||||||
|
@ -7,76 +7,82 @@ import logging
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import aiohttp
|
# import aiohttp
|
||||||
import validators
|
# import validators
|
||||||
from feedparser import FeedParserDict, parse
|
from feedparser import FeedParserDict, parse
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from discord import Interaction, TextChannel, Embed, Colour
|
from discord import Interaction, TextChannel, Embed, Colour
|
||||||
from discord.app_commands import Choice, Group, autocomplete, rename, command
|
from discord.app_commands import Choice, choices, Group, autocomplete, rename, command
|
||||||
from discord.errors import Forbidden
|
from discord.errors import Forbidden
|
||||||
|
|
||||||
from api import API
|
# from api import API
|
||||||
from feed import Subscription, TrackedContent
|
# from feed import Subscription, TrackedContent, ContentFilter
|
||||||
from utils import (
|
# from utils import (
|
||||||
Followup,
|
# Followup,
|
||||||
PaginationView,
|
# PaginationView,
|
||||||
get_rss_data,
|
# get_rss_data,
|
||||||
)
|
# )
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
rss_list_sort_choices = [
|
# rss_list_sort_choices = [
|
||||||
Choice(name="Nickname", value=0),
|
# Choice(name="Nickname", value=0),
|
||||||
Choice(name="Date Added", value=1)
|
# Choice(name="Date Added", value=1)
|
||||||
]
|
# ]
|
||||||
channels_list_sort_choices=[
|
# channels_list_sort_choices=[
|
||||||
Choice(name="Feed Nickname", value=0),
|
# Choice(name="Feed Nickname", value=0),
|
||||||
Choice(name="Channel ID", value=1),
|
# Choice(name="Channel ID", value=1),
|
||||||
Choice(name="Date Added", value=2)
|
# Choice(name="Date Added", value=2)
|
||||||
]
|
# ]
|
||||||
|
|
||||||
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the
|
# # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the
|
||||||
# target resource. Run a period task to check this.
|
# # target resource. Run a period task to check this.
|
||||||
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]:
|
# async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]:
|
||||||
"""Validate a provided RSS source.
|
# """Validate a provided RSS source.
|
||||||
|
|
||||||
Parameters
|
# Parameters
|
||||||
----------
|
# ----------
|
||||||
nickname : str
|
# nickname : str
|
||||||
Nickname of the source. Must not contain URL.
|
# Nickname of the source. Must not contain URL.
|
||||||
url : str
|
# url : str
|
||||||
URL of the source. Must be URL with valid status code and be an RSS feed.
|
# URL of the source. Must be URL with valid status code and be an RSS feed.
|
||||||
|
|
||||||
Returns
|
# Returns
|
||||||
-------
|
# -------
|
||||||
str or None
|
# str or None
|
||||||
String invalid message if invalid, NoneType if valid.
|
# String invalid message if invalid, NoneType if valid.
|
||||||
FeedParserDict or None
|
# FeedParserDict or None
|
||||||
The feed parsed from the given URL or None if invalid.
|
# The feed parsed from the given URL or None if invalid.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
# Ensure the URL is valid
|
# # Ensure the URL is valid
|
||||||
if not validators.url(url):
|
# if not validators.url(url):
|
||||||
return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
|
# return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
|
||||||
|
|
||||||
# Check the nickname is not a URL
|
# # Check the nickname is not a URL
|
||||||
if validators.url(nickname):
|
# if validators.url(nickname):
|
||||||
return "It looks like the nickname you have entered is a URL.\n" \
|
# return "It looks like the nickname you have entered is a URL.\n" \
|
||||||
f"For security reasons, this is not allowed.\n`{nickname=}`", None
|
# f"For security reasons, this is not allowed.\n`{nickname=}`", None
|
||||||
|
|
||||||
|
|
||||||
feed_data, status_code = await get_rss_data(url)
|
# feed_data, status_code = await get_rss_data(url)
|
||||||
|
|
||||||
# Check the URL status code is valid
|
# # Check the URL status code is valid
|
||||||
if status_code != 200:
|
# if status_code != 200:
|
||||||
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
|
# return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
|
||||||
|
|
||||||
# Check the contents is actually an RSS feed.
|
# # Check the contents is actually an RSS feed.
|
||||||
feed = parse(feed_data)
|
# feed = parse(feed_data)
|
||||||
if not feed.version:
|
# if not feed.version:
|
||||||
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
|
# return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
|
||||||
|
|
||||||
return None, feed
|
# return None, feed
|
||||||
|
|
||||||
|
# tri_choices = [
|
||||||
|
# Choice(name="Yes", value=2),
|
||||||
|
# Choice(name="No (default)", value=1),
|
||||||
|
# Choice(name="All", value=0),
|
||||||
|
# ]
|
||||||
|
|
||||||
|
|
||||||
class CommandsCog(commands.Cog):
|
class CommandsCog(commands.Cog):
|
||||||
@ -94,135 +100,206 @@ class CommandsCog(commands.Cog):
|
|||||||
|
|
||||||
log.info("%s cog is ready", self.__class__.__name__)
|
log.info("%s cog is ready", self.__class__.__name__)
|
||||||
|
|
||||||
# Group for commands about viewing data
|
# # Group for commands about viewing data
|
||||||
view_group = Group(
|
# view_group = Group(
|
||||||
name="view",
|
# name="view",
|
||||||
description="View data.",
|
# description="View data.",
|
||||||
guild_only=True
|
# guild_only=True
|
||||||
)
|
# )
|
||||||
|
|
||||||
@view_group.command(name="subscriptions")
|
# @view_group.command(name="subscriptions")
|
||||||
async def cmd_list_subs(self, inter: Interaction, search: str = ""):
|
# async def cmd_list_subs(self, inter: Interaction, search: str = ""):
|
||||||
"""List Subscriptions from this server."""
|
# """List Subscriptions from this server."""
|
||||||
|
|
||||||
await inter.response.defer()
|
# await inter.response.defer()
|
||||||
|
|
||||||
def formatdata(index, item):
|
# def formatdata(index, item):
|
||||||
item = Subscription.from_dict(item)
|
# item = Subscription.from_dict(item)
|
||||||
|
|
||||||
channels = f"{item.channels_count}{' channels' if item.channels_count != 1 else ' channel'}"
|
|
||||||
filters = f"{len(item.filters)}{' filters' if len(item.filters) != 1 else ' filter'}"
|
|
||||||
notes = item.extra_notes[:25] + "..." if len(item.extra_notes) > 28 else item.extra_notes
|
|
||||||
links = f"[RSS Link]({item.url}) · [API Link]({API.API_EXTERNAL_ENDPOINT}subscription/{item.id}/)"
|
|
||||||
|
|
||||||
description = f"{channels}, {filters}\n"
|
# notes = item.extra_notes[:25] + "..." if len(item.extra_notes) > 28 else item.extra_notes
|
||||||
description += f"{notes}\n" if notes else ""
|
# links = f"[RSS Link]({item.url}) • [API Link]({API.API_EXTERNAL_ENDPOINT}subscription/{item.id}/)"
|
||||||
description += links
|
# activeness = "✅ `enabled`" if item.active else "🚫 `disabled`"
|
||||||
|
|
||||||
key = f"{index}. {item.name}"
|
# description = f"🆔 `{item.id}`\n{activeness}\n#️⃣ `{item.channels_count}` 🔽 `{len(item.filters)}`\n"
|
||||||
return key, description # key, value pair
|
# description = f"{notes}\n" + description if notes else description
|
||||||
|
# description += links
|
||||||
|
|
||||||
async def getdata(page: int, pagesize: int):
|
# key = f"{index}. {item.name}"
|
||||||
async with aiohttp.ClientSession() as session:
|
# return key, description # key, value pair
|
||||||
api = API(self.bot.api_token, session)
|
|
||||||
return await api.get_subscriptions(
|
|
||||||
guild_id=inter.guild.id,
|
|
||||||
page=page,
|
|
||||||
page_size=pagesize,
|
|
||||||
search=search
|
|
||||||
)
|
|
||||||
|
|
||||||
embed = Followup(f"Subscriptions in {inter.guild.name}").info()._embed
|
# async def getdata(page: int, pagesize: int):
|
||||||
pagination = PaginationView(
|
# async with aiohttp.ClientSession() as session:
|
||||||
self.bot,
|
# api = API(self.bot.api_token, session)
|
||||||
inter=inter,
|
# return await api.get_subscriptions(
|
||||||
embed=embed,
|
# guild_id=inter.guild.id,
|
||||||
getdata=getdata,
|
# page=page,
|
||||||
formatdata=formatdata,
|
# page_size=pagesize,
|
||||||
pagesize=10,
|
# search=search
|
||||||
initpage=1
|
# )
|
||||||
)
|
|
||||||
await pagination.send()
|
|
||||||
|
|
||||||
@view_group.command(name="tracked-content")
|
# embed = Followup(f"Subscriptions in {inter.guild.name}").info()._embed
|
||||||
async def cmd_list_tracked(self, inter: Interaction, search: str = ""):
|
# pagination = PaginationView(
|
||||||
"""List Tracked Content from this server, or a given sub"""
|
# self.bot,
|
||||||
|
# inter=inter,
|
||||||
|
# embed=embed,
|
||||||
|
# getdata=getdata,
|
||||||
|
# formatdata=formatdata,
|
||||||
|
# pagesize=10,
|
||||||
|
# initpage=1
|
||||||
|
# )
|
||||||
|
# await pagination.send()
|
||||||
|
|
||||||
await inter.response.defer()
|
# @view_group.command(name="tracked-content")
|
||||||
|
# @choices(blocked=tri_choices)
|
||||||
|
# async def cmd_list_tracked(self, inter: Interaction, search: str = "", blocked: Choice[int] = 1):
|
||||||
|
# """List Tracked Content from this server""" # TODO: , or a given sub
|
||||||
|
|
||||||
def formatdata(index, item):
|
# await inter.response.defer()
|
||||||
item = TrackedContent.from_dict(item)
|
|
||||||
sub = Subscription.from_dict(item.subscription)
|
|
||||||
|
|
||||||
links = f"[Content Link]({item.url}) · [Message Link](https://discord.com/channels/{sub.guild_id}/{item.channel_id}/{item.message_id}/)"
|
# # If the user picks an option it's an instance of `Choice` otherwise `str`
|
||||||
description = f"Subscription: {sub.name}\n{links}"
|
# # Can't figure a way to select a default choices, so blame discordpy for this mess.
|
||||||
|
# if isinstance(blocked, Choice):
|
||||||
|
# blocked = blocked.value
|
||||||
|
|
||||||
key = f"{item.id}. {item.title}"
|
# def formatdata(index, item):
|
||||||
return key, description
|
# item = TrackedContent.from_dict(item)
|
||||||
|
# sub = Subscription.from_dict(item.subscription)
|
||||||
|
|
||||||
async def getdata(page: int, pagesize: int):
|
# links = f"[Content Link]({item.url}) · [Message Link](https://discord.com/channels/{sub.guild_id}/{item.channel_id}/{item.message_id}/) · [API Link]({API.API_EXTERNAL_ENDPOINT}tracked-content/{item.id}/)"
|
||||||
async with aiohttp.ClientSession() as session:
|
# delivery_state = "✅ Delivered" if not item.blocked else "🚫 Blocked"
|
||||||
api = API(self.bot.api_token, session)
|
|
||||||
return await api.get_tracked_content(
|
|
||||||
subscription__guild_id=inter.guild_id,
|
|
||||||
page=page,
|
|
||||||
page_size=pagesize,
|
|
||||||
search=search
|
|
||||||
)
|
|
||||||
|
|
||||||
embed = Followup(f"Tracked Content in {inter.guild.name}").info()._embed
|
# description = f"🆔 `{item.id}`\n"
|
||||||
pagination = PaginationView(
|
# description += f"{delivery_state}\n" if blocked == 0 else ""
|
||||||
self.bot,
|
# description += f"➡️ *{sub.name}*\n{links}"
|
||||||
inter=inter,
|
|
||||||
embed=embed,
|
|
||||||
getdata=getdata,
|
|
||||||
formatdata=formatdata,
|
|
||||||
pagesize=10,
|
|
||||||
initpage=1
|
|
||||||
)
|
|
||||||
await pagination.send()
|
|
||||||
|
|
||||||
# Group for test related commands
|
# key = f"{index}. {item.title}"
|
||||||
test_group = Group(
|
# return key, description
|
||||||
name="test",
|
|
||||||
description="Commands to test Bot functionality.",
|
|
||||||
guild_only=True
|
|
||||||
)
|
|
||||||
|
|
||||||
@test_group.command(name="channel-permissions")
|
# def determine_blocked():
|
||||||
async def cmd_test_channel_perms(self, inter: Interaction):
|
# match blocked:
|
||||||
"""Test that the current channel's permissions allow for PYRSS to operate in it."""
|
# case 0: return ""
|
||||||
|
# case 1: return "false"
|
||||||
|
# case 2: return "true"
|
||||||
|
# case _: return ""
|
||||||
|
|
||||||
try:
|
# async def getdata(page: int, pagesize: int):
|
||||||
test_message = await inter.channel.send(content="... testing permissions ...")
|
# async with aiohttp.ClientSession() as session:
|
||||||
await self.test_channel_perms(inter.channel)
|
# api = API(self.bot.api_token, session)
|
||||||
except Exception as error:
|
# is_blocked = determine_blocked()
|
||||||
await inter.response.send_message(content=f"Failed: {error}")
|
# return await api.get_tracked_content(
|
||||||
return
|
# subscription__guild_id=inter.guild_id,
|
||||||
|
# blocked=is_blocked,
|
||||||
|
# page=page,
|
||||||
|
# page_size=pagesize,
|
||||||
|
# search=search,
|
||||||
|
# )
|
||||||
|
|
||||||
await test_message.delete()
|
# embed = Followup(f"Tracked Content in {inter.guild.name}").info()._embed
|
||||||
await inter.response.send_message(content="Success")
|
# pagination = PaginationView(
|
||||||
|
# self.bot,
|
||||||
|
# inter=inter,
|
||||||
|
# embed=embed,
|
||||||
|
# getdata=getdata,
|
||||||
|
# formatdata=formatdata,
|
||||||
|
# pagesize=10,
|
||||||
|
# initpage=1
|
||||||
|
# )
|
||||||
|
# await pagination.send()
|
||||||
|
|
||||||
async def test_channel_perms(self, channel: TextChannel):
|
# @view_group.command(name="filters")
|
||||||
|
# async def cmd_list_filters(self, inter: Interaction, search: str = ""):
|
||||||
|
# """List Filters from this server."""
|
||||||
|
|
||||||
# Test generic message and delete
|
# await inter.response.defer()
|
||||||
msg = await channel.send(content="test message")
|
|
||||||
await msg.delete()
|
|
||||||
|
|
||||||
# Test detailed embed
|
# def formatdata(index, item):
|
||||||
embed = Embed(
|
# item = ContentFilter.from_dict(item)
|
||||||
title="test title",
|
|
||||||
description="test description",
|
# matching_algorithm = get_algorithm_name(item.matching_algorithm)
|
||||||
colour=Colour.random(),
|
# whitelist = "Whitelist" if item.is_whitelist else "Blacklist"
|
||||||
timestamp=datetime.now(),
|
# sensitivity = "Case insensitive" if item.is_insensitive else "Case sensitive"
|
||||||
url="https://google.com"
|
|
||||||
)
|
# description = f"🆔 `{item.id}`\n"
|
||||||
embed.set_author(name="test author")
|
# description += f"🔄 `{matching_algorithm}`\n🟰 `{item.match}`\n"
|
||||||
embed.set_footer(text="test footer")
|
# description += f"✅ `{whitelist}` 🔠 `{sensitivity}`\n"
|
||||||
embed.set_thumbnail(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
|
# description += f"[API Link]({API.API_EXTERNAL_ENDPOINT}filter/{item.id}/)"
|
||||||
embed.set_image(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
|
|
||||||
embed_msg = await channel.send(embed=embed)
|
# key = f"{index}. {item.name}"
|
||||||
await embed_msg.delete()
|
# return key, description
|
||||||
|
|
||||||
|
# def get_algorithm_name(matching_algorithm: int):
|
||||||
|
# match matching_algorithm:
|
||||||
|
# case 0: return "None"
|
||||||
|
# case 1: return "Any word"
|
||||||
|
# case 2: return "All words"
|
||||||
|
# case 3: return "Exact match"
|
||||||
|
# case 4: return "Regex match"
|
||||||
|
# case 5: return "Fuzzy match"
|
||||||
|
# case _: return "unknown"
|
||||||
|
|
||||||
|
# async def getdata(page, pagesize):
|
||||||
|
# async with aiohttp.ClientSession() as session:
|
||||||
|
# api = API(self.bot.api_token, session)
|
||||||
|
# return await api.get_filters(
|
||||||
|
# guild_id=inter.guild_id,
|
||||||
|
# page=page,
|
||||||
|
# page_size=pagesize,
|
||||||
|
# search=search
|
||||||
|
# )
|
||||||
|
|
||||||
|
# embed = Followup(f"Filters in {inter.guild.name}").info()._embed
|
||||||
|
# pagination = PaginationView(
|
||||||
|
# self.bot,
|
||||||
|
# inter=inter,
|
||||||
|
# embed=embed,
|
||||||
|
# getdata=getdata,
|
||||||
|
# formatdata=formatdata,
|
||||||
|
# pagesize=10,
|
||||||
|
# initpage=1
|
||||||
|
# )
|
||||||
|
# await pagination.send()
|
||||||
|
|
||||||
|
# # Group for test related commands
|
||||||
|
# test_group = Group(
|
||||||
|
# name="test",
|
||||||
|
# description="Commands to test Bot functionality.",
|
||||||
|
# guild_only=True
|
||||||
|
# )
|
||||||
|
|
||||||
|
# @test_group.command(name="channel-permissions")
|
||||||
|
# async def cmd_test_channel_perms(self, inter: Interaction):
|
||||||
|
# """Test that the current channel's permissions allow for PYRSS to operate in it."""
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# test_message = await inter.channel.send(content="... testing permissions ...")
|
||||||
|
# await self.test_channel_perms(inter.channel)
|
||||||
|
# except Exception as error:
|
||||||
|
# await inter.response.send_message(content=f"Failed: {error}")
|
||||||
|
# return
|
||||||
|
|
||||||
|
# await test_message.delete()
|
||||||
|
# await inter.response.send_message(content="Success")
|
||||||
|
|
||||||
|
# async def test_channel_perms(self, channel: TextChannel):
|
||||||
|
|
||||||
|
# # Test generic message and delete
|
||||||
|
# msg = await channel.send(content="test message")
|
||||||
|
# await msg.delete()
|
||||||
|
|
||||||
|
# # Test detailed embed
|
||||||
|
# embed = Embed(
|
||||||
|
# title="test title",
|
||||||
|
# description="test description",
|
||||||
|
# colour=Colour.random(),
|
||||||
|
# timestamp=datetime.now(),
|
||||||
|
# url="https://google.com"
|
||||||
|
# )
|
||||||
|
# embed.set_author(name="test author")
|
||||||
|
# embed.set_footer(text="test footer")
|
||||||
|
# embed.set_thumbnail(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
|
||||||
|
# embed.set_image(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
|
||||||
|
# embed_msg = await channel.send(embed=embed)
|
||||||
|
# await embed_msg.delete()
|
||||||
|
|
||||||
|
|
||||||
async def setup(bot):
|
async def setup(bot):
|
||||||
|
@ -3,29 +3,37 @@ Extension for the `TaskCog`.
|
|||||||
Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot.
|
Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import datetime
|
import datetime
|
||||||
|
import traceback
|
||||||
from os import getenv
|
from os import getenv
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from textwrap import shorten
|
||||||
|
|
||||||
import aiohttp
|
# import aiohttp
|
||||||
from aiocache import Cache
|
import httpx
|
||||||
|
import feedparser
|
||||||
|
import discord
|
||||||
|
# from aiocache import Cache
|
||||||
from discord import TextChannel
|
from discord import TextChannel
|
||||||
from discord import app_commands
|
from discord import app_commands, Interaction
|
||||||
from discord.ext import commands, tasks
|
from discord.ext import commands, tasks
|
||||||
from discord.errors import Forbidden
|
from discord.errors import Forbidden
|
||||||
from feedparser import parse
|
from markdownify import markdownify
|
||||||
|
|
||||||
from feed import RSSFeed, Subscription, RSSItem, GuildSettings
|
import models
|
||||||
from utils import get_unparsed_feed
|
from utils import do_batch_job
|
||||||
from filters import match_text
|
# from feed import RSSFeed, Subscription, RSSItem, GuildSettings
|
||||||
|
# from utils import get_unparsed_feed
|
||||||
|
# from filters import match_text
|
||||||
from api import API
|
from api import API
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
cache = Cache(Cache.MEMORY)
|
# cache = Cache(Cache.MEMORY)
|
||||||
|
|
||||||
BATCH_SIZE = 100
|
BATCH_SIZE = 100
|
||||||
|
|
||||||
@ -35,7 +43,7 @@ subscription_task_times = [
|
|||||||
for hour in range(24)
|
for hour in range(24)
|
||||||
for minute in range(0, 60, int(TASK_INTERVAL_MINUTES))
|
for minute in range(0, 60, int(TASK_INTERVAL_MINUTES))
|
||||||
]
|
]
|
||||||
log.debug("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES)
|
log.info("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES)
|
||||||
|
|
||||||
|
|
||||||
class TaskCog(commands.Cog):
|
class TaskCog(commands.Cog):
|
||||||
@ -46,16 +54,23 @@ class TaskCog(commands.Cog):
|
|||||||
api: API | None = None
|
api: API | None = None
|
||||||
content_queue = deque()
|
content_queue = deque()
|
||||||
|
|
||||||
|
api_base_url: str
|
||||||
|
api_headers: dict
|
||||||
|
client: httpx.AsyncClient | None
|
||||||
|
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
self.api_base_url = "http://localhost:8000/api/"
|
||||||
|
self.api_headers = {"Authorization": f"Token {self.bot.api_token}"}
|
||||||
|
|
||||||
@commands.Cog.listener()
|
@commands.Cog.listener()
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
"""
|
"""
|
||||||
Instructions to execute when the cog is ready.
|
Instructions to execute when the cog is ready.
|
||||||
"""
|
"""
|
||||||
self.subscription_task.start()
|
# self.subscription_task.start()
|
||||||
|
self.do_task.start()
|
||||||
log.info("%s cog is ready", self.__class__.__name__)
|
log.info("%s cog is ready", self.__class__.__name__)
|
||||||
|
|
||||||
@commands.Cog.listener(name="cog_unload")
|
@commands.Cog.listener(name="cog_unload")
|
||||||
@ -72,219 +87,434 @@ class TaskCog(commands.Cog):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@group.command(name="trigger")
|
@group.command(name="trigger")
|
||||||
async def cmd_trigger_task(self, inter):
|
async def cmd_trigger_task(self, inter: Interaction):
|
||||||
await inter.response.defer()
|
await inter.response.defer()
|
||||||
start_time = perf_counter()
|
start_time = perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.subscription_task()
|
await self.do_task()
|
||||||
except Exception as error:
|
except Exception as exc:
|
||||||
await inter.followup.send(str(error))
|
log.exception(exc)
|
||||||
|
await inter.followup.send(str(exc) or "unknown error")
|
||||||
finally:
|
finally:
|
||||||
end_time = perf_counter()
|
end_time = perf_counter()
|
||||||
await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds")
|
await inter.followup.send(f"completed command in {end_time - start_time:.4f} seconds")
|
||||||
|
|
||||||
@tasks.loop(time=subscription_task_times)
|
@tasks.loop(time=subscription_task_times)
|
||||||
async def subscription_task(self):
|
async def do_task(self):
|
||||||
"""
|
log.info("Running task")
|
||||||
Task for fetching and processing subscriptions.
|
|
||||||
"""
|
|
||||||
log.info("Running subscription task")
|
|
||||||
start_time = perf_counter()
|
start_time = perf_counter()
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with httpx.AsyncClient() as client:
|
||||||
self.api = API(self.bot.api_token, session)
|
self.client = client
|
||||||
await self.execute_task()
|
servers = await self.get_servers()
|
||||||
|
await do_batch_job(servers, self.process_server, 10)
|
||||||
|
|
||||||
end_time = perf_counter()
|
end_time = perf_counter()
|
||||||
log.debug(f"task completed in {end_time - start_time:.4f} seconds")
|
log.info(f"completed task in {end_time - start_time:.4f} seconds")
|
||||||
|
|
||||||
async def execute_task(self):
|
async def iterate_pages(self, url: str, params: dict={}):
|
||||||
"""Execute the task directly."""
|
|
||||||
|
|
||||||
# Filter out inactive guild IDs using related settings
|
for page_number, _ in enumerate(iterable=iter(int, 1), start=1):
|
||||||
guild_ids = [guild.id for guild in self.bot.guilds]
|
params.update({"page": page_number})
|
||||||
guild_settings = await self.get_guild_settings(guild_ids)
|
response = await self.client.get(
|
||||||
active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active]
|
self.api_base_url + url,
|
||||||
|
headers=self.api_headers,
|
||||||
subscriptions = await self.get_subscriptions(active_guild_ids)
|
params=params
|
||||||
await self.process_subscriptions(subscriptions)
|
|
||||||
|
|
||||||
async def get_guild_settings(self, guild_ids: list[int]) -> list[int]:
|
|
||||||
"""Returns a list of guild settings from the Bot's guilds, if they exist."""
|
|
||||||
|
|
||||||
guild_settings = []
|
|
||||||
|
|
||||||
# Iterate infinitely taking the iter no. as `page`
|
|
||||||
# data will be empty after last page reached.
|
|
||||||
for page, _ in enumerate(iter(int, 1)):
|
|
||||||
data = await self.get_guild_settings_page(guild_ids, page)
|
|
||||||
if not data:
|
|
||||||
break
|
|
||||||
|
|
||||||
guild_settings.extend(data[0])
|
|
||||||
|
|
||||||
# Only return active guild IDs
|
|
||||||
return GuildSettings.from_list(guild_settings)
|
|
||||||
|
|
||||||
async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]:
|
|
||||||
"""Returns an individual page of guild settings."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1)
|
|
||||||
except aiohttp.ClientResponseError as error:
|
|
||||||
self.handle_pagination_error(error)
|
|
||||||
return []
|
|
||||||
|
|
||||||
def handle_pagination_error(self, error: aiohttp.ClientResponseError):
|
|
||||||
"""Handle the error cases from pagination attempts."""
|
|
||||||
|
|
||||||
match error.status:
|
|
||||||
case 404:
|
|
||||||
log.debug("final page reached")
|
|
||||||
case 403:
|
|
||||||
log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True)
|
|
||||||
self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently
|
|
||||||
case _:
|
|
||||||
log.debug(error)
|
|
||||||
|
|
||||||
async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]:
|
|
||||||
"""Get a list of `Subscription`s matching the given `guild_ids`."""
|
|
||||||
|
|
||||||
subscriptions = []
|
|
||||||
|
|
||||||
# Iterate infinitely taking the iter no. as `page`
|
|
||||||
# data will be empty after last page reached.
|
|
||||||
for page, _ in enumerate(iter(int, 1)):
|
|
||||||
data = await self.get_subs_page(guild_ids, page)
|
|
||||||
if not data:
|
|
||||||
break
|
|
||||||
|
|
||||||
subscriptions.extend(data[0])
|
|
||||||
|
|
||||||
return Subscription.from_list(subscriptions)
|
|
||||||
|
|
||||||
async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]:
|
|
||||||
"""Returns an individual page of subscriptions."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1)
|
|
||||||
except aiohttp.ClientResponseError as error:
|
|
||||||
self.handle_pagination_error(error)
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def process_subscriptions(self, subscriptions: list[Subscription]):
|
|
||||||
"""Process a given list of `Subscription`s."""
|
|
||||||
|
|
||||||
async def process_single_subscription(sub: Subscription):
|
|
||||||
log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id)
|
|
||||||
|
|
||||||
if not sub.active or not sub.channels_count:
|
|
||||||
return
|
|
||||||
|
|
||||||
unparsed_feed = await get_unparsed_feed(sub.url)
|
|
||||||
parsed_feed = parse(unparsed_feed)
|
|
||||||
|
|
||||||
rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
|
|
||||||
await self.process_items(sub, rss_feed)
|
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(10)
|
|
||||||
|
|
||||||
async def semaphore_process(sub: Subscription):
|
|
||||||
async with semaphore:
|
|
||||||
await process_single_subscription(sub)
|
|
||||||
|
|
||||||
await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions))
|
|
||||||
|
|
||||||
async def process_items(self, sub: Subscription, feed: RSSFeed):
|
|
||||||
log.debug("processing items")
|
|
||||||
|
|
||||||
channels = await self.fetch_or_get_channels(await sub.get_channels(self.api))
|
|
||||||
filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters]
|
|
||||||
|
|
||||||
for item in feed.items:
|
|
||||||
log.debug("processing item '%s'", item.guid)
|
|
||||||
|
|
||||||
if item.pub_date < sub.published_threshold:
|
|
||||||
log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold)
|
|
||||||
continue
|
|
||||||
|
|
||||||
blocked = any(self.filter_item(_filter, item) for _filter in filters)
|
|
||||||
mutated_item = item.create_mutated_copy(sub.mutators) if sub.mutators else None
|
|
||||||
|
|
||||||
for channel in channels:
|
|
||||||
await self.track_and_send(sub, feed, item, mutated_item, channel, blocked)
|
|
||||||
|
|
||||||
async def fetch_or_get_channels(self, channels_data: list[dict]):
|
|
||||||
channels = []
|
|
||||||
|
|
||||||
for data in channels_data:
|
|
||||||
try:
|
|
||||||
channel = self.bot.get_channel(data.channel_id)
|
|
||||||
channels.append(channel or await self.bot.fetch_channel(data.channel_id))
|
|
||||||
except Forbidden:
|
|
||||||
log.error(f"Forbidden Channel '{data.channel_id}'")
|
|
||||||
|
|
||||||
return channels
|
|
||||||
|
|
||||||
def filter_item(self, _filter: dict, item: RSSItem) -> bool:
|
|
||||||
"""
|
|
||||||
Returns `True` if item should be ignored due to filters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
match_found = match_text(_filter, item.title) or match_text(_filter, item.description)
|
|
||||||
log.debug("filter match found? '%s'", match_found)
|
|
||||||
return match_found
|
|
||||||
|
|
||||||
async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem | None, channel: TextChannel, blocked: bool):
|
|
||||||
message_id = -1
|
|
||||||
|
|
||||||
log.debug("track and send func %s, %s", item.guid, item.title)
|
|
||||||
|
|
||||||
result = await self.api.get_tracked_content(guid=item.guid)
|
|
||||||
if result[1]:
|
|
||||||
log.debug(f"This item is already tracked, skipping '{item.guid}'")
|
|
||||||
return
|
|
||||||
|
|
||||||
result = await self.api.get_tracked_content(url=item.link)
|
|
||||||
if result[1]:
|
|
||||||
log.debug(f"This item is already tracked, skipping '{item.guid}'")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not blocked:
|
|
||||||
try:
|
|
||||||
log.debug("sending '%s', exists '%s'", item.guid, result[1])
|
|
||||||
sendable_item = mutated_item or item
|
|
||||||
message = await channel.send(embed=await sendable_item.to_embed(sub, feed, self.api.session))
|
|
||||||
message_id = message.id
|
|
||||||
except Forbidden:
|
|
||||||
log.error(f"Forbidden to send to channel {channel.id}")
|
|
||||||
|
|
||||||
await self.mark_tracked_item(sub, item, channel.id, message_id, blocked)
|
|
||||||
|
|
||||||
async def process_batch(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool):
|
|
||||||
try:
|
|
||||||
log.debug("marking as tracked")
|
|
||||||
await self.api.create_tracked_content(
|
|
||||||
guid=item.guid,
|
|
||||||
title=item.title,
|
|
||||||
url=item.link,
|
|
||||||
subscription=sub.id,
|
|
||||||
channel_id=channel_id,
|
|
||||||
message_id=message_id,
|
|
||||||
blocked=blocked
|
|
||||||
)
|
)
|
||||||
return True
|
response.raise_for_status()
|
||||||
except aiohttp.ClientResponseError as error:
|
content = response.json()
|
||||||
if error.status == 409:
|
|
||||||
log.debug(error)
|
|
||||||
else:
|
|
||||||
log.error(error)
|
|
||||||
|
|
||||||
return False
|
yield content.get("results", [])
|
||||||
|
|
||||||
|
if not content.get("next"):
|
||||||
|
break
|
||||||
|
|
||||||
|
async def get_servers(self) -> list[models.Server]:
|
||||||
|
servers = []
|
||||||
|
|
||||||
|
async for servers_batch in self.iterate_pages("servers/"):
|
||||||
|
if servers_batch:
|
||||||
|
servers.extend(servers_batch)
|
||||||
|
|
||||||
|
return models.Server.from_list(servers)
|
||||||
|
|
||||||
|
async def get_subscriptions(self, server: models.Server) -> list[models.Subscription]:
|
||||||
|
subscriptions = []
|
||||||
|
params = {"server": server.id, "active": True}
|
||||||
|
|
||||||
|
async for subscriptions_batch in self.iterate_pages("subscriptions/", params):
|
||||||
|
if subscriptions_batch:
|
||||||
|
subscriptions.extend(subscriptions_batch)
|
||||||
|
|
||||||
|
return models.Subscription.from_list(subscriptions)
|
||||||
|
|
||||||
|
async def get_contents(self, subscription: models.Subscription, raw_rss_content: dict):
|
||||||
|
contents = await models.Content.from_raw_rss(raw_rss_content, subscription, self.client)
|
||||||
|
duplicate_contents = []
|
||||||
|
|
||||||
|
async def check_duplicate_content(content: models.Content):
|
||||||
|
exists = await content.exists_via_api(
|
||||||
|
url=self.api_base_url + "content/",
|
||||||
|
headers=self.api_headers,
|
||||||
|
client=self.client
|
||||||
|
)
|
||||||
|
|
||||||
|
if exists:
|
||||||
|
log.debug(f"Removing duplicate {content}")
|
||||||
|
duplicate_contents.append(content)
|
||||||
|
|
||||||
|
await do_batch_job(contents, check_duplicate_content, 15)
|
||||||
|
|
||||||
|
log.debug(f"before removing duplicates: {len(contents)}")
|
||||||
|
for duplicate in duplicate_contents:
|
||||||
|
contents.remove(duplicate)
|
||||||
|
log.debug(f"after removing duplicates: {len(contents)}")
|
||||||
|
|
||||||
|
return contents
|
||||||
|
|
||||||
|
async def process_server(self, server: models.Server):
|
||||||
|
log.debug(f"processing server: {server.name}")
|
||||||
|
start_time = perf_counter()
|
||||||
|
|
||||||
|
subscriptions = await self.get_subscriptions(server)
|
||||||
|
for subscription in subscriptions:
|
||||||
|
subscription.server = server
|
||||||
|
|
||||||
|
await do_batch_job(subscriptions, self.process_subscription, 10)
|
||||||
|
|
||||||
|
end_time = perf_counter()
|
||||||
|
log.debug(f"Finished processing server: {server.name} in {end_time - start_time:.4f} seconds")
|
||||||
|
|
||||||
|
async def process_subscription(self, subscription: models.Subscription):
|
||||||
|
log.debug(f"processing subscription {subscription.name}")
|
||||||
|
start_time = perf_counter()
|
||||||
|
|
||||||
|
raw_rss_content = await subscription.get_rss_content(self.client)
|
||||||
|
if not raw_rss_content:
|
||||||
|
return
|
||||||
|
|
||||||
|
contents = await self.get_contents(subscription, raw_rss_content)
|
||||||
|
if not contents:
|
||||||
|
log.debug("no contents to process")
|
||||||
|
return
|
||||||
|
|
||||||
|
channels = await subscription.get_discord_channels(self.bot)
|
||||||
|
valid_contents, invalid_contents = subscription.filter_entries(contents)
|
||||||
|
|
||||||
|
async def send_content(channel: discord.TextChannel):
|
||||||
|
# BUG: I believe there are duplicate embeds here
|
||||||
|
# discord only shows 1 when urls are matching, but merges images from both into the 1
|
||||||
|
|
||||||
|
# embeds = [content.embed for content in valid_contents]
|
||||||
|
# batch_size = 10
|
||||||
|
# for i in range(0, len(embeds), batch_size):
|
||||||
|
# batch = embeds[i:i + batch_size]
|
||||||
|
# await channel.send(embeds=batch)
|
||||||
|
|
||||||
|
batch_size = 10
|
||||||
|
total_batches = (len(valid_contents) + batch_size - 1) // batch_size
|
||||||
|
for batch_number, i in enumerate(range(0, len(valid_contents), batch_size)):
|
||||||
|
contents_batch = valid_contents[i:i + batch_size]
|
||||||
|
embeds = await self.create_embeds(
|
||||||
|
contents=contents_batch,
|
||||||
|
subscription_name=subscription.name,
|
||||||
|
colour=subscription.message_style.colour,
|
||||||
|
batch_number=batch_number,
|
||||||
|
total_batches=total_batches
|
||||||
|
)
|
||||||
|
await channel.send(embeds=embeds)
|
||||||
|
|
||||||
|
await do_batch_job(channels, send_content, 5)
|
||||||
|
|
||||||
|
combined = valid_contents.copy()
|
||||||
|
combined.extend(invalid_contents)
|
||||||
|
|
||||||
|
tasks = [content.save(self.client, self.api_base_url, self.api_headers) for content in combined]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# TODO: mark invalid contents as blocked
|
||||||
|
|
||||||
|
end_time = perf_counter()
|
||||||
|
log.debug(f"Finished processing subscription: {subscription.name} in {end_time - start_time:.4f}")
|
||||||
|
|
||||||
|
async def create_embeds(self, contents: list[models.Content], subscription_name: str, colour: str, batch_number: int, total_batches: int):
|
||||||
|
discord_colour = discord.Colour.from_str(
|
||||||
|
colour if colour.startswith("#")
|
||||||
|
else f"#{colour}"
|
||||||
|
)
|
||||||
|
|
||||||
|
url = "https://pyrss.cor.bz"
|
||||||
|
title = subscription_name
|
||||||
|
|
||||||
|
if total_batches > 1:
|
||||||
|
title += f" [{batch_number+1}/{total_batches}]"
|
||||||
|
|
||||||
|
embed = discord.Embed(title=title, colour=discord_colour, url=url)
|
||||||
|
embeds = [embed]
|
||||||
|
|
||||||
|
for content in contents:
|
||||||
|
description = shorten(markdownify(content.item_description, strip=("img",)), 256)
|
||||||
|
description += f"\n[View Article]({content.item_url})"
|
||||||
|
embed.add_field(
|
||||||
|
name=shorten(markdownify(content.item_title, strip=("img", "a")), 256),
|
||||||
|
value=description,
|
||||||
|
inline=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# If there is only one content, set the main embed's image and return it.
|
||||||
|
# Otherwise progressing the normal way will create a wonky looking embed
|
||||||
|
# where the lone image is aligned left to an invisible non-existant right
|
||||||
|
# image.
|
||||||
|
if len(contents) == 1:
|
||||||
|
embed.set_image(url=content.item_image_url)
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
if len(embeds) <= 5:
|
||||||
|
image_embed = discord.Embed(title="dummy", url=url)
|
||||||
|
image_embed.set_image(url=content.item_image_url)
|
||||||
|
embeds.append(image_embed)
|
||||||
|
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
|
||||||
|
# async def process_valid_contents(
|
||||||
|
# self,
|
||||||
|
# contents: list[models.Content],
|
||||||
|
# channels: list[discord.TextChannel],
|
||||||
|
# client: httpx.AsyncClient
|
||||||
|
# ):
|
||||||
|
# semaphore = asyncio.Semaphore(5)
|
||||||
|
|
||||||
|
# async def batch_process(
|
||||||
|
# content: models.Content,
|
||||||
|
# channels: list[discord.TextChannel],
|
||||||
|
# client: httpx.AsyncClient
|
||||||
|
# ):
|
||||||
|
# async with semaphore: await self.process_valid_content(content, channels, client)
|
||||||
|
|
||||||
|
# tasks = [
|
||||||
|
# batch_process()
|
||||||
|
# ]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# @group.command(name="trigger")
|
||||||
|
# async def cmd_trigger_task(self, inter):
|
||||||
|
# await inter.response.defer()
|
||||||
|
# start_time = perf_counter()
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# await self.subscription_task()
|
||||||
|
# except Exception as error:
|
||||||
|
# await inter.followup.send(str(error))
|
||||||
|
# finally:
|
||||||
|
# end_time = perf_counter()
|
||||||
|
# await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds")
|
||||||
|
|
||||||
|
# @tasks.loop(time=subscription_task_times)
|
||||||
|
# async def subscription_task(self):
|
||||||
|
# """
|
||||||
|
# Task for fetching and processing subscriptions.
|
||||||
|
# """
|
||||||
|
# log.info("Running subscription task")
|
||||||
|
# start_time = perf_counter()
|
||||||
|
|
||||||
|
# async with aiohttp.ClientSession() as session:
|
||||||
|
# self.api = API(self.bot.api_token, session)
|
||||||
|
# await self.execute_task()
|
||||||
|
|
||||||
|
# end_time = perf_counter()
|
||||||
|
# log.debug(f"task completed in {end_time - start_time:.4f} seconds")
|
||||||
|
|
||||||
|
# async def execute_task(self):
|
||||||
|
# """Execute the task directly."""
|
||||||
|
|
||||||
|
# # Filter out inactive guild IDs using related settings
|
||||||
|
# guild_ids = [guild.id for guild in self.bot.guilds]
|
||||||
|
# guild_settings = await self.get_guild_settings(guild_ids)
|
||||||
|
# active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active]
|
||||||
|
|
||||||
|
# subscriptions = await self.get_subscriptions(active_guild_ids)
|
||||||
|
# await self.process_subscriptions(subscriptions)
|
||||||
|
|
||||||
|
# async def get_guild_settings(self, guild_ids: list[int]) -> list[int]:
|
||||||
|
# """Returns a list of guild settings from the Bot's guilds, if they exist."""
|
||||||
|
|
||||||
|
# guild_settings = []
|
||||||
|
|
||||||
|
# # Iterate infinitely taking the iter no. as `page`
|
||||||
|
# # data will be empty after last page reached.
|
||||||
|
# for page, _ in enumerate(iter(int, 1)):
|
||||||
|
# data = await self.get_guild_settings_page(guild_ids, page)
|
||||||
|
# if not data:
|
||||||
|
# break
|
||||||
|
|
||||||
|
# guild_settings.extend(data[0])
|
||||||
|
|
||||||
|
# # Only return active guild IDs
|
||||||
|
# return GuildSettings.from_list(guild_settings)
|
||||||
|
|
||||||
|
# async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]:
|
||||||
|
# """Returns an individual page of guild settings."""
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1)
|
||||||
|
# except aiohttp.ClientResponseError as error:
|
||||||
|
# self.handle_pagination_error(error)
|
||||||
|
# return []
|
||||||
|
|
||||||
|
# def handle_pagination_error(self, error: aiohttp.ClientResponseError):
|
||||||
|
# """Handle the error cases from pagination attempts."""
|
||||||
|
|
||||||
|
# match error.status:
|
||||||
|
# case 404:
|
||||||
|
# log.debug("final page reached")
|
||||||
|
# case 403:
|
||||||
|
# log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True)
|
||||||
|
# self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently
|
||||||
|
# case _:
|
||||||
|
# log.debug(error)
|
||||||
|
|
||||||
|
# async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]:
|
||||||
|
# """Get a list of `Subscription`s matching the given `guild_ids`."""
|
||||||
|
|
||||||
|
# subscriptions = []
|
||||||
|
|
||||||
|
# # Iterate infinitely taking the iter no. as `page`
|
||||||
|
# # data will be empty after last page reached.
|
||||||
|
# for page, _ in enumerate(iter(int, 1)):
|
||||||
|
# data = await self.get_subs_page(guild_ids, page)
|
||||||
|
# if not data:
|
||||||
|
# break
|
||||||
|
|
||||||
|
# subscriptions.extend(data[0])
|
||||||
|
|
||||||
|
# return Subscription.from_list(subscriptions)
|
||||||
|
|
||||||
|
# async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]:
|
||||||
|
# """Returns an individual page of subscriptions."""
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1)
|
||||||
|
# except aiohttp.ClientResponseError as error:
|
||||||
|
# self.handle_pagination_error(error)
|
||||||
|
# return []
|
||||||
|
|
||||||
|
# async def process_subscriptions(self, subscriptions: list[Subscription]):
|
||||||
|
# """Process a given list of `Subscription`s."""
|
||||||
|
|
||||||
|
# async def process_single_subscription(sub: Subscription):
|
||||||
|
# log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id)
|
||||||
|
|
||||||
|
# if not sub.active or not sub.channels_count:
|
||||||
|
# return
|
||||||
|
|
||||||
|
# unparsed_feed = await get_unparsed_feed(sub.url)
|
||||||
|
# parsed_feed = parse(unparsed_feed)
|
||||||
|
|
||||||
|
# rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
|
||||||
|
# await self.process_items(sub, rss_feed)
|
||||||
|
|
||||||
|
# semaphore = asyncio.Semaphore(10)
|
||||||
|
|
||||||
|
# async def semaphore_process(sub: Subscription):
|
||||||
|
# async with semaphore:
|
||||||
|
# await process_single_subscription(sub)
|
||||||
|
|
||||||
|
# await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions))
|
||||||
|
|
||||||
|
# async def process_items(self, sub: Subscription, feed: RSSFeed):
|
||||||
|
# log.debug("processing items")
|
||||||
|
|
||||||
|
# channels = await self.fetch_or_get_channels(await sub.get_channels(self.api))
|
||||||
|
# filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters]
|
||||||
|
|
||||||
|
# for item in feed.items:
|
||||||
|
# log.debug("processing item '%s'", item.guid)
|
||||||
|
|
||||||
|
# if item.pub_date < sub.published_threshold:
|
||||||
|
# log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold)
|
||||||
|
# continue
|
||||||
|
|
||||||
|
# blocked = any(self.filter_item(_filter, item) for _filter in filters)
|
||||||
|
# mutated_item = item.create_mutated_copy(sub.mutators) if sub.mutators else None
|
||||||
|
|
||||||
|
# for channel in channels:
|
||||||
|
# await self.track_and_send(sub, feed, item, mutated_item, channel, blocked)
|
||||||
|
|
||||||
|
# async def fetch_or_get_channels(self, channels_data: list[dict]):
|
||||||
|
# channels = []
|
||||||
|
|
||||||
|
# for data in channels_data:
|
||||||
|
# try:
|
||||||
|
# channel = self.bot.get_channel(data.channel_id)
|
||||||
|
# channels.append(channel or await self.bot.fetch_channel(data.channel_id))
|
||||||
|
# except Forbidden:
|
||||||
|
# log.error(f"Forbidden Channel '{data.channel_id}'")
|
||||||
|
|
||||||
|
# return channels
|
||||||
|
|
||||||
|
# def filter_item(self, _filter: dict, item: RSSItem) -> bool:
|
||||||
|
# """
|
||||||
|
# Returns `True` if item should be ignored due to filters.
|
||||||
|
# """
|
||||||
|
|
||||||
|
# match_found = match_text(_filter, item.title) or match_text(_filter, item.description)
|
||||||
|
# log.debug("filter match found? '%s'", match_found)
|
||||||
|
# return match_found
|
||||||
|
|
||||||
|
# async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem | None, channel: TextChannel, blocked: bool):
|
||||||
|
# message_id = -1
|
||||||
|
|
||||||
|
# log.debug("track and send func %s, %s", item.guid, item.title)
|
||||||
|
|
||||||
|
# result = await self.api.get_tracked_content(guid=item.guid)
|
||||||
|
# if result[1]:
|
||||||
|
# log.debug(f"This item is already tracked, skipping '{item.guid}'")
|
||||||
|
# return
|
||||||
|
|
||||||
|
# result = await self.api.get_tracked_content(url=item.link)
|
||||||
|
# if result[1]:
|
||||||
|
# log.debug(f"This item is already tracked, skipping '{item.guid}'")
|
||||||
|
# return
|
||||||
|
|
||||||
|
# if not blocked:
|
||||||
|
# try:
|
||||||
|
# log.debug("sending '%s', exists '%s'", item.guid, result[1])
|
||||||
|
# sendable_item = mutated_item or item
|
||||||
|
# message = await channel.send(embed=await sendable_item.to_embed(sub, feed, self.api.session))
|
||||||
|
# message_id = message.id
|
||||||
|
# except Forbidden:
|
||||||
|
# log.error(f"Forbidden to send to channel {channel.id}")
|
||||||
|
|
||||||
|
# await self.mark_tracked_item(sub, item, channel.id, message_id, blocked)
|
||||||
|
|
||||||
|
# async def process_batch(self):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool):
|
||||||
|
# try:
|
||||||
|
# log.debug("marking as tracked")
|
||||||
|
# await self.api.create_tracked_content(
|
||||||
|
# guid=item.guid,
|
||||||
|
# title=item.title,
|
||||||
|
# url=item.link,
|
||||||
|
# subscription=sub.id,
|
||||||
|
# channel_id=channel_id,
|
||||||
|
# message_id=message_id,
|
||||||
|
# blocked=blocked
|
||||||
|
# )
|
||||||
|
# return True
|
||||||
|
# except aiohttp.ClientResponseError as error:
|
||||||
|
# if error.status == 409:
|
||||||
|
# log.debug(error)
|
||||||
|
# else:
|
||||||
|
# log.error(error)
|
||||||
|
|
||||||
|
# return False
|
||||||
|
|
||||||
|
|
||||||
async def setup(bot):
|
async def setup(bot):
|
||||||
|
27
src/feed.py
27
src/feed.py
@ -199,7 +199,7 @@ class RSSFeed:
|
|||||||
description: str
|
description: str
|
||||||
link: str
|
link: str
|
||||||
lang: str
|
lang: str
|
||||||
last_build_date: datetime
|
last_build_date: datetime | None
|
||||||
image_href: str
|
image_href: str
|
||||||
items: list[RSSItem] = None
|
items: list[RSSItem] = None
|
||||||
|
|
||||||
@ -240,7 +240,8 @@ class RSSFeed:
|
|||||||
language = pf.feed.get('language', None)
|
language = pf.feed.get('language', None)
|
||||||
|
|
||||||
last_build_date = pf.feed.get('updated_parsed', None)
|
last_build_date = pf.feed.get('updated_parsed', None)
|
||||||
last_build_date = datetime(*last_build_date[0:-2] if last_build_date else None)
|
if last_build_date:
|
||||||
|
last_build_date = datetime(*last_build_date[0:-2])
|
||||||
|
|
||||||
image_href = pf.feed.get("image", {}).get("href")
|
image_href = pf.feed.get("image", {}).get("href")
|
||||||
|
|
||||||
@ -250,7 +251,7 @@ class RSSFeed:
|
|||||||
item = RSSItem.from_parsed_entry(entry)
|
item = RSSItem.from_parsed_entry(entry)
|
||||||
feed.add_item(item)
|
feed.add_item(item)
|
||||||
|
|
||||||
feed.items.reverse()
|
feed.items.reverse() # order so that older items are processed first
|
||||||
return feed
|
return feed
|
||||||
|
|
||||||
|
|
||||||
@ -301,6 +302,7 @@ class Subscription(DjangoDataModel):
|
|||||||
published_threshold: datetime
|
published_threshold: datetime
|
||||||
active: bool
|
active: bool
|
||||||
channels_count: int
|
channels_count: int
|
||||||
|
unique_content_rules: list
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parser(item: dict) -> dict:
|
def parser(item: dict) -> dict:
|
||||||
@ -311,6 +313,7 @@ class Subscription(DjangoDataModel):
|
|||||||
"description": item.pop("article_desc_mutators")
|
"description": item.pop("article_desc_mutators")
|
||||||
}
|
}
|
||||||
item["published_threshold"] = datetime.strptime(item["published_threshold"], "%Y-%m-%dT%H:%M:%S%z")
|
item["published_threshold"] = datetime.strptime(item["published_threshold"], "%Y-%m-%dT%H:%M:%S%z")
|
||||||
|
item["unique_content_rules"] = item.get("unique_content_rules", [])
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
@ -357,3 +360,21 @@ class TrackedContent(DjangoDataModel):
|
|||||||
|
|
||||||
item["creation_datetime"] = datetime.strptime(item["creation_datetime"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
item["creation_datetime"] = datetime.strptime(item["creation_datetime"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ContentFilter(DjangoDataModel):
|
||||||
|
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
matching_algorithm: int
|
||||||
|
match: str
|
||||||
|
is_insensitive: bool
|
||||||
|
is_whitelist: bool
|
||||||
|
guild_id: int
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parser(item: dict) -> dict:
|
||||||
|
|
||||||
|
item["guild_id"] = int(item["guild_id"]) # stored as str due to a django/sqlite bug, convert back to int
|
||||||
|
return item
|
||||||
|
527
src/models.py
Normal file
527
src/models.py
Normal file
@ -0,0 +1,527 @@
|
|||||||
|
import re
|
||||||
|
import logging
|
||||||
|
import hashlib
|
||||||
|
import asyncio
|
||||||
|
from enum import Enum
|
||||||
|
from time import perf_counter
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from textwrap import shorten
|
||||||
|
|
||||||
|
import feedparser.parsers
|
||||||
|
import httpx
|
||||||
|
import discord
|
||||||
|
import rapidfuzz
|
||||||
|
import feedparser
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from markdownify import markdownify
|
||||||
|
|
||||||
|
from utils import do_batch_job
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DjangoDataModel(ABC):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def parser(item: dict) -> dict:
|
||||||
|
return item
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_list(cls, data: list[dict]) -> list:
|
||||||
|
return [cls(**cls.parser(item)) for item in data]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict):
|
||||||
|
return cls(**cls.parser(data))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class Server(DjangoDataModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
icon_hash: str
|
||||||
|
is_bot_operational: bool
|
||||||
|
active: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parser(item: dict) -> dict:
|
||||||
|
item["id"] = int(item.pop("id"))
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
class MatchingAlgorithm(Enum):
|
||||||
|
NONE = 0
|
||||||
|
ANY = 1
|
||||||
|
ALL = 2
|
||||||
|
LITERAL = 3
|
||||||
|
REGEX = 4
|
||||||
|
FUZZY = 5
|
||||||
|
AUTO = 6
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_value(cls, value: int):
|
||||||
|
for member in cls:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
|
||||||
|
raise ValueError(f"No {cls.__class__.__name__} for value: {value}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ContentFilter(DjangoDataModel):
|
||||||
|
id: int
|
||||||
|
server_id: int
|
||||||
|
name: str
|
||||||
|
matching_pattern: str
|
||||||
|
matching_algorithm: MatchingAlgorithm
|
||||||
|
is_insensitive: bool
|
||||||
|
is_whitelist: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parser(item: dict) -> dict:
|
||||||
|
item["id"] = item.pop("id")
|
||||||
|
item["server_id"] = item.pop("server")
|
||||||
|
item["matching_pattern"] = item.pop("match")
|
||||||
|
item["matching_algorithm"] = MatchingAlgorithm.from_value(item.pop("matching_algorithm"))
|
||||||
|
return item
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _regex_flags(self):
|
||||||
|
return re.IGNORECASE if self.is_insensitive else 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cleaned_matching_pattern(self):
|
||||||
|
"""
|
||||||
|
Splits the pattern to individual keywords, getting rid of unnecessary
|
||||||
|
spaces and grouping quoted words together.
|
||||||
|
|
||||||
|
"""
|
||||||
|
findterms = re.compile(r'"([^"]+)"|(\S+)').findall
|
||||||
|
normspace = re.compile(r"\s+").sub
|
||||||
|
return [
|
||||||
|
re.escape(normspace(" ", (t[0] or t[1]).strip())).replace(r"\ ", r"\s+")
|
||||||
|
for t in findterms(self.matching_pattern)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _match_any(self, matching_against: str):
|
||||||
|
for word in self.cleaned_matching_pattern:
|
||||||
|
if re.search(rf"\b{word}\b", matching_against, self._regex_flags):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _match_all(self, matching_against: str):
|
||||||
|
for word in self.cleaned_matching_pattern:
|
||||||
|
if not re.search(rf"\b{word}\b", matching_against, self._regex_flags):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _match_literal(self, matching_against: str):
|
||||||
|
return bool(
|
||||||
|
re.search(
|
||||||
|
rf"\b{re.escape(self.matching_pattern)}\b",
|
||||||
|
matching_against,
|
||||||
|
self._regex_flags
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _match_regex(self, matching_against: str):
|
||||||
|
try:
|
||||||
|
return bool(re.search(
|
||||||
|
re.compile(self.matching_pattern, self._regex_flags),
|
||||||
|
matching_against
|
||||||
|
))
|
||||||
|
except re.error as exc:
|
||||||
|
log.error(f"Filter regex error: {exc}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _match_fuzzy(self, matching_against: str):
|
||||||
|
matching_against = re.sub(r"[^\w\s]", "", matching_against)
|
||||||
|
matching_pattern = re.sub(r"[^\w\s]", "", self.matching_pattern)
|
||||||
|
if self.is_insensitive:
|
||||||
|
matching_against = matching_against.lower()
|
||||||
|
matching_pattern = matching_pattern.lower()
|
||||||
|
|
||||||
|
return rapidfuzz.fuzz.partial_ratio(
|
||||||
|
matching_against,
|
||||||
|
matching_pattern,
|
||||||
|
score_cutoff=90
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_algorithm_func(self):
|
||||||
|
match self.matching_algorithm:
|
||||||
|
case MatchingAlgorithm.NONE: return
|
||||||
|
case MatchingAlgorithm.ANY: return self._match_any
|
||||||
|
case MatchingAlgorithm.ALL: return self._match_all
|
||||||
|
case MatchingAlgorithm.LITERAL: return self._match_literal
|
||||||
|
case MatchingAlgorithm.REGEX: return self._match_regex
|
||||||
|
case MatchingAlgorithm.FUZZY: return self._match_fuzzy
|
||||||
|
case _: return
|
||||||
|
|
||||||
|
def matches(self, content) -> bool:
|
||||||
|
log.debug(f"applying filter: {self}")
|
||||||
|
|
||||||
|
if not self.matching_pattern.strip():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.matching_algorithm == MatchingAlgorithm.ALL:
|
||||||
|
match_found = self._match_all(content.item_title + " " + content.item_description)
|
||||||
|
else:
|
||||||
|
algorithm_func = self._get_algorithm_func()
|
||||||
|
if not algorithm_func:
|
||||||
|
log.error(f"Bad algorithm function: {self.matching_algorithm}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
match_found = algorithm_func(content.item_title) or algorithm_func(content.item_description)
|
||||||
|
|
||||||
|
log.debug(f"filter match found: {match_found}")
|
||||||
|
return not match_found if self.is_whitelist else match_found
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class MessageMutator(DjangoDataModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
value: str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parser(item: dict) -> dict:
|
||||||
|
item["id"] = item.pop("id")
|
||||||
|
return item
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class MessageStyle(DjangoDataModel):
|
||||||
|
id: int
|
||||||
|
server_id: int
|
||||||
|
name: str
|
||||||
|
colour: str
|
||||||
|
is_embed: bool
|
||||||
|
is_hyperlinked: bool
|
||||||
|
show_author: bool
|
||||||
|
show_timestamp: bool
|
||||||
|
show_images: bool
|
||||||
|
fetch_images: bool
|
||||||
|
title_mutator: dict | None
|
||||||
|
description_mutator: dict | None
|
||||||
|
auto_created: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parser(item: dict) -> dict:
|
||||||
|
item["id"] = int(item.pop("id"))
|
||||||
|
item["server_id"] = int(item.pop("server"))
|
||||||
|
item["title_mutator"] = item.pop("title_mutator_detail")
|
||||||
|
item["description_mutator"] = item.pop("description_mutator_detail")
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class DiscordChannel(DjangoDataModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
is_nsfw: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parser(item: dict) -> dict:
|
||||||
|
item["id"] = int(item.pop("id"))
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class Subscription(DjangoDataModel):
|
||||||
|
id: int
|
||||||
|
server_id: int
|
||||||
|
name: str
|
||||||
|
url: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
extra_notes: str
|
||||||
|
active: bool
|
||||||
|
publish_threshold: datetime
|
||||||
|
channels: list[DiscordChannel]
|
||||||
|
filters: list[ContentFilter]
|
||||||
|
message_style: MessageStyle
|
||||||
|
_server: Server | None = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parser(item: dict) -> dict:
|
||||||
|
item["id"] = int(item.pop("id"))
|
||||||
|
item["server_id"] = int(item.pop("server"))
|
||||||
|
item["created_at"] = datetime.strptime(item.pop("created_at"), "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||||
|
item["updated_at"] = datetime.strptime(item.pop("updated_at"), "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||||
|
item["publish_threshold"] = datetime.strptime(item.pop("publish_threshold"), "%Y-%m-%dT%H:%M:%S%z")
|
||||||
|
item["channels"] = DiscordChannel.from_list(item.pop("channels_detail"))
|
||||||
|
item["filters"] = ContentFilter.from_list(item.pop("filters_detail"))
|
||||||
|
item["message_style"] = MessageStyle.from_dict(item.pop("message_style_detail"))
|
||||||
|
return item
|
||||||
|
|
||||||
|
@property
|
||||||
|
def server(self) -> Server:
|
||||||
|
return self._server
|
||||||
|
|
||||||
|
@server.setter
|
||||||
|
def server(self, server: server):
|
||||||
|
self._server = server
|
||||||
|
|
||||||
|
async def get_rss_content(self, client: httpx.AsyncClient) -> str:
|
||||||
|
start_time = perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get(self.url)
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
log.error("(%s) HTTP Exception for %s - %s", type(exc), exc.request.url, exc)
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
log.debug(f"Got rss content in {perf_counter() - start_time:.4f} seconds")
|
||||||
|
|
||||||
|
content_type = response.headers.get("Content-Type")
|
||||||
|
if not "text/xml" in content_type:
|
||||||
|
log.warning("Invalid 'Content-Type' header: %s (must contain 'text/xml')", content_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
async def get_discord_channels(self, bot) -> list[discord.TextChannel]:
|
||||||
|
start_time = perf_counter()
|
||||||
|
channels = []
|
||||||
|
|
||||||
|
for channel_detail in self.channels:
|
||||||
|
try:
|
||||||
|
channel = bot.get_channel(channel_detail.id)
|
||||||
|
channels.append(channel or await bot.fetch_channel(channel_detail.id))
|
||||||
|
except Exception as exc:
|
||||||
|
channel_reference = f"({channel_detail.name}, {channel_detail.id})"
|
||||||
|
server_reference = f"({self.server.name}, {self.server.id})"
|
||||||
|
log.debug(f"Failed to get channel {channel_reference} from {server_reference}: {exc}")
|
||||||
|
|
||||||
|
log.debug(f"Got channels in {perf_counter() - start_time:.4f} seconds")
|
||||||
|
return channels
|
||||||
|
|
||||||
|
def filter_entries(self, contents: list) -> tuple[list, list]:
|
||||||
|
log.debug(f"filtering entries for {self.name} in {self.server.name}")
|
||||||
|
|
||||||
|
valid_contents = []
|
||||||
|
invalid_contents = []
|
||||||
|
|
||||||
|
for content in contents:
|
||||||
|
log.debug(f"filtering: '{content.item_title}'")
|
||||||
|
if any(content_filter.matches(content) for content_filter in self.filters):
|
||||||
|
content.blocked = True
|
||||||
|
invalid_contents.append(content)
|
||||||
|
else:
|
||||||
|
valid_contents.append(content)
|
||||||
|
|
||||||
|
log.debug(f"filtered content: valid:{len(valid_contents)}, invalid:{len(invalid_contents)}")
|
||||||
|
return valid_contents, invalid_contents
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class Content(DjangoDataModel):
|
||||||
|
id: int
|
||||||
|
subscription_id: int
|
||||||
|
item_id: str
|
||||||
|
item_guid: str
|
||||||
|
item_url: str
|
||||||
|
item_title: str
|
||||||
|
item_description: str
|
||||||
|
item_content_hash: str
|
||||||
|
item_image_url: str | None
|
||||||
|
item_thumbnail_url: str | None
|
||||||
|
item_published: datetime | None
|
||||||
|
item_author: str
|
||||||
|
item_author_url: str | None
|
||||||
|
item_feed_title: str
|
||||||
|
item_feed_url: str
|
||||||
|
_subscription: Subscription | None = None
|
||||||
|
blocked: bool = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parser(item: dict) -> dict:
|
||||||
|
item["id"] = item.pop("id")
|
||||||
|
item["subscription_id"] = item.pop("subscription")
|
||||||
|
return item
|
||||||
|
|
||||||
|
async def exists_via_api(self, url: str, headers: dict, client: httpx.AsyncClient):
|
||||||
|
log.debug(f"checking if {self.item_content_hash} exists via API")
|
||||||
|
params = {
|
||||||
|
"match_any": True, # allows any param to match, instead of needing all
|
||||||
|
"item_id": self.item_id,
|
||||||
|
"item_guid": self.item_guid,
|
||||||
|
"item_url": self.item_url,
|
||||||
|
"item_title": self.item_title,
|
||||||
|
"item_content_hash": self.item_content_hash,
|
||||||
|
"subscription": self.subscription_id
|
||||||
|
}
|
||||||
|
|
||||||
|
log.debug(f"params: {params}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get(
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
log.error(f"assuming not duplicate due to error: {exc}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return response.json().get("results", [])
|
||||||
|
|
||||||
|
def is_duplicate(self, other):
|
||||||
|
if not isinstance(other, Content):
|
||||||
|
raise ValueError(f"Expected Content, received {type(other)}")
|
||||||
|
|
||||||
|
other_details = other.duplicate_details
|
||||||
|
return any(
|
||||||
|
other_details.get(key) == value
|
||||||
|
for key, value in self.duplicate_details.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duplicate_details(self):
|
||||||
|
keys = [
|
||||||
|
"item_id",
|
||||||
|
"item_guid",
|
||||||
|
"item_url",
|
||||||
|
"item_title",
|
||||||
|
"item_content_hash"
|
||||||
|
]
|
||||||
|
data = asdict(self)
|
||||||
|
return { key: data[key] for key in keys }
|
||||||
|
|
||||||
|
async def save(self, client: httpx.AsyncClient, base_url: str, headers: dict):
|
||||||
|
log.debug(f"saving content {self.item_content_hash}")
|
||||||
|
|
||||||
|
data = asdict(self)
|
||||||
|
data.pop("id")
|
||||||
|
data["subscription"] = data.pop("subscription_id")
|
||||||
|
item_published = data.pop("item_published")
|
||||||
|
data["item_published"] = item_published.strftime("%Y-%m-%d") if item_published else None
|
||||||
|
data.pop("_subscription")
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
url=base_url + "content/",
|
||||||
|
headers=headers,
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
log.debug(f"save success for {self.item_content_hash}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def from_raw_rss(cls, rss: str, subscription: Subscription, client: httpx.AsyncClient):
|
||||||
|
style = subscription.message_style
|
||||||
|
parsed_rss = feedparser.parse(rss)
|
||||||
|
contents = []
|
||||||
|
|
||||||
|
async def create_content(entry: feedparser.FeedParserDict):
|
||||||
|
published = entry.get("published_parsed")
|
||||||
|
published = datetime(*published[0:6] if published else None, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
if published < subscription.publish_threshold:
|
||||||
|
log.debug("skipping due to publish threshold")
|
||||||
|
return
|
||||||
|
|
||||||
|
content_hash = hashlib.new("sha256")
|
||||||
|
content_hash.update(entry.get("description", "").encode())
|
||||||
|
|
||||||
|
item_url = entry.get("link", "")
|
||||||
|
item_image_url = entry.get("media_thumbnail", [{}])[0].get("url")
|
||||||
|
if style.fetch_images:
|
||||||
|
item_image_url = await cls.get_image_url(item_url, client)
|
||||||
|
|
||||||
|
content = Content.from_dict({
|
||||||
|
"id": -1,
|
||||||
|
"subscription": subscription.id,
|
||||||
|
"item_id": entry.get("id", ""),
|
||||||
|
"item_guid": entry.get("guid", ""),
|
||||||
|
"item_url": item_url,
|
||||||
|
"item_title": entry.get("title", ""),
|
||||||
|
"item_description": entry.get("description", ""),
|
||||||
|
"item_content_hash": content_hash.hexdigest(),
|
||||||
|
"item_image_url": item_image_url,
|
||||||
|
"item_thumbnail_url": parsed_rss.feed.image.href or None,
|
||||||
|
"item_published": published,
|
||||||
|
"item_author": entry.get("author", ""),
|
||||||
|
"item_author_url": entry.get("author_detail", {}).get("href"),
|
||||||
|
"item_feed_title": parsed_rss.get("feed", {}).get("title"),
|
||||||
|
"item_feed_url": parsed_rss.get("feed", {}).get("link")
|
||||||
|
})
|
||||||
|
|
||||||
|
# Weed out duplicates
|
||||||
|
log.debug("weeding out duplicates")
|
||||||
|
if any(content.is_duplicate(other) for other in contents):
|
||||||
|
log.debug("found duplicate while loading rss data")
|
||||||
|
return
|
||||||
|
|
||||||
|
content.subscription = subscription
|
||||||
|
contents.append(content)
|
||||||
|
|
||||||
|
await do_batch_job(parsed_rss.entries, create_content, 15)
|
||||||
|
contents.sort(key=lambda k: k.item_published)
|
||||||
|
return contents
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_image_url(url: str, client: httpx.AsyncClient) -> str | None:
|
||||||
|
log.debug("Fetching image url")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get(url, timeout=15)
|
||||||
|
except httpx.HTTPError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
image_element = soup.select_one("meta[property='og:image']")
|
||||||
|
if not image_element:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return image_element.get("content")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subscription(self) -> Subscription:
|
||||||
|
return self._subscription
|
||||||
|
|
||||||
|
@subscription.setter
|
||||||
|
def subscription(self, subscription: Subscription):
|
||||||
|
self._subscription = subscription
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embed(self):
|
||||||
|
colour=discord.Colour.from_str(
|
||||||
|
f"#{self.subscription.message_style.colour}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ensure content fits within character limits
|
||||||
|
title = shorten(markdownify(self.item_title, strip=("img", "a")), 256)
|
||||||
|
description = shorten(markdownify(self.item_description, strip=("img",)), 4096)
|
||||||
|
author = self.item_author or self.item_feed_title
|
||||||
|
|
||||||
|
combined_length = len(title) + len(description) + (len(author) * 2)
|
||||||
|
cutoff = combined_length - 6000
|
||||||
|
description = shorten(description, cutoff) if cutoff > 0 else description
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
url=self.item_url,
|
||||||
|
colour=colour,
|
||||||
|
timestamp=self.item_published
|
||||||
|
)
|
||||||
|
|
||||||
|
embed.set_image(url=self.item_image_url)
|
||||||
|
embed.set_thumbnail(url=self.item_thumbnail_url)
|
||||||
|
embed.set_author(
|
||||||
|
name=author,
|
||||||
|
url=self.item_author_url or self.item_feed_url
|
||||||
|
)
|
||||||
|
embed.set_footer(text=self.subscription.name)
|
||||||
|
|
||||||
|
log.debug(f"created embed: {embed.to_dict()}")
|
||||||
|
|
||||||
|
return embed
|
0
src/tests/__init__.py
Normal file
0
src/tests/__init__.py
Normal file
67
src/tests/test_content.py
Normal file
67
src/tests/test_content.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from models import MatchingAlgorithm, ContentFilter, Content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def content() -> Content:
|
||||||
|
return Content(
|
||||||
|
id=0,
|
||||||
|
subscription_id=0,
|
||||||
|
item_id="",
|
||||||
|
item_guid="",
|
||||||
|
item_url="",
|
||||||
|
item_title="This week in the papers:",
|
||||||
|
item_description="The price of petrol has risen by over 2% since the previous financial report. Read the full article here.",
|
||||||
|
item_content_hash="",
|
||||||
|
item_image_url=None,
|
||||||
|
item_thumbnail_url=None,
|
||||||
|
item_published=None,
|
||||||
|
item_author="",
|
||||||
|
item_author_url=None,
|
||||||
|
item_feed_title="",
|
||||||
|
item_feed_url=""
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def content_filter() -> ContentFilter:
|
||||||
|
return ContentFilter(
|
||||||
|
id=0,
|
||||||
|
server_id=0,
|
||||||
|
name="Test Content Filter",
|
||||||
|
matching_pattern="",
|
||||||
|
matching_algorithm=MatchingAlgorithm.NONE,
|
||||||
|
is_insensitive=True,
|
||||||
|
is_whitelist=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_content_filter_any(content: Content, content_filter: ContentFilter):
|
||||||
|
content_filter.matching_pattern = "france twenty report grass lately"
|
||||||
|
content_filter.matching_algorithm = MatchingAlgorithm.ANY
|
||||||
|
assert content_filter.matches(content)
|
||||||
|
|
||||||
|
def test_content_filter_all(content: Content, content_filter: ContentFilter):
|
||||||
|
content_filter.matching_pattern = "week petrol risen since"
|
||||||
|
content_filter.matching_algorithm = MatchingAlgorithm.ALL
|
||||||
|
assert content_filter.matches(content)
|
||||||
|
|
||||||
|
def test_content_filter_literal(content: Content, content_filter: ContentFilter):
|
||||||
|
content_filter.matching_pattern = "this week in the papers"
|
||||||
|
content_filter.matching_algorithm = MatchingAlgorithm.LITERAL
|
||||||
|
assert content_filter.matches(content)
|
||||||
|
|
||||||
|
def test_content_filter_regex(content: Content, content_filter: ContentFilter):
|
||||||
|
content_filter.matching_pattern = r"\b(The Papers|weekly quiz)\b"
|
||||||
|
content_filter.matching_algorithm = MatchingAlgorithm.REGEX
|
||||||
|
assert content_filter.matches(content)
|
||||||
|
|
||||||
|
# def test_content_filter_fuzzy(content: Content, content_filter: ContentFilter):
|
||||||
|
# content_filter.matching_algorithm = "this week in the papers"
|
||||||
|
# content_filter.matching_algorithm = MatchingAlgorithm.FUZZY
|
||||||
|
# assert content_filter.matches(content)
|
||||||
|
|
||||||
|
def test_content_duplicates(content: Content):
|
||||||
|
copy_of_content = replace(content)
|
||||||
|
assert content.is_duplicate(copy_of_content)
|
532
src/utils.py
532
src/utils.py
@ -1,10 +1,10 @@
|
|||||||
"""A collection of utility functions that can be used in various places."""
|
"""A collection of utility functions that can be used in various places."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
# import aiohttp
|
||||||
import logging
|
import logging
|
||||||
import async_timeout
|
# import async_timeout
|
||||||
from typing import Callable
|
# from typing import Callable
|
||||||
|
|
||||||
from discord import Interaction, Embed, Colour, ButtonStyle, Button
|
from discord import Interaction, Embed, Colour, ButtonStyle, Button
|
||||||
from discord.ui import View, button
|
from discord.ui import View, button
|
||||||
@ -12,325 +12,335 @@ from discord.ext.commands import Bot
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def fetch(session, url: str) -> str:
|
async def do_batch_job(iterable: list, func, batch_size: int):
|
||||||
async with async_timeout.timeout(20):
|
semaphore = asyncio.Semaphore(batch_size)
|
||||||
async with session.get(url) as response:
|
|
||||||
return await response.text()
|
|
||||||
|
|
||||||
async def get_unparsed_feed(url: str, session: aiohttp.ClientSession=None):
|
async def batch_job(item):
|
||||||
if session is not None:
|
async with semaphore:
|
||||||
return await fetch(session, url)
|
await func(item)
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
tasks = [batch_job(item) for item in iterable]
|
||||||
return await fetch(session, url)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
async def get_rss_data(url: str):
|
# async def fetch(session, url: str) -> str:
|
||||||
async with aiohttp.ClientSession() as session:
|
# async with async_timeout.timeout(20):
|
||||||
async with session.get(url) as response:
|
# async with session.get(url) as response:
|
||||||
items = await response.text(), response.status
|
# return await response.text()
|
||||||
|
|
||||||
return items
|
# async def get_unparsed_feed(url: str, session: aiohttp.ClientSession=None):
|
||||||
|
# if session is not None:
|
||||||
|
# return await fetch(session, url)
|
||||||
|
|
||||||
async def followup(inter: Interaction, *args, **kwargs):
|
# async with aiohttp.ClientSession() as session:
|
||||||
"""Shorthand for following up on an interaction.
|
# return await fetch(session, url)
|
||||||
|
|
||||||
Parameters
|
# async def get_rss_data(url: str):
|
||||||
----------
|
# async with aiohttp.ClientSession() as session:
|
||||||
inter : Interaction
|
# async with session.get(url) as response:
|
||||||
Represents an app command interaction.
|
# items = await response.text(), response.status
|
||||||
"""
|
|
||||||
|
|
||||||
await inter.followup.send(*args, **kwargs)
|
# return items
|
||||||
|
|
||||||
|
# async def followup(inter: Interaction, *args, **kwargs):
|
||||||
|
# """Shorthand for following up on an interaction.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# inter : Interaction
|
||||||
|
# Represents an app command interaction.
|
||||||
|
# """
|
||||||
|
|
||||||
|
# await inter.followup.send(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
|
# # https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
|
||||||
|
|
||||||
class FollowupIcons:
|
# class FollowupIcons:
|
||||||
error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
|
# error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
|
||||||
success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"
|
# success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"
|
||||||
trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png"
|
# trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png"
|
||||||
info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png"
|
# info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png"
|
||||||
added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png"
|
# added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png"
|
||||||
assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
|
# assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
|
||||||
|
|
||||||
|
|
||||||
class PaginationView(View):
|
# class PaginationView(View):
|
||||||
"""A Discord UI View that adds pagination to an embed."""
|
# """A Discord UI View that adds pagination to an embed."""
|
||||||
|
|
||||||
def __init__(
|
# def __init__(
|
||||||
self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable,
|
# self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable,
|
||||||
formatdata: Callable, pagesize: int, initpage: int=1
|
# formatdata: Callable, pagesize: int, initpage: int=1
|
||||||
):
|
# ):
|
||||||
"""_summary_
|
# """_summary_
|
||||||
|
|
||||||
Args:
|
# Args:
|
||||||
bot (commands.Bot) The discord bot
|
# bot (commands.Bot) The discord bot
|
||||||
inter (Interaction): Represents a discord command interaction.
|
# inter (Interaction): Represents a discord command interaction.
|
||||||
embed (Embed): The base embed to paginate.
|
# embed (Embed): The base embed to paginate.
|
||||||
getdata (Callable): A function that provides data, must return Tuple[List[Any], int].
|
# getdata (Callable): A function that provides data, must return Tuple[List[Any], int].
|
||||||
formatdata (Callable): A formatter function that determines how the data is displayed.
|
# formatdata (Callable): A formatter function that determines how the data is displayed.
|
||||||
pagesize (int): The size of each page.
|
# pagesize (int): The size of each page.
|
||||||
initpage (int, optional): The inital page. Defaults to 1.
|
# initpage (int, optional): The inital page. Defaults to 1.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
self.bot = bot
|
# self.bot = bot
|
||||||
self.inter = inter
|
# self.inter = inter
|
||||||
self.embed = embed
|
# self.embed = embed
|
||||||
self.getdata = getdata
|
# self.getdata = getdata
|
||||||
self.formatdata = formatdata
|
# self.formatdata = formatdata
|
||||||
self.maxpage = None
|
# self.maxpage = None
|
||||||
self.pagesize = pagesize
|
# self.pagesize = pagesize
|
||||||
self.index = initpage
|
# self.index = initpage
|
||||||
|
|
||||||
# emoji reference
|
# # emoji reference
|
||||||
self.start_emoji = bot.get_emoji(1204542364073463818)
|
# self.start_emoji = bot.get_emoji(1204542364073463818)
|
||||||
self.end_emoji = bot.get_emoji(1204542367752003624)
|
# self.end_emoji = bot.get_emoji(1204542367752003624)
|
||||||
|
|
||||||
super().__init__(timeout=100)
|
# super().__init__(timeout=100)
|
||||||
|
|
||||||
async def check_user_is_author(self, inter: Interaction) -> bool:
|
# async def check_user_is_author(self, inter: Interaction) -> bool:
|
||||||
"""Ensure the user is the author of the original command."""
|
# """Ensure the user is the author of the original command."""
|
||||||
|
|
||||||
if inter.user == self.inter.user:
|
# if inter.user == self.inter.user:
|
||||||
return True
|
# return True
|
||||||
|
|
||||||
await inter.response.defer()
|
# await inter.response.defer()
|
||||||
await (
|
# await (
|
||||||
Followup(None, "Only the author can interact with this.")
|
# Followup(None, "Only the author can interact with this.")
|
||||||
.error()
|
# .error()
|
||||||
.send(inter, ephemeral=True)
|
# .send(inter, ephemeral=True)
|
||||||
)
|
# )
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
async def on_timeout(self):
|
# async def on_timeout(self):
|
||||||
"""Erase the controls on timeout."""
|
# """Erase the controls on timeout."""
|
||||||
|
|
||||||
message = await self.inter.original_response()
|
# message = await self.inter.original_response()
|
||||||
await message.edit(view=None)
|
# await message.edit(view=None)
|
||||||
|
|
||||||
@staticmethod
|
# @staticmethod
|
||||||
def calc_total_pages(results: int, max_pagesize: int) -> int:
|
# def calc_total_pages(results: int, max_pagesize: int) -> int:
|
||||||
result = ((results - 1) // max_pagesize) + 1
|
# result = ((results - 1) // max_pagesize) + 1
|
||||||
return result
|
# return result
|
||||||
|
|
||||||
def calc_dataitem_index(self, dataitem_index: int):
|
# def calc_dataitem_index(self, dataitem_index: int):
|
||||||
"""Calculates a given index to be relative to the sum of all pages items
|
# """Calculates a given index to be relative to the sum of all pages items
|
||||||
|
|
||||||
Example: dataitem_index = 6
|
# Example: dataitem_index = 6
|
||||||
pagesize = 10
|
# pagesize = 10
|
||||||
if page == 1 then return 6
|
# if page == 1 then return 6
|
||||||
else return 6 + 10 * (page - 1)"""
|
# else return 6 + 10 * (page - 1)"""
|
||||||
|
|
||||||
if self.index > 1:
|
# if self.index > 1:
|
||||||
dataitem_index += self.pagesize * (self.index - 1)
|
# dataitem_index += self.pagesize * (self.index - 1)
|
||||||
|
|
||||||
dataitem_index += 1
|
# dataitem_index += 1
|
||||||
return dataitem_index
|
# return dataitem_index
|
||||||
|
|
||||||
@button(emoji="◀️", style=ButtonStyle.blurple)
|
# @button(emoji="◀️", style=ButtonStyle.blurple)
|
||||||
async def backward(self, inter: Interaction, button: Button):
|
# async def backward(self, inter: Interaction, button: Button):
|
||||||
"""
|
# """
|
||||||
Action the backwards button.
|
# Action the backwards button.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
self.index -= 1
|
# self.index -= 1
|
||||||
await inter.response.defer()
|
# await inter.response.defer()
|
||||||
self.inter = inter
|
# self.inter = inter
|
||||||
await self.navigate()
|
# await self.navigate()
|
||||||
|
|
||||||
@button(emoji="▶️", style=ButtonStyle.blurple)
|
# @button(emoji="▶️", style=ButtonStyle.blurple)
|
||||||
async def forward(self, inter: Interaction, button: Button):
|
# async def forward(self, inter: Interaction, button: Button):
|
||||||
"""
|
# """
|
||||||
Action the forwards button.
|
# Action the forwards button.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
self.index += 1
|
# self.index += 1
|
||||||
await inter.response.defer()
|
# await inter.response.defer()
|
||||||
self.inter = inter
|
# self.inter = inter
|
||||||
await self.navigate()
|
# await self.navigate()
|
||||||
|
|
||||||
@button(emoji="⏭️", style=ButtonStyle.blurple)
|
# @button(emoji="⏭️", style=ButtonStyle.blurple)
|
||||||
async def start_or_end(self, inter: Interaction, button: Button):
|
# async def start_or_end(self, inter: Interaction, button: Button):
|
||||||
"""
|
# """
|
||||||
Action the start and end button.
|
# Action the start and end button.
|
||||||
This button becomes return to start if at end, otherwise skip to end.
|
# This button becomes return to start if at end, otherwise skip to end.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
# Determine if should skip to start or end
|
# # Determine if should skip to start or end
|
||||||
if self.index <= self.maxpage // 2:
|
# if self.index <= self.maxpage // 2:
|
||||||
self.index = self.maxpage
|
# self.index = self.maxpage
|
||||||
else:
|
# else:
|
||||||
self.index = 1
|
# self.index = 1
|
||||||
|
|
||||||
await inter.response.defer()
|
# await inter.response.defer()
|
||||||
self.inter = inter
|
# self.inter = inter
|
||||||
await self.navigate()
|
# await self.navigate()
|
||||||
|
|
||||||
async def navigate(self):
|
# async def navigate(self):
|
||||||
"""
|
# """
|
||||||
Acts as an update method for the entire instance.
|
# Acts as an update method for the entire instance.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
log.debug("navigating to page: %s", self.index)
|
# log.debug("navigating to page: %s", self.index)
|
||||||
|
|
||||||
self.update_buttons()
|
# self.update_buttons()
|
||||||
paged_embed = await self.create_paged_embed()
|
# paged_embed = await self.create_paged_embed()
|
||||||
await self.inter.edit_original_response(embed=paged_embed, view=self)
|
# await self.inter.edit_original_response(embed=paged_embed, view=self)
|
||||||
|
|
||||||
async def create_paged_embed(self) -> Embed:
|
# async def create_paged_embed(self) -> Embed:
|
||||||
"""
|
# """
|
||||||
Returns a copy of the known embed, but with data from the current page.
|
# Returns a copy of the known embed, but with data from the current page.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
embed = self.embed.copy()
|
# embed = self.embed.copy()
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
data, total_results = await self.getdata(self.index, self.pagesize)
|
# data, total_results = await self.getdata(self.index, self.pagesize)
|
||||||
except aiohttp.ClientResponseError as exc:
|
# except aiohttp.ClientResponseError as exc:
|
||||||
log.error(exc)
|
# log.error(exc)
|
||||||
await (
|
# await (
|
||||||
Followup(f"Error · {exc.message}",)
|
# Followup(f"Error · {exc.message}",)
|
||||||
.footer(f"HTTP {exc.code}")
|
# .footer(f"HTTP {exc.code}")
|
||||||
.error()
|
# .error()
|
||||||
.send(self.inter)
|
# .send(self.inter)
|
||||||
)
|
# )
|
||||||
raise exc
|
# raise exc
|
||||||
|
|
||||||
self.maxpage = self.calc_total_pages(total_results, self.pagesize)
|
# self.maxpage = self.calc_total_pages(total_results, self.pagesize)
|
||||||
log.debug(f"{self.maxpage=!r}")
|
# log.debug(f"{self.maxpage=!r}")
|
||||||
|
|
||||||
for i, item in enumerate(data):
|
# for i, item in enumerate(data):
|
||||||
i = self.calc_dataitem_index(i)
|
# i = self.calc_dataitem_index(i)
|
||||||
if asyncio.iscoroutinefunction(self.formatdata):
|
# if asyncio.iscoroutinefunction(self.formatdata):
|
||||||
key, value = await self.formatdata(i, item)
|
# key, value = await self.formatdata(i, item)
|
||||||
else:
|
# else:
|
||||||
key, value = self.formatdata(i, item)
|
# key, value = self.formatdata(i, item)
|
||||||
|
|
||||||
embed.add_field(name=key, value=value, inline=False)
|
# embed.add_field(name=key, value=value, inline=False)
|
||||||
|
|
||||||
if not total_results:
|
# if not total_results:
|
||||||
embed.description = "There are no results"
|
# embed.description = "There are no results"
|
||||||
|
|
||||||
if self.maxpage > 1:
|
# if self.maxpage > 1:
|
||||||
embed.set_footer(text=f"Page {self.index}/{self.maxpage}")
|
# embed.set_footer(text=f"Page {self.index}/{self.maxpage}")
|
||||||
|
|
||||||
return embed
|
# return embed
|
||||||
|
|
||||||
def update_buttons(self):
|
# def update_buttons(self):
|
||||||
if self.index >= self.maxpage:
|
# if self.index >= self.maxpage:
|
||||||
self.children[2].emoji = self.start_emoji
|
# self.children[2].emoji = self.start_emoji
|
||||||
else:
|
# else:
|
||||||
self.children[2].emoji = self.end_emoji
|
# self.children[2].emoji = self.end_emoji
|
||||||
|
|
||||||
self.children[0].disabled = self.index == 1
|
# self.children[0].disabled = self.index == 1
|
||||||
self.children[1].disabled = self.index == self.maxpage
|
# self.children[1].disabled = self.index == self.maxpage
|
||||||
|
|
||||||
async def send(self):
|
# async def send(self):
|
||||||
"""Send the pagination view. It may be important to defer before invoking this method."""
|
# """Send the pagination view. It may be important to defer before invoking this method."""
|
||||||
|
|
||||||
log.debug("sending pagination view")
|
# log.debug("sending pagination view")
|
||||||
embed = await self.create_paged_embed()
|
# embed = await self.create_paged_embed()
|
||||||
|
|
||||||
if self.maxpage <= 1:
|
# if self.maxpage <= 1:
|
||||||
await self.inter.edit_original_response(embed=embed)
|
# await self.inter.edit_original_response(embed=embed)
|
||||||
return
|
# return
|
||||||
|
|
||||||
self.update_buttons()
|
# self.update_buttons()
|
||||||
await self.inter.edit_original_response(embed=embed, view=self)
|
# await self.inter.edit_original_response(embed=embed, view=self)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Followup:
|
# class Followup:
|
||||||
"""Wrapper for a discord embed to follow up an interaction."""
|
# """Wrapper for a discord embed to follow up an interaction."""
|
||||||
|
|
||||||
def __init__(
|
# def __init__(
|
||||||
self,
|
# self,
|
||||||
title: str = None,
|
# title: str = None,
|
||||||
description: str = None,
|
# description: str = None,
|
||||||
):
|
# ):
|
||||||
self._embed = Embed(
|
# self._embed = Embed(
|
||||||
title=title,
|
# title=title,
|
||||||
description=description
|
# description=description
|
||||||
)
|
# )
|
||||||
|
|
||||||
async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False):
|
# async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral)
|
# await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral)
|
||||||
|
|
||||||
def fields(self, inline: bool = False, **fields: dict):
|
# def fields(self, inline: bool = False, **fields: dict):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
for key, value in fields.items():
|
# for key, value in fields.items():
|
||||||
self._embed.add_field(name=key, value=value, inline=inline)
|
# self._embed.add_field(name=key, value=value, inline=inline)
|
||||||
|
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def image(self, url: str):
|
# def image(self, url: str):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.set_image(url=url)
|
# self._embed.set_image(url=url)
|
||||||
|
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def author(self, name: str, url: str=None, icon_url: str=None):
|
# def author(self, name: str, url: str=None, icon_url: str=None):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.set_author(name=name, url=url, icon_url=icon_url)
|
# self._embed.set_author(name=name, url=url, icon_url=icon_url)
|
||||||
|
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def footer(self, text: str, icon_url: str = None):
|
# def footer(self, text: str, icon_url: str = None):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.set_footer(text=text, icon_url=icon_url)
|
# self._embed.set_footer(text=text, icon_url=icon_url)
|
||||||
|
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def error(self):
|
# def error(self):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.colour = Colour.red()
|
# self._embed.colour = Colour.red()
|
||||||
self._embed.set_thumbnail(url=FollowupIcons.error)
|
# self._embed.set_thumbnail(url=FollowupIcons.error)
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def success(self):
|
# def success(self):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.colour = Colour.green()
|
# self._embed.colour = Colour.green()
|
||||||
self._embed.set_thumbnail(url=FollowupIcons.success)
|
# self._embed.set_thumbnail(url=FollowupIcons.success)
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def info(self):
|
# def info(self):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.colour = Colour.blue()
|
# self._embed.colour = Colour.blue()
|
||||||
self._embed.set_thumbnail(url=FollowupIcons.info)
|
# self._embed.set_thumbnail(url=FollowupIcons.info)
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def added(self):
|
# def added(self):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.colour = Colour.blue()
|
# self._embed.colour = Colour.blue()
|
||||||
self._embed.set_thumbnail(url=FollowupIcons.added)
|
# self._embed.set_thumbnail(url=FollowupIcons.added)
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def assign(self):
|
# def assign(self):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.colour = Colour.blue()
|
# self._embed.colour = Colour.blue()
|
||||||
self._embed.set_thumbnail(url=FollowupIcons.assigned)
|
# self._embed.set_thumbnail(url=FollowupIcons.assigned)
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def trash(self):
|
# def trash(self):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.colour = Colour.red()
|
# self._embed.colour = Colour.red()
|
||||||
self._embed.set_thumbnail(url=FollowupIcons.trash)
|
# self._embed.set_thumbnail(url=FollowupIcons.trash)
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
|
|
||||||
def extract_error_info(error: Exception) -> str:
|
# def extract_error_info(error: Exception) -> str:
|
||||||
class_name = error.__class__.__name__
|
# class_name = error.__class__.__name__
|
||||||
desc = str(error)
|
# desc = str(error)
|
||||||
return class_name, desc
|
# return class_name, desc
|
||||||
|
Loading…
x
Reference in New Issue
Block a user