diff --git a/src/db/models.py b/src/db/models.py index d4477a0..10b9b1d 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -74,3 +74,8 @@ class FeedChannelModel(Base): discord_server_id = Column(BigInteger, nullable=False) search_name = Column(String, nullable=False) rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False) + + # the rss source must be unique, but only within the same discord channel + __table_args__ = ( + UniqueConstraint('rss_source_id', 'discord_channel_id', name='uq_rss_discord_channel'), + ) diff --git a/src/extensions/channels.py b/src/extensions/channels.py index 5f0876b..dc1a34d 100644 --- a/src/extensions/channels.py +++ b/src/extensions/channels.py @@ -5,12 +5,12 @@ Loading this file via `commands.Bot.load_extension` will add `ChannelCog` to the import logging -from sqlalchemy import select, and_ +from sqlalchemy import select, insert, delete, and_ from discord import Interaction, TextChannel from discord.ext import commands from discord.app_commands import Group, Choice, autocomplete -from db import DatabaseManager, FeedChannelModel +from db import DatabaseManager, FeedChannelModel, RssSourceModel from utils import followup log = logging.getLogger(__name__) @@ -29,6 +29,25 @@ class ChannelCog(commands.Cog): async def on_ready(self): log.info(f"{self.__class__.__name__} cog is ready") + async def autocomplete_rss_sources(self, inter: Interaction, nickname: str): + """""" + + async with DatabaseManager() as database: + whereclause = and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.nick.ilike(f"%{nickname}%") + ) + query = select(RssSourceModel).where(whereclause) + result = await database.session.execute(query) + sources = [ + Choice(name=rss.nick, value=rss.id) + for rss in result.scalars().all() + ] + + log.debug(f"Autocomplete rss_sources returned {len(sources)} results") + + return sources + async def autocomplete_existing_feeds(self, inter: Interaction, current: str): """Returns a list of existing RSS + Channel feeds. @@ -52,17 +71,20 @@ class ChannelCog(commands.Cog): for feed in result.scalars().all() ] + log.debug(f"Autocomplete existing_feeds returned {len(feeds)} results") + return feeds # All RSS commands belong to this group. channel_group = Group( name="channel", description="Commands for channel assignment.", - guild_only=True # We store guild IDs in the database, so guild only = True + guild_only=True # These commands belong to channels of ) - channel_group.command(name="include-feed") - async def include_feed(self, inter: Interaction, rss: str, channel: TextChannel): + @channel_group.command(name="include-feed") + @autocomplete(rss=autocomplete_rss_sources) + async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None): """Include a feed within the specified channel. Parameters @@ -76,9 +98,32 @@ class ChannelCog(commands.Cog): """ await inter.response.defer() - await followup(inter, "Ping") - channel_group.command(name="exclude-feed") + channel = channel or inter.channel + + async with DatabaseManager() as database: + select_query = select(RssSourceModel).where(and_( + RssSourceModel.id == rss, + RssSourceModel.discord_server_id == inter.guild_id + )) + + select_result = await database.session.execute(select_query) + rss_source = select_result.scalars().one() + nick, rss_url = rss_source.nick, rss_source.rss_url + + insert_query = insert(FeedChannelModel).values( + discord_server_id = inter.guild_id, + discord_channel_id = channel.id, + rss_source_id=rss, + search_name=f"{nick} #{channel.name}" + ) + + insert_result = await database.session.execute(insert_query) + + + await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}") + + @channel_group.command(name="exclude-feed") @autocomplete(option=autocomplete_existing_feeds) async def exclude_feed(self, inter: Interaction, option: int): """Undo command for the `/channel include-feed` command. @@ -92,7 +137,20 @@ class ChannelCog(commands.Cog): """ await inter.response.defer() - await followup(inter, "Pong") + + async with DatabaseManager() as database: + query = delete(FeedChannelModel).where(and_( + FeedChannelModel.id == option, + FeedChannelModel.discord_server_id == inter.guild_id + )) + + result = await database.session.execute(query) + + if not result.rowcount: + await followup(inter, "I couldn't find any items under that ID (placeholder response)") + return + + await followup(inter, "I've removed this item (placeholder response)") async def setup(bot): diff --git a/src/extensions/rss.py b/src/extensions/rss.py index 2ab7bb7..b133795 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -299,26 +299,44 @@ class RssCog(commands.Cog): await followup(inter, "Sorry, I couldn't find any articles from this feed.") return + # TODO: embed rules + # + # https://discord.com/safety/using-webhooks-and-embeds + # + # 1. Embed titles are limited to 256 characters + # 2. Embed descriptions are limited to 2048 characters + # 3. There can be up to 25 fields + # 4. The name of a field is limited to 256 characters and its value to 1024 characters + # 5. The footer text is limited to 2048 characters + # 6. The author name is limited to 256 characters + # 7. In addition, the sum of all characters in an embed structure must not exceed 6000 characters + # 8. The title cannot hyperlink URLs + # 10. The embed and author URL must be valid, these will be checked. + embeds = [] for article in articles: md_description = markdownify(article.description, strip=("img",)) article_description = textwrap.shorten(md_description, 4096) + md_title = markdownify(article.title, strip=("img", "a")) + article_title = textwrap.shorten(md_title, 256) + embed = Embed( - title=article.title, + title=article_title, description=article_description, - url=article.url, + url=article.url if validators.url(article.url) else None, timestamp=article.published, colour=Colour.brand_red() ) + thumbail_url = await article.get_thumbnail_url() thumbail_url = thumbail_url if validators.url(thumbail_url) else None - embed.set_thumbnail(url=source.icon_url) + embed.set_thumbnail(url=source.icon_url if validators.url(source.icon_url) else None) embed.set_image(url=thumbail_url) embed.set_footer(text=article.author) embed.set_author( name=source.name, - url=source.url, + url=source.url if validators.url(source.url) else None, ) embeds.append(embed) diff --git a/src/feed.py b/src/feed.py index c4f12db..c08011d 100644 --- a/src/feed.py +++ b/src/feed.py @@ -64,9 +64,13 @@ class Article: log.debug("Fetching thumbnail for article: %s", self) - async with aiohttp.ClientSession() as session: - async with session.get(self.url) as response: - html = await response.text() + try: + async with aiohttp.ClientSession() as session: + async with session.get(self.url) as response: + html = await response.text() + except aiohttp.InvalidURL as error: + log.error(error) + return None soup = bs4(html, "html.parser") image_element = soup.select_one("meta[property='og:image']")