moved channels.py into rss.py
This commit is contained in:
parent
1279bec6aa
commit
df60652da6
@ -1,222 +0,0 @@
|
||||
"""
|
||||
Extension for the `ChannelCog`.
|
||||
Loading this file via `commands.Bot.load_extension` will add `ChannelCog` to the bot.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy import select, insert, delete, and_
|
||||
from discord import Interaction, TextChannel, Embed, Colour
|
||||
from discord.ext import commands
|
||||
from discord.app_commands import Group, Choice, autocomplete, choices
|
||||
|
||||
from db import DatabaseManager, FeedChannelModel, RssSourceModel
|
||||
from utils import followup
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChannelCog(commands.Cog):
|
||||
"""
|
||||
Command cog.
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_ready(self):
|
||||
log.info(f"{self.__class__.__name__} cog is ready")
|
||||
|
||||
async def autocomplete_rss_sources(self, inter: Interaction, nickname: str):
|
||||
""""""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.nick.ilike(f"%{nickname}%")
|
||||
)
|
||||
query = select(RssSourceModel).where(whereclause)
|
||||
result = await database.session.execute(query)
|
||||
sources = [
|
||||
Choice(name=rss.nick, value=rss.id)
|
||||
for rss in result.scalars().all()
|
||||
]
|
||||
|
||||
log.debug(f"Autocomplete rss_sources returned {len(sources)} results")
|
||||
|
||||
return sources
|
||||
|
||||
async def autocomplete_existing_feeds(self, inter: Interaction, current: str):
|
||||
"""Returns a list of existing RSS + Channel feeds.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
current : str
|
||||
The current text entered for the autocomplete.
|
||||
"""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
FeedChannelModel.discord_server_id == inter.guild_id,
|
||||
FeedChannelModel.search_name.ilike(f"%{current}%"), # is this secure from SQL Injection atk ?
|
||||
RssSourceModel.id == FeedChannelModel.rss_source_id
|
||||
)
|
||||
query = (
|
||||
select(FeedChannelModel, RssSourceModel)
|
||||
.where(whereclause)
|
||||
.join(RssSourceModel)
|
||||
.order_by(FeedChannelModel.discord_channel_id)
|
||||
)
|
||||
result = await database.session.execute(query)
|
||||
feeds = []
|
||||
for feed in result.scalars().all():
|
||||
channel = inter.guild.get_channel(feed.discord_channel_id)
|
||||
feeds.append(Choice(name=f"# {channel.name} | {feed.rss_source.nick}", value=feed.id))
|
||||
|
||||
log.debug(f"Autocomplete existing_feeds returned {len(feeds)} results")
|
||||
|
||||
return feeds
|
||||
|
||||
# All RSS commands belong to this group.
|
||||
channel_group = Group(
|
||||
name="channels",
|
||||
description="Commands for channel assignment.",
|
||||
guild_only=True # These commands belong to channels of
|
||||
)
|
||||
|
||||
@channel_group.command(name="include-feed")
|
||||
@autocomplete(rss=autocomplete_rss_sources)
|
||||
async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None):
|
||||
"""Include a feed within the specified channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
rss : str
|
||||
The RSS feed to include.
|
||||
channel : TextChannel
|
||||
The channel to include the feed in.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
channel = channel or inter.channel
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
select_query = select(RssSourceModel).where(and_(
|
||||
RssSourceModel.id == rss,
|
||||
RssSourceModel.discord_server_id == inter.guild_id
|
||||
))
|
||||
|
||||
select_result = await database.session.execute(select_query)
|
||||
rss_source = select_result.scalars().one()
|
||||
nick, rss_url = rss_source.nick, rss_source.rss_url
|
||||
|
||||
insert_query = insert(FeedChannelModel).values(
|
||||
discord_server_id = inter.guild_id,
|
||||
discord_channel_id = channel.id,
|
||||
rss_source_id=rss,
|
||||
search_name=f"{nick} #{channel.name}"
|
||||
)
|
||||
|
||||
insert_result = await database.session.execute(insert_query)
|
||||
|
||||
|
||||
await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}")
|
||||
|
||||
@channel_group.command(name="exclude-feed")
|
||||
@autocomplete(option=autocomplete_existing_feeds)
|
||||
async def exclude_feed(self, inter: Interaction, option: int):
|
||||
"""Undo command for the `/channel include-feed` command.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
option : str
|
||||
The RSS feed and channel to exclude.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = delete(FeedChannelModel).where(and_(
|
||||
FeedChannelModel.id == option,
|
||||
FeedChannelModel.discord_server_id == inter.guild_id
|
||||
))
|
||||
|
||||
result = await database.session.execute(query)
|
||||
|
||||
if not result.rowcount:
|
||||
await followup(inter, "I couldn't find any items under that ID (placeholder response)")
|
||||
return
|
||||
|
||||
await followup(inter, "I've removed this item (placeholder response)")
|
||||
|
||||
@channel_group.command(name="list")
|
||||
# @choices(sort=[
|
||||
# Choice(name="RSS Nickname", value=0),
|
||||
# Choice(name="Channel ID", value=1),
|
||||
# Choice(name="Date Added", value=2)
|
||||
# ])
|
||||
async def list_feeds(self, inter: Interaction): # sort: int
|
||||
"""List all of the channels and their respective included feeds.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
FeedChannelModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.id == FeedChannelModel.rss_source_id
|
||||
)
|
||||
query = (
|
||||
select(FeedChannelModel, RssSourceModel)
|
||||
.where(whereclause)
|
||||
.join(RssSourceModel)
|
||||
.order_by(FeedChannelModel.discord_channel_id)
|
||||
)
|
||||
result = await database.session.execute(query)
|
||||
|
||||
feed_channels = result.scalars().all()
|
||||
rowcount = len(feed_channels)
|
||||
|
||||
if not feed_channels:
|
||||
await followup(inter, "It looks like there are no feed channels available.")
|
||||
return
|
||||
|
||||
output = "\n".join([
|
||||
f"{i}. <#{feed.discord_channel_id}> · [{feed.rss_source.nick}]({feed.rss_source.rss_url})"
|
||||
for i, feed in enumerate(feed_channels)
|
||||
])
|
||||
|
||||
embed = Embed(
|
||||
title="Saved Feed Channels",
|
||||
description=f"{output}",
|
||||
colour=Colour.blue()
|
||||
)
|
||||
embed.set_footer(text=f"Showing {rowcount} results")
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
"""
|
||||
Setup function for this extension.
|
||||
Adds `ChannelCog` to the bot.
|
||||
"""
|
||||
|
||||
cog = ChannelCog(bot)
|
||||
await bot.add_cog(cog)
|
||||
log.info(f"Added {cog.__class__.__name__} cog")
|
@ -4,20 +4,23 @@ Loading this file via `commands.Bot.load_extension` will add `RssCog` to the bot
|
||||
"""
|
||||
|
||||
import logging
|
||||
import validators
|
||||
from typing import Tuple
|
||||
|
||||
import textwrap
|
||||
import feedparser
|
||||
from markdownify import markdownify
|
||||
from discord import Interaction, Embed, Colour
|
||||
import validators
|
||||
from feedparser import FeedParserDict, parse
|
||||
from discord.ext import commands
|
||||
from discord.app_commands import Choice, Group, autocomplete, choices
|
||||
from discord import Interaction, Embed, Colour, TextChannel
|
||||
from discord.app_commands import Choice, Group, autocomplete, choices, rename
|
||||
from sqlalchemy import insert, select, and_, delete
|
||||
|
||||
from utils import get_rss_data, followup, audit
|
||||
from feed import get_source, Source
|
||||
from db import DatabaseManager, SentArticleModel, RssSourceModel
|
||||
from utils import get_rss_data, followup, audit # pylint: disable=E0401
|
||||
from feed import get_source, Source # pylint: disable=E0401
|
||||
from db import ( # pylint: disable=E0401
|
||||
DatabaseManager,
|
||||
SentArticleModel,
|
||||
RssSourceModel,
|
||||
FeedChannelModel
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -26,9 +29,9 @@ rss_list_sort_choices = [
|
||||
Choice(name="Date Added", value=1)
|
||||
]
|
||||
|
||||
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource.
|
||||
# Run a period task to check this.
|
||||
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]:
|
||||
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the
|
||||
# target resource. Run a period task to check this.
|
||||
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]:
|
||||
"""Validate a provided RSS source.
|
||||
|
||||
Parameters
|
||||
@ -63,7 +66,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feed
|
||||
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
|
||||
|
||||
# Check the contents is actually an RSS feed.
|
||||
feed = feedparser.parse(feed_data)
|
||||
feed = parse(feed_data)
|
||||
if not feed.version:
|
||||
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
|
||||
|
||||
@ -81,7 +84,9 @@ class RssCog(commands.Cog):
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_ready(self):
|
||||
log.info(f"{self.__class__.__name__} cog is ready")
|
||||
"""Instructions to call when the cog is ready."""
|
||||
|
||||
log.info("%s cog is ready", self.__class__.__name__)
|
||||
|
||||
async def source_autocomplete(self, inter: Interaction, nickname: str):
|
||||
"""Provides RSS source autocomplete functionality for commands.
|
||||
@ -114,13 +119,13 @@ class RssCog(commands.Cog):
|
||||
return sources
|
||||
|
||||
# All RSS commands belong to this group.
|
||||
rss_group = Group(
|
||||
name="rss",
|
||||
feed_group = Group(
|
||||
name="feed",
|
||||
description="Commands for rss sources.",
|
||||
guild_only=True # We store guild IDs in the database, so guild only = True
|
||||
)
|
||||
|
||||
@rss_group.command(name="add")
|
||||
@feed_group.command(name="add")
|
||||
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
|
||||
"""Add a new RSS source.
|
||||
|
||||
@ -163,7 +168,7 @@ class RssCog(commands.Cog):
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@rss_group.command(name="remove")
|
||||
@feed_group.command(name="remove")
|
||||
@autocomplete(url=source_autocomplete)
|
||||
async def remove_rss_source(self, inter: Interaction, url: str):
|
||||
"""Delete an existing RSS source.
|
||||
@ -178,7 +183,7 @@ class RssCog(commands.Cog):
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
log.debug(f"Attempting to remove RSS source ({url=})")
|
||||
log.debug("Attempting to remove RSS source (url=%s)", url)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
select_result = await database.session.execute(
|
||||
@ -219,7 +224,7 @@ class RssCog(commands.Cog):
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@rss_group.command(name="list")
|
||||
@feed_group.command(name="list")
|
||||
@choices(sort=rss_list_sort_choices)
|
||||
async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False):
|
||||
"""Provides a with a list of RSS sources available for the current server.
|
||||
@ -278,9 +283,10 @@ class RssCog(commands.Cog):
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@rss_group.command(name="fetch")
|
||||
@feed_group.command(name="fetch")
|
||||
@rename(max_="max")
|
||||
@autocomplete(rss=source_autocomplete)
|
||||
async def fetch_rss(self, inter: Interaction, rss: str, max: int=1):
|
||||
async def fetch_rss(self, inter: Interaction, rss: str, max_: int=1):
|
||||
"""Fetch an item from the specified RSS feed.
|
||||
|
||||
Parameters
|
||||
@ -289,13 +295,13 @@ class RssCog(commands.Cog):
|
||||
Represents an app command interaction.
|
||||
rss : str
|
||||
The RSS feed to fetch from.
|
||||
max : int, optional
|
||||
max_ : int, optional
|
||||
Maximum number of items to fetch, by default 1, limits at 5.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
if max > 5:
|
||||
if max_ > 5:
|
||||
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
|
||||
return
|
||||
|
||||
@ -305,7 +311,7 @@ class RssCog(commands.Cog):
|
||||
return
|
||||
|
||||
source = Source.from_parsed(feed)
|
||||
articles = source.get_latest_articles(max)
|
||||
articles = source.get_latest_articles(max_)
|
||||
|
||||
if not articles:
|
||||
await followup(inter, "Sorry, I couldn't find any articles from this feed.")
|
||||
@ -324,11 +330,196 @@ class RssCog(commands.Cog):
|
||||
for article in articles
|
||||
])
|
||||
await database.session.execute(query)
|
||||
await audit(self, f"User is requesting {max} articles from {source.name}", inter.user.id, database=database)
|
||||
await audit(self, f"User is requesting {max_} articles from {source.name}", inter.user.id, database=database)
|
||||
|
||||
await followup(inter, embeds=embeds)
|
||||
|
||||
|
||||
# Channels ---- ---- ----
|
||||
|
||||
|
||||
async def autocomplete_rss_sources(self, inter: Interaction, nickname: str):
|
||||
""""""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.nick.ilike(f"%{nickname}%")
|
||||
)
|
||||
query = select(RssSourceModel).where(whereclause)
|
||||
result = await database.session.execute(query)
|
||||
sources = [
|
||||
Choice(name=rss.nick, value=rss.id)
|
||||
for rss in result.scalars().all()
|
||||
]
|
||||
|
||||
log.debug("Autocomplete rss_sources returned %s results", len(sources))
|
||||
|
||||
return sources
|
||||
|
||||
async def autocomplete_existing_feeds(self, inter: Interaction, current: str):
|
||||
"""Returns a list of existing RSS + Channel feeds.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
current : str
|
||||
The current text entered for the autocomplete.
|
||||
"""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
FeedChannelModel.discord_server_id == inter.guild_id,
|
||||
FeedChannelModel.search_name.ilike(f"%{current}%"), # is this secure from SQL Injection atk ?
|
||||
RssSourceModel.id == FeedChannelModel.rss_source_id
|
||||
)
|
||||
query = (
|
||||
select(FeedChannelModel, RssSourceModel)
|
||||
.where(whereclause)
|
||||
.join(RssSourceModel)
|
||||
.order_by(FeedChannelModel.discord_channel_id)
|
||||
)
|
||||
result = await database.session.execute(query)
|
||||
feeds = []
|
||||
for feed in result.scalars().all():
|
||||
channel = inter.guild.get_channel(feed.discord_channel_id)
|
||||
feeds.append(Choice(name=f"# {channel.name} | {feed.rss_source.nick}", value=feed.id))
|
||||
|
||||
log.debug("Autocomplete existing_feeds returned %s results", len(feeds))
|
||||
|
||||
return feeds
|
||||
|
||||
# # All RSS commands belong to this group.
|
||||
# channel_group = Group(
|
||||
# name="channels",
|
||||
# description="Commands for channel assignment.",
|
||||
# guild_only=True # These commands belong to channels of
|
||||
# )
|
||||
|
||||
@feed_group.command(name="assign")
|
||||
@autocomplete(rss=autocomplete_rss_sources)
|
||||
async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None):
|
||||
"""Include a feed within the specified channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
rss : str
|
||||
The RSS feed to include.
|
||||
channel : TextChannel
|
||||
The channel to include the feed in.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
channel = channel or inter.channel
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
select_query = select(RssSourceModel).where(and_(
|
||||
RssSourceModel.id == rss,
|
||||
RssSourceModel.discord_server_id == inter.guild_id
|
||||
))
|
||||
|
||||
select_result = await database.session.execute(select_query)
|
||||
rss_source = select_result.scalars().one()
|
||||
nick, rss_url = rss_source.nick, rss_source.rss_url
|
||||
|
||||
insert_query = insert(FeedChannelModel).values(
|
||||
discord_server_id = inter.guild_id,
|
||||
discord_channel_id = channel.id,
|
||||
rss_source_id=rss,
|
||||
search_name=f"{nick} #{channel.name}"
|
||||
)
|
||||
|
||||
await database.session.execute(insert_query)
|
||||
|
||||
|
||||
await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}")
|
||||
|
||||
@feed_group.command(name="unassign")
|
||||
@autocomplete(option=autocomplete_existing_feeds)
|
||||
async def exclude_feed(self, inter: Interaction, option: int):
|
||||
"""Undo command for the `/channel include-feed` command.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
option : str
|
||||
The RSS feed and channel to exclude.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = delete(FeedChannelModel).where(and_(
|
||||
FeedChannelModel.id == option,
|
||||
FeedChannelModel.discord_server_id == inter.guild_id
|
||||
))
|
||||
|
||||
result = await database.session.execute(query)
|
||||
|
||||
if not result.rowcount:
|
||||
await followup(inter, "I couldn't find any items under that ID (placeholder response)")
|
||||
return
|
||||
|
||||
await followup(inter, "I've removed this item (placeholder response)")
|
||||
|
||||
@feed_group.command(name="channels")
|
||||
# @choices(sort=[
|
||||
# Choice(name="RSS Nickname", value=0),
|
||||
# Choice(name="Channel ID", value=1),
|
||||
# Choice(name="Date Added", value=2)
|
||||
# ])
|
||||
async def list_feeds(self, inter: Interaction): # sort: int
|
||||
"""List all of the channels and their respective included feeds.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
FeedChannelModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.id == FeedChannelModel.rss_source_id
|
||||
)
|
||||
query = (
|
||||
select(FeedChannelModel, RssSourceModel)
|
||||
.where(whereclause)
|
||||
.join(RssSourceModel)
|
||||
.order_by(FeedChannelModel.discord_channel_id)
|
||||
)
|
||||
result = await database.session.execute(query)
|
||||
|
||||
feed_channels = result.scalars().all()
|
||||
rowcount = len(feed_channels)
|
||||
|
||||
if not feed_channels:
|
||||
await followup(inter, "It looks like there are no feed channels available.")
|
||||
return
|
||||
|
||||
output = "\n".join([
|
||||
f"{i}. <#{feed.discord_channel_id}> · [{feed.rss_source.nick}]({feed.rss_source.rss_url})"
|
||||
for i, feed in enumerate(feed_channels)
|
||||
])
|
||||
|
||||
embed = Embed(
|
||||
title="Saved Feed Channels",
|
||||
description=f"{output}",
|
||||
colour=Colour.blue()
|
||||
)
|
||||
embed.set_footer(text=f"Showing {rowcount} results")
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
"""
|
||||
Setup function for this extension.
|
||||
@ -337,4 +528,4 @@ async def setup(bot):
|
||||
|
||||
cog = RssCog(bot)
|
||||
await bot.add_cog(cog)
|
||||
log.info(f"Added {cog.__class__.__name__} cog")
|
||||
log.info("Added %s cog", cog.__class__.__name__)
|
||||
|
Loading…
x
Reference in New Issue
Block a user