This commit is contained in:
Corban-Lee Jones 2024-02-07 09:25:32 +00:00
commit 83a889cf61
8 changed files with 475 additions and 198 deletions

View File

@ -1,17 +1,7 @@
# NewsBot # PYRSS
Bot delivering news articles to discord servers. An RSS driven Discord bot written in Python.
Plans Provides user commands for storing RSS feed URLs that can be assigned to any given discord channel.
- Multiple news providers Content is shared every 10 minutes as an Embed.
- Choose how much of each provider should be delivered
- Check for duplicate articles between providers, and only deliver preferred provider article
## Dev Notes:
For the sake of development, the following defintions apply:
- Feed - An RSS feed stored within the database, submitted by a user.
- Assigned Feed - A discord channel set to receive content from a Feed.

127
src/api.py Normal file
View File

@ -0,0 +1,127 @@
import logging
import aiohttp
log = logging.getLogger(__name__)
class APIException(Exception):
pass
class NotCreatedException(APIException):
pass
class BadStatusException(APIException):
pass
class API:
"""Interactions with the API."""
API_HOST = "http://localhost:8000/"
API_ENDPOINT = API_HOST + "api/"
RSS_FEED_ENDPOINT = API_ENDPOINT + "rssfeed/"
FEED_CHANNEL_ENDPOINT = API_ENDPOINT + "feedchannel/"
def __init__(self, api_token: str, session: aiohttp.ClientSession):
log.debug("API session initialised")
self.session = session
self.token_headers = {"Authorization": f"Token {api_token}"}
async def make_request(self, method: str, url: str, **kwargs) -> dict:
"""Make a request to the given API endpoint.
Args:
method (str): The request method to use, examples: GET, POST, DELETE...
url (str): The API endpoint to request to.
**kwargs: Passed into self.session.request.
Returns:
dict: Dictionary containing status code, json or text.
"""
async with self.session.request(method, url, headers=self.token_headers, **kwargs) as response:
response.raise_for_status()
try:
json = await response.json()
text = None
except aiohttp.ContentTypeError:
json = None
text = await response.text()
status = response.status
return {"json": json, "text": text, "status": status}
async def create_new_rssfeed(self, name: str, url: str, image_url: str, discord_server_id: int) -> dict:
"""Create a new RSS Feed.
Args:
name (str): Name of the RSS Feed.
url (str): URL for the RSS Feed.
image_url (str): URL of the image representation of the RSS Feed.
discord_server_id (int): ID of the discord server behind this item.
Returns:
dict: JSON representation of the newly created RSS Feed.
"""
log.debug("creating rssfeed: %s %s %s %s", name, url, image_url, discord_server_id)
async with self.session.get(image_url) as response:
image_data = await response.read()
# Using formdata to make the image transfer easier.
form = aiohttp.FormData({
"name": name,
"url": url,
"discord_server_id": discord_server_id
})
form.add_field("image", image_data, filename="file.jpg")
data = (await self.make_response("POST", self.RSS_FEED_ENDPOINT, data=form))["json"]
return data
async def get_rssfeed(self, uuid: str) -> dict:
"""Get a particular RSS Feed given it's UUID.
Args:
uuid (str): Identifier of the desired RSS Feed.
Returns:
dict: A JSON representation of the RSS Feed.
"""
log.debug("getting rssfeed: %s", uuid)
endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/"
data = (await self.make_request("GET", endpoint))["json"]
return data
async def get_rssfeed_list(self, **filters) -> tuple[list[dict], int]:
"""Get all RSS Feeds with the associated filters.
Returns:
tuple[list[dict], int] list contains dictionaries of each item, int is total items.
"""
log.debug("getting list of rss feeds with filters: %s", filters)
data = (await self.make_request("GET", self.RSS_FEED_ENDPOINT, params=filters))["json"]
return data["results"], data["count"]
async def delete_rssfeed(self, uuid: str) -> int:
"""Delete a specified RSS Feed.
Args:
uuid (str): Identifier of the RSS Feed to delete.
Returns:
int: Status code of the response.
"""
log.debug("deleting rssfeed: %s", uuid)
endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/"
status = (await self.make_request("DELETE", endpoint))["status"]
return status

