Compare commits
36 Commits
Author | SHA1 | Date | |
---|---|---|---|
7d395d2001 | |||
b0b02aae9f | |||
ca2037528d | |||
a9c08dea4d | |||
8fcccc7ac9 | |||
efbb4b18ff | |||
81795feb65 | |||
1f9075ce60 | |||
6e153782c2 | |||
d86fc0eb71 | |||
08295dfea6 | |||
470f78c144 | |||
94b154742e | |||
1294909188 | |||
cf8fb34a29 | |||
eb97dca5c6 | |||
b8f1ffb8d9 | |||
ccfa35adda | |||
cc06d3e09f | |||
f9de8ff085 | |||
02a4917152 | |||
ff30e31cf1 | |||
50f7a62cd4 | |||
4c31fe69e2 | |||
bb3475b79d | |||
82fe6bea9a | |||
32b8092034 | |||
26f697cf78 | |||
ab472e1979 | |||
7515d6b86e | |||
a5c593bb14 | |||
48697c08a6 | |||
cbe15aceb8 | |||
e8a9c270e4 | |||
2a1aaa689e | |||
1a4f25ec97 |
@ -1,4 +1,4 @@
|
||||
[bumpversion]
|
||||
current_version = 0.2.0
|
||||
current_version = 0.2.1
|
||||
commit = True
|
||||
tag = True
|
||||
|
60
CHANGELOG.md
60
CHANGELOG.md
@ -1,14 +1,54 @@
|
||||
# Changelog
|
||||
|
||||
**v0.2.0**
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
- Fix: Fetch channels if not found in bot cache (error fix)
|
||||
- Enhancement: command to test a channel's permissions allow for the Bot to function
|
||||
- Enhancement: account for active state from a server's settings (`GuildSettings`)
|
||||
- Enhancement: command to view tracked content from the server or a given subscription of the same server.
|
||||
- Other: code optimisation & `GuildSettings` dataclass
|
||||
- Other: Cleaned out many instances of unused code
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
**v0.1.1**
|
||||
## [Unreleased]
|
||||
|
||||
- Docs: Start of changelog
|
||||
- Enhancement: Versioning with tagged docker images
|
||||
### Added
|
||||
|
||||
- Search and filter controls for the data viewing commands
|
||||
|
||||
### Fixed
|
||||
|
||||
- RSS feeds without a build date would break the subscription task
|
||||
- TypeError when an RSS item lacks a title or description
|
||||
- Fix an issue with a missing field on the Subscription model
|
||||
|
||||
### Changed
|
||||
|
||||
- Show whether a subscription is active or inactive when using a data view command
|
||||
- Added `unique_content_rules` field to `Subscription` dataclass (support latest pyrss-website version)
|
||||
- Update changelog to follow [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
|
||||
|
||||
## [0.2.0] - 2024-08-19
|
||||
|
||||
### Added
|
||||
|
||||
- Command to view tracked content from the relevant server
|
||||
- Command to test the bot's permissions in a specified channel
|
||||
- `GuildSettings` dataclass
|
||||
|
||||
### Fixed
|
||||
|
||||
- channels are `NoneType` because they didn't exist in the cache, fixed by fetching from API
|
||||
|
||||
### Changed
|
||||
|
||||
- Subscription task will ignore subscriptions flagged as 'inactive'
|
||||
- Code optimisation
|
||||
|
||||
### Removed
|
||||
|
||||
- Unused and commented out code
|
||||
|
||||
## [0.1.1] - 2024-08-17
|
||||
|
||||
### Added
|
||||
|
||||
- Start of changelog
|
||||
- Versioning with tagged docker images
|
||||
|
||||
## [0.1.0] - 2024-08-13
|
||||
|
@ -4,4 +4,4 @@ An RSS driven Discord bot written in Python.
|
||||
|
||||
Provides user commands for storing RSS feed URLs that can be assigned to any given discord channel.
|
||||
|
||||
Depends on the [web application](https://gitea.corbz.dev/corbz/PYRSS-Website). Check the releases for compatible versions.
|
||||
Depends on the [web application](https://gitea.cor.bz/corbz/PYRSS-Website). Check the releases for compatible versions.
|
||||
|
@ -1,15 +1,15 @@
|
||||
{
|
||||
"version": 1,
|
||||
"disable_existing_loggers": false,
|
||||
"disable_existing_loggers": true,
|
||||
"formatters": {
|
||||
"simple": {
|
||||
"format": "%(levelname)s %(message)s"
|
||||
"format": "[%(module)s|%(message)s]"
|
||||
},
|
||||
"detail": {
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s"
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(module)s]: %(message)s"
|
||||
},
|
||||
"complex": {
|
||||
"format": "[%(levelname)s|%(module)s|L%(lineno)d] %(asctime)s %(message)s",
|
||||
"format": "[%(levelname)s|%(name)s|L%(lineno)d] %(asctime)s %(message)s",
|
||||
"datefmt": "%Y-%m-%dT%H:%M:%S%z"
|
||||
}
|
||||
},
|
||||
@ -17,7 +17,7 @@
|
||||
"stdout": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"formatter": "simple",
|
||||
"formatter": "detail",
|
||||
"stream": "ext://sys.stdout"
|
||||
},
|
||||
"file": {
|
||||
@ -46,6 +46,14 @@
|
||||
},
|
||||
"discord": {
|
||||
"level": "INFO"
|
||||
},
|
||||
"httpx": {
|
||||
"level": "WARNING",
|
||||
"propagate": false
|
||||
},
|
||||
"httpcore": {
|
||||
"level": "WARNING",
|
||||
"propagate": false
|
||||
}
|
||||
}
|
||||
}
|
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@ -0,0 +1,3 @@
|
||||
[pytest]
|
||||
filterwarnings =
|
||||
ignore:'audioop' is deprecated and slated for removal in Python 3.13:DeprecationWarning
|
@ -1,29 +1,30 @@
|
||||
aiocache==0.12.2
|
||||
aiohttp==3.9.3
|
||||
aiohappyeyeballs==2.4.3
|
||||
aiohttp==3.10.10
|
||||
aiosignal==1.3.1
|
||||
aiosqlite==0.19.0
|
||||
async-timeout==4.0.3
|
||||
asyncpg==0.29.0
|
||||
attrs==23.2.0
|
||||
anyio==4.6.2.post1
|
||||
attrs==24.2.0
|
||||
beautifulsoup4==4.12.3
|
||||
bump2version==1.0.1
|
||||
click==8.1.7
|
||||
certifi==2024.8.30
|
||||
discord.py==2.3.2
|
||||
feedparser==6.0.11
|
||||
frozenlist==1.4.1
|
||||
greenlet==3.0.3
|
||||
idna==3.6
|
||||
frozenlist==1.5.0
|
||||
h11==0.14.0
|
||||
httpcore==1.0.6
|
||||
httpx==0.27.2
|
||||
idna==3.10
|
||||
iniconfig==2.0.0
|
||||
markdownify==0.11.6
|
||||
multidict==6.0.5
|
||||
pip-chill==1.0.3
|
||||
psycopg2-binary==2.9.9
|
||||
multidict==6.1.0
|
||||
packaging==24.2
|
||||
pluggy==1.5.0
|
||||
propcache==0.2.0
|
||||
pytest==8.3.3
|
||||
python-dotenv==1.0.0
|
||||
rapidfuzz==3.9.4
|
||||
sgmllib3k==1.0.0
|
||||
six==1.16.0
|
||||
soupsieve==2.5
|
||||
SQLAlchemy==2.0.23
|
||||
typing_extensions==4.10.0
|
||||
sniffio==1.3.1
|
||||
soupsieve==2.6
|
||||
uwuipy==0.1.9
|
||||
validators==0.22.0
|
||||
yarl==1.9.4
|
||||
yarl==1.17.0
|
||||
|
@ -157,7 +157,7 @@ class API:
|
||||
|
||||
log.debug("getting tracked content")
|
||||
|
||||
return await self._get_many(self.API_ENDPOINT + f"tracked-content/", filters)
|
||||
return await self._get_many(self.API_ENDPOINT + "tracked-content/", filters)
|
||||
|
||||
async def get_filter(self, filter_id: int) -> dict:
|
||||
"""
|
||||
@ -167,3 +167,10 @@ class API:
|
||||
log.debug("getting a filter")
|
||||
|
||||
return await self._get_one(f"{self.API_ENDPOINT}filter/{filter_id}")
|
||||
|
||||
async def get_filters(self, **filters) -> tuple[list[dict], int]:
|
||||
"""
|
||||
Get many instances of Filter.
|
||||
"""
|
||||
|
||||
return await self._get_many(self.API_ENDPOINT + "filter/", filters)
|
||||
|
@ -7,76 +7,82 @@ import logging
|
||||
from typing import Tuple
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
import validators
|
||||
# import aiohttp
|
||||
# import validators
|
||||
from feedparser import FeedParserDict, parse
|
||||
from discord.ext import commands
|
||||
from discord import Interaction, TextChannel, Embed, Colour
|
||||
from discord.app_commands import Choice, Group, autocomplete, rename, command
|
||||
from discord.app_commands import Choice, choices, Group, autocomplete, rename, command
|
||||
from discord.errors import Forbidden
|
||||
|
||||
from api import API
|
||||
from feed import Subscription, TrackedContent
|
||||
from utils import (
|
||||
Followup,
|
||||
PaginationView,
|
||||
get_rss_data,
|
||||
)
|
||||
# from api import API
|
||||
# from feed import Subscription, TrackedContent, ContentFilter
|
||||
# from utils import (
|
||||
# Followup,
|
||||
# PaginationView,
|
||||
# get_rss_data,
|
||||
# )
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
rss_list_sort_choices = [
|
||||
Choice(name="Nickname", value=0),
|
||||
Choice(name="Date Added", value=1)
|
||||
]
|
||||
channels_list_sort_choices=[
|
||||
Choice(name="Feed Nickname", value=0),
|
||||
Choice(name="Channel ID", value=1),
|
||||
Choice(name="Date Added", value=2)
|
||||
]
|
||||
# rss_list_sort_choices = [
|
||||
# Choice(name="Nickname", value=0),
|
||||
# Choice(name="Date Added", value=1)
|
||||
# ]
|
||||
# channels_list_sort_choices=[
|
||||
# Choice(name="Feed Nickname", value=0),
|
||||
# Choice(name="Channel ID", value=1),
|
||||
# Choice(name="Date Added", value=2)
|
||||
# ]
|
||||
|
||||
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the
|
||||
# target resource. Run a period task to check this.
|
||||
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]:
|
||||
"""Validate a provided RSS source.
|
||||
# # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the
|
||||
# # target resource. Run a period task to check this.
|
||||
# async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]:
|
||||
# """Validate a provided RSS source.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nickname : str
|
||||
Nickname of the source. Must not contain URL.
|
||||
url : str
|
||||
URL of the source. Must be URL with valid status code and be an RSS feed.
|
||||
# Parameters
|
||||
# ----------
|
||||
# nickname : str
|
||||
# Nickname of the source. Must not contain URL.
|
||||
# url : str
|
||||
# URL of the source. Must be URL with valid status code and be an RSS feed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
String invalid message if invalid, NoneType if valid.
|
||||
FeedParserDict or None
|
||||
The feed parsed from the given URL or None if invalid.
|
||||
"""
|
||||
# Returns
|
||||
# -------
|
||||
# str or None
|
||||
# String invalid message if invalid, NoneType if valid.
|
||||
# FeedParserDict or None
|
||||
# The feed parsed from the given URL or None if invalid.
|
||||
# """
|
||||
|
||||
# Ensure the URL is valid
|
||||
if not validators.url(url):
|
||||
return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
|
||||
# # Ensure the URL is valid
|
||||
# if not validators.url(url):
|
||||
# return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
|
||||
|
||||
# Check the nickname is not a URL
|
||||
if validators.url(nickname):
|
||||
return "It looks like the nickname you have entered is a URL.\n" \
|
||||
f"For security reasons, this is not allowed.\n`{nickname=}`", None
|
||||
# # Check the nickname is not a URL
|
||||
# if validators.url(nickname):
|
||||
# return "It looks like the nickname you have entered is a URL.\n" \
|
||||
# f"For security reasons, this is not allowed.\n`{nickname=}`", None
|
||||
|
||||
|
||||
feed_data, status_code = await get_rss_data(url)
|
||||
# feed_data, status_code = await get_rss_data(url)
|
||||
|
||||
# Check the URL status code is valid
|
||||
if status_code != 200:
|
||||
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
|
||||
# # Check the URL status code is valid
|
||||
# if status_code != 200:
|
||||
# return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
|
||||
|
||||
# Check the contents is actually an RSS feed.
|
||||
feed = parse(feed_data)
|
||||
if not feed.version:
|
||||
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
|
||||
# # Check the contents is actually an RSS feed.
|
||||
# feed = parse(feed_data)
|
||||
# if not feed.version:
|
||||
# return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
|
||||
|
||||
return None, feed
|
||||
# return None, feed
|
||||
|
||||
# tri_choices = [
|
||||
# Choice(name="Yes", value=2),
|
||||
# Choice(name="No (default)", value=1),
|
||||
# Choice(name="All", value=0),
|
||||
# ]
|
||||
|
||||
|
||||
class CommandsCog(commands.Cog):
|
||||
@ -94,135 +100,206 @@ class CommandsCog(commands.Cog):
|
||||
|
||||
log.info("%s cog is ready", self.__class__.__name__)
|
||||
|
||||
# Group for commands about viewing data
|
||||
view_group = Group(
|
||||
name="view",
|
||||
description="View data.",
|
||||
guild_only=True
|
||||
)
|
||||
# # Group for commands about viewing data
|
||||
# view_group = Group(
|
||||
# name="view",
|
||||
# description="View data.",
|
||||
# guild_only=True
|
||||
# )
|
||||
|
||||
@view_group.command(name="subscriptions")
|
||||
async def cmd_list_subs(self, inter: Interaction, search: str = ""):
|
||||
"""List Subscriptions from this server."""
|
||||
# @view_group.command(name="subscriptions")
|
||||
# async def cmd_list_subs(self, inter: Interaction, search: str = ""):
|
||||
# """List Subscriptions from this server."""
|
||||
|
||||
await inter.response.defer()
|
||||
# await inter.response.defer()
|
||||
|
||||
def formatdata(index, item):
|
||||
item = Subscription.from_dict(item)
|
||||
|
||||
channels = f"{item.channels_count}{' channels' if item.channels_count != 1 else ' channel'}"
|
||||
filters = f"{len(item.filters)}{' filters' if len(item.filters) != 1 else ' filter'}"
|
||||
notes = item.extra_notes[:25] + "..." if len(item.extra_notes) > 28 else item.extra_notes
|
||||
links = f"[RSS Link]({item.url}) · [API Link]({API.API_EXTERNAL_ENDPOINT}subscription/{item.id}/)"
|
||||
# def formatdata(index, item):
|
||||
# item = Subscription.from_dict(item)
|
||||
|
||||
description = f"{channels}, {filters}\n"
|
||||
description += f"{notes}\n" if notes else ""
|
||||
description += links
|
||||
# notes = item.extra_notes[:25] + "..." if len(item.extra_notes) > 28 else item.extra_notes
|
||||
# links = f"[RSS Link]({item.url}) • [API Link]({API.API_EXTERNAL_ENDPOINT}subscription/{item.id}/)"
|
||||
# activeness = "✅ `enabled`" if item.active else "🚫 `disabled`"
|
||||
|
||||
key = f"{index}. {item.name}"
|
||||
return key, description # key, value pair
|
||||
# description = f"🆔 `{item.id}`\n{activeness}\n#️⃣ `{item.channels_count}` 🔽 `{len(item.filters)}`\n"
|
||||
# description = f"{notes}\n" + description if notes else description
|
||||
# description += links
|
||||
|
||||
async def getdata(page: int, pagesize: int):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
api = API(self.bot.api_token, session)
|
||||
return await api.get_subscriptions(
|
||||
guild_id=inter.guild.id,
|
||||
page=page,
|
||||
page_size=pagesize,
|
||||
search=search
|
||||
)
|
||||
# key = f"{index}. {item.name}"
|
||||
# return key, description # key, value pair
|
||||
|
||||
embed = Followup(f"Subscriptions in {inter.guild.name}").info()._embed
|
||||
pagination = PaginationView(
|
||||
self.bot,
|
||||
inter=inter,
|
||||
embed=embed,
|
||||
getdata=getdata,
|
||||
formatdata=formatdata,
|
||||
pagesize=10,
|
||||
initpage=1
|
||||
)
|
||||
await pagination.send()
|
||||
# async def getdata(page: int, pagesize: int):
|
||||
# async with aiohttp.ClientSession() as session:
|
||||
# api = API(self.bot.api_token, session)
|
||||
# return await api.get_subscriptions(
|
||||
# guild_id=inter.guild.id,
|
||||
# page=page,
|
||||
# page_size=pagesize,
|
||||
# search=search
|
||||
# )
|
||||
|
||||
@view_group.command(name="tracked-content")
|
||||
async def cmd_list_tracked(self, inter: Interaction, search: str = ""):
|
||||
"""List Tracked Content from this server, or a given sub"""
|
||||
# embed = Followup(f"Subscriptions in {inter.guild.name}").info()._embed
|
||||
# pagination = PaginationView(
|
||||
# self.bot,
|
||||
# inter=inter,
|
||||
# embed=embed,
|
||||
# getdata=getdata,
|
||||
# formatdata=formatdata,
|
||||
# pagesize=10,
|
||||
# initpage=1
|
||||
# )
|
||||
# await pagination.send()
|
||||
|
||||
await inter.response.defer()
|
||||
# @view_group.command(name="tracked-content")
|
||||
# @choices(blocked=tri_choices)
|
||||
# async def cmd_list_tracked(self, inter: Interaction, search: str = "", blocked: Choice[int] = 1):
|
||||
# """List Tracked Content from this server""" # TODO: , or a given sub
|
||||
|
||||
def formatdata(index, item):
|
||||
item = TrackedContent.from_dict(item)
|
||||
sub = Subscription.from_dict(item.subscription)
|
||||
# await inter.response.defer()
|
||||
|
||||
links = f"[Content Link]({item.url}) · [Message Link](https://discord.com/channels/{sub.guild_id}/{item.channel_id}/{item.message_id}/)"
|
||||
description = f"Subscription: {sub.name}\n{links}"
|
||||
# # If the user picks an option it's an instance of `Choice` otherwise `str`
|
||||
# # Can't figure a way to select a default choices, so blame discordpy for this mess.
|
||||
# if isinstance(blocked, Choice):
|
||||
# blocked = blocked.value
|
||||
|
||||
key = f"{item.id}. {item.title}"
|
||||
return key, description
|
||||
# def formatdata(index, item):
|
||||
# item = TrackedContent.from_dict(item)
|
||||
# sub = Subscription.from_dict(item.subscription)
|
||||
|
||||
async def getdata(page: int, pagesize: int):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
api = API(self.bot.api_token, session)
|
||||
return await api.get_tracked_content(
|
||||
subscription__guild_id=inter.guild_id,
|
||||
page=page,
|
||||
page_size=pagesize,
|
||||
search=search
|
||||
)
|
||||
# links = f"[Content Link]({item.url}) · [Message Link](https://discord.com/channels/{sub.guild_id}/{item.channel_id}/{item.message_id}/) · [API Link]({API.API_EXTERNAL_ENDPOINT}tracked-content/{item.id}/)"
|
||||
# delivery_state = "✅ Delivered" if not item.blocked else "🚫 Blocked"
|
||||
|
||||
embed = Followup(f"Tracked Content in {inter.guild.name}").info()._embed
|
||||
pagination = PaginationView(
|
||||
self.bot,
|
||||
inter=inter,
|
||||
embed=embed,
|
||||
getdata=getdata,
|
||||
formatdata=formatdata,
|
||||
pagesize=10,
|
||||
initpage=1
|
||||
)
|
||||
await pagination.send()
|
||||
# description = f"🆔 `{item.id}`\n"
|
||||
# description += f"{delivery_state}\n" if blocked == 0 else ""
|
||||
# description += f"➡️ *{sub.name}*\n{links}"
|
||||
|
||||
# Group for test related commands
|
||||
test_group = Group(
|
||||
name="test",
|
||||
description="Commands to test Bot functionality.",
|
||||
guild_only=True
|
||||
)
|
||||
# key = f"{index}. {item.title}"
|
||||
# return key, description
|
||||
|
||||
@test_group.command(name="channel-permissions")
|
||||
async def cmd_test_channel_perms(self, inter: Interaction):
|
||||
"""Test that the current channel's permissions allow for PYRSS to operate in it."""
|
||||
# def determine_blocked():
|
||||
# match blocked:
|
||||
# case 0: return ""
|
||||
# case 1: return "false"
|
||||
# case 2: return "true"
|
||||
# case _: return ""
|
||||
|
||||
try:
|
||||
test_message = await inter.channel.send(content="... testing permissions ...")
|
||||
await self.test_channel_perms(inter.channel)
|
||||
except Exception as error:
|
||||
await inter.response.send_message(content=f"Failed: {error}")
|
||||
return
|
||||
# async def getdata(page: int, pagesize: int):
|
||||
# async with aiohttp.ClientSession() as session:
|
||||
# api = API(self.bot.api_token, session)
|
||||
# is_blocked = determine_blocked()
|
||||
# return await api.get_tracked_content(
|
||||
# subscription__guild_id=inter.guild_id,
|
||||
# blocked=is_blocked,
|
||||
# page=page,
|
||||
# page_size=pagesize,
|
||||
# search=search,
|
||||
# )
|
||||
|
||||
await test_message.delete()
|
||||
await inter.response.send_message(content="Success")
|
||||
# embed = Followup(f"Tracked Content in {inter.guild.name}").info()._embed
|
||||
# pagination = PaginationView(
|
||||
# self.bot,
|
||||
# inter=inter,
|
||||
# embed=embed,
|
||||
# getdata=getdata,
|
||||
# formatdata=formatdata,
|
||||
# pagesize=10,
|
||||
# initpage=1
|
||||
# )
|
||||
# await pagination.send()
|
||||
|
||||
async def test_channel_perms(self, channel: TextChannel):
|
||||
# @view_group.command(name="filters")
|
||||
# async def cmd_list_filters(self, inter: Interaction, search: str = ""):
|
||||
# """List Filters from this server."""
|
||||
|
||||
# Test generic message and delete
|
||||
msg = await channel.send(content="test message")
|
||||
await msg.delete()
|
||||
# await inter.response.defer()
|
||||
|
||||
# Test detailed embed
|
||||
embed = Embed(
|
||||
title="test title",
|
||||
description="test description",
|
||||
colour=Colour.random(),
|
||||
timestamp=datetime.now(),
|
||||
url="https://google.com"
|
||||
)
|
||||
embed.set_author(name="test author")
|
||||
embed.set_footer(text="test footer")
|
||||
embed.set_thumbnail(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
|
||||
embed.set_image(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
|
||||
embed_msg = await channel.send(embed=embed)
|
||||
await embed_msg.delete()
|
||||
# def formatdata(index, item):
|
||||
# item = ContentFilter.from_dict(item)
|
||||
|
||||
# matching_algorithm = get_algorithm_name(item.matching_algorithm)
|
||||
# whitelist = "Whitelist" if item.is_whitelist else "Blacklist"
|
||||
# sensitivity = "Case insensitive" if item.is_insensitive else "Case sensitive"
|
||||
|
||||
# description = f"🆔 `{item.id}`\n"
|
||||
# description += f"🔄 `{matching_algorithm}`\n🟰 `{item.match}`\n"
|
||||
# description += f"✅ `{whitelist}` 🔠 `{sensitivity}`\n"
|
||||
# description += f"[API Link]({API.API_EXTERNAL_ENDPOINT}filter/{item.id}/)"
|
||||
|
||||
# key = f"{index}. {item.name}"
|
||||
# return key, description
|
||||
|
||||
# def get_algorithm_name(matching_algorithm: int):
|
||||
# match matching_algorithm:
|
||||
# case 0: return "None"
|
||||
# case 1: return "Any word"
|
||||
# case 2: return "All words"
|
||||
# case 3: return "Exact match"
|
||||
# case 4: return "Regex match"
|
||||
# case 5: return "Fuzzy match"
|
||||
# case _: return "unknown"
|
||||
|
||||
# async def getdata(page, pagesize):
|
||||
# async with aiohttp.ClientSession() as session:
|
||||
# api = API(self.bot.api_token, session)
|
||||
# return await api.get_filters(
|
||||
# guild_id=inter.guild_id,
|
||||
# page=page,
|
||||
# page_size=pagesize,
|
||||
# search=search
|
||||
# )
|
||||
|
||||
# embed = Followup(f"Filters in {inter.guild.name}").info()._embed
|
||||
# pagination = PaginationView(
|
||||
# self.bot,
|
||||
# inter=inter,
|
||||
# embed=embed,
|
||||
# getdata=getdata,
|
||||
# formatdata=formatdata,
|
||||
# pagesize=10,
|
||||
# initpage=1
|
||||
# )
|
||||
# await pagination.send()
|
||||
|
||||
# # Group for test related commands
|
||||
# test_group = Group(
|
||||
# name="test",
|
||||
# description="Commands to test Bot functionality.",
|
||||
# guild_only=True
|
||||
# )
|
||||
|
||||
# @test_group.command(name="channel-permissions")
|
||||
# async def cmd_test_channel_perms(self, inter: Interaction):
|
||||
# """Test that the current channel's permissions allow for PYRSS to operate in it."""
|
||||
|
||||
# try:
|
||||
# test_message = await inter.channel.send(content="... testing permissions ...")
|
||||
# await self.test_channel_perms(inter.channel)
|
||||
# except Exception as error:
|
||||
# await inter.response.send_message(content=f"Failed: {error}")
|
||||
# return
|
||||
|
||||
# await test_message.delete()
|
||||
# await inter.response.send_message(content="Success")
|
||||
|
||||
# async def test_channel_perms(self, channel: TextChannel):
|
||||
|
||||
# # Test generic message and delete
|
||||
# msg = await channel.send(content="test message")
|
||||
# await msg.delete()
|
||||
|
||||
# # Test detailed embed
|
||||
# embed = Embed(
|
||||
# title="test title",
|
||||
# description="test description",
|
||||
# colour=Colour.random(),
|
||||
# timestamp=datetime.now(),
|
||||
# url="https://google.com"
|
||||
# )
|
||||
# embed.set_author(name="test author")
|
||||
# embed.set_footer(text="test footer")
|
||||
# embed.set_thumbnail(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
|
||||
# embed.set_image(url="https://www.google.com/images/branding/googlelogo/2x/googlelogo_light_color_272x92dp.png")
|
||||
# embed_msg = await channel.send(embed=embed)
|
||||
# await embed_msg.delete()
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
|
@ -3,29 +3,37 @@ Extension for the `TaskCog`.
|
||||
Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot.
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
import datetime
|
||||
import traceback
|
||||
from os import getenv
|
||||
from time import perf_counter
|
||||
from collections import deque
|
||||
from textwrap import shorten
|
||||
|
||||
import aiohttp
|
||||
from aiocache import Cache
|
||||
# import aiohttp
|
||||
import httpx
|
||||
import feedparser
|
||||
import discord
|
||||
# from aiocache import Cache
|
||||
from discord import TextChannel
|
||||
from discord import app_commands
|
||||
from discord import app_commands, Interaction
|
||||
from discord.ext import commands, tasks
|
||||
from discord.errors import Forbidden
|
||||
from feedparser import parse
|
||||
from markdownify import markdownify
|
||||
|
||||
from feed import RSSFeed, Subscription, RSSItem, GuildSettings
|
||||
from utils import get_unparsed_feed
|
||||
from filters import match_text
|
||||
import models
|
||||
from utils import do_batch_job
|
||||
# from feed import RSSFeed, Subscription, RSSItem, GuildSettings
|
||||
# from utils import get_unparsed_feed
|
||||
# from filters import match_text
|
||||
from api import API
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
cache = Cache(Cache.MEMORY)
|
||||
# cache = Cache(Cache.MEMORY)
|
||||
|
||||
BATCH_SIZE = 100
|
||||
|
||||
@ -35,7 +43,7 @@ subscription_task_times = [
|
||||
for hour in range(24)
|
||||
for minute in range(0, 60, int(TASK_INTERVAL_MINUTES))
|
||||
]
|
||||
log.debug("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES)
|
||||
log.info("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES)
|
||||
|
||||
|
||||
class TaskCog(commands.Cog):
|
||||
@ -46,16 +54,23 @@ class TaskCog(commands.Cog):
|
||||
api: API | None = None
|
||||
content_queue = deque()
|
||||
|
||||
api_base_url: str
|
||||
api_headers: dict
|
||||
client: httpx.AsyncClient | None
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
self.api_base_url = "http://localhost:8000/api/"
|
||||
self.api_headers = {"Authorization": f"Token {self.bot.api_token}"}
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_ready(self):
|
||||
"""
|
||||
Instructions to execute when the cog is ready.
|
||||
"""
|
||||
self.subscription_task.start()
|
||||
# self.subscription_task.start()
|
||||
self.do_task.start()
|
||||
log.info("%s cog is ready", self.__class__.__name__)
|
||||
|
||||
@commands.Cog.listener(name="cog_unload")
|
||||
@ -72,219 +87,434 @@ class TaskCog(commands.Cog):
|
||||
)
|
||||
|
||||
@group.command(name="trigger")
|
||||
async def cmd_trigger_task(self, inter):
|
||||
async def cmd_trigger_task(self, inter: Interaction):
|
||||
await inter.response.defer()
|
||||
start_time = perf_counter()
|
||||
|
||||
try:
|
||||
await self.subscription_task()
|
||||
except Exception as error:
|
||||
await inter.followup.send(str(error))
|
||||
await self.do_task()
|
||||
except Exception as exc:
|
||||
log.exception(exc)
|
||||
await inter.followup.send(str(exc) or "unknown error")
|
||||
finally:
|
||||
end_time = perf_counter()
|
||||
await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds")
|
||||
await inter.followup.send(f"completed command in {end_time - start_time:.4f} seconds")
|
||||
|
||||
@tasks.loop(time=subscription_task_times)
|
||||
async def subscription_task(self):
|
||||
"""
|
||||
Task for fetching and processing subscriptions.
|
||||
"""
|
||||
log.info("Running subscription task")
|
||||
async def do_task(self):
|
||||
log.info("Running task")
|
||||
start_time = perf_counter()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
self.api = API(self.bot.api_token, session)
|
||||
await self.execute_task()
|
||||
async with httpx.AsyncClient() as client:
|
||||
self.client = client
|
||||
servers = await self.get_servers()
|
||||
await do_batch_job(servers, self.process_server, 10)
|
||||
|
||||
end_time = perf_counter()
|
||||
log.debug(f"task completed in {end_time - start_time:.4f} seconds")
|
||||
log.info(f"completed task in {end_time - start_time:.4f} seconds")
|
||||
|
||||
async def execute_task(self):
|
||||
"""Execute the task directly."""
|
||||
async def iterate_pages(self, url: str, params: dict={}):
|
||||
|
||||
# Filter out inactive guild IDs using related settings
|
||||
guild_ids = [guild.id for guild in self.bot.guilds]
|
||||
guild_settings = await self.get_guild_settings(guild_ids)
|
||||
active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active]
|
||||
|
||||
subscriptions = await self.get_subscriptions(active_guild_ids)
|
||||
await self.process_subscriptions(subscriptions)
|
||||
|
||||
async def get_guild_settings(self, guild_ids: list[int]) -> list[int]:
|
||||
"""Returns a list of guild settings from the Bot's guilds, if they exist."""
|
||||
|
||||
guild_settings = []
|
||||
|
||||
# Iterate infinitely taking the iter no. as `page`
|
||||
# data will be empty after last page reached.
|
||||
for page, _ in enumerate(iter(int, 1)):
|
||||
data = await self.get_guild_settings_page(guild_ids, page)
|
||||
if not data:
|
||||
break
|
||||
|
||||
guild_settings.extend(data[0])
|
||||
|
||||
# Only return active guild IDs
|
||||
return GuildSettings.from_list(guild_settings)
|
||||
|
||||
async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]:
|
||||
"""Returns an individual page of guild settings."""
|
||||
|
||||
try:
|
||||
return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1)
|
||||
except aiohttp.ClientResponseError as error:
|
||||
self.handle_pagination_error(error)
|
||||
return []
|
||||
|
||||
def handle_pagination_error(self, error: aiohttp.ClientResponseError):
|
||||
"""Handle the error cases from pagination attempts."""
|
||||
|
||||
match error.status:
|
||||
case 404:
|
||||
log.debug("final page reached")
|
||||
case 403:
|
||||
log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True)
|
||||
self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently
|
||||
case _:
|
||||
log.debug(error)
|
||||
|
||||
async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]:
|
||||
"""Get a list of `Subscription`s matching the given `guild_ids`."""
|
||||
|
||||
subscriptions = []
|
||||
|
||||
# Iterate infinitely taking the iter no. as `page`
|
||||
# data will be empty after last page reached.
|
||||
for page, _ in enumerate(iter(int, 1)):
|
||||
data = await self.get_subs_page(guild_ids, page)
|
||||
if not data:
|
||||
break
|
||||
|
||||
subscriptions.extend(data[0])
|
||||
|
||||
return Subscription.from_list(subscriptions)
|
||||
|
||||
async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]:
|
||||
"""Returns an individual page of subscriptions."""
|
||||
|
||||
try:
|
||||
return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1)
|
||||
except aiohttp.ClientResponseError as error:
|
||||
self.handle_pagination_error(error)
|
||||
return []
|
||||
|
||||
async def process_subscriptions(self, subscriptions: list[Subscription]):
|
||||
"""Process a given list of `Subscription`s."""
|
||||
|
||||
async def process_single_subscription(sub: Subscription):
|
||||
log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id)
|
||||
|
||||
if not sub.active or not sub.channels_count:
|
||||
return
|
||||
|
||||
unparsed_feed = await get_unparsed_feed(sub.url)
|
||||
parsed_feed = parse(unparsed_feed)
|
||||
|
||||
rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
|
||||
await self.process_items(sub, rss_feed)
|
||||
|
||||
semaphore = asyncio.Semaphore(10)
|
||||
|
||||
async def semaphore_process(sub: Subscription):
|
||||
async with semaphore:
|
||||
await process_single_subscription(sub)
|
||||
|
||||
await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions))
|
||||
|
||||
async def process_items(self, sub: Subscription, feed: RSSFeed):
|
||||
log.debug("processing items")
|
||||
|
||||
channels = await self.fetch_or_get_channels(await sub.get_channels(self.api))
|
||||
filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters]
|
||||
|
||||
for item in feed.items:
|
||||
log.debug("processing item '%s'", item.guid)
|
||||
|
||||
if item.pub_date < sub.published_threshold:
|
||||
log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold)
|
||||
continue
|
||||
|
||||
blocked = any(self.filter_item(_filter, item) for _filter in filters)
|
||||
mutated_item = item.create_mutated_copy(sub.mutators) if sub.mutators else None
|
||||
|
||||
for channel in channels:
|
||||
await self.track_and_send(sub, feed, item, mutated_item, channel, blocked)
|
||||
|
||||
async def fetch_or_get_channels(self, channels_data: list[dict]):
|
||||
channels = []
|
||||
|
||||
for data in channels_data:
|
||||
try:
|
||||
channel = self.bot.get_channel(data.channel_id)
|
||||
channels.append(channel or await self.bot.fetch_channel(data.channel_id))
|
||||
except Forbidden:
|
||||
log.error(f"Forbidden Channel '{data.channel_id}'")
|
||||
|
||||
return channels
|
||||
|
||||
def filter_item(self, _filter: dict, item: RSSItem) -> bool:
|
||||
"""
|
||||
Returns `True` if item should be ignored due to filters.
|
||||
"""
|
||||
|
||||
match_found = match_text(_filter, item.title) or match_text(_filter, item.description)
|
||||
log.debug("filter match found? '%s'", match_found)
|
||||
return match_found
|
||||
|
||||
async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem | None, channel: TextChannel, blocked: bool):
|
||||
message_id = -1
|
||||
|
||||
log.debug("track and send func %s, %s", item.guid, item.title)
|
||||
|
||||
result = await self.api.get_tracked_content(guid=item.guid)
|
||||
if result[1]:
|
||||
log.debug(f"This item is already tracked, skipping '{item.guid}'")
|
||||
return
|
||||
|
||||
result = await self.api.get_tracked_content(url=item.link)
|
||||
if result[1]:
|
||||
log.debug(f"This item is already tracked, skipping '{item.guid}'")
|
||||
return
|
||||
|
||||
if not blocked:
|
||||
try:
|
||||
log.debug("sending '%s', exists '%s'", item.guid, result[1])
|
||||
sendable_item = mutated_item or item
|
||||
message = await channel.send(embed=await sendable_item.to_embed(sub, feed, self.api.session))
|
||||
message_id = message.id
|
||||
except Forbidden:
|
||||
log.error(f"Forbidden to send to channel {channel.id}")
|
||||
|
||||
await self.mark_tracked_item(sub, item, channel.id, message_id, blocked)
|
||||
|
||||
async def process_batch(self):
|
||||
pass
|
||||
|
||||
async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool):
|
||||
try:
|
||||
log.debug("marking as tracked")
|
||||
await self.api.create_tracked_content(
|
||||
guid=item.guid,
|
||||
title=item.title,
|
||||
url=item.link,
|
||||
subscription=sub.id,
|
||||
channel_id=channel_id,
|
||||
message_id=message_id,
|
||||
blocked=blocked
|
||||
for page_number, _ in enumerate(iterable=iter(int, 1), start=1):
|
||||
params.update({"page": page_number})
|
||||
response = await self.client.get(
|
||||
self.api_base_url + url,
|
||||
headers=self.api_headers,
|
||||
params=params
|
||||
)
|
||||
return True
|
||||
except aiohttp.ClientResponseError as error:
|
||||
if error.status == 409:
|
||||
log.debug(error)
|
||||
else:
|
||||
log.error(error)
|
||||
response.raise_for_status()
|
||||
content = response.json()
|
||||
|
||||
return False
|
||||
yield content.get("results", [])
|
||||
|
||||
if not content.get("next"):
|
||||
break
|
||||
|
||||
async def get_servers(self) -> list[models.Server]:
|
||||
servers = []
|
||||
|
||||
async for servers_batch in self.iterate_pages("servers/"):
|
||||
if servers_batch:
|
||||
servers.extend(servers_batch)
|
||||
|
||||
return models.Server.from_list(servers)
|
||||
|
||||
async def get_subscriptions(self, server: models.Server) -> list[models.Subscription]:
|
||||
subscriptions = []
|
||||
params = {"server": server.id, "active": True}
|
||||
|
||||
async for subscriptions_batch in self.iterate_pages("subscriptions/", params):
|
||||
if subscriptions_batch:
|
||||
subscriptions.extend(subscriptions_batch)
|
||||
|
||||
return models.Subscription.from_list(subscriptions)
|
||||
|
||||
async def get_contents(self, subscription: models.Subscription, raw_rss_content: dict):
|
||||
contents = await models.Content.from_raw_rss(raw_rss_content, subscription, self.client)
|
||||
duplicate_contents = []
|
||||
|
||||
async def check_duplicate_content(content: models.Content):
|
||||
exists = await content.exists_via_api(
|
||||
url=self.api_base_url + "content/",
|
||||
headers=self.api_headers,
|
||||
client=self.client
|
||||
)
|
||||
|
||||
if exists:
|
||||
log.debug(f"Removing duplicate {content}")
|
||||
duplicate_contents.append(content)
|
||||
|
||||
await do_batch_job(contents, check_duplicate_content, 15)
|
||||
|
||||
log.debug(f"before removing duplicates: {len(contents)}")
|
||||
for duplicate in duplicate_contents:
|
||||
contents.remove(duplicate)
|
||||
log.debug(f"after removing duplicates: {len(contents)}")
|
||||
|
||||
return contents
|
||||
|
||||
async def process_server(self, server: models.Server):
|
||||
log.debug(f"processing server: {server.name}")
|
||||
start_time = perf_counter()
|
||||
|
||||
subscriptions = await self.get_subscriptions(server)
|
||||
for subscription in subscriptions:
|
||||
subscription.server = server
|
||||
|
||||
await do_batch_job(subscriptions, self.process_subscription, 10)
|
||||
|
||||
end_time = perf_counter()
|
||||
log.debug(f"Finished processing server: {server.name} in {end_time - start_time:.4f} seconds")
|
||||
|
||||
async def process_subscription(self, subscription: models.Subscription):
|
||||
log.debug(f"processing subscription {subscription.name}")
|
||||
start_time = perf_counter()
|
||||
|
||||
raw_rss_content = await subscription.get_rss_content(self.client)
|
||||
if not raw_rss_content:
|
||||
return
|
||||
|
||||
contents = await self.get_contents(subscription, raw_rss_content)
|
||||
if not contents:
|
||||
log.debug("no contents to process")
|
||||
return
|
||||
|
||||
channels = await subscription.get_discord_channels(self.bot)
|
||||
valid_contents, invalid_contents = subscription.filter_entries(contents)
|
||||
|
||||
async def send_content(channel: discord.TextChannel):
|
||||
# BUG: I believe there are duplicate embeds here
|
||||
# discord only shows 1 when urls are matching, but merges images from both into the 1
|
||||
|
||||
# embeds = [content.embed for content in valid_contents]
|
||||
# batch_size = 10
|
||||
# for i in range(0, len(embeds), batch_size):
|
||||
# batch = embeds[i:i + batch_size]
|
||||
# await channel.send(embeds=batch)
|
||||
|
||||
batch_size = 10
|
||||
total_batches = (len(valid_contents) + batch_size - 1) // batch_size
|
||||
for batch_number, i in enumerate(range(0, len(valid_contents), batch_size)):
|
||||
contents_batch = valid_contents[i:i + batch_size]
|
||||
embeds = await self.create_embeds(
|
||||
contents=contents_batch,
|
||||
subscription_name=subscription.name,
|
||||
colour=subscription.message_style.colour,
|
||||
batch_number=batch_number,
|
||||
total_batches=total_batches
|
||||
)
|
||||
await channel.send(embeds=embeds)
|
||||
|
||||
await do_batch_job(channels, send_content, 5)
|
||||
|
||||
combined = valid_contents.copy()
|
||||
combined.extend(invalid_contents)
|
||||
|
||||
tasks = [content.save(self.client, self.api_base_url, self.api_headers) for content in combined]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# TODO: mark invalid contents as blocked
|
||||
|
||||
end_time = perf_counter()
|
||||
log.debug(f"Finished processing subscription: {subscription.name} in {end_time - start_time:.4f}")
|
||||
|
||||
async def create_embeds(self, contents: list[models.Content], subscription_name: str, colour: str, batch_number: int, total_batches: int):
|
||||
discord_colour = discord.Colour.from_str(
|
||||
colour if colour.startswith("#")
|
||||
else f"#{colour}"
|
||||
)
|
||||
|
||||
url = "https://pyrss.cor.bz"
|
||||
title = subscription_name
|
||||
|
||||
if total_batches > 1:
|
||||
title += f" [{batch_number+1}/{total_batches}]"
|
||||
|
||||
embed = discord.Embed(title=title, colour=discord_colour, url=url)
|
||||
embeds = [embed]
|
||||
|
||||
for content in contents:
|
||||
description = shorten(markdownify(content.item_description, strip=("img",)), 256)
|
||||
description += f"\n[View Article]({content.item_url})"
|
||||
embed.add_field(
|
||||
name=shorten(markdownify(content.item_title, strip=("img", "a")), 256),
|
||||
value=description,
|
||||
inline=False
|
||||
)
|
||||
|
||||
# If there is only one content, set the main embed's image and return it.
|
||||
# Otherwise progressing the normal way will create a wonky looking embed
|
||||
# where the lone image is aligned left to an invisible non-existant right
|
||||
# image.
|
||||
if len(contents) == 1:
|
||||
embed.set_image(url=content.item_image_url)
|
||||
return embeds
|
||||
|
||||
if len(embeds) <= 5:
|
||||
image_embed = discord.Embed(title="dummy", url=url)
|
||||
image_embed.set_image(url=content.item_image_url)
|
||||
embeds.append(image_embed)
|
||||
|
||||
return embeds
|
||||
|
||||
|
||||
# async def process_valid_contents(
|
||||
# self,
|
||||
# contents: list[models.Content],
|
||||
# channels: list[discord.TextChannel],
|
||||
# client: httpx.AsyncClient
|
||||
# ):
|
||||
# semaphore = asyncio.Semaphore(5)
|
||||
|
||||
# async def batch_process(
|
||||
# content: models.Content,
|
||||
# channels: list[discord.TextChannel],
|
||||
# client: httpx.AsyncClient
|
||||
# ):
|
||||
# async with semaphore: await self.process_valid_content(content, channels, client)
|
||||
|
||||
# tasks = [
|
||||
# batch_process()
|
||||
# ]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# @group.command(name="trigger")
|
||||
# async def cmd_trigger_task(self, inter):
|
||||
# await inter.response.defer()
|
||||
# start_time = perf_counter()
|
||||
|
||||
# try:
|
||||
# await self.subscription_task()
|
||||
# except Exception as error:
|
||||
# await inter.followup.send(str(error))
|
||||
# finally:
|
||||
# end_time = perf_counter()
|
||||
# await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds")
|
||||
|
||||
# @tasks.loop(time=subscription_task_times)
|
||||
# async def subscription_task(self):
|
||||
# """
|
||||
# Task for fetching and processing subscriptions.
|
||||
# """
|
||||
# log.info("Running subscription task")
|
||||
# start_time = perf_counter()
|
||||
|
||||
# async with aiohttp.ClientSession() as session:
|
||||
# self.api = API(self.bot.api_token, session)
|
||||
# await self.execute_task()
|
||||
|
||||
# end_time = perf_counter()
|
||||
# log.debug(f"task completed in {end_time - start_time:.4f} seconds")
|
||||
|
||||
# async def execute_task(self):
|
||||
# """Execute the task directly."""
|
||||
|
||||
# # Filter out inactive guild IDs using related settings
|
||||
# guild_ids = [guild.id for guild in self.bot.guilds]
|
||||
# guild_settings = await self.get_guild_settings(guild_ids)
|
||||
# active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active]
|
||||
|
||||
# subscriptions = await self.get_subscriptions(active_guild_ids)
|
||||
# await self.process_subscriptions(subscriptions)
|
||||
|
||||
# async def get_guild_settings(self, guild_ids: list[int]) -> list[int]:
|
||||
# """Returns a list of guild settings from the Bot's guilds, if they exist."""
|
||||
|
||||
# guild_settings = []
|
||||
|
||||
# # Iterate infinitely taking the iter no. as `page`
|
||||
# # data will be empty after last page reached.
|
||||
# for page, _ in enumerate(iter(int, 1)):
|
||||
# data = await self.get_guild_settings_page(guild_ids, page)
|
||||
# if not data:
|
||||
# break
|
||||
|
||||
# guild_settings.extend(data[0])
|
||||
|
||||
# # Only return active guild IDs
|
||||
# return GuildSettings.from_list(guild_settings)
|
||||
|
||||
# async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]:
|
||||
# """Returns an individual page of guild settings."""
|
||||
|
||||
# try:
|
||||
# return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1)
|
||||
# except aiohttp.ClientResponseError as error:
|
||||
# self.handle_pagination_error(error)
|
||||
# return []
|
||||
|
||||
# def handle_pagination_error(self, error: aiohttp.ClientResponseError):
|
||||
# """Handle the error cases from pagination attempts."""
|
||||
|
||||
# match error.status:
|
||||
# case 404:
|
||||
# log.debug("final page reached")
|
||||
# case 403:
|
||||
# log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True)
|
||||
# self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently
|
||||
# case _:
|
||||
# log.debug(error)
|
||||
|
||||
# async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]:
|
||||
# """Get a list of `Subscription`s matching the given `guild_ids`."""
|
||||
|
||||
# subscriptions = []
|
||||
|
||||
# # Iterate infinitely taking the iter no. as `page`
|
||||
# # data will be empty after last page reached.
|
||||
# for page, _ in enumerate(iter(int, 1)):
|
||||
# data = await self.get_subs_page(guild_ids, page)
|
||||
# if not data:
|
||||
# break
|
||||
|
||||
# subscriptions.extend(data[0])
|
||||
|
||||
# return Subscription.from_list(subscriptions)
|
||||
|
||||
# async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]:
|
||||
# """Returns an individual page of subscriptions."""
|
||||
|
||||
# try:
|
||||
# return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1)
|
||||
# except aiohttp.ClientResponseError as error:
|
||||
# self.handle_pagination_error(error)
|
||||
# return []
|
||||
|
||||
# async def process_subscriptions(self, subscriptions: list[Subscription]):
|
||||
# """Process a given list of `Subscription`s."""
|
||||
|
||||
# async def process_single_subscription(sub: Subscription):
|
||||
# log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id)
|
||||
|
||||
# if not sub.active or not sub.channels_count:
|
||||
# return
|
||||
|
||||
# unparsed_feed = await get_unparsed_feed(sub.url)
|
||||
# parsed_feed = parse(unparsed_feed)
|
||||
|
||||
# rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
|
||||
# await self.process_items(sub, rss_feed)
|
||||
|
||||
# semaphore = asyncio.Semaphore(10)
|
||||
|
||||
# async def semaphore_process(sub: Subscription):
|
||||
# async with semaphore:
|
||||
# await process_single_subscription(sub)
|
||||
|
||||
# await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions))
|
||||
|
||||
# async def process_items(self, sub: Subscription, feed: RSSFeed):
|
||||
# log.debug("processing items")
|
||||
|
||||
# channels = await self.fetch_or_get_channels(await sub.get_channels(self.api))
|
||||
# filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters]
|
||||
|
||||
# for item in feed.items:
|
||||
# log.debug("processing item '%s'", item.guid)
|
||||
|
||||
# if item.pub_date < sub.published_threshold:
|
||||
# log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold)
|
||||
# continue
|
||||
|
||||
# blocked = any(self.filter_item(_filter, item) for _filter in filters)
|
||||
# mutated_item = item.create_mutated_copy(sub.mutators) if sub.mutators else None
|
||||
|
||||
# for channel in channels:
|
||||
# await self.track_and_send(sub, feed, item, mutated_item, channel, blocked)
|
||||
|
||||
# async def fetch_or_get_channels(self, channels_data: list[dict]):
|
||||
# channels = []
|
||||
|
||||
# for data in channels_data:
|
||||
# try:
|
||||
# channel = self.bot.get_channel(data.channel_id)
|
||||
# channels.append(channel or await self.bot.fetch_channel(data.channel_id))
|
||||
# except Forbidden:
|
||||
# log.error(f"Forbidden Channel '{data.channel_id}'")
|
||||
|
||||
# return channels
|
||||
|
||||
# def filter_item(self, _filter: dict, item: RSSItem) -> bool:
|
||||
# """
|
||||
# Returns `True` if item should be ignored due to filters.
|
||||
# """
|
||||
|
||||
# match_found = match_text(_filter, item.title) or match_text(_filter, item.description)
|
||||
# log.debug("filter match found? '%s'", match_found)
|
||||
# return match_found
|
||||
|
||||
# async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem | None, channel: TextChannel, blocked: bool):
|
||||
# message_id = -1
|
||||
|
||||
# log.debug("track and send func %s, %s", item.guid, item.title)
|
||||
|
||||
# result = await self.api.get_tracked_content(guid=item.guid)
|
||||
# if result[1]:
|
||||
# log.debug(f"This item is already tracked, skipping '{item.guid}'")
|
||||
# return
|
||||
|
||||
# result = await self.api.get_tracked_content(url=item.link)
|
||||
# if result[1]:
|
||||
# log.debug(f"This item is already tracked, skipping '{item.guid}'")
|
||||
# return
|
||||
|
||||
# if not blocked:
|
||||
# try:
|
||||
# log.debug("sending '%s', exists '%s'", item.guid, result[1])
|
||||
# sendable_item = mutated_item or item
|
||||
# message = await channel.send(embed=await sendable_item.to_embed(sub, feed, self.api.session))
|
||||
# message_id = message.id
|
||||
# except Forbidden:
|
||||
# log.error(f"Forbidden to send to channel {channel.id}")
|
||||
|
||||
# await self.mark_tracked_item(sub, item, channel.id, message_id, blocked)
|
||||
|
||||
# async def process_batch(self):
|
||||
# pass
|
||||
|
||||
# async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool):
|
||||
# try:
|
||||
# log.debug("marking as tracked")
|
||||
# await self.api.create_tracked_content(
|
||||
# guid=item.guid,
|
||||
# title=item.title,
|
||||
# url=item.link,
|
||||
# subscription=sub.id,
|
||||
# channel_id=channel_id,
|
||||
# message_id=message_id,
|
||||
# blocked=blocked
|
||||
# )
|
||||
# return True
|
||||
# except aiohttp.ClientResponseError as error:
|
||||
# if error.status == 409:
|
||||
# log.debug(error)
|
||||
# else:
|
||||
# log.error(error)
|
||||
|
||||
# return False
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
|
27
src/feed.py
27
src/feed.py
@ -199,7 +199,7 @@ class RSSFeed:
|
||||
description: str
|
||||
link: str
|
||||
lang: str
|
||||
last_build_date: datetime
|
||||
last_build_date: datetime | None
|
||||
image_href: str
|
||||
items: list[RSSItem] = None
|
||||
|
||||
@ -240,7 +240,8 @@ class RSSFeed:
|
||||
language = pf.feed.get('language', None)
|
||||
|
||||
last_build_date = pf.feed.get('updated_parsed', None)
|
||||
last_build_date = datetime(*last_build_date[0:-2] if last_build_date else None)
|
||||
if last_build_date:
|
||||
last_build_date = datetime(*last_build_date[0:-2])
|
||||
|
||||
image_href = pf.feed.get("image", {}).get("href")
|
||||
|
||||
@ -250,7 +251,7 @@ class RSSFeed:
|
||||
item = RSSItem.from_parsed_entry(entry)
|
||||
feed.add_item(item)
|
||||
|
||||
feed.items.reverse()
|
||||
feed.items.reverse() # order so that older items are processed first
|
||||
return feed
|
||||
|
||||
|
||||
@ -301,6 +302,7 @@ class Subscription(DjangoDataModel):
|
||||
published_threshold: datetime
|
||||
active: bool
|
||||
channels_count: int
|
||||
unique_content_rules: list
|
||||
|
||||
@staticmethod
|
||||
def parser(item: dict) -> dict:
|
||||
@ -311,6 +313,7 @@ class Subscription(DjangoDataModel):
|
||||
"description": item.pop("article_desc_mutators")
|
||||
}
|
||||
item["published_threshold"] = datetime.strptime(item["published_threshold"], "%Y-%m-%dT%H:%M:%S%z")
|
||||
item["unique_content_rules"] = item.get("unique_content_rules", [])
|
||||
|
||||
return item
|
||||
|
||||
@ -357,3 +360,21 @@ class TrackedContent(DjangoDataModel):
|
||||
|
||||
item["creation_datetime"] = datetime.strptime(item["creation_datetime"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
return item
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ContentFilter(DjangoDataModel):
|
||||
|
||||
id: int
|
||||
name: str
|
||||
matching_algorithm: int
|
||||
match: str
|
||||
is_insensitive: bool
|
||||
is_whitelist: bool
|
||||
guild_id: int
|
||||
|
||||
@staticmethod
|
||||
def parser(item: dict) -> dict:
|
||||
|
||||
item["guild_id"] = int(item["guild_id"]) # stored as str due to a django/sqlite bug, convert back to int
|
||||
return item
|
||||
|
527
src/models.py
Normal file
527
src/models.py
Normal file
@ -0,0 +1,527 @@
|
||||
import re
|
||||
import logging
|
||||
import hashlib
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from time import perf_counter
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime, timezone
|
||||
from textwrap import shorten
|
||||
|
||||
import feedparser.parsers
|
||||
import httpx
|
||||
import discord
|
||||
import rapidfuzz
|
||||
import feedparser
|
||||
from bs4 import BeautifulSoup
|
||||
from markdownify import markdownify
|
||||
|
||||
from utils import do_batch_job
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DjangoDataModel(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def parser(item: dict) -> dict:
|
||||
return item
|
||||
|
||||
@classmethod
|
||||
def from_list(cls, data: list[dict]) -> list:
|
||||
return [cls(**cls.parser(item)) for item in data]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**cls.parser(data))
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Server(DjangoDataModel):
|
||||
id: int
|
||||
name: str
|
||||
icon_hash: str
|
||||
is_bot_operational: bool
|
||||
active: bool
|
||||
|
||||
@staticmethod
|
||||
def parser(item: dict) -> dict:
|
||||
item["id"] = int(item.pop("id"))
|
||||
return item
|
||||
|
||||
|
||||
class MatchingAlgorithm(Enum):
|
||||
NONE = 0
|
||||
ANY = 1
|
||||
ALL = 2
|
||||
LITERAL = 3
|
||||
REGEX = 4
|
||||
FUZZY = 5
|
||||
AUTO = 6
|
||||
|
||||
@classmethod
|
||||
def from_value(cls, value: int):
|
||||
for member in cls:
|
||||
if member.value == value:
|
||||
return member
|
||||
|
||||
raise ValueError(f"No {cls.__class__.__name__} for value: {value}")
|
||||
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ContentFilter(DjangoDataModel):
|
||||
id: int
|
||||
server_id: int
|
||||
name: str
|
||||
matching_pattern: str
|
||||
matching_algorithm: MatchingAlgorithm
|
||||
is_insensitive: bool
|
||||
is_whitelist: bool
|
||||
|
||||
@staticmethod
|
||||
def parser(item: dict) -> dict:
|
||||
item["id"] = item.pop("id")
|
||||
item["server_id"] = item.pop("server")
|
||||
item["matching_pattern"] = item.pop("match")
|
||||
item["matching_algorithm"] = MatchingAlgorithm.from_value(item.pop("matching_algorithm"))
|
||||
return item
|
||||
|
||||
@property
|
||||
def _regex_flags(self):
|
||||
return re.IGNORECASE if self.is_insensitive else 0
|
||||
|
||||
@property
|
||||
def cleaned_matching_pattern(self):
|
||||
"""
|
||||
Splits the pattern to individual keywords, getting rid of unnecessary
|
||||
spaces and grouping quoted words together.
|
||||
|
||||
"""
|
||||
findterms = re.compile(r'"([^"]+)"|(\S+)').findall
|
||||
normspace = re.compile(r"\s+").sub
|
||||
return [
|
||||
re.escape(normspace(" ", (t[0] or t[1]).strip())).replace(r"\ ", r"\s+")
|
||||
for t in findterms(self.matching_pattern)
|
||||
]
|
||||
|
||||
def _match_any(self, matching_against: str):
|
||||
for word in self.cleaned_matching_pattern:
|
||||
if re.search(rf"\b{word}\b", matching_against, self._regex_flags):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _match_all(self, matching_against: str):
|
||||
for word in self.cleaned_matching_pattern:
|
||||
if not re.search(rf"\b{word}\b", matching_against, self._regex_flags):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _match_literal(self, matching_against: str):
|
||||
return bool(
|
||||
re.search(
|
||||
rf"\b{re.escape(self.matching_pattern)}\b",
|
||||
matching_against,
|
||||
self._regex_flags
|
||||
)
|
||||
)
|
||||
|
||||
def _match_regex(self, matching_against: str):
|
||||
try:
|
||||
return bool(re.search(
|
||||
re.compile(self.matching_pattern, self._regex_flags),
|
||||
matching_against
|
||||
))
|
||||
except re.error as exc:
|
||||
log.error(f"Filter regex error: {exc}")
|
||||
return False
|
||||
|
||||
def _match_fuzzy(self, matching_against: str):
|
||||
matching_against = re.sub(r"[^\w\s]", "", matching_against)
|
||||
matching_pattern = re.sub(r"[^\w\s]", "", self.matching_pattern)
|
||||
if self.is_insensitive:
|
||||
matching_against = matching_against.lower()
|
||||
matching_pattern = matching_pattern.lower()
|
||||
|
||||
return rapidfuzz.fuzz.partial_ratio(
|
||||
matching_against,
|
||||
matching_pattern,
|
||||
score_cutoff=90
|
||||
)
|
||||
|
||||
def _get_algorithm_func(self):
|
||||
match self.matching_algorithm:
|
||||
case MatchingAlgorithm.NONE: return
|
||||
case MatchingAlgorithm.ANY: return self._match_any
|
||||
case MatchingAlgorithm.ALL: return self._match_all
|
||||
case MatchingAlgorithm.LITERAL: return self._match_literal
|
||||
case MatchingAlgorithm.REGEX: return self._match_regex
|
||||
case MatchingAlgorithm.FUZZY: return self._match_fuzzy
|
||||
case _: return
|
||||
|
||||
def matches(self, content) -> bool:
|
||||
log.debug(f"applying filter: {self}")
|
||||
|
||||
if not self.matching_pattern.strip():
|
||||
return False
|
||||
|
||||
if self.matching_algorithm == MatchingAlgorithm.ALL:
|
||||
match_found = self._match_all(content.item_title + " " + content.item_description)
|
||||
else:
|
||||
algorithm_func = self._get_algorithm_func()
|
||||
if not algorithm_func:
|
||||
log.error(f"Bad algorithm function: {self.matching_algorithm}")
|
||||
return False
|
||||
|
||||
match_found = algorithm_func(content.item_title) or algorithm_func(content.item_description)
|
||||
|
||||
log.debug(f"filter match found: {match_found}")
|
||||
return not match_found if self.is_whitelist else match_found
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MessageMutator(DjangoDataModel):
|
||||
id: int
|
||||
name: str
|
||||
value: str
|
||||
|
||||
@staticmethod
|
||||
def parser(item: dict) -> dict:
|
||||
item["id"] = item.pop("id")
|
||||
return item
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MessageStyle(DjangoDataModel):
|
||||
id: int
|
||||
server_id: int
|
||||
name: str
|
||||
colour: str
|
||||
is_embed: bool
|
||||
is_hyperlinked: bool
|
||||
show_author: bool
|
||||
show_timestamp: bool
|
||||
show_images: bool
|
||||
fetch_images: bool
|
||||
title_mutator: dict | None
|
||||
description_mutator: dict | None
|
||||
auto_created: bool
|
||||
|
||||
@staticmethod
|
||||
def parser(item: dict) -> dict:
|
||||
item["id"] = int(item.pop("id"))
|
||||
item["server_id"] = int(item.pop("server"))
|
||||
item["title_mutator"] = item.pop("title_mutator_detail")
|
||||
item["description_mutator"] = item.pop("description_mutator_detail")
|
||||
return item
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DiscordChannel(DjangoDataModel):
|
||||
id: int
|
||||
name: str
|
||||
is_nsfw: bool
|
||||
|
||||
@staticmethod
|
||||
def parser(item: dict) -> dict:
|
||||
item["id"] = int(item.pop("id"))
|
||||
return item
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Subscription(DjangoDataModel):
|
||||
id: int
|
||||
server_id: int
|
||||
name: str
|
||||
url: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
extra_notes: str
|
||||
active: bool
|
||||
publish_threshold: datetime
|
||||
channels: list[DiscordChannel]
|
||||
filters: list[ContentFilter]
|
||||
message_style: MessageStyle
|
||||
_server: Server | None = None
|
||||
|
||||
@staticmethod
|
||||
def parser(item: dict) -> dict:
|
||||
item["id"] = int(item.pop("id"))
|
||||
item["server_id"] = int(item.pop("server"))
|
||||
item["created_at"] = datetime.strptime(item.pop("created_at"), "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
item["updated_at"] = datetime.strptime(item.pop("updated_at"), "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
item["publish_threshold"] = datetime.strptime(item.pop("publish_threshold"), "%Y-%m-%dT%H:%M:%S%z")
|
||||
item["channels"] = DiscordChannel.from_list(item.pop("channels_detail"))
|
||||
item["filters"] = ContentFilter.from_list(item.pop("filters_detail"))
|
||||
item["message_style"] = MessageStyle.from_dict(item.pop("message_style_detail"))
|
||||
return item
|
||||
|
||||
@property
|
||||
def server(self) -> Server:
|
||||
return self._server
|
||||
|
||||
@server.setter
|
||||
def server(self, server: server):
|
||||
self._server = server
|
||||
|
||||
async def get_rss_content(self, client: httpx.AsyncClient) -> str:
|
||||
start_time = perf_counter()
|
||||
|
||||
try:
|
||||
response = await client.get(self.url)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as exc:
|
||||
log.error("(%s) HTTP Exception for %s - %s", type(exc), exc.request.url, exc)
|
||||
return
|
||||
finally:
|
||||
log.debug(f"Got rss content in {perf_counter() - start_time:.4f} seconds")
|
||||
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if not "text/xml" in content_type:
|
||||
log.warning("Invalid 'Content-Type' header: %s (must contain 'text/xml')", content_type)
|
||||
return
|
||||
|
||||
return response.text
|
||||
|
||||
async def get_discord_channels(self, bot) -> list[discord.TextChannel]:
|
||||
start_time = perf_counter()
|
||||
channels = []
|
||||
|
||||
for channel_detail in self.channels:
|
||||
try:
|
||||
channel = bot.get_channel(channel_detail.id)
|
||||
channels.append(channel or await bot.fetch_channel(channel_detail.id))
|
||||
except Exception as exc:
|
||||
channel_reference = f"({channel_detail.name}, {channel_detail.id})"
|
||||
server_reference = f"({self.server.name}, {self.server.id})"
|
||||
log.debug(f"Failed to get channel {channel_reference} from {server_reference}: {exc}")
|
||||
|
||||
log.debug(f"Got channels in {perf_counter() - start_time:.4f} seconds")
|
||||
return channels
|
||||
|
||||
def filter_entries(self, contents: list) -> tuple[list, list]:
|
||||
log.debug(f"filtering entries for {self.name} in {self.server.name}")
|
||||
|
||||
valid_contents = []
|
||||
invalid_contents = []
|
||||
|
||||
for content in contents:
|
||||
log.debug(f"filtering: '{content.item_title}'")
|
||||
if any(content_filter.matches(content) for content_filter in self.filters):
|
||||
content.blocked = True
|
||||
invalid_contents.append(content)
|
||||
else:
|
||||
valid_contents.append(content)
|
||||
|
||||
log.debug(f"filtered content: valid:{len(valid_contents)}, invalid:{len(invalid_contents)}")
|
||||
return valid_contents, invalid_contents
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Content(DjangoDataModel):
|
||||
id: int
|
||||
subscription_id: int
|
||||
item_id: str
|
||||
item_guid: str
|
||||
item_url: str
|
||||
item_title: str
|
||||
item_description: str
|
||||
item_content_hash: str
|
||||
item_image_url: str | None
|
||||
item_thumbnail_url: str | None
|
||||
item_published: datetime | None
|
||||
item_author: str
|
||||
item_author_url: str | None
|
||||
item_feed_title: str
|
||||
item_feed_url: str
|
||||
_subscription: Subscription | None = None
|
||||
blocked: bool = False
|
||||
|
||||
@staticmethod
|
||||
def parser(item: dict) -> dict:
|
||||
item["id"] = item.pop("id")
|
||||
item["subscription_id"] = item.pop("subscription")
|
||||
return item
|
||||
|
||||
async def exists_via_api(self, url: str, headers: dict, client: httpx.AsyncClient):
|
||||
log.debug(f"checking if {self.item_content_hash} exists via API")
|
||||
params = {
|
||||
"match_any": True, # allows any param to match, instead of needing all
|
||||
"item_id": self.item_id,
|
||||
"item_guid": self.item_guid,
|
||||
"item_url": self.item_url,
|
||||
"item_title": self.item_title,
|
||||
"item_content_hash": self.item_content_hash,
|
||||
"subscription": self.subscription_id
|
||||
}
|
||||
|
||||
log.debug(f"params: {params}")
|
||||
|
||||
try:
|
||||
response = await client.get(
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=params
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as exc:
|
||||
log.error(f"assuming not duplicate due to error: {exc}")
|
||||
return False
|
||||
|
||||
return response.json().get("results", [])
|
||||
|
||||
def is_duplicate(self, other):
|
||||
if not isinstance(other, Content):
|
||||
raise ValueError(f"Expected Content, received {type(other)}")
|
||||
|
||||
other_details = other.duplicate_details
|
||||
return any(
|
||||
other_details.get(key) == value
|
||||
for key, value in self.duplicate_details.items()
|
||||
)
|
||||
|
||||
@property
|
||||
def duplicate_details(self):
|
||||
keys = [
|
||||
"item_id",
|
||||
"item_guid",
|
||||
"item_url",
|
||||
"item_title",
|
||||
"item_content_hash"
|
||||
]
|
||||
data = asdict(self)
|
||||
return { key: data[key] for key in keys }
|
||||
|
||||
async def save(self, client: httpx.AsyncClient, base_url: str, headers: dict):
|
||||
log.debug(f"saving content {self.item_content_hash}")
|
||||
|
||||
data = asdict(self)
|
||||
data.pop("id")
|
||||
data["subscription"] = data.pop("subscription_id")
|
||||
item_published = data.pop("item_published")
|
||||
data["item_published"] = item_published.strftime("%Y-%m-%d") if item_published else None
|
||||
data.pop("_subscription")
|
||||
|
||||
response = await client.post(
|
||||
url=base_url + "content/",
|
||||
headers=headers,
|
||||
data=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
log.debug(f"save success for {self.item_content_hash}")
|
||||
|
||||
@classmethod
|
||||
async def from_raw_rss(cls, rss: str, subscription: Subscription, client: httpx.AsyncClient):
|
||||
style = subscription.message_style
|
||||
parsed_rss = feedparser.parse(rss)
|
||||
contents = []
|
||||
|
||||
async def create_content(entry: feedparser.FeedParserDict):
|
||||
published = entry.get("published_parsed")
|
||||
published = datetime(*published[0:6] if published else None, tzinfo=timezone.utc)
|
||||
|
||||
if published < subscription.publish_threshold:
|
||||
log.debug("skipping due to publish threshold")
|
||||
return
|
||||
|
||||
content_hash = hashlib.new("sha256")
|
||||
content_hash.update(entry.get("description", "").encode())
|
||||
|
||||
item_url = entry.get("link", "")
|
||||
item_image_url = entry.get("media_thumbnail", [{}])[0].get("url")
|
||||
if style.fetch_images:
|
||||
item_image_url = await cls.get_image_url(item_url, client)
|
||||
|
||||
content = Content.from_dict({
|
||||
"id": -1,
|
||||
"subscription": subscription.id,
|
||||
"item_id": entry.get("id", ""),
|
||||
"item_guid": entry.get("guid", ""),
|
||||
"item_url": item_url,
|
||||
"item_title": entry.get("title", ""),
|
||||
"item_description": entry.get("description", ""),
|
||||
"item_content_hash": content_hash.hexdigest(),
|
||||
"item_image_url": item_image_url,
|
||||
"item_thumbnail_url": parsed_rss.feed.image.href or None,
|
||||
"item_published": published,
|
||||
"item_author": entry.get("author", ""),
|
||||
"item_author_url": entry.get("author_detail", {}).get("href"),
|
||||
"item_feed_title": parsed_rss.get("feed", {}).get("title"),
|
||||
"item_feed_url": parsed_rss.get("feed", {}).get("link")
|
||||
})
|
||||
|
||||
# Weed out duplicates
|
||||
log.debug("weeding out duplicates")
|
||||
if any(content.is_duplicate(other) for other in contents):
|
||||
log.debug("found duplicate while loading rss data")
|
||||
return
|
||||
|
||||
content.subscription = subscription
|
||||
contents.append(content)
|
||||
|
||||
await do_batch_job(parsed_rss.entries, create_content, 15)
|
||||
contents.sort(key=lambda k: k.item_published)
|
||||
return contents
|
||||
|
||||
@staticmethod
|
||||
async def get_image_url(url: str, client: httpx.AsyncClient) -> str | None:
|
||||
log.debug("Fetching image url")
|
||||
|
||||
try:
|
||||
response = await client.get(url, timeout=15)
|
||||
except httpx.HTTPError:
|
||||
return None
|
||||
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
image_element = soup.select_one("meta[property='og:image']")
|
||||
if not image_element:
|
||||
return None
|
||||
|
||||
return image_element.get("content")
|
||||
|
||||
@property
|
||||
def subscription(self) -> Subscription:
|
||||
return self._subscription
|
||||
|
||||
@subscription.setter
|
||||
def subscription(self, subscription: Subscription):
|
||||
self._subscription = subscription
|
||||
|
||||
@property
|
||||
def embed(self):
|
||||
colour=discord.Colour.from_str(
|
||||
f"#{self.subscription.message_style.colour}"
|
||||
)
|
||||
|
||||
# ensure content fits within character limits
|
||||
title = shorten(markdownify(self.item_title, strip=("img", "a")), 256)
|
||||
description = shorten(markdownify(self.item_description, strip=("img",)), 4096)
|
||||
author = self.item_author or self.item_feed_title
|
||||
|
||||
combined_length = len(title) + len(description) + (len(author) * 2)
|
||||
cutoff = combined_length - 6000
|
||||
description = shorten(description, cutoff) if cutoff > 0 else description
|
||||
|
||||
embed = discord.Embed(
|
||||
title=title,
|
||||
description=description,
|
||||
url=self.item_url,
|
||||
colour=colour,
|
||||
timestamp=self.item_published
|
||||
)
|
||||
|
||||
embed.set_image(url=self.item_image_url)
|
||||
embed.set_thumbnail(url=self.item_thumbnail_url)
|
||||
embed.set_author(
|
||||
name=author,
|
||||
url=self.item_author_url or self.item_feed_url
|
||||
)
|
||||
embed.set_footer(text=self.subscription.name)
|
||||
|
||||
log.debug(f"created embed: {embed.to_dict()}")
|
||||
|
||||
return embed
|
0
src/tests/__init__.py
Normal file
0
src/tests/__init__.py
Normal file
67
src/tests/test_content.py
Normal file
67
src/tests/test_content.py
Normal file
@ -0,0 +1,67 @@
|
||||
from dataclasses import replace
|
||||
|
||||
import pytest
|
||||
|
||||
from models import MatchingAlgorithm, ContentFilter, Content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def content() -> Content:
|
||||
return Content(
|
||||
id=0,
|
||||
subscription_id=0,
|
||||
item_id="",
|
||||
item_guid="",
|
||||
item_url="",
|
||||
item_title="This week in the papers:",
|
||||
item_description="The price of petrol has risen by over 2% since the previous financial report. Read the full article here.",
|
||||
item_content_hash="",
|
||||
item_image_url=None,
|
||||
item_thumbnail_url=None,
|
||||
item_published=None,
|
||||
item_author="",
|
||||
item_author_url=None,
|
||||
item_feed_title="",
|
||||
item_feed_url=""
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def content_filter() -> ContentFilter:
|
||||
return ContentFilter(
|
||||
id=0,
|
||||
server_id=0,
|
||||
name="Test Content Filter",
|
||||
matching_pattern="",
|
||||
matching_algorithm=MatchingAlgorithm.NONE,
|
||||
is_insensitive=True,
|
||||
is_whitelist=False
|
||||
)
|
||||
|
||||
def test_content_filter_any(content: Content, content_filter: ContentFilter):
|
||||
content_filter.matching_pattern = "france twenty report grass lately"
|
||||
content_filter.matching_algorithm = MatchingAlgorithm.ANY
|
||||
assert content_filter.matches(content)
|
||||
|
||||
def test_content_filter_all(content: Content, content_filter: ContentFilter):
|
||||
content_filter.matching_pattern = "week petrol risen since"
|
||||
content_filter.matching_algorithm = MatchingAlgorithm.ALL
|
||||
assert content_filter.matches(content)
|
||||
|
||||
def test_content_filter_literal(content: Content, content_filter: ContentFilter):
|
||||
content_filter.matching_pattern = "this week in the papers"
|
||||
content_filter.matching_algorithm = MatchingAlgorithm.LITERAL
|
||||
assert content_filter.matches(content)
|
||||
|
||||
def test_content_filter_regex(content: Content, content_filter: ContentFilter):
|
||||
content_filter.matching_pattern = r"\b(The Papers|weekly quiz)\b"
|
||||
content_filter.matching_algorithm = MatchingAlgorithm.REGEX
|
||||
assert content_filter.matches(content)
|
||||
|
||||
# def test_content_filter_fuzzy(content: Content, content_filter: ContentFilter):
|
||||
# content_filter.matching_algorithm = "this week in the papers"
|
||||
# content_filter.matching_algorithm = MatchingAlgorithm.FUZZY
|
||||
# assert content_filter.matches(content)
|
||||
|
||||
def test_content_duplicates(content: Content):
|
||||
copy_of_content = replace(content)
|
||||
assert content.is_duplicate(copy_of_content)
|
532
src/utils.py
532
src/utils.py
@ -1,10 +1,10 @@
|
||||
"""A collection of utility functions that can be used in various places."""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
# import aiohttp
|
||||
import logging
|
||||
import async_timeout
|
||||
from typing import Callable
|
||||
# import async_timeout
|
||||
# from typing import Callable
|
||||
|
||||
from discord import Interaction, Embed, Colour, ButtonStyle, Button
|
||||
from discord.ui import View, button
|
||||
@ -12,325 +12,335 @@ from discord.ext.commands import Bot
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
async def fetch(session, url: str) -> str:
|
||||
async with async_timeout.timeout(20):
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
async def do_batch_job(iterable: list, func, batch_size: int):
|
||||
semaphore = asyncio.Semaphore(batch_size)
|
||||
|
||||
async def get_unparsed_feed(url: str, session: aiohttp.ClientSession=None):
|
||||
if session is not None:
|
||||
return await fetch(session, url)
|
||||
async def batch_job(item):
|
||||
async with semaphore:
|
||||
await func(item)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
return await fetch(session, url)
|
||||
tasks = [batch_job(item) for item in iterable]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def get_rss_data(url: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
items = await response.text(), response.status
|
||||
# async def fetch(session, url: str) -> str:
|
||||
# async with async_timeout.timeout(20):
|
||||
# async with session.get(url) as response:
|
||||
# return await response.text()
|
||||
|
||||
return items
|
||||
# async def get_unparsed_feed(url: str, session: aiohttp.ClientSession=None):
|
||||
# if session is not None:
|
||||
# return await fetch(session, url)
|
||||
|
||||
async def followup(inter: Interaction, *args, **kwargs):
|
||||
"""Shorthand for following up on an interaction.
|
||||
# async with aiohttp.ClientSession() as session:
|
||||
# return await fetch(session, url)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
"""
|
||||
# async def get_rss_data(url: str):
|
||||
# async with aiohttp.ClientSession() as session:
|
||||
# async with session.get(url) as response:
|
||||
# items = await response.text(), response.status
|
||||
|
||||
await inter.followup.send(*args, **kwargs)
|
||||
# return items
|
||||
|
||||
# async def followup(inter: Interaction, *args, **kwargs):
|
||||
# """Shorthand for following up on an interaction.
|
||||
|
||||
# Parameters
|
||||
# ----------
|
||||
# inter : Interaction
|
||||
# Represents an app command interaction.
|
||||
# """
|
||||
|
||||
# await inter.followup.send(*args, **kwargs)
|
||||
|
||||
|
||||
# https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
|
||||
# # https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
|
||||
|
||||
class FollowupIcons:
|
||||
error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
|
||||
success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"
|
||||
trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png"
|
||||
info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png"
|
||||
added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png"
|
||||
assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
|
||||
# class FollowupIcons:
|
||||
# error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
|
||||
# success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"
|
||||
# trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png"
|
||||
# info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png"
|
||||
# added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png"
|
||||
# assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
|
||||
|
||||
|
||||
class PaginationView(View):
|
||||
"""A Discord UI View that adds pagination to an embed."""
|
||||
# class PaginationView(View):
|
||||
# """A Discord UI View that adds pagination to an embed."""
|
||||
|
||||
def __init__(
|
||||
self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable,
|
||||
formatdata: Callable, pagesize: int, initpage: int=1
|
||||
):
|
||||
"""_summary_
|
||||
# def __init__(
|
||||
# self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable,
|
||||
# formatdata: Callable, pagesize: int, initpage: int=1
|
||||
# ):
|
||||
# """_summary_
|
||||
|
||||
Args:
|
||||
bot (commands.Bot) The discord bot
|
||||
inter (Interaction): Represents a discord command interaction.
|
||||
embed (Embed): The base embed to paginate.
|
||||
getdata (Callable): A function that provides data, must return Tuple[List[Any], int].
|
||||
formatdata (Callable): A formatter function that determines how the data is displayed.
|
||||
pagesize (int): The size of each page.
|
||||
initpage (int, optional): The inital page. Defaults to 1.
|
||||
"""
|
||||
# Args:
|
||||
# bot (commands.Bot) The discord bot
|
||||
# inter (Interaction): Represents a discord command interaction.
|
||||
# embed (Embed): The base embed to paginate.
|
||||
# getdata (Callable): A function that provides data, must return Tuple[List[Any], int].
|
||||
# formatdata (Callable): A formatter function that determines how the data is displayed.
|
||||
# pagesize (int): The size of each page.
|
||||
# initpage (int, optional): The inital page. Defaults to 1.
|
||||
# """
|
||||
|
||||
self.bot = bot
|
||||
self.inter = inter
|
||||
self.embed = embed
|
||||
self.getdata = getdata
|
||||
self.formatdata = formatdata
|
||||
self.maxpage = None
|
||||
self.pagesize = pagesize
|
||||
self.index = initpage
|
||||
|
||||
# emoji reference
|
||||
self.start_emoji = bot.get_emoji(1204542364073463818)
|
||||
self.end_emoji = bot.get_emoji(1204542367752003624)
|
||||
|
||||
super().__init__(timeout=100)
|
||||
|
||||
async def check_user_is_author(self, inter: Interaction) -> bool:
|
||||
"""Ensure the user is the author of the original command."""
|
||||
|
||||
if inter.user == self.inter.user:
|
||||
return True
|
||||
|
||||
await inter.response.defer()
|
||||
await (
|
||||
Followup(None, "Only the author can interact with this.")
|
||||
.error()
|
||||
.send(inter, ephemeral=True)
|
||||
)
|
||||
return False
|
||||
|
||||
async def on_timeout(self):
|
||||
"""Erase the controls on timeout."""
|
||||
|
||||
message = await self.inter.original_response()
|
||||
await message.edit(view=None)
|
||||
|
||||
@staticmethod
|
||||
def calc_total_pages(results: int, max_pagesize: int) -> int:
|
||||
result = ((results - 1) // max_pagesize) + 1
|
||||
return result
|
||||
|
||||
def calc_dataitem_index(self, dataitem_index: int):
|
||||
"""Calculates a given index to be relative to the sum of all pages items
|
||||
|
||||
Example: dataitem_index = 6
|
||||
pagesize = 10
|
||||
if page == 1 then return 6
|
||||
else return 6 + 10 * (page - 1)"""
|
||||
|
||||
if self.index > 1:
|
||||
dataitem_index += self.pagesize * (self.index - 1)
|
||||
|
||||
dataitem_index += 1
|
||||
return dataitem_index
|
||||
|
||||
@button(emoji="◀️", style=ButtonStyle.blurple)
|
||||
async def backward(self, inter: Interaction, button: Button):
|
||||
"""
|
||||
Action the backwards button.
|
||||
"""
|
||||
|
||||
self.index -= 1
|
||||
await inter.response.defer()
|
||||
self.inter = inter
|
||||
await self.navigate()
|
||||
|
||||
@button(emoji="▶️", style=ButtonStyle.blurple)
|
||||
async def forward(self, inter: Interaction, button: Button):
|
||||
"""
|
||||
Action the forwards button.
|
||||
"""
|
||||
|
||||
self.index += 1
|
||||
await inter.response.defer()
|
||||
self.inter = inter
|
||||
await self.navigate()
|
||||
|
||||
@button(emoji="⏭️", style=ButtonStyle.blurple)
|
||||
async def start_or_end(self, inter: Interaction, button: Button):
|
||||
"""
|
||||
Action the start and end button.
|
||||
This button becomes return to start if at end, otherwise skip to end.
|
||||
"""
|
||||
|
||||
# Determine if should skip to start or end
|
||||
if self.index <= self.maxpage // 2:
|
||||
self.index = self.maxpage
|
||||
else:
|
||||
self.index = 1
|
||||
|
||||
await inter.response.defer()
|
||||
self.inter = inter
|
||||
await self.navigate()
|
||||
|
||||
async def navigate(self):
|
||||
"""
|
||||
Acts as an update method for the entire instance.
|
||||
"""
|
||||
# self.bot = bot
|
||||
# self.inter = inter
|
||||
# self.embed = embed
|
||||
# self.getdata = getdata
|
||||
# self.formatdata = formatdata
|
||||
# self.maxpage = None
|
||||
# self.pagesize = pagesize
|
||||
# self.index = initpage
|
||||
|
||||
# # emoji reference
|
||||
# self.start_emoji = bot.get_emoji(1204542364073463818)
|
||||
# self.end_emoji = bot.get_emoji(1204542367752003624)
|
||||
|
||||
# super().__init__(timeout=100)
|
||||
|
||||
# async def check_user_is_author(self, inter: Interaction) -> bool:
|
||||
# """Ensure the user is the author of the original command."""
|
||||
|
||||
# if inter.user == self.inter.user:
|
||||
# return True
|
||||
|
||||
# await inter.response.defer()
|
||||
# await (
|
||||
# Followup(None, "Only the author can interact with this.")
|
||||
# .error()
|
||||
# .send(inter, ephemeral=True)
|
||||
# )
|
||||
# return False
|
||||
|
||||
# async def on_timeout(self):
|
||||
# """Erase the controls on timeout."""
|
||||
|
||||
# message = await self.inter.original_response()
|
||||
# await message.edit(view=None)
|
||||
|
||||
# @staticmethod
|
||||
# def calc_total_pages(results: int, max_pagesize: int) -> int:
|
||||
# result = ((results - 1) // max_pagesize) + 1
|
||||
# return result
|
||||
|
||||
# def calc_dataitem_index(self, dataitem_index: int):
|
||||
# """Calculates a given index to be relative to the sum of all pages items
|
||||
|
||||
# Example: dataitem_index = 6
|
||||
# pagesize = 10
|
||||
# if page == 1 then return 6
|
||||
# else return 6 + 10 * (page - 1)"""
|
||||
|
||||
# if self.index > 1:
|
||||
# dataitem_index += self.pagesize * (self.index - 1)
|
||||
|
||||
# dataitem_index += 1
|
||||
# return dataitem_index
|
||||
|
||||
# @button(emoji="◀️", style=ButtonStyle.blurple)
|
||||
# async def backward(self, inter: Interaction, button: Button):
|
||||
# """
|
||||
# Action the backwards button.
|
||||
# """
|
||||
|
||||
# self.index -= 1
|
||||
# await inter.response.defer()
|
||||
# self.inter = inter
|
||||
# await self.navigate()
|
||||
|
||||
# @button(emoji="▶️", style=ButtonStyle.blurple)
|
||||
# async def forward(self, inter: Interaction, button: Button):
|
||||
# """
|
||||
# Action the forwards button.
|
||||
# """
|
||||
|
||||
# self.index += 1
|
||||
# await inter.response.defer()
|
||||
# self.inter = inter
|
||||
# await self.navigate()
|
||||
|
||||
# @button(emoji="⏭️", style=ButtonStyle.blurple)
|
||||
# async def start_or_end(self, inter: Interaction, button: Button):
|
||||
# """
|
||||
# Action the start and end button.
|
||||
# This button becomes return to start if at end, otherwise skip to end.
|
||||
# """
|
||||
|
||||
# # Determine if should skip to start or end
|
||||
# if self.index <= self.maxpage // 2:
|
||||
# self.index = self.maxpage
|
||||
# else:
|
||||
# self.index = 1
|
||||
|
||||
# await inter.response.defer()
|
||||
# self.inter = inter
|
||||
# await self.navigate()
|
||||
|
||||
# async def navigate(self):
|
||||
# """
|
||||
# Acts as an update method for the entire instance.
|
||||
# """
|
||||
|
||||
log.debug("navigating to page: %s", self.index)
|
||||
# log.debug("navigating to page: %s", self.index)
|
||||
|
||||
self.update_buttons()
|
||||
paged_embed = await self.create_paged_embed()
|
||||
await self.inter.edit_original_response(embed=paged_embed, view=self)
|
||||
# self.update_buttons()
|
||||
# paged_embed = await self.create_paged_embed()
|
||||
# await self.inter.edit_original_response(embed=paged_embed, view=self)
|
||||
|
||||
async def create_paged_embed(self) -> Embed:
|
||||
"""
|
||||
Returns a copy of the known embed, but with data from the current page.
|
||||
"""
|
||||
# async def create_paged_embed(self) -> Embed:
|
||||
# """
|
||||
# Returns a copy of the known embed, but with data from the current page.
|
||||
# """
|
||||
|
||||
embed = self.embed.copy()
|
||||
# embed = self.embed.copy()
|
||||
|
||||
try:
|
||||
data, total_results = await self.getdata(self.index, self.pagesize)
|
||||
except aiohttp.ClientResponseError as exc:
|
||||
log.error(exc)
|
||||
await (
|
||||
Followup(f"Error · {exc.message}",)
|
||||
.footer(f"HTTP {exc.code}")
|
||||
.error()
|
||||
.send(self.inter)
|
||||
)
|
||||
raise exc
|
||||
# try:
|
||||
# data, total_results = await self.getdata(self.index, self.pagesize)
|
||||
# except aiohttp.ClientResponseError as exc:
|
||||
# log.error(exc)
|
||||
# await (
|
||||
# Followup(f"Error · {exc.message}",)
|
||||
# .footer(f"HTTP {exc.code}")
|
||||
# .error()
|
||||
# .send(self.inter)
|
||||
# )
|
||||
# raise exc
|
||||
|
||||
self.maxpage = self.calc_total_pages(total_results, self.pagesize)
|
||||
log.debug(f"{self.maxpage=!r}")
|
||||
# self.maxpage = self.calc_total_pages(total_results, self.pagesize)
|
||||
# log.debug(f"{self.maxpage=!r}")
|
||||
|
||||
for i, item in enumerate(data):
|
||||
i = self.calc_dataitem_index(i)
|
||||
if asyncio.iscoroutinefunction(self.formatdata):
|
||||
key, value = await self.formatdata(i, item)
|
||||
else:
|
||||
key, value = self.formatdata(i, item)
|
||||
# for i, item in enumerate(data):
|
||||
# i = self.calc_dataitem_index(i)
|
||||
# if asyncio.iscoroutinefunction(self.formatdata):
|
||||
# key, value = await self.formatdata(i, item)
|
||||
# else:
|
||||
# key, value = self.formatdata(i, item)
|
||||
|
||||
embed.add_field(name=key, value=value, inline=False)
|
||||
# embed.add_field(name=key, value=value, inline=False)
|
||||
|
||||
if not total_results:
|
||||
embed.description = "There are no results"
|
||||
# if not total_results:
|
||||
# embed.description = "There are no results"
|
||||
|
||||
if self.maxpage > 1:
|
||||
embed.set_footer(text=f"Page {self.index}/{self.maxpage}")
|
||||
# if self.maxpage > 1:
|
||||
# embed.set_footer(text=f"Page {self.index}/{self.maxpage}")
|
||||
|
||||
return embed
|
||||
# return embed
|
||||
|
||||
def update_buttons(self):
|
||||
if self.index >= self.maxpage:
|
||||
self.children[2].emoji = self.start_emoji
|
||||
else:
|
||||
self.children[2].emoji = self.end_emoji
|
||||
# def update_buttons(self):
|
||||
# if self.index >= self.maxpage:
|
||||
# self.children[2].emoji = self.start_emoji
|
||||
# else:
|
||||
# self.children[2].emoji = self.end_emoji
|
||||
|
||||
self.children[0].disabled = self.index == 1
|
||||
self.children[1].disabled = self.index == self.maxpage
|
||||
# self.children[0].disabled = self.index == 1
|
||||
# self.children[1].disabled = self.index == self.maxpage
|
||||
|
||||
async def send(self):
|
||||
"""Send the pagination view. It may be important to defer before invoking this method."""
|
||||
# async def send(self):
|
||||
# """Send the pagination view. It may be important to defer before invoking this method."""
|
||||
|
||||
log.debug("sending pagination view")
|
||||
embed = await self.create_paged_embed()
|
||||
# log.debug("sending pagination view")
|
||||
# embed = await self.create_paged_embed()
|
||||
|
||||
if self.maxpage <= 1:
|
||||
await self.inter.edit_original_response(embed=embed)
|
||||
return
|
||||
# if self.maxpage <= 1:
|
||||
# await self.inter.edit_original_response(embed=embed)
|
||||
# return
|
||||
|
||||
self.update_buttons()
|
||||
await self.inter.edit_original_response(embed=embed, view=self)
|
||||
# self.update_buttons()
|
||||
# await self.inter.edit_original_response(embed=embed, view=self)
|
||||
|
||||
|
||||
|
||||
class Followup:
|
||||
"""Wrapper for a discord embed to follow up an interaction."""
|
||||
# class Followup:
|
||||
# """Wrapper for a discord embed to follow up an interaction."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
title: str = None,
|
||||
description: str = None,
|
||||
):
|
||||
self._embed = Embed(
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
# def __init__(
|
||||
# self,
|
||||
# title: str = None,
|
||||
# description: str = None,
|
||||
# ):
|
||||
# self._embed = Embed(
|
||||
# title=title,
|
||||
# description=description
|
||||
# )
|
||||
|
||||
async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False):
|
||||
""""""
|
||||
# async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False):
|
||||
# """"""
|
||||
|
||||
await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral)
|
||||
# await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral)
|
||||
|
||||
def fields(self, inline: bool = False, **fields: dict):
|
||||
""""""
|
||||
# def fields(self, inline: bool = False, **fields: dict):
|
||||
# """"""
|
||||
|
||||
for key, value in fields.items():
|
||||
self._embed.add_field(name=key, value=value, inline=inline)
|
||||
# for key, value in fields.items():
|
||||
# self._embed.add_field(name=key, value=value, inline=inline)
|
||||
|
||||
return self
|
||||
# return self
|
||||
|
||||
def image(self, url: str):
|
||||
""""""
|
||||
# def image(self, url: str):
|
||||
# """"""
|
||||
|
||||
self._embed.set_image(url=url)
|
||||
# self._embed.set_image(url=url)
|
||||
|
||||
return self
|
||||
# return self
|
||||
|
||||
def author(self, name: str, url: str=None, icon_url: str=None):
|
||||
""""""
|
||||
# def author(self, name: str, url: str=None, icon_url: str=None):
|
||||
# """"""
|
||||
|
||||
self._embed.set_author(name=name, url=url, icon_url=icon_url)
|
||||
# self._embed.set_author(name=name, url=url, icon_url=icon_url)
|
||||
|
||||
return self
|
||||
# return self
|
||||
|
||||
def footer(self, text: str, icon_url: str = None):
|
||||
""""""
|
||||
# def footer(self, text: str, icon_url: str = None):
|
||||
# """"""
|
||||
|
||||
self._embed.set_footer(text=text, icon_url=icon_url)
|
||||
# self._embed.set_footer(text=text, icon_url=icon_url)
|
||||
|
||||
return self
|
||||
# return self
|
||||
|
||||
def error(self):
|
||||
""""""
|
||||
# def error(self):
|
||||
# """"""
|
||||
|
||||
self._embed.colour = Colour.red()
|
||||
self._embed.set_thumbnail(url=FollowupIcons.error)
|
||||
return self
|
||||
# self._embed.colour = Colour.red()
|
||||
# self._embed.set_thumbnail(url=FollowupIcons.error)
|
||||
# return self
|
||||
|
||||
def success(self):
|
||||
""""""
|
||||
# def success(self):
|
||||
# """"""
|
||||
|
||||
self._embed.colour = Colour.green()
|
||||
self._embed.set_thumbnail(url=FollowupIcons.success)
|
||||
return self
|
||||
# self._embed.colour = Colour.green()
|
||||
# self._embed.set_thumbnail(url=FollowupIcons.success)
|
||||
# return self
|
||||
|
||||
def info(self):
|
||||
""""""
|
||||
# def info(self):
|
||||
# """"""
|
||||
|
||||
self._embed.colour = Colour.blue()
|
||||
self._embed.set_thumbnail(url=FollowupIcons.info)
|
||||
return self
|
||||
# self._embed.colour = Colour.blue()
|
||||
# self._embed.set_thumbnail(url=FollowupIcons.info)
|
||||
# return self
|
||||
|
||||
def added(self):
|
||||
""""""
|
||||
# def added(self):
|
||||
# """"""
|
||||
|
||||
self._embed.colour = Colour.blue()
|
||||
self._embed.set_thumbnail(url=FollowupIcons.added)
|
||||
return self
|
||||
# self._embed.colour = Colour.blue()
|
||||
# self._embed.set_thumbnail(url=FollowupIcons.added)
|
||||
# return self
|
||||
|
||||
def assign(self):
|
||||
""""""
|
||||
# def assign(self):
|
||||
# """"""
|
||||
|
||||
self._embed.colour = Colour.blue()
|
||||
self._embed.set_thumbnail(url=FollowupIcons.assigned)
|
||||
return self
|
||||
# self._embed.colour = Colour.blue()
|
||||
# self._embed.set_thumbnail(url=FollowupIcons.assigned)
|
||||
# return self
|
||||
|
||||
def trash(self):
|
||||
""""""
|
||||
# def trash(self):
|
||||
# """"""
|
||||
|
||||
self._embed.colour = Colour.red()
|
||||
self._embed.set_thumbnail(url=FollowupIcons.trash)
|
||||
return self
|
||||
# self._embed.colour = Colour.red()
|
||||
# self._embed.set_thumbnail(url=FollowupIcons.trash)
|
||||
# return self
|
||||
|
||||
|
||||
def extract_error_info(error: Exception) -> str:
|
||||
class_name = error.__class__.__name__
|
||||
desc = str(error)
|
||||
return class_name, desc
|
||||
# def extract_error_info(error: Exception) -> str:
|
||||
# class_name = error.__class__.__name__
|
||||
# desc = str(error)
|
||||
# return class_name, desc
|
||||
|
Loading…
x
Reference in New Issue
Block a user