moved channels.py into rss.py

This commit is contained in:
Corban-Lee Jones 2023-12-19 19:11:44 +00:00
parent 1279bec6aa
commit df60652da6
2 changed files with 218 additions and 249 deletions

View File

@ -1,222 +0,0 @@
"""
Extension for the `ChannelCog`.
Loading this file via `commands.Bot.load_extension` will add `ChannelCog` to the bot.
"""
import logging
from sqlalchemy.orm import aliased
from sqlalchemy import select, insert, delete, and_
from discord import Interaction, TextChannel, Embed, Colour
from discord.ext import commands
from discord.app_commands import Group, Choice, autocomplete, choices
from db import DatabaseManager, FeedChannelModel, RssSourceModel
from utils import followup
log = logging.getLogger(__name__)
class ChannelCog(commands.Cog):
"""
Command cog.
"""
def __init__(self, bot):
super().__init__()
self.bot = bot
@commands.Cog.listener()
async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready")
async def autocomplete_rss_sources(self, inter: Interaction, nickname: str):
""""""
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.nick.ilike(f"%{nickname}%")
)
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
sources = [
Choice(name=rss.nick, value=rss.id)
for rss in result.scalars().all()
]
log.debug(f"Autocomplete rss_sources returned {len(sources)} results")
return sources
async def autocomplete_existing_feeds(self, inter: Interaction, current: str):
"""Returns a list of existing RSS + Channel feeds.
Parameters
----------
inter : Interaction
Represents an app command interaction.
current : str
The current text entered for the autocomplete.
"""
async with DatabaseManager() as database:
whereclause = and_(
FeedChannelModel.discord_server_id == inter.guild_id,
FeedChannelModel.search_name.ilike(f"%{current}%"), # is this secure from SQL Injection atk ?
RssSourceModel.id == FeedChannelModel.rss_source_id
)
query = (
select(FeedChannelModel, RssSourceModel)
.where(whereclause)
.join(RssSourceModel)
.order_by(FeedChannelModel.discord_channel_id)
)
result = await database.session.execute(query)
feeds = []
for feed in result.scalars().all():
channel = inter.guild.get_channel(feed.discord_channel_id)
feeds.append(Choice(name=f"# {channel.name} | {feed.rss_source.nick}", value=feed.id))
log.debug(f"Autocomplete existing_feeds returned {len(feeds)} results")
return feeds
# All RSS commands belong to this group.
channel_group = Group(
name="channels",
description="Commands for channel assignment.",
guild_only=True # These commands belong to channels of
)
@channel_group.command(name="include-feed")
@autocomplete(rss=autocomplete_rss_sources)
async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None):
"""Include a feed within the specified channel.
Parameters
----------
inter : Interaction
Represents an app command interaction.
rss : str
The RSS feed to include.
channel : TextChannel
The channel to include the feed in.
"""
await inter.response.defer()
channel = channel or inter.channel
async with DatabaseManager() as database:
select_query = select(RssSourceModel).where(and_(
RssSourceModel.id == rss,
RssSourceModel.discord_server_id == inter.guild_id
))
select_result = await database.session.execute(select_query)
rss_source = select_result.scalars().one()
nick, rss_url = rss_source.nick, rss_source.rss_url
insert_query = insert(FeedChannelModel).values(
discord_server_id = inter.guild_id,
discord_channel_id = channel.id,
rss_source_id=rss,
search_name=f"{nick} #{channel.name}"
)
insert_result = await database.session.execute(insert_query)
await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}")
@channel_group.command(name="exclude-feed")
@autocomplete(option=autocomplete_existing_feeds)
async def exclude_feed(self, inter: Interaction, option: int):
"""Undo command for the `/channel include-feed` command.
Parameters
----------
inter : Interaction
Represents an app command interaction.
option : str
The RSS feed and channel to exclude.
"""
await inter.response.defer()
async with DatabaseManager() as database:
query = delete(FeedChannelModel).where(and_(
FeedChannelModel.id == option,
FeedChannelModel.discord_server_id == inter.guild_id
))
result = await database.session.execute(query)
if not result.rowcount:
await followup(inter, "I couldn't find any items under that ID (placeholder response)")
return
await followup(inter, "I've removed this item (placeholder response)")
@channel_group.command(name="list")
# @choices(sort=[
# Choice(name="RSS Nickname", value=0),
# Choice(name="Channel ID", value=1),
# Choice(name="Date Added", value=2)
# ])
async def list_feeds(self, inter: Interaction): # sort: int
"""List all of the channels and their respective included feeds.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.response.defer()
async with DatabaseManager() as database:
whereclause = and_(
FeedChannelModel.discord_server_id == inter.guild_id,
RssSourceModel.id == FeedChannelModel.rss_source_id
)
query = (
select(FeedChannelModel, RssSourceModel)
.where(whereclause)
.join(RssSourceModel)
.order_by(FeedChannelModel.discord_channel_id)
)
result = await database.session.execute(query)
feed_channels = result.scalars().all()
rowcount = len(feed_channels)
if not feed_channels:
await followup(inter, "It looks like there are no feed channels available.")
return
output = "\n".join([
f"{i}. <#{feed.discord_channel_id}> · [{feed.rss_source.nick}]({feed.rss_source.rss_url})"
for i, feed in enumerate(feed_channels)
])
embed = Embed(
title="Saved Feed Channels",
description=f"{output}",
colour=Colour.blue()
)
embed.set_footer(text=f"Showing {rowcount} results")
await followup(inter, embed=embed)
async def setup(bot):
"""
Setup function for this extension.
Adds `ChannelCog` to the bot.
"""
cog = ChannelCog(bot)
await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog")

View File

@ -4,20 +4,23 @@ Loading this file via `commands.Bot.load_extension` will add `RssCog` to the bot
"""
import logging
import validators
from typing import Tuple
import textwrap
import feedparser
from markdownify import markdownify
from discord import Interaction, Embed, Colour
import validators
from feedparser import FeedParserDict, parse
from discord.ext import commands
from discord.app_commands import Choice, Group, autocomplete, choices
from discord import Interaction, Embed, Colour, TextChannel
from discord.app_commands import Choice, Group, autocomplete, choices, rename
from sqlalchemy import insert, select, and_, delete
from utils import get_rss_data, followup, audit
from feed import get_source, Source
from db import DatabaseManager, SentArticleModel, RssSourceModel
from utils import get_rss_data, followup, audit # pylint: disable=E0401
from feed import get_source, Source # pylint: disable=E0401
from db import ( # pylint: disable=E0401
DatabaseManager,
SentArticleModel,
RssSourceModel,
FeedChannelModel
)
log = logging.getLogger(__name__)
@ -26,9 +29,9 @@ rss_list_sort_choices = [
Choice(name="Date Added", value=1)
]
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource.
# Run a period task to check this.
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]:
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the
# target resource. Run a period task to check this.
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, FeedParserDict | None]:
"""Validate a provided RSS source.
Parameters
@ -63,7 +66,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feed
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
# Check the contents is actually an RSS feed.
feed = feedparser.parse(feed_data)
feed = parse(feed_data)
if not feed.version:
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
@ -81,7 +84,9 @@ class RssCog(commands.Cog):
@commands.Cog.listener()
async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready")
"""Instructions to call when the cog is ready."""
log.info("%s cog is ready", self.__class__.__name__)
async def source_autocomplete(self, inter: Interaction, nickname: str):
"""Provides RSS source autocomplete functionality for commands.
@ -114,13 +119,13 @@ class RssCog(commands.Cog):
return sources
# All RSS commands belong to this group.
rss_group = Group(
name="rss",
feed_group = Group(
name="feed",
description="Commands for rss sources.",
guild_only=True # We store guild IDs in the database, so guild only = True
)
@rss_group.command(name="add")
@feed_group.command(name="add")
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
"""Add a new RSS source.
@ -163,7 +168,7 @@ class RssCog(commands.Cog):
await followup(inter, embed=embed)
@rss_group.command(name="remove")
@feed_group.command(name="remove")
@autocomplete(url=source_autocomplete)
async def remove_rss_source(self, inter: Interaction, url: str):
"""Delete an existing RSS source.
@ -178,7 +183,7 @@ class RssCog(commands.Cog):
await inter.response.defer()
log.debug(f"Attempting to remove RSS source ({url=})")
log.debug("Attempting to remove RSS source (url=%s)", url)
async with DatabaseManager() as database:
select_result = await database.session.execute(
@ -219,7 +224,7 @@ class RssCog(commands.Cog):
await followup(inter, embed=embed)
@rss_group.command(name="list")
@feed_group.command(name="list")
@choices(sort=rss_list_sort_choices)
async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False):
"""Provides a with a list of RSS sources available for the current server.
@ -278,9 +283,10 @@ class RssCog(commands.Cog):
await followup(inter, embed=embed)
@rss_group.command(name="fetch")
@feed_group.command(name="fetch")
@rename(max_="max")
@autocomplete(rss=source_autocomplete)
async def fetch_rss(self, inter: Interaction, rss: str, max: int=1):
async def fetch_rss(self, inter: Interaction, rss: str, max_: int=1):
"""Fetch an item from the specified RSS feed.
Parameters
@ -289,13 +295,13 @@ class RssCog(commands.Cog):
Represents an app command interaction.
rss : str
The RSS feed to fetch from.
max : int, optional
max_ : int, optional
Maximum number of items to fetch, by default 1, limits at 5.
"""
await inter.response.defer()
if max > 5:
if max_ > 5:
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
return
@ -305,7 +311,7 @@ class RssCog(commands.Cog):
return
source = Source.from_parsed(feed)
articles = source.get_latest_articles(max)
articles = source.get_latest_articles(max_)
if not articles:
await followup(inter, "Sorry, I couldn't find any articles from this feed.")
@ -324,11 +330,196 @@ class RssCog(commands.Cog):
for article in articles
])
await database.session.execute(query)
await audit(self, f"User is requesting {max} articles from {source.name}", inter.user.id, database=database)
await audit(self, f"User is requesting {max_} articles from {source.name}", inter.user.id, database=database)
await followup(inter, embeds=embeds)
# Channels ---- ---- ----
async def autocomplete_rss_sources(self, inter: Interaction, nickname: str):
""""""
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.nick.ilike(f"%{nickname}%")
)
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
sources = [
Choice(name=rss.nick, value=rss.id)
for rss in result.scalars().all()
]
log.debug("Autocomplete rss_sources returned %s results", len(sources))
return sources
async def autocomplete_existing_feeds(self, inter: Interaction, current: str):
"""Returns a list of existing RSS + Channel feeds.
Parameters
----------
inter : Interaction
Represents an app command interaction.
current : str
The current text entered for the autocomplete.
"""
async with DatabaseManager() as database:
whereclause = and_(
FeedChannelModel.discord_server_id == inter.guild_id,
FeedChannelModel.search_name.ilike(f"%{current}%"), # is this secure from SQL Injection atk ?
RssSourceModel.id == FeedChannelModel.rss_source_id
)
query = (
select(FeedChannelModel, RssSourceModel)
.where(whereclause)
.join(RssSourceModel)
.order_by(FeedChannelModel.discord_channel_id)
)
result = await database.session.execute(query)
feeds = []
for feed in result.scalars().all():
channel = inter.guild.get_channel(feed.discord_channel_id)
feeds.append(Choice(name=f"# {channel.name} | {feed.rss_source.nick}", value=feed.id))
log.debug("Autocomplete existing_feeds returned %s results", len(feeds))
return feeds
# # All RSS commands belong to this group.
# channel_group = Group(
# name="channels",
# description="Commands for channel assignment.",
# guild_only=True # These commands belong to channels of
# )
@feed_group.command(name="assign")
@autocomplete(rss=autocomplete_rss_sources)
async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None):
"""Include a feed within the specified channel.
Parameters
----------
inter : Interaction
Represents an app command interaction.
rss : str
The RSS feed to include.
channel : TextChannel
The channel to include the feed in.
"""
await inter.response.defer()
channel = channel or inter.channel
async with DatabaseManager() as database:
select_query = select(RssSourceModel).where(and_(
RssSourceModel.id == rss,
RssSourceModel.discord_server_id == inter.guild_id
))
select_result = await database.session.execute(select_query)
rss_source = select_result.scalars().one()
nick, rss_url = rss_source.nick, rss_source.rss_url
insert_query = insert(FeedChannelModel).values(
discord_server_id = inter.guild_id,
discord_channel_id = channel.id,
rss_source_id=rss,
search_name=f"{nick} #{channel.name}"
)
await database.session.execute(insert_query)
await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}")
@feed_group.command(name="unassign")
@autocomplete(option=autocomplete_existing_feeds)
async def exclude_feed(self, inter: Interaction, option: int):
"""Undo command for the `/channel include-feed` command.
Parameters
----------
inter : Interaction
Represents an app command interaction.
option : str
The RSS feed and channel to exclude.
"""
await inter.response.defer()
async with DatabaseManager() as database:
query = delete(FeedChannelModel).where(and_(
FeedChannelModel.id == option,
FeedChannelModel.discord_server_id == inter.guild_id
))
result = await database.session.execute(query)
if not result.rowcount:
await followup(inter, "I couldn't find any items under that ID (placeholder response)")
return
await followup(inter, "I've removed this item (placeholder response)")
@feed_group.command(name="channels")
# @choices(sort=[
# Choice(name="RSS Nickname", value=0),
# Choice(name="Channel ID", value=1),
# Choice(name="Date Added", value=2)
# ])
async def list_feeds(self, inter: Interaction): # sort: int
"""List all of the channels and their respective included feeds.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.response.defer()
async with DatabaseManager() as database:
whereclause = and_(
FeedChannelModel.discord_server_id == inter.guild_id,
RssSourceModel.id == FeedChannelModel.rss_source_id
)
query = (
select(FeedChannelModel, RssSourceModel)
.where(whereclause)
.join(RssSourceModel)
.order_by(FeedChannelModel.discord_channel_id)
)
result = await database.session.execute(query)
feed_channels = result.scalars().all()
rowcount = len(feed_channels)
if not feed_channels:
await followup(inter, "It looks like there are no feed channels available.")
return
output = "\n".join([
f"{i}. <#{feed.discord_channel_id}> · [{feed.rss_source.nick}]({feed.rss_source.rss_url})"
for i, feed in enumerate(feed_channels)
])
embed = Embed(
title="Saved Feed Channels",
description=f"{output}",
colour=Colour.blue()
)
embed.set_footer(text=f"Showing {rowcount} results")
await followup(inter, embed=embed)
async def setup(bot):
"""
Setup function for this extension.
@ -337,4 +528,4 @@ async def setup(bot):
cog = RssCog(bot)
await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog")
log.info("Added %s cog", cog.__class__.__name__)