diff --git a/README.md b/README.md index bdd0618..07e5819 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,7 @@ -# NewsBot +# PYRSS -Bot delivering news articles to discord servers. +An RSS driven Discord bot written in Python. -Plans +Provides user commands for storing RSS feed URLs that can be assigned to any given discord channel. -- Multiple news providers -- Choose how much of each provider should be delivered -- Check for duplicate articles between providers, and only deliver preferred provider article - - -## Dev Notes: - -For the sake of development, the following defintions apply: - -- Feed - An RSS feed stored within the database, submitted by a user. -- Assigned Feed - A discord channel set to receive content from a Feed. \ No newline at end of file +Content is shared every 10 minutes as an Embed. \ No newline at end of file diff --git a/src/api.py b/src/api.py new file mode 100644 index 0000000..3b22e43 --- /dev/null +++ b/src/api.py @@ -0,0 +1,127 @@ + +import logging + +import aiohttp + +log = logging.getLogger(__name__) + + +class APIException(Exception): + pass + + +class NotCreatedException(APIException): + pass + +class BadStatusException(APIException): + pass + + +class API: + """Interactions with the API.""" + + API_HOST = "http://localhost:8000/" + API_ENDPOINT = API_HOST + "api/" + + RSS_FEED_ENDPOINT = API_ENDPOINT + "rssfeed/" + FEED_CHANNEL_ENDPOINT = API_ENDPOINT + "feedchannel/" + + def __init__(self, api_token: str, session: aiohttp.ClientSession): + log.debug("API session initialised") + self.session = session + self.token_headers = {"Authorization": f"Token {api_token}"} + + async def make_request(self, method: str, url: str, **kwargs) -> dict: + """Make a request to the given API endpoint. + + Args: + method (str): The request method to use, examples: GET, POST, DELETE... + url (str): The API endpoint to request to. + **kwargs: Passed into self.session.request. + + Returns: + dict: Dictionary containing status code, json or text. + """ + + async with self.session.request(method, url, headers=self.token_headers, **kwargs) as response: + response.raise_for_status() + try: + json = await response.json() + text = None + except aiohttp.ContentTypeError: + json = None + text = await response.text() + + status = response.status + + return {"json": json, "text": text, "status": status} + + async def create_new_rssfeed(self, name: str, url: str, image_url: str, discord_server_id: int) -> dict: + """Create a new RSS Feed. + + Args: + name (str): Name of the RSS Feed. + url (str): URL for the RSS Feed. + image_url (str): URL of the image representation of the RSS Feed. + discord_server_id (int): ID of the discord server behind this item. + + Returns: + dict: JSON representation of the newly created RSS Feed. + """ + + log.debug("creating rssfeed: %s %s %s %s", name, url, image_url, discord_server_id) + + async with self.session.get(image_url) as response: + image_data = await response.read() + + # Using formdata to make the image transfer easier. + form = aiohttp.FormData({ + "name": name, + "url": url, + "discord_server_id": discord_server_id + }) + form.add_field("image", image_data, filename="file.jpg") + + data = (await self.make_response("POST", self.RSS_FEED_ENDPOINT, data=form))["json"] + return data + + async def get_rssfeed(self, uuid: str) -> dict: + """Get a particular RSS Feed given it's UUID. + + Args: + uuid (str): Identifier of the desired RSS Feed. + + Returns: + dict: A JSON representation of the RSS Feed. + """ + + log.debug("getting rssfeed: %s", uuid) + endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/" + data = (await self.make_request("GET", endpoint))["json"] + return data + + async def get_rssfeed_list(self, **filters) -> tuple[list[dict], int]: + """Get all RSS Feeds with the associated filters. + + Returns: + tuple[list[dict], int] list contains dictionaries of each item, int is total items. + """ + + log.debug("getting list of rss feeds with filters: %s", filters) + data = (await self.make_request("GET", self.RSS_FEED_ENDPOINT, params=filters))["json"] + return data["results"], data["count"] + + async def delete_rssfeed(self, uuid: str) -> int: + """Delete a specified RSS Feed. + + Args: + uuid (str): Identifier of the RSS Feed to delete. + + Returns: + int: Status code of the response. + """ + + log.debug("deleting rssfeed: %s", uuid) + endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/" + status = (await self.make_request("DELETE", endpoint))["status"] + return status diff --git a/src/bot.py b/src/bot.py index e7ff73e..d3697d3 100644 --- a/src/bot.py +++ b/src/bot.py @@ -5,7 +5,7 @@ The discord bot for the application. import logging from pathlib import Path -from discord import Intents +from discord import Intents, Game from discord.ext import commands from sqlalchemy import insert @@ -17,11 +17,13 @@ log = logging.getLogger(__name__) class DiscordBot(commands.Bot): - def __init__(self, BASE_DIR: Path, developing: bool): - super().__init__(command_prefix="-", intents=Intents.all()) - self.functions = Functions(self) + def __init__(self, BASE_DIR: Path, developing: bool, api_token: str): + activity = Game("Indev") if developing else None + super().__init__(command_prefix="-", intents=Intents.all(), activity=activity) + self.functions = Functions(self, api_token) self.BASE_DIR = BASE_DIR self.developing = developing + self.api_token = api_token log.info("developing=%s", developing) diff --git a/src/extensions/rss.py b/src/extensions/rss.py index 2b45e85..ffe6b1a 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -6,6 +6,7 @@ Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bo import logging from typing import Tuple +import aiohttp import validators from feedparser import FeedParserDict, parse from discord.ext import commands @@ -14,7 +15,8 @@ from discord.app_commands import Choice, Group, autocomplete, choices, rename from sqlalchemy import insert, select, and_, delete from sqlalchemy.exc import NoResultFound, IntegrityError -from feed import Source +from api import API +from feed import Source, RSSFeed from errors import IllegalFeed from db import ( DatabaseManager, @@ -25,6 +27,7 @@ from db import ( ) from utils import ( Followup, + PaginationView, get_rss_data, followup, audit, @@ -121,6 +124,21 @@ class FeedCog(commands.Cog): log.info("%s cog is ready", self.__class__.__name__) + async def autocomplete_rssfeed(self, inter: Interaction, name: str) -> list[Choice]: + + async with aiohttp.ClientSession() as session: + data, _ = await API(self.bot.api_token, session).get_rssfeed_list( + discord_server_id=inter.guild_id + ) + rssfeeds = RSSFeed.from_list(data) + + choices = [ + Choice(name=item.name, value=item.uuid) + for item in rssfeeds + ] + + return choices + async def source_autocomplete(self, inter: Interaction, nickname: str): """Provides RSS source autocomplete functionality for commands. @@ -154,69 +172,55 @@ class FeedCog(commands.Cog): # All RSS commands belong to this group. feed_group = Group( name="feed", - description="Commands for rss sources.", + description="Commands for RSS sources.", default_permissions=Permissions.elevated(), guild_only=True # We store guild IDs in the database, so guild only = True ) - @feed_group.command(name="add") - async def add_rss_source(self, inter: Interaction, nickname: str, url: str): - """Add a new Feed for this server. + @feed_group.command(name="new") + async def add_rssfeed(self, inter: Interaction, name: str, url: str): + """Add a new RSS Feed for this server. - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - nickname : str - A name used to identify the Feed. - url : str - The Feed URL. + Args: + inter (Interaction): Represents the discord command interaction. + name (str): A nickname used to refer to this RSS Feed. + url (str): The URL of the RSS Feed. """ await inter.response.defer() try: - source = await self.bot.functions.create_new_feed(nickname, url, inter.guild_id) - except IllegalFeed as error: - title, desc = extract_error_info(error) - await Followup(title, desc).fields(**error.items).error().send(inter) - except IntegrityError as error: + rssfeed = await self.bot.functions.create_new_rssfeed(name, url, inter.guild_id) + except Exception as exc: await ( - Followup( - "Duplicate Feed Error", - "A Feed with the same nickname already exist." - ) - .fields(nickname=nickname) + Followup(exc.__class__.__name__, str(exc)) .error() .send(inter) ) else: await ( - Followup("Feed Added") - .image(source.icon_url) - .fields(nickname=nickname, url=url) + Followup("New RSS Feed") + .image(rssfeed.image) + .fields(uuid=rssfeed.uuid, name=name, url=url) .added() .send(inter) ) - @feed_group.command(name="remove") - @rename(url="option") - @autocomplete(url=source_autocomplete) - async def remove_rss_source(self, inter: Interaction, url: str): - """Delete an existing Feed from this server. + @feed_group.command(name="delete") + @autocomplete(uuid=autocomplete_rssfeed) + @rename(uuid="rssfeed") + async def delete_rssfeed(self, inter: Interaction, uuid: str): + """Delete an existing RSS Feed for this server. - Parameters - ---------- - inter : Interaction - Represents an app command interaction. - url : str - The Feed to be removed. Autocomplete or enter the URL. + Args: + inter (Interaction): Represents the discord command interaction. + uuid (str): The UUID of the """ await inter.response.defer() try: - source = await self.bot.functions.delete_feed(url, inter.guild_id) + rssfeed = await self.bot.functions.delete_rssfeed(uuid) except NoResultFound: await ( Followup( @@ -229,47 +233,43 @@ class FeedCog(commands.Cog): else: await ( Followup("Feed Deleted") - .image(source.icon_url) - .fields(url=url) + .image(rssfeed.image) + .fields(uuid=rssfeed.uuid, name=rssfeed.name, url=rssfeed.url) .trash() .send(inter) ) @feed_group.command(name="list") - async def list_rss_sources(self, inter: Interaction): - """Provides a with a list of Feeds available for this server. + async def list_rssfeeds(self, inter: Interaction): + """Provides a list of all RSS Feeds - Parameters - ---------- - inter : Interaction - Represents an app command interaction. + Args: + inter (Interaction): Represents the discord command interaction. """ await inter.response.defer() + page = 1 + pagesize = 10 + try: - feeds = await self.bot.functions.get_feeds(inter.guild_id) - except NoResultFound: + def formatdata(index, item): + key = f"{index}. {item.name}" + value = f"[RSS]({item.url}) · [API](http://localhost:8000/api/rssfeed/{item.uuid}/)" + return key, value + + async def getdata(page): + data, count = await self.bot.functions.get_rssfeeds(inter.guild_id, page, pagesize) + return data, count + + embed = Followup(f"Available RSS Feeds in {inter.guild.name}").info()._embed + pagination = PaginationView(self.bot, inter, embed, getdata, formatdata, pagesize, 1) + await pagination.send() + + except Exception as exc: await ( - Followup( - "Feeds Not Found Error", - "There are no available Feeds for this server.\n" - "Add a new feed with `/feed add`." - ) + Followup(exc.__class__.__name__, str(exc)) .error() - .send() - ) - else: - description = "\n".join([ - f"{i}. **[{info[0]}]({info[1]})**" # info = (nick, url) - for i, info in enumerate(feeds) - ]) - await ( - Followup( - f"Available Feeds in {inter.guild.name}", - description - ) - .info() .send(inter) ) diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py index c7fcfa0..b0e06eb 100644 --- a/src/extensions/tasks.py +++ b/src/extensions/tasks.py @@ -10,12 +10,13 @@ from time import process_time import aiohttp from discord import TextChannel +from discord import app_commands from discord.ext import commands, tasks from discord.errors import Forbidden from sqlalchemy import insert, select, and_ from feedparser import parse -from feed import Source, Article +from feed import Source, Article, RSSFeed from db import ( DatabaseManager, FeedChannelModel, @@ -23,11 +24,13 @@ from db import ( SentArticleModel ) from utils import get_unparsed_feed +from api import API log = logging.getLogger(__name__) TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES") +# task trigger times : must be of type list times = [ datetime.time(hour, minute, tzinfo=datetime.timezone.utc) for hour in range(24) @@ -62,6 +65,12 @@ class TaskCog(commands.Cog): self.rss_task.cancel() + @app_commands.command(name="debug-trigger-task") + async def debug_trigger_task(self, inter): + await inter.response.defer() + await self.rss_task() + await inter.followup.send("done") + @tasks.loop(time=times) async def rss_task(self): """Automated task responsible for processing rss feeds.""" @@ -69,13 +78,24 @@ class TaskCog(commands.Cog): log.info("Running rss task") time = process_time() - async with DatabaseManager() as database: - query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel) - result = await database.session.execute(query) - feeds = result.scalars().all() + # async with DatabaseManager() as database: + # query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel) + # result = await database.session.execute(query) + # feeds = result.scalars().all() + + # for feed in feeds: + # await self.process_feed(feed, database) + + guild_ids = [guild.id for guild in self.bot.guilds] + + async with aiohttp.ClientSession() as session: + api = API(self.bot.api_token, session) + data, count = await api.get_rssfeed_list(discord_server_id__in=guild_ids) + rssfeeds = RSSFeed.from_list(data) + for item in rssfeeds: + log.info(item.name) + - for feed in feeds: - await self.process_feed(feed, database) log.info("Finished rss task, time elapsed: %s", process_time() - time) diff --git a/src/feed.py b/src/feed.py index 7b45bac..afc1818 100644 --- a/src/feed.py +++ b/src/feed.py @@ -1,4 +1,5 @@ +import ssl import json import logging from dataclasses import dataclass @@ -18,6 +19,7 @@ from textwrap import shorten from errors import IllegalFeed from db import DatabaseManager, RssSourceModel, FeedChannelModel from utils import get_rss_data, get_unparsed_feed +from api import API log = logging.getLogger(__name__) dumps = lambda _dict: json.dumps(_dict, indent=8) @@ -140,7 +142,7 @@ class Article: @dataclass class Source: """Represents an RSS source.""" - + name: str | None url: str | None icon_url: str | None @@ -173,7 +175,9 @@ class Source: @classmethod async def from_url(cls, url: str): unparsed_content = await get_unparsed_feed(url) - return cls.from_parsed(parse(unparsed_content)) + source = cls.from_parsed(parse(unparsed_content)) + source.url = url + return source def get_latest_articles(self, max: int = 999) -> list[Article]: """Returns a list of Article objects. @@ -198,10 +202,37 @@ class Source: ] +@dataclass +class RSSFeed: + + uuid: str + name: str + url: str + image: str + discord_server_id: id + created_at: str + + @classmethod + def from_list(cls, data: list) -> list: + result = [] + + for item in data: + key = "discord_server_id" + item[key] = int(item.get(key)) + result.append(cls(**item)) + + return result + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + class Functions: - def __init__(self, bot): + def __init__(self, bot, api_token: str): self.bot = bot + self.api_token = api_token async def validate_feed(self, nickname: str, url: str) -> FeedParserDict: """Validates a feed based on the given nickname and url. @@ -254,106 +285,50 @@ class Functions: return feed - async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source: - """Create a new Feed, and return it as a Source object. + async def create_new_rssfeed(self, name: str, url: str, guild_id: int) -> RSSFeed: - Parameters - ---------- - nickname : str - Human readable nickname used to refer to the feed. - url : str - URL to fetch content from the feed. - guild_id : int - Discord Server ID associated with the feed. - Returns - ------- - Source - Dataclass containing attributes of the feed. + log.info("Creating new Feed: %s", name) + + parsed_feed = await self.validate_feed(name, url) + source = Source.from_parsed(parsed_feed) + + async with aiohttp.ClientSession() as session: + data = await API(self.api_token, session).create_new_rssfeed( + name, url, source.icon_url, guild_id + ) + + return RSSFeed.from_dict(data) + + async def delete_rssfeed(self, uuid: str) -> RSSFeed: + + log.info("Deleting Feed '%s'", uuid) + + async with aiohttp.ClientSession() as session: + api = API(self.api_token, session) + data = await api.get_rssfeed(uuid) + await api.delete_rssfeed(uuid) + + return RSSFeed.from_dict(data) + + async def get_rssfeeds(self, guild_id: int, page: int, pagesize: int) -> list[RSSFeed]: + """Get a list of RSS Feeds. + + Args: + guild_id (int): The guild_id to filter by. + + Returns: + list[RSSFeed]: Resulting list of RSS Feeds """ - log.info("Creating new Feed: %s - %s", nickname, guild_id) - - parsed_feed = await self.validate_feed(nickname, url) - - async with DatabaseManager() as database: - query = insert(RssSourceModel).values( + async with aiohttp.ClientSession() as session: + data, count = await API(self.api_token, session).get_rssfeed_list( discord_server_id=guild_id, - rss_url=url, - nick=nickname - ) - await database.session.execute(query) - - log.info("Created Feed: %s - %s", nickname, guild_id) - - return Source.from_parsed(parsed_feed) - - async def delete_feed(self, url: str, guild_id: int) -> Source: - """Delete an existing Feed, then return it as a Source object. - - Parameters - ---------- - url : str - URL of the feed, used in the whereclause. - guild_id : int - Discord Server ID of the feed, used in the whereclause. - - Returns - ------- - Source - Dataclass containing attributes of the feed. - """ - - log.info("Deleting Feed: %s - %s", url, guild_id) - - async with DatabaseManager() as database: - whereclause = and_( - RssSourceModel.discord_server_id == guild_id, - RssSourceModel.rss_url == url + page=page, + page_size=pagesize ) - # Select the Feed entry, because an exception is raised if not found. - select_query = select(RssSourceModel).filter(whereclause) - select_result = await database.session.execute(select_query) - select_result.scalars().one() - - delete_query = delete(RssSourceModel).filter(whereclause) - await database.session.execute(delete_query) - - log.info("Deleted Feed: %s - %s", url, guild_id) - - return await Source.from_url(url) - - async def get_feeds(self, guild_id: int) -> list[tuple[str, str]]: - """Returns a list of fetched Feed objects from the database. - Note: a request will be made too all found Feed UR Ls. - - Parameters - ---------- - guild_id : int - The Discord Server ID, used to filter down the Feed query. - - Returns - ------- - list[tuple[str, str]] - List of Source objects, resulting from the query. - - Raises - ------ - NoResultFound - Raised if no results are found. - """ - - async with DatabaseManager() as database: - whereclause = and_(RssSourceModel.discord_server_id == guild_id) - query = select(RssSourceModel).where(whereclause) - result = await database.session.execute(query) - rss_sources = result.scalars().all() - - if not rss_sources: - raise NoResultFound - - return [(feed.nick, feed.rss_url) for feed in rss_sources] + return RSSFeed.from_list(data), count async def assign_feed( self, url: str, channel_name: str, channel_id: int, guild_id: int diff --git a/src/main.py b/src/main.py index 0398eee..5881e14 100644 --- a/src/main.py +++ b/src/main.py @@ -9,9 +9,9 @@ from os import getenv from pathlib import Path # it's important to load environment variables before -# importing the packages that depend on them. +# importing the modules that depend on them. from dotenv import load_dotenv -load_dotenv() +load_dotenv(override=True) from bot import DiscordBot from logs import LogSetup @@ -26,12 +26,17 @@ async def main(): # Grab the token before anything else, because if there is no token # available then the bot cannot be started anyways. - token = getenv("BOT_TOKEN") + bot_token = getenv("BOT_TOKEN") + if not bot_token: + raise ValueError("Bot Token is empty") - if not token: - raise ValueError("Token is empty") + # ^ same story for the API token. Without it the API cannot be + # interacted with, so grab it first. + api_token = getenv("API_TOKEN") + if not api_token: + raise ValueError("API Token is empty") - developing = getenv("DEVELOPING") == "True" + developing = getenv("DEVELOPING", "False") == "True" # Setup logging settings and mute spammy loggers logsetup = LogSetup(BASE_DIR / "logs/") @@ -41,9 +46,10 @@ async def main(): level=logging.WARNING ) - async with DiscordBot(BASE_DIR, developing=developing) as bot: + + async with DiscordBot(BASE_DIR, developing=developing, api_token=api_token) as bot: await bot.load_extensions() - await bot.start(token, reconnect=True) + await bot.start(bot_token, reconnect=True) if __name__ == "__main__": asyncio.run(main()) diff --git a/src/utils.py b/src/utils.py index 98cb37d..4f68477 100644 --- a/src/utils.py +++ b/src/utils.py @@ -3,8 +3,11 @@ import aiohttp import logging import async_timeout +from typing import Callable -from discord import Interaction, Embed, Colour +from discord import Interaction, Embed, Colour, ButtonStyle, Button +from discord.ui import View, button +from discord.ext.commands import Bot log = logging.getLogger(__name__) @@ -56,6 +59,153 @@ class FollowupIcons: assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png" +class PaginationView(View): + """A Discord UI View that adds pagination to an embed.""" + + def __init__( + self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable, + formatdata: Callable, pagesize: int, initpage: int=1 + ): + """_summary_ + + Args: + bot (commands.Bot) The discord bot + inter (Interaction): Represents a discord command interaction. + embed (Embed): The base embed to paginate. + getdata (Callable): A function that provides data, must return Tuple[List[Any], int]. + formatdata (Callable): A formatter function that determines how the data is displayed. + pagesize (int): The size of each page. + initpage (int, optional): The inital page. Defaults to 1. + """ + + self.bot = bot + self.inter = inter + self.embed = embed + self.getdata = getdata + self.formatdata = formatdata + self.maxpage = None + self.pagesize = pagesize + self.index = initpage + + # emoji reference + next_emoji = bot.get_emoji(1204542366602502265) + prev_emoji = bot.get_emoji(1204542365432422470) + self.start_emoji = bot.get_emoji(1204542364073463818) + self.end_emoji = bot.get_emoji(1204542367752003624) + + super().__init__(timeout=100) + + async def check_user_is_author(self, inter: Interaction) -> bool: + """Ensure the user is the author of the original command.""" + + if inter.user == self.inter.user: + return True + + await inter.response.defer() + await ( + Followup(None, "Only the author can interact with this.") + .error() + .send(inter, ephemeral=True) + ) + return False + + async def on_timeout(self): + """Erase the controls on timeout.""" + + message = await self.inter.original_response() + await message.edit(view=None) + + @staticmethod + def calc_total_pages(results: int, max_pagesize: int) -> int: + result = ((results - 1) // max_pagesize) + 1 + log.debug("total pages calculated: %s", result) + return result + + def calc_dataitem_index(self, dataitem_index: int): + """Calculates a given index to be relative to the sum of all pages items + + Example: dataitem_index = 6 + pagesize = 10 + if page == 1 then return 6 + else return 6 + 10 * (page - 1)""" + + if self.index > 1: + dataitem_index += self.pagesize * (self.index - 1) + + dataitem_index += 1 + return dataitem_index + + @button(emoji="◀️", style=ButtonStyle.blurple) + async def backward(self, inter: Interaction, button: Button): + self.index -= 1 + await inter.response.defer() + self.inter = inter + await self.navigate() + + @button(emoji="▶️", style=ButtonStyle.blurple) + async def forward(self, inter: Interaction, button: Button): + self.index += 1 + await inter.response.defer() + self.inter = inter + await self.navigate() + + @button(emoji="⏭️", style=ButtonStyle.blurple) + async def start_or_end(self, inter: Interaction, button: Button): + if self.index <= self.maxpage // 2: + self.index = self.maxpage + else: + self.index = 1 + + await inter.response.defer() + self.inter = inter + await self.navigate() + + async def navigate(self): + log.debug("navigating to page: %s", self.index) + + self.update_buttons() + paged_embed = await self.create_paged_embed() + await self.inter.edit_original_response(embed=paged_embed, view=self) + + async def create_paged_embed(self) -> Embed: + embed = self.embed.copy() + data, total_results = await self.getdata(self.index) + self.maxpage = self.calc_total_pages(total_results, self.pagesize) + + for i, item in enumerate(data): + i = self.calc_dataitem_index(i) + key, value = self.formatdata(i, item) + embed.add_field(name=key, value=value, inline=False) + + if not total_results: + embed.description = "There are no results" + + if self.maxpage > 1: + embed.set_footer(text=f"Page {self.index}/{self.maxpage}") + + return embed + + def update_buttons(self): + if self.index >= self.maxpage: + self.children[2].emoji = self.start_emoji + else: + self.children[2].emoji = self.end_emoji + + self.children[0].disabled = self.index == 1 + self.children[1].disabled = self.index == self.maxpage + + async def send(self): + embed = await self.create_paged_embed() + + if self.maxpage <= 1: + await self.inter.edit_original_response(embed=embed) + return + + self.update_buttons() + await self.inter.edit_original_response(embed=embed, view=self) + + + class Followup: """Wrapper for a discord embed to follow up an interaction.""" @@ -69,10 +219,10 @@ class Followup: description=description ) - async def send(self, inter: Interaction, message: str = None): + async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False): """""" - await inter.followup.send(content=message, embed=self._embed) + await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral) def fields(self, inline: bool = False, **fields: dict): """""" @@ -89,6 +239,13 @@ class Followup: return self + def footer(self, text: str, icon_url: str = None): + """""" + + self._embed.set_footer(text=text, icon_url=icon_url) + + return self + def error(self): """"""