diff --git a/src/extensions/test.py b/src/extensions/test.py index 18c2722..7651254 100644 --- a/src/extensions/test.py +++ b/src/extensions/test.py @@ -9,9 +9,9 @@ import textwrap from markdownify import markdownify from discord import app_commands, Interaction, Embed from discord.ext import commands, tasks -from sqlalchemy import insert, select +from sqlalchemy import insert, select, and_ -from db import DatabaseManager, AuditModel, SentArticleModel +from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel from feed import Feeds, get_source log = logging.getLogger(__name__) @@ -31,19 +31,39 @@ class Test(commands.Cog): async def on_ready(self): log.info(f"{self.__class__.__name__} cog is ready") + async def source_autocomplete(self, inter: Interaction, current: str): + """ + + """ + + async with DatabaseManager() as database: + whereclause = and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.rss_url.ilike(f"%{current}%") + ) + query = select(RssSourceModel).where(whereclause) + result = await database.session.execute(query) + sources = [ + app_commands.Choice(name=rss.rss_url, value=rss.rss_url) + for rss in result.scalars().all() + ] + + return sources + @app_commands.command(name="test-latest-article") - # @app_commands.choices(source=[ - # app_commands.Choice(name="The Babylon Bee", value=Feeds.THE_BABYLON_BEE), - # app_commands.Choice(name="The Upper Lip", value=Feeds.THE_UPPER_LIP), - # app_commands.Choice(name="BBC News", value=Feeds.BBC_NEWS), - # ]) - async def test_bee(self, inter: Interaction, source: Feeds): + @app_commands.autocomplete(source=source_autocomplete) + async def test_news(self, inter: Interaction, source: str): await inter.response.defer() await self.bot.audit("Requesting latest article.", inter.user.id) - source = get_source(source) - article = source.get_latest_article() + try: + source = get_source(source) + article = source.get_latest_article() + except IndexError as e: + log.error(e) + await inter.followup.send("An error occured, it's possible that the source provided was bad.") + return md_description = markdownify(article.description, strip=("img",)) article_description = textwrap.shorten(md_description, 4096) @@ -74,9 +94,6 @@ class Test(commands.Cog): await inter.followup.send(embed=embed) - - - async def setup(bot): """ Setup function for this extension.