PYRSS-Bot/src/utils.py
corbz 8c35f42a0e Improved pagination view
Removed double API call
added pagesize to API call
added calc_dataitem_index method for properly calculating the index of each data item, given the current page.
2024-01-30 19:29:42 +00:00

273 lines
7.9 KiB
Python

"""A collection of utility functions that can be used in various places."""
import aiohttp
import logging
import async_timeout
from typing import Callable
from discord import Interaction, Embed, Colour, ButtonStyle, Button
from discord.ui import View, button
log = logging.getLogger(__name__)
async def fetch(session, url: str) -> str:
async with async_timeout.timeout(20):
async with session.get(url) as response:
return await response.text()
async def get_unparsed_feed(url: str):
async with aiohttp.ClientSession() as session:
return await fetch(session, url)
async def get_rss_data(url: str):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
items = await response.text(), response.status
return items
async def followup(inter: Interaction, *args, **kwargs):
"""Shorthand for following up on an interaction.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.followup.send(*args, **kwargs)
async def audit(cog, *args, **kwargs):
"""Shorthand for auditing an interaction."""
await cog.bot.audit(*args, **kwargs)
# https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
class FollowupIcons:
error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"
trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png"
info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png"
added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png"
assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
class PaginationView(View):
"""A Discord UI View that adds pagination to an embed."""
def __init__(
self, inter: Interaction, embed: Embed, getdata: Callable,
formatdata: Callable, pagesize: int, initpage: int=1
):
"""_summary_
Args:
inter (Interaction): Represents a discord command interaction.
embed (Embed): The base embed to paginate.
getdata (Callable): A function that provides data, must return Tuple[List[Any], int].
formatdata (Callable): A formatter function that determines how the data is displayed.
pagesize (int): The size of each page.
initpage (int, optional): The inital page. Defaults to 1.
"""
self.inter = inter
self.embed = embed
self.getdata = getdata
self.formatdata = formatdata
self.maxpage = None
self.pagesize = pagesize
self.index = initpage
super().__init__(timeout=100)
async def check_user_is_author(self, inter: Interaction) -> bool:
"""Ensure the user is the author of the original command."""
if inter.user == self.inter.user:
return True
await inter.response.defer()
await (
Followup(None, "Only the author can interact with this.")
.error()
.send(inter, ephemeral=True)
)
return False
async def on_timeout(self):
"""Erase the controls on timeout."""
message = await self.inter.original_response()
await message.edit(view=None)
@staticmethod
def calc_total_pages(results: int, max_pagesize: int) -> int:
result = ((results - 1) // max_pagesize) + 1
log.debug("total pages calculated: %s", result)
return result
def calc_dataitem_index(self, dataitem_index: int):
if self.index > 1:
dataitem_index += self.pagesize * (self.index - 1)
dataitem_index += 1
return dataitem_index
@button(emoji="◀️", style=ButtonStyle.blurple)
async def backward(self, inter: Interaction, button: Button):
self.index -= 1
await inter.response.defer()
self.inter = inter
await self.navigate()
@button(emoji="▶️", style=ButtonStyle.blurple)
async def forward(self, inter: Interaction, button: Button):
self.index += 1
await inter.response.defer()
self.inter = inter
await self.navigate()
@button(emoji="⏭️", style=ButtonStyle.blurple)
async def start_or_end(self, inter: Interaction, button: Button):
if self.index <= self.maxpage // 2:
self.index = self.maxpage
else:
self.index = 1
await inter.response.defer()
self.inter = inter
await self.navigate()
async def navigate(self):
log.debug("navigating to page: %s", self.index)
self.update_buttons()
paged_embed = await self.create_paged_embed()
await self.inter.edit_original_response(embed=paged_embed, view=self)
async def create_paged_embed(self) -> Embed:
embed = self.embed.copy()
data, total_results = await self.getdata(self.index)
self.maxpage = self.calc_total_pages(total_results, self.pagesize)
for i, item in enumerate(data):
i = self.calc_dataitem_index(i)
key, value = self.formatdata(i, item)
embed.add_field(name=key, value=value, inline=False)
if self.maxpage != 1:
embed.set_footer(text=f"Page {self.index}/{self.maxpage}")
return embed
def update_buttons(self):
if self.index >= self.maxpage:
self.children[2].emoji = "⏮️"
else:
self.children[2].emoji = "⏭️"
self.children[0].disabled = self.index == 1
self.children[1].disabled = self.index == self.maxpage
async def send(self):
embed = await self.create_paged_embed()
if self.maxpage == 1:
await self.inter.edit_original_response(embed=embed)
return
self.update_buttons()
await self.inter.edit_original_response(embed=embed, view=self)
class Followup:
"""Wrapper for a discord embed to follow up an interaction."""
def __init__(
self,
title: str = None,
description: str = None,
):
self._embed = Embed(
title=title,
description=description
)
async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False):
""""""
await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral)
def fields(self, inline: bool = False, **fields: dict):
""""""
for key, value in fields.items():
self._embed.add_field(name=key, value=value, inline=inline)
return self
def image(self, url: str):
""""""
self._embed.set_image(url=url)
return self
def footer(self, text: str, icon_url: str = None):
""""""
self._embed.set_footer(text=text, icon_url=icon_url)
return self
def error(self):
""""""
self._embed.colour = Colour.red()
self._embed.set_thumbnail(url=FollowupIcons.error)
return self
def success(self):
""""""
self._embed.colour = Colour.green()
self._embed.set_thumbnail(url=FollowupIcons.success)
return self
def info(self):
""""""
self._embed.colour = Colour.blue()
self._embed.set_thumbnail(url=FollowupIcons.info)
return self
def added(self):
""""""
self._embed.colour = Colour.blue()
self._embed.set_thumbnail(url=FollowupIcons.added)
return self
def assign(self):
""""""
self._embed.colour = Colour.blue()
self._embed.set_thumbnail(url=FollowupIcons.assigned)
return self
def trash(self):
""""""
self._embed.colour = Colour.red()
self._embed.set_thumbnail(url=FollowupIcons.trash)
return self
def extract_error_info(error: Exception) -> str:
class_name = error.__class__.__name__
desc = str(error)
return class_name, desc