diff --git a/src/db/models.py b/src/db/models.py index e44731a..d4477a0 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -5,7 +5,7 @@ All table classes should be suffixed with `Model`. from enum import Enum, auto -from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint +from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint, ForeignKey from sqlalchemy.sql import func from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import declarative_base @@ -54,6 +54,8 @@ class RssSourceModel(Base): rss_url = Column(String, nullable=False) created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + feed_channels = relationship("FeedChannelModel", cascade="all, delete") + # the nickname must be unique, but only within the same discord server __table_args__ = ( UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'), @@ -69,3 +71,6 @@ class FeedChannelModel(Base): id = Column(Integer, primary_key=True, autoincrement=True) discord_channel_id = Column(BigInteger, nullable=False) + discord_server_id = Column(BigInteger, nullable=False) + search_name = Column(String, nullable=False) + rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False) diff --git a/src/extensions/channels.py b/src/extensions/channels.py new file mode 100644 index 0000000..5f0876b --- /dev/null +++ b/src/extensions/channels.py @@ -0,0 +1,106 @@ +""" +Extension for the `ChannelCog`. +Loading this file via `commands.Bot.load_extension` will add `ChannelCog` to the bot. +""" + +import logging + +from sqlalchemy import select, and_ +from discord import Interaction, TextChannel +from discord.ext import commands +from discord.app_commands import Group, Choice, autocomplete + +from db import DatabaseManager, FeedChannelModel +from utils import followup + +log = logging.getLogger(__name__) + + +class ChannelCog(commands.Cog): + """ + Command cog. + """ + + def __init__(self, bot): + super().__init__() + self.bot = bot + + @commands.Cog.listener() + async def on_ready(self): + log.info(f"{self.__class__.__name__} cog is ready") + + async def autocomplete_existing_feeds(self, inter: Interaction, current: str): + """Returns a list of existing RSS + Channel feeds. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + current : str + The current text entered for the autocomplete. + """ + + async with DatabaseManager() as database: + whereclause = and_( + FeedChannelModel.discord_server_id == inter.guild_id, + FeedChannelModel.search_name.ilike(f"%{current}%") # is this secure from SQL Injection atk ? + ) + query = select(FeedChannelModel).where(whereclause) + result = await database.session.execute(query) + feeds = [ + Choice(name=feed.search_name, value=feed.id) + for feed in result.scalars().all() + ] + + return feeds + + # All RSS commands belong to this group. + channel_group = Group( + name="channel", + description="Commands for channel assignment.", + guild_only=True # We store guild IDs in the database, so guild only = True + ) + + channel_group.command(name="include-feed") + async def include_feed(self, inter: Interaction, rss: str, channel: TextChannel): + """Include a feed within the specified channel. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + rss : str + The RSS feed to include. + channel : TextChannel + The channel to include the feed in. + """ + + await inter.response.defer() + await followup(inter, "Ping") + + channel_group.command(name="exclude-feed") + @autocomplete(option=autocomplete_existing_feeds) + async def exclude_feed(self, inter: Interaction, option: int): + """Undo command for the `/channel include-feed` command. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + option : str + The RSS feed and channel to exclude. + """ + + await inter.response.defer() + await followup(inter, "Pong") + + +async def setup(bot): + """ + Setup function for this extension. + Adds `ChannelCog` to the bot. + """ + + cog = ChannelCog(bot) + await bot.add_cog(cog) + log.info(f"Added {cog.__class__.__name__} cog")