PYRSS-Bot/src/extensions/channels.py
2023-12-19 00:35:53 +00:00

197 lines
6.3 KiB
Python

"""
Extension for the `ChannelCog`.
Loading this file via `commands.Bot.load_extension` will add `ChannelCog` to the bot.
"""
import logging
from sqlalchemy.orm import aliased
from sqlalchemy import select, insert, delete, and_
from discord import Interaction, TextChannel, Embed, Colour
from discord.ext import commands
from discord.app_commands import Group, Choice, autocomplete
from db import DatabaseManager, FeedChannelModel, RssSourceModel
from utils import followup
log = logging.getLogger(__name__)
class ChannelCog(commands.Cog):
"""
Command cog.
"""
def __init__(self, bot):
super().__init__()
self.bot = bot
@commands.Cog.listener()
async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready")
async def autocomplete_rss_sources(self, inter: Interaction, nickname: str):
""""""
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.nick.ilike(f"%{nickname}%")
)
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
sources = [
Choice(name=rss.nick, value=rss.id)
for rss in result.scalars().all()
]
log.debug(f"Autocomplete rss_sources returned {len(sources)} results")
return sources
async def autocomplete_existing_feeds(self, inter: Interaction, current: str):
"""Returns a list of existing RSS + Channel feeds.
Parameters
----------
inter : Interaction
Represents an app command interaction.
current : str
The current text entered for the autocomplete.
"""
async with DatabaseManager() as database:
whereclause = and_(
FeedChannelModel.discord_server_id == inter.guild_id,
FeedChannelModel.search_name.ilike(f"%{current}%") # is this secure from SQL Injection atk ?
)
query = select(FeedChannelModel).where(whereclause)
result = await database.session.execute(query)
feeds = [
Choice(name=feed.search_name, value=feed.id)
for feed in result.scalars().all()
]
log.debug(f"Autocomplete existing_feeds returned {len(feeds)} results")
return feeds
# All RSS commands belong to this group.
channel_group = Group(
name="channels",
description="Commands for channel assignment.",
guild_only=True # These commands belong to channels of
)
@channel_group.command(name="include-feed")
@autocomplete(rss=autocomplete_rss_sources)
async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None):
"""Include a feed within the specified channel.
Parameters
----------
inter : Interaction
Represents an app command interaction.
rss : str
The RSS feed to include.
channel : TextChannel
The channel to include the feed in.
"""
await inter.response.defer()
channel = channel or inter.channel
async with DatabaseManager() as database:
select_query = select(RssSourceModel).where(and_(
RssSourceModel.id == rss,
RssSourceModel.discord_server_id == inter.guild_id
))
select_result = await database.session.execute(select_query)
rss_source = select_result.scalars().one()
nick, rss_url = rss_source.nick, rss_source.rss_url
insert_query = insert(FeedChannelModel).values(
discord_server_id = inter.guild_id,
discord_channel_id = channel.id,
rss_source_id=rss,
search_name=f"{nick} #{channel.name}"
)
insert_result = await database.session.execute(insert_query)
await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}")
@channel_group.command(name="exclude-feed")
@autocomplete(option=autocomplete_existing_feeds)
async def exclude_feed(self, inter: Interaction, option: int):
"""Undo command for the `/channel include-feed` command.
Parameters
----------
inter : Interaction
Represents an app command interaction.
option : str
The RSS feed and channel to exclude.
"""
await inter.response.defer()
async with DatabaseManager() as database:
query = delete(FeedChannelModel).where(and_(
FeedChannelModel.id == option,
FeedChannelModel.discord_server_id == inter.guild_id
))
result = await database.session.execute(query)
if not result.rowcount:
await followup(inter, "I couldn't find any items under that ID (placeholder response)")
return
await followup(inter, "I've removed this item (placeholder response)")
@channel_group.command(name="list")
async def list_feeds(self, inter: Interaction):
# """"""
await inter.response.defer()
async with DatabaseManager() as database:
whereclause = and_(FeedChannelModel.discord_server_id == inter.guild_id)
query = (
select(FeedChannelModel)
.where(whereclause)
.order_by(FeedChannelModel.search_name)
)
result = await database.session.execute(query)
feed_channels = result.scalars().all()
if not feed_channels:
await followup(inter, "It looks like there are no feed channels available.")
return
output = "\n".join([f"{i}. <#{feed.discord_channel_id}> · {feed.search_name}" for i, feed in enumerate(feed_channels)])
embed = Embed(
title="Saved Feed Channels",
description=f"placeholder, add rss hyperlink for each item using a sql join\n\n{output}",
colour=Colour.lighter_grey()
)
await followup(inter, embed=embed)
async def setup(bot):
"""
Setup function for this extension.
Adds `ChannelCog` to the bot.
"""
cog = ChannelCog(bot)
await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog")