efficiency and function for tasks
Some checks failed
Build and Push Docker Image / build (push) Failing after 7m7s
Some checks failed
Build and Push Docker Image / build (push) Failing after 7m7s
This commit is contained in:
parent
470f78c144
commit
08295dfea6
@ -23,6 +23,7 @@ from discord.ext import commands, tasks
|
|||||||
from discord.errors import Forbidden
|
from discord.errors import Forbidden
|
||||||
|
|
||||||
import models
|
import models
|
||||||
|
from utils import do_batch_job
|
||||||
# from feed import RSSFeed, Subscription, RSSItem, GuildSettings
|
# from feed import RSSFeed, Subscription, RSSItem, GuildSettings
|
||||||
# from utils import get_unparsed_feed
|
# from utils import get_unparsed_feed
|
||||||
# from filters import match_text
|
# from filters import match_text
|
||||||
@ -53,6 +54,7 @@ class TaskCog(commands.Cog):
|
|||||||
|
|
||||||
api_base_url: str
|
api_base_url: str
|
||||||
api_headers: dict
|
api_headers: dict
|
||||||
|
client: httpx.AsyncClient | None
|
||||||
|
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -100,17 +102,18 @@ class TaskCog(commands.Cog):
|
|||||||
start_time = perf_counter()
|
start_time = perf_counter()
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
servers = await self.get_servers(client)
|
self.client = client
|
||||||
await self.process_servers(servers, client)
|
servers = await self.get_servers()
|
||||||
|
await do_batch_job(servers, self.process_server, 10)
|
||||||
|
|
||||||
end_time = perf_counter()
|
end_time = perf_counter()
|
||||||
log.debug(f"completed task in {end_time - start_time:.4f} seconds")
|
log.debug(f"completed task in {end_time - start_time:.4f} seconds")
|
||||||
|
|
||||||
async def iterate_pages(self, client: httpx.AsyncClient, url: str, params: dict={}):
|
async def iterate_pages(self, url: str, params: dict={}):
|
||||||
|
|
||||||
for page_number, _ in enumerate(iterable=iter(int, 1), start=1):
|
for page_number, _ in enumerate(iterable=iter(int, 1), start=1):
|
||||||
params.update({"page": page_number})
|
params.update({"page": page_number})
|
||||||
response = await client.get(
|
response = await self.client.get(
|
||||||
self.api_base_url + url,
|
self.api_base_url + url,
|
||||||
headers=self.api_headers,
|
headers=self.api_headers,
|
||||||
params=params
|
params=params
|
||||||
@ -123,82 +126,84 @@ class TaskCog(commands.Cog):
|
|||||||
if not content.get("next"):
|
if not content.get("next"):
|
||||||
break
|
break
|
||||||
|
|
||||||
async def get_servers(self, client: httpx.AsyncClient) -> list[models.Server]:
|
async def get_servers(self) -> list[models.Server]:
|
||||||
servers = []
|
servers = []
|
||||||
|
|
||||||
async for servers_batch in self.iterate_pages(client, "servers/"):
|
async for servers_batch in self.iterate_pages("servers/"):
|
||||||
if servers_batch:
|
if servers_batch:
|
||||||
servers.extend(servers_batch)
|
servers.extend(servers_batch)
|
||||||
|
|
||||||
return models.Server.from_list(servers)
|
return models.Server.from_list(servers)
|
||||||
|
|
||||||
async def get_subscriptions(self, server: models.Server, client: httpx.AsyncClient) -> list[models.Subscription]:
|
async def get_subscriptions(self, server: models.Server) -> list[models.Subscription]:
|
||||||
subscriptions = []
|
subscriptions = []
|
||||||
params = {"server": server.id}
|
params = {"server": server.id, "active": True}
|
||||||
|
|
||||||
async for subscriptions_batch in self.iterate_pages(client, "subscriptions/", params):
|
async for subscriptions_batch in self.iterate_pages("subscriptions/", params):
|
||||||
if subscriptions_batch:
|
if subscriptions_batch:
|
||||||
subscriptions.extend(subscriptions_batch)
|
subscriptions.extend(subscriptions_batch)
|
||||||
|
|
||||||
return models.Subscription.from_list(subscriptions)
|
return models.Subscription.from_list(subscriptions)
|
||||||
|
|
||||||
async def process_servers(self, servers: list[models.Server], client: httpx.AsyncClient):
|
async def process_server(self, server: models.Server):
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(10)
|
|
||||||
|
|
||||||
async def batch_process(server: models.Server, client: httpx.AsyncClient):
|
|
||||||
async with semaphore: await self.process_server(server, client)
|
|
||||||
|
|
||||||
tasks = [batch_process(server, client) for server in servers if server.active]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
async def process_server(self, server: models.Server, client: httpx.AsyncClient):
|
|
||||||
log.debug(f"processing server: {server.name}")
|
log.debug(f"processing server: {server.name}")
|
||||||
start_time = perf_counter()
|
start_time = perf_counter()
|
||||||
|
|
||||||
subscriptions = await self.get_subscriptions(server, client)
|
subscriptions = await self.get_subscriptions(server)
|
||||||
for subscription in subscriptions:
|
for subscription in subscriptions:
|
||||||
subscription.server = server
|
subscription.server = server
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(10)
|
await do_batch_job(subscriptions, self.process_subscription, 10)
|
||||||
|
|
||||||
async def batch_process(subscription: models.Subscription, client: httpx.AsyncClient):
|
|
||||||
async with semaphore: await self.process_subscription(subscription, client)
|
|
||||||
|
|
||||||
tasks = [
|
|
||||||
batch_process(subscription, client)
|
|
||||||
for subscription in subscriptions
|
|
||||||
if subscription.active
|
|
||||||
]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
end_time = perf_counter()
|
end_time = perf_counter()
|
||||||
log.debug(f"Finished processing server: {server.name} in {end_time - start_time:.4f} seconds")
|
log.debug(f"Finished processing server: {server.name} in {end_time - start_time:.4f} seconds")
|
||||||
|
|
||||||
async def process_subscription(self, subscription: models.Subscription, client: httpx.AsyncClient):
|
async def process_subscription(self, subscription: models.Subscription):
|
||||||
log.debug(f"processing subscription {subscription.name}")
|
log.debug(f"processing subscription {subscription.name}")
|
||||||
start_time = perf_counter()
|
start_time = perf_counter()
|
||||||
|
|
||||||
raw_rss_content = await subscription.get_rss_content(client)
|
raw_rss_content = await subscription.get_rss_content(self.client)
|
||||||
if not raw_rss_content:
|
if not raw_rss_content:
|
||||||
return
|
return
|
||||||
|
|
||||||
channels = await subscription.get_discord_channels(self.bot)
|
channels = await subscription.get_discord_channels(self.bot)
|
||||||
contents = models.Content.from_raw_rss(raw_rss_content, subscription)
|
contents = await models.Content.from_raw_rss(raw_rss_content, subscription, self.client)
|
||||||
valid_contents, invalid_contents = subscription.filter_entries(contents)
|
valid_contents, invalid_contents = subscription.filter_entries(contents)
|
||||||
|
|
||||||
for content in valid_contents:
|
async def send_content(channel: discord.TextChannel):
|
||||||
await self.process_content(content, channels)
|
embeds = [content.embed for content in valid_contents]
|
||||||
tasks = [channel.send(content.item_title) for channel in channels]
|
batch_size = 10
|
||||||
asyncio.gather(*tasks)
|
for i in range(0, len(embeds), batch_size):
|
||||||
|
batch = embeds[i:i + batch_size]
|
||||||
|
await channel.send(embeds=batch)
|
||||||
|
|
||||||
|
await do_batch_job(channels, send_content, 5)
|
||||||
|
|
||||||
|
# TODO: mark invalid contents as blocked
|
||||||
|
|
||||||
end_time = perf_counter()
|
end_time = perf_counter()
|
||||||
log.debug(f"Finished processing subscription: {subscription.name} in {end_time - start_time:.4f}")
|
log.debug(f"Finished processing subscription: {subscription.name} in {end_time - start_time:.4f}")
|
||||||
|
|
||||||
async def process_valid_contents(contents: list[models.Content], channels: list[discord.TextChannel], client: httpx.AsyncClient):
|
# async def process_valid_contents(
|
||||||
semaphore = asyncio.Semaphore(5)
|
# self,
|
||||||
|
# contents: list[models.Content],
|
||||||
|
# channels: list[discord.TextChannel],
|
||||||
|
# client: httpx.AsyncClient
|
||||||
|
# ):
|
||||||
|
# semaphore = asyncio.Semaphore(5)
|
||||||
|
|
||||||
|
# async def batch_process(
|
||||||
|
# content: models.Content,
|
||||||
|
# channels: list[discord.TextChannel],
|
||||||
|
# client: httpx.AsyncClient
|
||||||
|
# ):
|
||||||
|
# async with semaphore: await self.process_valid_content(content, channels, client)
|
||||||
|
|
||||||
|
# tasks = [
|
||||||
|
# batch_process()
|
||||||
|
# ]
|
||||||
|
|
||||||
|
|
||||||
async def batch_process(content: models.Content, )
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
109
src/models.py
109
src/models.py
@ -1,15 +1,23 @@
|
|||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import asyncio
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from datetime import datetime
|
from time import perf_counter
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from textwrap import shorten
|
||||||
|
|
||||||
|
import feedparser.parsers
|
||||||
import httpx
|
import httpx
|
||||||
import discord
|
import discord
|
||||||
import rapidfuzz
|
import rapidfuzz
|
||||||
import feedparser
|
import feedparser
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from markdownify import markdownify
|
||||||
|
|
||||||
|
from utils import do_batch_job
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -272,12 +280,16 @@ class Subscription(DjangoDataModel):
|
|||||||
self._server = server
|
self._server = server
|
||||||
|
|
||||||
async def get_rss_content(self, client: httpx.AsyncClient) -> str:
|
async def get_rss_content(self, client: httpx.AsyncClient) -> str:
|
||||||
|
start_time = perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.get(self.url)
|
response = await client.get(self.url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPError as exc:
|
except httpx.HTTPError as exc:
|
||||||
log.error("(%s) HTTP Exception for %s - %s", type(exc), exc.request.url, exc)
|
log.error("(%s) HTTP Exception for %s - %s", type(exc), exc.request.url, exc)
|
||||||
return
|
return
|
||||||
|
finally:
|
||||||
|
log.debug(f"Got rss content in {perf_counter() - start_time:.4f} seconds")
|
||||||
|
|
||||||
content_type = response.headers.get("Content-Type")
|
content_type = response.headers.get("Content-Type")
|
||||||
if not "text/xml" in content_type:
|
if not "text/xml" in content_type:
|
||||||
@ -286,16 +298,20 @@ class Subscription(DjangoDataModel):
|
|||||||
|
|
||||||
return response.text
|
return response.text
|
||||||
|
|
||||||
async def get_discord_channels(self, bot) -> list:
|
async def get_discord_channels(self, bot) -> list[discord.TextChannel]:
|
||||||
|
start_time = perf_counter()
|
||||||
channels = []
|
channels = []
|
||||||
|
|
||||||
for channel_detail in self.channels:
|
for channel_detail in self.channels:
|
||||||
try:
|
try:
|
||||||
channel = bot.get_channel(channel_detail.id)
|
channel = bot.get_channel(channel_detail.id)
|
||||||
channels.append(channel or await bot.fetch_channel(channel_detail.id))
|
channels.append(channel or await bot.fetch_channel(channel_detail.id))
|
||||||
except discord.Forbidden:
|
except Exception as exc:
|
||||||
log.error(f"Forbidden channel: ({channel.name}, {channel.id}) from ({self.server.name}, {self.server.id})")
|
channel_reference = f"({channel_detail.name}, {channel_detail.id})"
|
||||||
|
server_reference = f"({self.server.name}, {self.server.id})"
|
||||||
|
log.debug(f"Failed to get channel {channel_reference} from {server_reference}: {exc}")
|
||||||
|
|
||||||
|
log.debug(f"Got channels in {perf_counter() - start_time:.4f} seconds")
|
||||||
return channels
|
return channels
|
||||||
|
|
||||||
def filter_entries(self, contents: list) -> tuple[list, list]:
|
def filter_entries(self, contents: list) -> tuple[list, list]:
|
||||||
@ -324,6 +340,13 @@ class Content(DjangoDataModel):
|
|||||||
item_url: str
|
item_url: str
|
||||||
item_title: str
|
item_title: str
|
||||||
item_description: str
|
item_description: str
|
||||||
|
item_image_url: str | None
|
||||||
|
item_thumbnail_url: str | None
|
||||||
|
item_published: datetime | None
|
||||||
|
item_author: str
|
||||||
|
item_author_url: str
|
||||||
|
item_feed_title: str
|
||||||
|
item_feed_url: str
|
||||||
_subscription: Subscription | None = None
|
_subscription: Subscription | None = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -333,31 +356,60 @@ class Content(DjangoDataModel):
|
|||||||
return item
|
return item
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_raw_rss(cls, raw_rss_content: str, subscription: Subscription):
|
async def from_raw_rss(cls, rss: str, subscription: Subscription, client: httpx.AsyncClient):
|
||||||
parsed_rss = feedparser.parse(raw_rss_content)
|
style = subscription.message_style
|
||||||
|
parsed_rss = feedparser.parse(rss)
|
||||||
contents = []
|
contents = []
|
||||||
|
|
||||||
for entry in parsed_rss.entries:
|
async def create_content(entry: feedparser.FeedParserDict):
|
||||||
# content_hash = hashlib.new("sha256")
|
# content_hash = hashlib.new("sha256")
|
||||||
# content_hash.update(entry.get("description", "").encode())
|
# content_hash.update(entry.get("description", "").encode())
|
||||||
# content_hash.hexdigest()
|
# content_hash.hexdigest()
|
||||||
|
|
||||||
data = {
|
item_url = entry.get("link", "")
|
||||||
|
item_image_url = entry.get("media_thumbnail", [{}])[0].get("url")
|
||||||
|
if style.fetch_images:
|
||||||
|
item_image_url = await cls.get_image_url(item_url, client)
|
||||||
|
|
||||||
|
published = entry.get("published_parsed")
|
||||||
|
published = datetime(*published[0:6] if published else None, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
content = Content.from_dict({
|
||||||
"id": -1,
|
"id": -1,
|
||||||
"subscription": subscription.id,
|
"subscription": subscription.id,
|
||||||
"item_id": entry.get("id", ""),
|
"item_id": entry.get("id", ""),
|
||||||
"item_guid": entry.get("guid", ""),
|
"item_guid": entry.get("guid", ""),
|
||||||
"item_url": entry.get("link", ""),
|
"item_url": item_url,
|
||||||
"item_title": entry.get("title", ""),
|
"item_title": entry.get("title", ""),
|
||||||
"item_description": entry.get("description", "")
|
"item_description": entry.get("description", ""),
|
||||||
}
|
"item_image_url": item_image_url,
|
||||||
|
"item_thumbnail_url": parsed_rss.feed.image.href or None,
|
||||||
|
"item_published": published,
|
||||||
|
"item_author": entry.get("author", ""),
|
||||||
|
"item_author_url": entry.get("author_detail", {}).get("href", ""),
|
||||||
|
"item_feed_title": parsed_rss.get("feed", {}).get("title"),
|
||||||
|
"item_feed_url": parsed_rss.get("feed", {}).get("link")
|
||||||
|
})
|
||||||
|
|
||||||
content = Content.from_dict(data)
|
|
||||||
content.subscription = subscription
|
content.subscription = subscription
|
||||||
contents.append(content)
|
contents.append(content)
|
||||||
|
|
||||||
|
await do_batch_job(parsed_rss.entries, create_content, 15)
|
||||||
|
contents.sort(key=lambda k: k.item_published)
|
||||||
return contents
|
return contents
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_image_url(url: str, client: httpx.AsyncClient) -> str | None:
|
||||||
|
log.debug("Fetching image url")
|
||||||
|
|
||||||
|
response = await client.get(url, timeout=15)
|
||||||
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
image_element = soup.select_one("meta[property='og:image']")
|
||||||
|
if not image_element:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return image_element.get("content")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def subscription(self) -> Subscription:
|
def subscription(self) -> Subscription:
|
||||||
return self._subscription
|
return self._subscription
|
||||||
@ -365,3 +417,36 @@ class Content(DjangoDataModel):
|
|||||||
@subscription.setter
|
@subscription.setter
|
||||||
def subscription(self, subscription: Subscription):
|
def subscription(self, subscription: Subscription):
|
||||||
self._subscription = subscription
|
self._subscription = subscription
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embed(self):
|
||||||
|
colour=discord.Colour.from_str(
|
||||||
|
f"#{self.subscription.message_style.colour}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ensure content fits within character limits
|
||||||
|
title = shorten(markdownify(self.item_title, strip=("img", "a")), 256)
|
||||||
|
description = shorten(markdownify(self.item_description, strip=("img",)), 4096)
|
||||||
|
author = self.item_author or self.item_feed_title
|
||||||
|
|
||||||
|
combined_length = len(title) + len(description) + (len(author) * 2)
|
||||||
|
cutoff = combined_length - 6000
|
||||||
|
description = shorten(description, cutoff) if cutoff > 0 else description
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
url=self.item_url,
|
||||||
|
colour=colour,
|
||||||
|
timestamp=self.item_published
|
||||||
|
)
|
||||||
|
|
||||||
|
embed.set_image(url=self.item_image_url)
|
||||||
|
embed.set_thumbnail(url=self.item_thumbnail_url)
|
||||||
|
embed.set_author(
|
||||||
|
name=author,
|
||||||
|
url=self.item_author_url or self.item_feed_url
|
||||||
|
)
|
||||||
|
embed.set_footer(text=self.subscription.name)
|
||||||
|
|
||||||
|
return embed
|
||||||
|
532
src/utils.py
532
src/utils.py
@ -1,10 +1,10 @@
|
|||||||
"""A collection of utility functions that can be used in various places."""
|
"""A collection of utility functions that can be used in various places."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
# import aiohttp
|
||||||
import logging
|
import logging
|
||||||
import async_timeout
|
# import async_timeout
|
||||||
from typing import Callable
|
# from typing import Callable
|
||||||
|
|
||||||
from discord import Interaction, Embed, Colour, ButtonStyle, Button
|
from discord import Interaction, Embed, Colour, ButtonStyle, Button
|
||||||
from discord.ui import View, button
|
from discord.ui import View, button
|
||||||
@ -12,325 +12,335 @@ from discord.ext.commands import Bot
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def fetch(session, url: str) -> str:
|
async def do_batch_job(iterable: list, func, batch_size: int):
|
||||||
async with async_timeout.timeout(20):
|
semaphore = asyncio.Semaphore(batch_size)
|
||||||
async with session.get(url) as response:
|
|
||||||
return await response.text()
|
|
||||||
|
|
||||||
async def get_unparsed_feed(url: str, session: aiohttp.ClientSession=None):
|
async def batch_job(item):
|
||||||
if session is not None:
|
async with semaphore:
|
||||||
return await fetch(session, url)
|
await func(item)
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
tasks = [batch_job(item) for item in iterable]
|
||||||
return await fetch(session, url)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
async def get_rss_data(url: str):
|
# async def fetch(session, url: str) -> str:
|
||||||
async with aiohttp.ClientSession() as session:
|
# async with async_timeout.timeout(20):
|
||||||
async with session.get(url) as response:
|
# async with session.get(url) as response:
|
||||||
items = await response.text(), response.status
|
# return await response.text()
|
||||||
|
|
||||||
return items
|
# async def get_unparsed_feed(url: str, session: aiohttp.ClientSession=None):
|
||||||
|
# if session is not None:
|
||||||
|
# return await fetch(session, url)
|
||||||
|
|
||||||
async def followup(inter: Interaction, *args, **kwargs):
|
# async with aiohttp.ClientSession() as session:
|
||||||
"""Shorthand for following up on an interaction.
|
# return await fetch(session, url)
|
||||||
|
|
||||||
Parameters
|
# async def get_rss_data(url: str):
|
||||||
----------
|
# async with aiohttp.ClientSession() as session:
|
||||||
inter : Interaction
|
# async with session.get(url) as response:
|
||||||
Represents an app command interaction.
|
# items = await response.text(), response.status
|
||||||
"""
|
|
||||||
|
|
||||||
await inter.followup.send(*args, **kwargs)
|
# return items
|
||||||
|
|
||||||
|
# async def followup(inter: Interaction, *args, **kwargs):
|
||||||
|
# """Shorthand for following up on an interaction.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# inter : Interaction
|
||||||
|
# Represents an app command interaction.
|
||||||
|
# """
|
||||||
|
|
||||||
|
# await inter.followup.send(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
|
# # https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
|
||||||
|
|
||||||
class FollowupIcons:
|
# class FollowupIcons:
|
||||||
error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
|
# error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
|
||||||
success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"
|
# success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"
|
||||||
trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png"
|
# trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png"
|
||||||
info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png"
|
# info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png"
|
||||||
added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png"
|
# added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png"
|
||||||
assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
|
# assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
|
||||||
|
|
||||||
|
|
||||||
class PaginationView(View):
|
# class PaginationView(View):
|
||||||
"""A Discord UI View that adds pagination to an embed."""
|
# """A Discord UI View that adds pagination to an embed."""
|
||||||
|
|
||||||
def __init__(
|
# def __init__(
|
||||||
self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable,
|
# self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable,
|
||||||
formatdata: Callable, pagesize: int, initpage: int=1
|
# formatdata: Callable, pagesize: int, initpage: int=1
|
||||||
):
|
# ):
|
||||||
"""_summary_
|
# """_summary_
|
||||||
|
|
||||||
Args:
|
# Args:
|
||||||
bot (commands.Bot) The discord bot
|
# bot (commands.Bot) The discord bot
|
||||||
inter (Interaction): Represents a discord command interaction.
|
# inter (Interaction): Represents a discord command interaction.
|
||||||
embed (Embed): The base embed to paginate.
|
# embed (Embed): The base embed to paginate.
|
||||||
getdata (Callable): A function that provides data, must return Tuple[List[Any], int].
|
# getdata (Callable): A function that provides data, must return Tuple[List[Any], int].
|
||||||
formatdata (Callable): A formatter function that determines how the data is displayed.
|
# formatdata (Callable): A formatter function that determines how the data is displayed.
|
||||||
pagesize (int): The size of each page.
|
# pagesize (int): The size of each page.
|
||||||
initpage (int, optional): The inital page. Defaults to 1.
|
# initpage (int, optional): The inital page. Defaults to 1.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
self.bot = bot
|
# self.bot = bot
|
||||||
self.inter = inter
|
# self.inter = inter
|
||||||
self.embed = embed
|
# self.embed = embed
|
||||||
self.getdata = getdata
|
# self.getdata = getdata
|
||||||
self.formatdata = formatdata
|
# self.formatdata = formatdata
|
||||||
self.maxpage = None
|
# self.maxpage = None
|
||||||
self.pagesize = pagesize
|
# self.pagesize = pagesize
|
||||||
self.index = initpage
|
# self.index = initpage
|
||||||
|
|
||||||
# emoji reference
|
# # emoji reference
|
||||||
self.start_emoji = bot.get_emoji(1204542364073463818)
|
# self.start_emoji = bot.get_emoji(1204542364073463818)
|
||||||
self.end_emoji = bot.get_emoji(1204542367752003624)
|
# self.end_emoji = bot.get_emoji(1204542367752003624)
|
||||||
|
|
||||||
super().__init__(timeout=100)
|
# super().__init__(timeout=100)
|
||||||
|
|
||||||
async def check_user_is_author(self, inter: Interaction) -> bool:
|
# async def check_user_is_author(self, inter: Interaction) -> bool:
|
||||||
"""Ensure the user is the author of the original command."""
|
# """Ensure the user is the author of the original command."""
|
||||||
|
|
||||||
if inter.user == self.inter.user:
|
# if inter.user == self.inter.user:
|
||||||
return True
|
# return True
|
||||||
|
|
||||||
await inter.response.defer()
|
# await inter.response.defer()
|
||||||
await (
|
# await (
|
||||||
Followup(None, "Only the author can interact with this.")
|
# Followup(None, "Only the author can interact with this.")
|
||||||
.error()
|
# .error()
|
||||||
.send(inter, ephemeral=True)
|
# .send(inter, ephemeral=True)
|
||||||
)
|
# )
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
async def on_timeout(self):
|
# async def on_timeout(self):
|
||||||
"""Erase the controls on timeout."""
|
# """Erase the controls on timeout."""
|
||||||
|
|
||||||
message = await self.inter.original_response()
|
# message = await self.inter.original_response()
|
||||||
await message.edit(view=None)
|
# await message.edit(view=None)
|
||||||
|
|
||||||
@staticmethod
|
# @staticmethod
|
||||||
def calc_total_pages(results: int, max_pagesize: int) -> int:
|
# def calc_total_pages(results: int, max_pagesize: int) -> int:
|
||||||
result = ((results - 1) // max_pagesize) + 1
|
# result = ((results - 1) // max_pagesize) + 1
|
||||||
return result
|
# return result
|
||||||
|
|
||||||
def calc_dataitem_index(self, dataitem_index: int):
|
# def calc_dataitem_index(self, dataitem_index: int):
|
||||||
"""Calculates a given index to be relative to the sum of all pages items
|
# """Calculates a given index to be relative to the sum of all pages items
|
||||||
|
|
||||||
Example: dataitem_index = 6
|
# Example: dataitem_index = 6
|
||||||
pagesize = 10
|
# pagesize = 10
|
||||||
if page == 1 then return 6
|
# if page == 1 then return 6
|
||||||
else return 6 + 10 * (page - 1)"""
|
# else return 6 + 10 * (page - 1)"""
|
||||||
|
|
||||||
if self.index > 1:
|
# if self.index > 1:
|
||||||
dataitem_index += self.pagesize * (self.index - 1)
|
# dataitem_index += self.pagesize * (self.index - 1)
|
||||||
|
|
||||||
dataitem_index += 1
|
# dataitem_index += 1
|
||||||
return dataitem_index
|
# return dataitem_index
|
||||||
|
|
||||||
@button(emoji="◀️", style=ButtonStyle.blurple)
|
# @button(emoji="◀️", style=ButtonStyle.blurple)
|
||||||
async def backward(self, inter: Interaction, button: Button):
|
# async def backward(self, inter: Interaction, button: Button):
|
||||||
"""
|
# """
|
||||||
Action the backwards button.
|
# Action the backwards button.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
self.index -= 1
|
# self.index -= 1
|
||||||
await inter.response.defer()
|
# await inter.response.defer()
|
||||||
self.inter = inter
|
# self.inter = inter
|
||||||
await self.navigate()
|
# await self.navigate()
|
||||||
|
|
||||||
@button(emoji="▶️", style=ButtonStyle.blurple)
|
# @button(emoji="▶️", style=ButtonStyle.blurple)
|
||||||
async def forward(self, inter: Interaction, button: Button):
|
# async def forward(self, inter: Interaction, button: Button):
|
||||||
"""
|
# """
|
||||||
Action the forwards button.
|
# Action the forwards button.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
self.index += 1
|
# self.index += 1
|
||||||
await inter.response.defer()
|
# await inter.response.defer()
|
||||||
self.inter = inter
|
# self.inter = inter
|
||||||
await self.navigate()
|
# await self.navigate()
|
||||||
|
|
||||||
@button(emoji="⏭️", style=ButtonStyle.blurple)
|
# @button(emoji="⏭️", style=ButtonStyle.blurple)
|
||||||
async def start_or_end(self, inter: Interaction, button: Button):
|
# async def start_or_end(self, inter: Interaction, button: Button):
|
||||||
"""
|
# """
|
||||||
Action the start and end button.
|
# Action the start and end button.
|
||||||
This button becomes return to start if at end, otherwise skip to end.
|
# This button becomes return to start if at end, otherwise skip to end.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
# Determine if should skip to start or end
|
# # Determine if should skip to start or end
|
||||||
if self.index <= self.maxpage // 2:
|
# if self.index <= self.maxpage // 2:
|
||||||
self.index = self.maxpage
|
# self.index = self.maxpage
|
||||||
else:
|
# else:
|
||||||
self.index = 1
|
# self.index = 1
|
||||||
|
|
||||||
await inter.response.defer()
|
# await inter.response.defer()
|
||||||
self.inter = inter
|
# self.inter = inter
|
||||||
await self.navigate()
|
# await self.navigate()
|
||||||
|
|
||||||
async def navigate(self):
|
# async def navigate(self):
|
||||||
"""
|
# """
|
||||||
Acts as an update method for the entire instance.
|
# Acts as an update method for the entire instance.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
log.debug("navigating to page: %s", self.index)
|
# log.debug("navigating to page: %s", self.index)
|
||||||
|
|
||||||
self.update_buttons()
|
# self.update_buttons()
|
||||||
paged_embed = await self.create_paged_embed()
|
# paged_embed = await self.create_paged_embed()
|
||||||
await self.inter.edit_original_response(embed=paged_embed, view=self)
|
# await self.inter.edit_original_response(embed=paged_embed, view=self)
|
||||||
|
|
||||||
async def create_paged_embed(self) -> Embed:
|
# async def create_paged_embed(self) -> Embed:
|
||||||
"""
|
# """
|
||||||
Returns a copy of the known embed, but with data from the current page.
|
# Returns a copy of the known embed, but with data from the current page.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
embed = self.embed.copy()
|
# embed = self.embed.copy()
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
data, total_results = await self.getdata(self.index, self.pagesize)
|
# data, total_results = await self.getdata(self.index, self.pagesize)
|
||||||
except aiohttp.ClientResponseError as exc:
|
# except aiohttp.ClientResponseError as exc:
|
||||||
log.error(exc)
|
# log.error(exc)
|
||||||
await (
|
# await (
|
||||||
Followup(f"Error · {exc.message}",)
|
# Followup(f"Error · {exc.message}",)
|
||||||
.footer(f"HTTP {exc.code}")
|
# .footer(f"HTTP {exc.code}")
|
||||||
.error()
|
# .error()
|
||||||
.send(self.inter)
|
# .send(self.inter)
|
||||||
)
|
# )
|
||||||
raise exc
|
# raise exc
|
||||||
|
|
||||||
self.maxpage = self.calc_total_pages(total_results, self.pagesize)
|
# self.maxpage = self.calc_total_pages(total_results, self.pagesize)
|
||||||
log.debug(f"{self.maxpage=!r}")
|
# log.debug(f"{self.maxpage=!r}")
|
||||||
|
|
||||||
for i, item in enumerate(data):
|
# for i, item in enumerate(data):
|
||||||
i = self.calc_dataitem_index(i)
|
# i = self.calc_dataitem_index(i)
|
||||||
if asyncio.iscoroutinefunction(self.formatdata):
|
# if asyncio.iscoroutinefunction(self.formatdata):
|
||||||
key, value = await self.formatdata(i, item)
|
# key, value = await self.formatdata(i, item)
|
||||||
else:
|
# else:
|
||||||
key, value = self.formatdata(i, item)
|
# key, value = self.formatdata(i, item)
|
||||||
|
|
||||||
embed.add_field(name=key, value=value, inline=False)
|
# embed.add_field(name=key, value=value, inline=False)
|
||||||
|
|
||||||
if not total_results:
|
# if not total_results:
|
||||||
embed.description = "There are no results"
|
# embed.description = "There are no results"
|
||||||
|
|
||||||
if self.maxpage > 1:
|
# if self.maxpage > 1:
|
||||||
embed.set_footer(text=f"Page {self.index}/{self.maxpage}")
|
# embed.set_footer(text=f"Page {self.index}/{self.maxpage}")
|
||||||
|
|
||||||
return embed
|
# return embed
|
||||||
|
|
||||||
def update_buttons(self):
|
# def update_buttons(self):
|
||||||
if self.index >= self.maxpage:
|
# if self.index >= self.maxpage:
|
||||||
self.children[2].emoji = self.start_emoji
|
# self.children[2].emoji = self.start_emoji
|
||||||
else:
|
# else:
|
||||||
self.children[2].emoji = self.end_emoji
|
# self.children[2].emoji = self.end_emoji
|
||||||
|
|
||||||
self.children[0].disabled = self.index == 1
|
# self.children[0].disabled = self.index == 1
|
||||||
self.children[1].disabled = self.index == self.maxpage
|
# self.children[1].disabled = self.index == self.maxpage
|
||||||
|
|
||||||
async def send(self):
|
# async def send(self):
|
||||||
"""Send the pagination view. It may be important to defer before invoking this method."""
|
# """Send the pagination view. It may be important to defer before invoking this method."""
|
||||||
|
|
||||||
log.debug("sending pagination view")
|
# log.debug("sending pagination view")
|
||||||
embed = await self.create_paged_embed()
|
# embed = await self.create_paged_embed()
|
||||||
|
|
||||||
if self.maxpage <= 1:
|
# if self.maxpage <= 1:
|
||||||
await self.inter.edit_original_response(embed=embed)
|
# await self.inter.edit_original_response(embed=embed)
|
||||||
return
|
# return
|
||||||
|
|
||||||
self.update_buttons()
|
# self.update_buttons()
|
||||||
await self.inter.edit_original_response(embed=embed, view=self)
|
# await self.inter.edit_original_response(embed=embed, view=self)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Followup:
|
# class Followup:
|
||||||
"""Wrapper for a discord embed to follow up an interaction."""
|
# """Wrapper for a discord embed to follow up an interaction."""
|
||||||
|
|
||||||
def __init__(
|
# def __init__(
|
||||||
self,
|
# self,
|
||||||
title: str = None,
|
# title: str = None,
|
||||||
description: str = None,
|
# description: str = None,
|
||||||
):
|
# ):
|
||||||
self._embed = Embed(
|
# self._embed = Embed(
|
||||||
title=title,
|
# title=title,
|
||||||
description=description
|
# description=description
|
||||||
)
|
# )
|
||||||
|
|
||||||
async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False):
|
# async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral)
|
# await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral)
|
||||||
|
|
||||||
def fields(self, inline: bool = False, **fields: dict):
|
# def fields(self, inline: bool = False, **fields: dict):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
for key, value in fields.items():
|
# for key, value in fields.items():
|
||||||
self._embed.add_field(name=key, value=value, inline=inline)
|
# self._embed.add_field(name=key, value=value, inline=inline)
|
||||||
|
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def image(self, url: str):
|
# def image(self, url: str):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.set_image(url=url)
|
# self._embed.set_image(url=url)
|
||||||
|
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def author(self, name: str, url: str=None, icon_url: str=None):
|
# def author(self, name: str, url: str=None, icon_url: str=None):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.set_author(name=name, url=url, icon_url=icon_url)
|
# self._embed.set_author(name=name, url=url, icon_url=icon_url)
|
||||||
|
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def footer(self, text: str, icon_url: str = None):
|
# def footer(self, text: str, icon_url: str = None):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.set_footer(text=text, icon_url=icon_url)
|
# self._embed.set_footer(text=text, icon_url=icon_url)
|
||||||
|
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def error(self):
|
# def error(self):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.colour = Colour.red()
|
# self._embed.colour = Colour.red()
|
||||||
self._embed.set_thumbnail(url=FollowupIcons.error)
|
# self._embed.set_thumbnail(url=FollowupIcons.error)
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def success(self):
|
# def success(self):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.colour = Colour.green()
|
# self._embed.colour = Colour.green()
|
||||||
self._embed.set_thumbnail(url=FollowupIcons.success)
|
# self._embed.set_thumbnail(url=FollowupIcons.success)
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def info(self):
|
# def info(self):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.colour = Colour.blue()
|
# self._embed.colour = Colour.blue()
|
||||||
self._embed.set_thumbnail(url=FollowupIcons.info)
|
# self._embed.set_thumbnail(url=FollowupIcons.info)
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def added(self):
|
# def added(self):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.colour = Colour.blue()
|
# self._embed.colour = Colour.blue()
|
||||||
self._embed.set_thumbnail(url=FollowupIcons.added)
|
# self._embed.set_thumbnail(url=FollowupIcons.added)
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def assign(self):
|
# def assign(self):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.colour = Colour.blue()
|
# self._embed.colour = Colour.blue()
|
||||||
self._embed.set_thumbnail(url=FollowupIcons.assigned)
|
# self._embed.set_thumbnail(url=FollowupIcons.assigned)
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def trash(self):
|
# def trash(self):
|
||||||
""""""
|
# """"""
|
||||||
|
|
||||||
self._embed.colour = Colour.red()
|
# self._embed.colour = Colour.red()
|
||||||
self._embed.set_thumbnail(url=FollowupIcons.trash)
|
# self._embed.set_thumbnail(url=FollowupIcons.trash)
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
|
|
||||||
def extract_error_info(error: Exception) -> str:
|
# def extract_error_info(error: Exception) -> str:
|
||||||
class_name = error.__class__.__name__
|
# class_name = error.__class__.__name__
|
||||||
desc = str(error)
|
# desc = str(error)
|
||||||
return class_name, desc
|
# return class_name, desc
|
||||||
|
Loading…
x
Reference in New Issue
Block a user