View File

@ -5,7 +5,7 @@ The discord bot for the application.
import logging import logging
from pathlib import Path from pathlib import Path
from discord import Intents from discord import Intents, Game
from discord.ext import commands from discord.ext import commands
from sqlalchemy import insert from sqlalchemy import insert
@ -17,11 +17,13 @@ log = logging.getLogger(__name__)
class DiscordBot(commands.Bot): class DiscordBot(commands.Bot):
def __init__(self, BASE_DIR: Path, developing: bool): def __init__(self, BASE_DIR: Path, developing: bool, api_token: str):
super().__init__(command_prefix="-", intents=Intents.all()) activity = Game("Indev") if developing else None
self.functions = Functions(self) super().__init__(command_prefix="-", intents=Intents.all(), activity=activity)
self.functions = Functions(self, api_token)
self.BASE_DIR = BASE_DIR self.BASE_DIR = BASE_DIR
self.developing = developing self.developing = developing
self.api_token = api_token
log.info("developing=%s", developing) log.info("developing=%s", developing)

View File

@ -6,6 +6,7 @@ Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bo
import logging import logging
from typing import Tuple from typing import Tuple
import aiohttp
import validators import validators
from feedparser import FeedParserDict, parse from feedparser import FeedParserDict, parse
from discord.ext import commands from discord.ext import commands
@ -14,7 +15,8 @@ from discord.app_commands import Choice, Group, autocomplete, choices, rename
from sqlalchemy import insert, select, and_, delete from sqlalchemy import insert, select, and_, delete
from sqlalchemy.exc import NoResultFound, IntegrityError from sqlalchemy.exc import NoResultFound, IntegrityError
from feed import Source from api import API
from feed import Source, RSSFeed
from errors import IllegalFeed from errors import IllegalFeed
from db import ( from db import (
DatabaseManager, DatabaseManager,
@ -25,6 +27,7 @@ from db import (
) )
from utils import ( from utils import (
Followup, Followup,
PaginationView,
get_rss_data, get_rss_data,
followup, followup,
audit, audit,
@ -121,6 +124,21 @@ class FeedCog(commands.Cog):
log.info("%s cog is ready", self.__class__.__name__) log.info("%s cog is ready", self.__class__.__name__)
async def autocomplete_rssfeed(self, inter: Interaction, name: str) -> list[Choice]:
async with aiohttp.ClientSession() as session:
data, _ = await API(self.bot.api_token, session).get_rssfeed_list(
discord_server_id=inter.guild_id
)
rssfeeds = RSSFeed.from_list(data)
choices = [
Choice(name=item.name, value=item.uuid)
for item in rssfeeds
]
return choices
async def source_autocomplete(self, inter: Interaction, nickname: str): async def source_autocomplete(self, inter: Interaction, nickname: str):
"""Provides RSS source autocomplete functionality for commands. """Provides RSS source autocomplete functionality for commands.
@ -154,69 +172,55 @@ class FeedCog(commands.Cog):
# All RSS commands belong to this group. # All RSS commands belong to this group.
feed_group = Group( feed_group = Group(
name="feed", name="feed",
description="Commands for rss sources.", description="Commands for RSS sources.",
default_permissions=Permissions.elevated(), default_permissions=Permissions.elevated(),
guild_only=True # We store guild IDs in the database, so guild only = True guild_only=True # We store guild IDs in the database, so guild only = True
) )
@feed_group.command(name="add") @feed_group.command(name="new")
async def add_rss_source(self, inter: Interaction, nickname: str, url: str): async def add_rssfeed(self, inter: Interaction, name: str, url: str):
"""Add a new Feed for this server. """Add a new RSS Feed for this server.
Parameters Args:
---------- inter (Interaction): Represents the discord command interaction.
inter : Interaction name (str): A nickname used to refer to this RSS Feed.
Represents an app command interaction. url (str): The URL of the RSS Feed.
nickname : str
A name used to identify the Feed.
url : str
The Feed URL.
""" """
await inter.response.defer() await inter.response.defer()
try: try:
source = await self.bot.functions.create_new_feed(nickname, url, inter.guild_id) rssfeed = await self.bot.functions.create_new_rssfeed(name, url, inter.guild_id)
except IllegalFeed as error: except Exception as exc:
title, desc = extract_error_info(error)
await Followup(title, desc).fields(**error.items).error().send(inter)
except IntegrityError as error:
await ( await (
Followup( Followup(exc.__class__.__name__, str(exc))
"Duplicate Feed Error",
"A Feed with the same nickname already exist."
)
.fields(nickname=nickname)
.error() .error()
.send(inter) .send(inter)
) )
else: else:
await ( await (
Followup("Feed Added") Followup("New RSS Feed")
.image(source.icon_url) .image(rssfeed.image)
.fields(nickname=nickname, url=url) .fields(uuid=rssfeed.uuid, name=name, url=url)
.added() .added()
.send(inter) .send(inter)
) )
@feed_group.command(name="remove") @feed_group.command(name="delete")
@rename(url="option") @autocomplete(uuid=autocomplete_rssfeed)
@autocomplete(url=source_autocomplete) @rename(uuid="rssfeed")
async def remove_rss_source(self, inter: Interaction, url: str): async def delete_rssfeed(self, inter: Interaction, uuid: str):
"""Delete an existing Feed from this server. """Delete an existing RSS Feed for this server.
Parameters Args:
---------- inter (Interaction): Represents the discord command interaction.
inter : Interaction uuid (str): The UUID of the
Represents an app command interaction.
url : str
The Feed to be removed. Autocomplete or enter the URL.
""" """
await inter.response.defer() await inter.response.defer()
try: try:
source = await self.bot.functions.delete_feed(url, inter.guild_id) rssfeed = await self.bot.functions.delete_rssfeed(uuid)
except NoResultFound: except NoResultFound:
await ( await (
Followup( Followup(
@ -229,47 +233,43 @@ class FeedCog(commands.Cog):
else: else:
await ( await (
Followup("Feed Deleted") Followup("Feed Deleted")
.image(source.icon_url) .image(rssfeed.image)
.fields(url=url) .fields(uuid=rssfeed.uuid, name=rssfeed.name, url=rssfeed.url)
.trash() .trash()
.send(inter) .send(inter)
) )
@feed_group.command(name="list") @feed_group.command(name="list")
async def list_rss_sources(self, inter: Interaction): async def list_rssfeeds(self, inter: Interaction):
"""Provides a with a list of Feeds available for this server. """Provides a list of all RSS Feeds
Parameters Args:
---------- inter (Interaction): Represents the discord command interaction.
inter : Interaction
Represents an app command interaction.
""" """
await inter.response.defer() await inter.response.defer()
page = 1
pagesize = 10
try: try:
feeds = await self.bot.functions.get_feeds(inter.guild_id) def formatdata(index, item):
except NoResultFound: key = f"{index}. {item.name}"
value = f"[RSS]({item.url}) · [API](http://localhost:8000/api/rssfeed/{item.uuid}/)"
return key, value
async def getdata(page):
data, count = await self.bot.functions.get_rssfeeds(inter.guild_id, page, pagesize)
return data, count
embed = Followup(f"Available RSS Feeds in {inter.guild.name}").info()._embed
pagination = PaginationView(self.bot, inter, embed, getdata, formatdata, pagesize, 1)
await pagination.send()
except Exception as exc:
await ( await (
Followup( Followup(exc.__class__.__name__, str(exc))
"Feeds Not Found Error",
"There are no available Feeds for this server.\n"
"Add a new feed with `/feed add`."
)
.error() .error()
.send()
)
else:
description = "\n".join([
f"{i}. **[{info[0]}]({info[1]})**" # info = (nick, url)
for i, info in enumerate(feeds)
])
await (
Followup(
f"Available Feeds in {inter.guild.name}",
description
)
.info()
.send(inter) .send(inter)
) )

View File

@ -10,12 +10,13 @@ from time import process_time
import aiohttp import aiohttp
from discord import TextChannel from discord import TextChannel
from discord import app_commands
from discord.ext import commands, tasks from discord.ext import commands, tasks
from discord.errors import Forbidden from discord.errors import Forbidden
from sqlalchemy import insert, select, and_ from sqlalchemy import insert, select, and_
from feedparser import parse from feedparser import parse
from feed import Source, Article from feed import Source, Article, RSSFeed
from db import ( from db import (
DatabaseManager, DatabaseManager,
FeedChannelModel, FeedChannelModel,
@ -23,11 +24,13 @@ from db import (
SentArticleModel SentArticleModel
) )
from utils import get_unparsed_feed from utils import get_unparsed_feed
from api import API
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES") TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES")
# task trigger times : must be of type list
times = [ times = [
datetime.time(hour, minute, tzinfo=datetime.timezone.utc) datetime.time(hour, minute, tzinfo=datetime.timezone.utc)
for hour in range(24) for hour in range(24)
@ -62,6 +65,12 @@ class TaskCog(commands.Cog):
self.rss_task.cancel() self.rss_task.cancel()
@app_commands.command(name="debug-trigger-task")
async def debug_trigger_task(self, inter):
await inter.response.defer()
await self.rss_task()
await inter.followup.send("done")
@tasks.loop(time=times) @tasks.loop(time=times)
async def rss_task(self): async def rss_task(self):
"""Automated task responsible for processing rss feeds.""" """Automated task responsible for processing rss feeds."""
@ -69,13 +78,24 @@ class TaskCog(commands.Cog):
log.info("Running rss task") log.info("Running rss task")
time = process_time() time = process_time()
async with DatabaseManager() as database: # async with DatabaseManager() as database:
query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel) # query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel)
result = await database.session.execute(query) # result = await database.session.execute(query)
feeds = result.scalars().all() # feeds = result.scalars().all()
# for feed in feeds:
# await self.process_feed(feed, database)
guild_ids = [guild.id for guild in self.bot.guilds]
async with aiohttp.ClientSession() as session:
api = API(self.bot.api_token, session)
data, count = await api.get_rssfeed_list(discord_server_id__in=guild_ids)
rssfeeds = RSSFeed.from_list(data)
for item in rssfeeds:
log.info(item.name)
for feed in feeds:
await self.process_feed(feed, database)
log.info("Finished rss task, time elapsed: %s", process_time() - time) log.info("Finished rss task, time elapsed: %s", process_time() - time)

View File

@ -1,4 +1,5 @@
import ssl
import json import json
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
@ -18,6 +19,7 @@ from textwrap import shorten
from errors import IllegalFeed from errors import IllegalFeed
from db import DatabaseManager, RssSourceModel, FeedChannelModel from db import DatabaseManager, RssSourceModel, FeedChannelModel
from utils import get_rss_data, get_unparsed_feed from utils import get_rss_data, get_unparsed_feed
from api import API
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
dumps = lambda _dict: json.dumps(_dict, indent=8) dumps = lambda _dict: json.dumps(_dict, indent=8)
@ -140,7 +142,7 @@ class Article:
@dataclass @dataclass
class Source: class Source:
"""Represents an RSS source.""" """Represents an RSS source."""
name: str | None name: str | None
url: str | None url: str | None
icon_url: str | None icon_url: str | None
@ -173,7 +175,9 @@ class Source:
@classmethod @classmethod
async def from_url(cls, url: str): async def from_url(cls, url: str):
unparsed_content = await get_unparsed_feed(url) unparsed_content = await get_unparsed_feed(url)
return cls.from_parsed(parse(unparsed_content)) source = cls.from_parsed(parse(unparsed_content))
source.url = url
return source
def get_latest_articles(self, max: int = 999) -> list[Article]: def get_latest_articles(self, max: int = 999) -> list[Article]:
"""Returns a list of Article objects. """Returns a list of Article objects.
@ -198,10 +202,37 @@ class Source:
] ]
@dataclass
class RSSFeed:
uuid: str
name: str
url: str
image: str
discord_server_id: id
created_at: str
@classmethod
def from_list(cls, data: list) -> list:
result = []
for item in data:
key = "discord_server_id"
item[key] = int(item.get(key))
result.append(cls(**item))
return result
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
class Functions: class Functions:
def __init__(self, bot): def __init__(self, bot, api_token: str):
self.bot = bot self.bot = bot
self.api_token = api_token
async def validate_feed(self, nickname: str, url: str) -> FeedParserDict: async def validate_feed(self, nickname: str, url: str) -> FeedParserDict:
"""Validates a feed based on the given nickname and url. """Validates a feed based on the given nickname and url.
@ -254,106 +285,50 @@ class Functions:
return feed return feed
async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source: async def create_new_rssfeed(self, name: str, url: str, guild_id: int) -> RSSFeed:
"""Create a new Feed, and return it as a Source object.
Parameters
----------
nickname : str
Human readable nickname used to refer to the feed.
url : str
URL to fetch content from the feed.
guild_id : int
Discord Server ID associated with the feed.
Returns log.info("Creating new Feed: %s", name)
-------
Source parsed_feed = await self.validate_feed(name, url)
Dataclass containing attributes of the feed. source = Source.from_parsed(parsed_feed)
async with aiohttp.ClientSession() as session:
data = await API(self.api_token, session).create_new_rssfeed(
name, url, source.icon_url, guild_id
)
return RSSFeed.from_dict(data)
async def delete_rssfeed(self, uuid: str) -> RSSFeed:
log.info("Deleting Feed '%s'", uuid)
async with aiohttp.ClientSession() as session:
api = API(self.api_token, session)
data = await api.get_rssfeed(uuid)
await api.delete_rssfeed(uuid)
return RSSFeed.from_dict(data)
async def get_rssfeeds(self, guild_id: int, page: int, pagesize: int) -> list[RSSFeed]:
"""Get a list of RSS Feeds.
Args:
guild_id (int): The guild_id to filter by.
Returns:
list[RSSFeed]: Resulting list of RSS Feeds
""" """
log.info("Creating new Feed: %s - %s", nickname, guild_id) async with aiohttp.ClientSession() as session:
data, count = await API(self.api_token, session).get_rssfeed_list(
parsed_feed = await self.validate_feed(nickname, url)
async with DatabaseManager() as database:
query = insert(RssSourceModel).values(
discord_server_id=guild_id, discord_server_id=guild_id,
rss_url=url, page=page,
nick=nickname page_size=pagesize
)
await database.session.execute(query)
log.info("Created Feed: %s - %s", nickname, guild_id)
return Source.from_parsed(parsed_feed)
async def delete_feed(self, url: str, guild_id: int) -> Source:
"""Delete an existing Feed, then return it as a Source object.
Parameters
----------
url : str
URL of the feed, used in the whereclause.
guild_id : int
Discord Server ID of the feed, used in the whereclause.
Returns
-------
Source
Dataclass containing attributes of the feed.
"""
log.info("Deleting Feed: %s - %s", url, guild_id)
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == guild_id,
RssSourceModel.rss_url == url
) )
# Select the Feed entry, because an exception is raised if not found. return RSSFeed.from_list(data), count
select_query = select(RssSourceModel).filter(whereclause)
select_result = await database.session.execute(select_query)
select_result.scalars().one()
delete_query = delete(RssSourceModel).filter(whereclause)
await database.session.execute(delete_query)
log.info("Deleted Feed: %s - %s", url, guild_id)
return await Source.from_url(url)
async def get_feeds(self, guild_id: int) -> list[tuple[str, str]]:
"""Returns a list of fetched Feed objects from the database.
Note: a request will be made too all found Feed UR Ls.
Parameters
----------
guild_id : int
The Discord Server ID, used to filter down the Feed query.
Returns
-------
list[tuple[str, str]]
List of Source objects, resulting from the query.
Raises
------
NoResultFound
Raised if no results are found.
"""
async with DatabaseManager() as database:
whereclause = and_(RssSourceModel.discord_server_id == guild_id)
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
rss_sources = result.scalars().all()
if not rss_sources:
raise NoResultFound
return [(feed.nick, feed.rss_url) for feed in rss_sources]
async def assign_feed( async def assign_feed(
self, url: str, channel_name: str, channel_id: int, guild_id: int self, url: str, channel_name: str, channel_id: int, guild_id: int

View File

@ -9,9 +9,9 @@ from os import getenv
from pathlib import Path from pathlib import Path
# it's important to load environment variables before # it's important to load environment variables before
# importing the packages that depend on them. # importing the modules that depend on them.
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv(override=True)
from bot import DiscordBot from bot import DiscordBot
from logs import LogSetup from logs import LogSetup
@ -26,12 +26,17 @@ async def main():
# Grab the token before anything else, because if there is no token # Grab the token before anything else, because if there is no token
# available then the bot cannot be started anyways. # available then the bot cannot be started anyways.
token = getenv("BOT_TOKEN") bot_token = getenv("BOT_TOKEN")
if not bot_token:
raise ValueError("Bot Token is empty")
if not token: # ^ same story for the API token. Without it the API cannot be
raise ValueError("Token is empty") # interacted with, so grab it first.
api_token = getenv("API_TOKEN")
if not api_token:
raise ValueError("API Token is empty")
developing = getenv("DEVELOPING") == "True" developing = getenv("DEVELOPING", "False") == "True"
# Setup logging settings and mute spammy loggers # Setup logging settings and mute spammy loggers
logsetup = LogSetup(BASE_DIR / "logs/") logsetup = LogSetup(BASE_DIR / "logs/")
@ -41,9 +46,10 @@ async def main():
level=logging.WARNING level=logging.WARNING
) )
async with DiscordBot(BASE_DIR, developing=developing) as bot:
async with DiscordBot(BASE_DIR, developing=developing, api_token=api_token) as bot:
await bot.load_extensions() await bot.load_extensions()
await bot.start(token, reconnect=True) await bot.start(bot_token, reconnect=True)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@ -3,8 +3,11 @@
import aiohttp import aiohttp
import logging import logging
import async_timeout import async_timeout
from typing import Callable
from discord import Interaction, Embed, Colour from discord import Interaction, Embed, Colour, ButtonStyle, Button
from discord.ui import View, button
from discord.ext.commands import Bot
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -56,6 +59,153 @@ class FollowupIcons:
assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png" assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
class PaginationView(View):
"""A Discord UI View that adds pagination to an embed."""
def __init__(
self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable,
formatdata: Callable, pagesize: int, initpage: int=1
):
"""_summary_
Args:
bot (commands.Bot) The discord bot
inter (Interaction): Represents a discord command interaction.
embed (Embed): The base embed to paginate.
getdata (Callable): A function that provides data, must return Tuple[List[Any], int].
formatdata (Callable): A formatter function that determines how the data is displayed.
pagesize (int): The size of each page.
initpage (int, optional): The inital page. Defaults to 1.
"""
self.bot = bot
self.inter = inter
self.embed = embed
self.getdata = getdata
self.formatdata = formatdata
self.maxpage = None
self.pagesize = pagesize
self.index = initpage
# emoji reference
next_emoji = bot.get_emoji(1204542366602502265)
prev_emoji = bot.get_emoji(1204542365432422470)
self.start_emoji = bot.get_emoji(1204542364073463818)
self.end_emoji = bot.get_emoji(1204542367752003624)
super().__init__(timeout=100)
async def check_user_is_author(self, inter: Interaction) -> bool:
"""Ensure the user is the author of the original command."""
if inter.user == self.inter.user:
return True
await inter.response.defer()
await (
Followup(None, "Only the author can interact with this.")
.error()
.send(inter, ephemeral=True)
)
return False
async def on_timeout(self):
"""Erase the controls on timeout."""
message = await self.inter.original_response()
await message.edit(view=None)
@staticmethod
def calc_total_pages(results: int, max_pagesize: int) -> int:
result = ((results - 1) // max_pagesize) + 1
log.debug("total pages calculated: %s", result)
return result
def calc_dataitem_index(self, dataitem_index: int):
"""Calculates a given index to be relative to the sum of all pages items
Example: dataitem_index = 6
pagesize = 10
if page == 1 then return 6
else return 6 + 10 * (page - 1)"""
if self.index > 1:
dataitem_index += self.pagesize * (self.index - 1)
dataitem_index += 1
return dataitem_index
@button(emoji="◀️", style=ButtonStyle.blurple)
async def backward(self, inter: Interaction, button: Button):
self.index -= 1
await inter.response.defer()
self.inter = inter
await self.navigate()
@button(emoji="▶️", style=ButtonStyle.blurple)
async def forward(self, inter: Interaction, button: Button):
self.index += 1
await inter.response.defer()
self.inter = inter
await self.navigate()
@button(emoji="⏭️", style=ButtonStyle.blurple)
async def start_or_end(self, inter: Interaction, button: Button):
if self.index <= self.maxpage // 2:
self.index = self.maxpage
else:
self.index = 1
await inter.response.defer()
self.inter = inter
await self.navigate()
async def navigate(self):
log.debug("navigating to page: %s", self.index)
self.update_buttons()
paged_embed = await self.create_paged_embed()
await self.inter.edit_original_response(embed=paged_embed, view=self)
async def create_paged_embed(self) -> Embed:
embed = self.embed.copy()
data, total_results = await self.getdata(self.index)
self.maxpage = self.calc_total_pages(total_results, self.pagesize)
for i, item in enumerate(data):
i = self.calc_dataitem_index(i)
key, value = self.formatdata(i, item)
embed.add_field(name=key, value=value, inline=False)
if not total_results:
embed.description = "There are no results"
if self.maxpage > 1:
embed.set_footer(text=f"Page {self.index}/{self.maxpage}")
return embed
def update_buttons(self):
if self.index >= self.maxpage:
self.children[2].emoji = self.start_emoji
else:
self.children[2].emoji = self.end_emoji
self.children[0].disabled = self.index == 1
self.children[1].disabled = self.index == self.maxpage
async def send(self):
embed = await self.create_paged_embed()
if self.maxpage <= 1:
await self.inter.edit_original_response(embed=embed)
return
self.update_buttons()
await self.inter.edit_original_response(embed=embed, view=self)
class Followup: class Followup:
"""Wrapper for a discord embed to follow up an interaction.""" """Wrapper for a discord embed to follow up an interaction."""
@ -69,10 +219,10 @@ class Followup:
description=description description=description
) )
async def send(self, inter: Interaction, message: str = None): async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False):
"""""" """"""
await inter.followup.send(content=message, embed=self._embed) await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral)
def fields(self, inline: bool = False, **fields: dict): def fields(self, inline: bool = False, **fields: dict):
"""""" """"""
@ -89,6 +239,13 @@ class Followup:
return self return self
def footer(self, text: str, icon_url: str = None):
""""""
self._embed.set_footer(text=text, icon_url=icon_url)
return self
def error(self): def error(self):
"""""" """"""