197 lines
6.3 KiB
Python
197 lines
6.3 KiB
Python
"""
|
|
Extension for the `ChannelCog`.
|
|
Loading this file via `commands.Bot.load_extension` will add `ChannelCog` to the bot.
|
|
"""
|
|
|
|
import logging
|
|
|
|
from sqlalchemy.orm import aliased
|
|
from sqlalchemy import select, insert, delete, and_
|
|
from discord import Interaction, TextChannel, Embed, Colour
|
|
from discord.ext import commands
|
|
from discord.app_commands import Group, Choice, autocomplete
|
|
|
|
from db import DatabaseManager, FeedChannelModel, RssSourceModel
|
|
from utils import followup
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class ChannelCog(commands.Cog):
|
|
"""
|
|
Command cog.
|
|
"""
|
|
|
|
def __init__(self, bot):
|
|
super().__init__()
|
|
self.bot = bot
|
|
|
|
@commands.Cog.listener()
|
|
async def on_ready(self):
|
|
log.info(f"{self.__class__.__name__} cog is ready")
|
|
|
|
async def autocomplete_rss_sources(self, inter: Interaction, nickname: str):
|
|
""""""
|
|
|
|
async with DatabaseManager() as database:
|
|
whereclause = and_(
|
|
RssSourceModel.discord_server_id == inter.guild_id,
|
|
RssSourceModel.nick.ilike(f"%{nickname}%")
|
|
)
|
|
query = select(RssSourceModel).where(whereclause)
|
|
result = await database.session.execute(query)
|
|
sources = [
|
|
Choice(name=rss.nick, value=rss.id)
|
|
for rss in result.scalars().all()
|
|
]
|
|
|
|
log.debug(f"Autocomplete rss_sources returned {len(sources)} results")
|
|
|
|
return sources
|
|
|
|
async def autocomplete_existing_feeds(self, inter: Interaction, current: str):
|
|
"""Returns a list of existing RSS + Channel feeds.
|
|
|
|
Parameters
|
|
----------
|
|
inter : Interaction
|
|
Represents an app command interaction.
|
|
current : str
|
|
The current text entered for the autocomplete.
|
|
"""
|
|
|
|
async with DatabaseManager() as database:
|
|
whereclause = and_(
|
|
FeedChannelModel.discord_server_id == inter.guild_id,
|
|
FeedChannelModel.search_name.ilike(f"%{current}%") # is this secure from SQL Injection atk ?
|
|
)
|
|
query = select(FeedChannelModel).where(whereclause)
|
|
result = await database.session.execute(query)
|
|
feeds = [
|
|
Choice(name=feed.search_name, value=feed.id)
|
|
for feed in result.scalars().all()
|
|
]
|
|
|
|
log.debug(f"Autocomplete existing_feeds returned {len(feeds)} results")
|
|
|
|
return feeds
|
|
|
|
# All RSS commands belong to this group.
|
|
channel_group = Group(
|
|
name="channels",
|
|
description="Commands for channel assignment.",
|
|
guild_only=True # These commands belong to channels of
|
|
)
|
|
|
|
@channel_group.command(name="include-feed")
|
|
@autocomplete(rss=autocomplete_rss_sources)
|
|
async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None):
|
|
"""Include a feed within the specified channel.
|
|
|
|
Parameters
|
|
----------
|
|
inter : Interaction
|
|
Represents an app command interaction.
|
|
rss : str
|
|
The RSS feed to include.
|
|
channel : TextChannel
|
|
The channel to include the feed in.
|
|
"""
|
|
|
|
await inter.response.defer()
|
|
|
|
channel = channel or inter.channel
|
|
|
|
async with DatabaseManager() as database:
|
|
select_query = select(RssSourceModel).where(and_(
|
|
RssSourceModel.id == rss,
|
|
RssSourceModel.discord_server_id == inter.guild_id
|
|
))
|
|
|
|
select_result = await database.session.execute(select_query)
|
|
rss_source = select_result.scalars().one()
|
|
nick, rss_url = rss_source.nick, rss_source.rss_url
|
|
|
|
insert_query = insert(FeedChannelModel).values(
|
|
discord_server_id = inter.guild_id,
|
|
discord_channel_id = channel.id,
|
|
rss_source_id=rss,
|
|
search_name=f"{nick} #{channel.name}"
|
|
)
|
|
|
|
insert_result = await database.session.execute(insert_query)
|
|
|
|
|
|
await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}")
|
|
|
|
@channel_group.command(name="exclude-feed")
|
|
@autocomplete(option=autocomplete_existing_feeds)
|
|
async def exclude_feed(self, inter: Interaction, option: int):
|
|
"""Undo command for the `/channel include-feed` command.
|
|
|
|
Parameters
|
|
----------
|
|
inter : Interaction
|
|
Represents an app command interaction.
|
|
option : str
|
|
The RSS feed and channel to exclude.
|
|
"""
|
|
|
|
await inter.response.defer()
|
|
|
|
async with DatabaseManager() as database:
|
|
query = delete(FeedChannelModel).where(and_(
|
|
FeedChannelModel.id == option,
|
|
FeedChannelModel.discord_server_id == inter.guild_id
|
|
))
|
|
|
|
result = await database.session.execute(query)
|
|
|
|
if not result.rowcount:
|
|
await followup(inter, "I couldn't find any items under that ID (placeholder response)")
|
|
return
|
|
|
|
await followup(inter, "I've removed this item (placeholder response)")
|
|
|
|
@channel_group.command(name="list")
|
|
async def list_feeds(self, inter: Interaction):
|
|
# """"""
|
|
|
|
await inter.response.defer()
|
|
|
|
async with DatabaseManager() as database:
|
|
whereclause = and_(FeedChannelModel.discord_server_id == inter.guild_id)
|
|
query = (
|
|
select(FeedChannelModel)
|
|
.where(whereclause)
|
|
.order_by(FeedChannelModel.search_name)
|
|
)
|
|
result = await database.session.execute(query)
|
|
|
|
feed_channels = result.scalars().all()
|
|
|
|
if not feed_channels:
|
|
await followup(inter, "It looks like there are no feed channels available.")
|
|
return
|
|
|
|
output = "\n".join([f"{i}. <#{feed.discord_channel_id}> · {feed.search_name}" for i, feed in enumerate(feed_channels)])
|
|
|
|
embed = Embed(
|
|
title="Saved Feed Channels",
|
|
description=f"placeholder, add rss hyperlink for each item using a sql join\n\n{output}",
|
|
colour=Colour.lighter_grey()
|
|
)
|
|
|
|
await followup(inter, embed=embed)
|
|
|
|
|
|
async def setup(bot):
|
|
"""
|
|
Setup function for this extension.
|
|
Adds `ChannelCog` to the bot.
|
|
"""
|
|
|
|
cog = ChannelCog(bot)
|
|
await bot.add_cog(cog)
|
|
log.info(f"Added {cog.__class__.__name__} cog")
|