Created Followup and move functions to feeds.py

This commit is contained in:
Corban-Lee Jones 2023-12-24 23:01:48 +00:00
parent 94b7972a38
commit e8d13ae26b
6 changed files with 422 additions and 200 deletions

View File

@ -4,8 +4,7 @@ All table classes should be suffixed with `Model`.
"""
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, declarative_base
from sqlalchemy import (
Column,
Integer,

View File

@ -1,4 +1,5 @@
class IllegalFeed(Exception):
pass
def __init__(self, message: str, **items):
super().__init__(message)
self.items = items

View File

@ -12,18 +12,25 @@ from discord.ext import commands
from discord import Interaction, Embed, Colour, TextChannel, Permissions
from discord.app_commands import Choice, Group, autocomplete, choices, rename
from sqlalchemy import insert, select, and_, delete
from sqlalchemy.exc import NoResultFound
from sqlalchemy.exc import NoResultFound, IntegrityError
from utils import get_rss_data, followup, audit, followup_error, extract_error_info # pylint: disable=E0401
from feed import Source # pylint: disable=E0401
from db import ( # pylint: disable=E0401
from feed import Source
from errors import IllegalFeed
from db import (
DatabaseManager,
SentArticleModel,
RssSourceModel,
FeedChannelModel,
AuditModel
)
from errors import IllegalFeed
from utils import (
Followup,
get_rss_data,
followup,
audit,
extract_error_info,
get_unparsed_feed
)
log = logging.getLogger(__name__)
@ -81,7 +88,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed
return None, feed
async def set_all_articles_as_sent(inter, channel: TextChannel, feed_id: int, rss_url: str):
unparsed_feed = await self.bot.functions.get_unparsed_feed(rss_url)
unparsed_feed = await get_unparsed_feed(rss_url)
source = Source.from_parsed(parse(unparsed_feed))
articles = source.get_latest_articles()
@ -148,101 +155,89 @@ class FeedCog(commands.Cog):
feed_group = Group(
name="feed",
description="Commands for rss sources.",
guild_only=True, # We store guild IDs in the database, so guild only = True
default_permissions=Permissions.elevated()
default_permissions=Permissions.elevated(),
guild_only=True # We store guild IDs in the database, so guild only = True
)
@feed_group.command(name="add")
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
"""Add a new RSS source.
"""Add a new Feed for this server.
Parameters
----------
inter : Interaction
Represents an app command interaction.
nickname : str
A name used to identify the RSS source.
A name used to identify the Feed.
url : str
The RSS feed URL.
The Feed URL.
"""
await inter.response.defer()
try:
source = self.bot.functions.create_new_feed(nickname, url)
source = await self.bot.functions.create_new_feed(nickname, url, inter.guild_id)
except IllegalFeed as error:
title, desc = extract_error_info(error)
followup_error(inter, title=title, description=desc)
embed = Embed(title="RSS Feed Added", colour=Colour.dark_green())
embed.add_field(name="Nickname", value=nickname)
embed.add_field(name="URL", value=url)
embed.set_thumbnail(url=source.thumb_url)
await followup(inter, embed=embed)
await Followup(title, desc).fields(**error.items).error().send(inter)
except IntegrityError as error:
await (
Followup(
"Duplicate Feed Error",
"A Feed with the same nickname already exist."
)
.fields(nickname=nickname)
.error()
.send(inter)
)
else:
await (
Followup("Feed Added")
.image(source.icon_url)
.fields(nickname=nickname, url=url)
.added()
.send(inter)
)
@feed_group.command(name="remove")
@rename(url="option")
@autocomplete(url=source_autocomplete)
async def remove_rss_source(self, inter: Interaction, url: str):
"""Delete an existing RSS source.
"""Delete an existing Feed from this server.
Parameters
----------
inter : Interaction
Represents an app command interaction.
url : str
The RSS source to be removed. Autocomplete or enter the URL.
The Feed to be removed. Autocomplete or enter the URL.
"""
await inter.response.defer()
log.debug("Attempting to remove RSS source (url=%s)", url)
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == url
)
# We will select the item first, so we can reference it's nickname later.
select_query = select(RssSourceModel).filter(whereclause)
select_result = await database.session.execute(select_query)
try:
rss_source = select_result.scalars().one()
except NoResultFound:
await followup_error(inter,
title="Error Deleting Feed",
message=f"I couldn't find anything for `{url}`"
try:
source = await self.bot.functions.delete_feed(url, inter.guild_id)
except NoResultFound:
await (
Followup(
"Feed Not Found Error",
"A Feed with these parameters could not be found."
)
return
nickname = rss_source.nick
delete_query = delete(RssSourceModel).filter(whereclause)
delete_result = await database.session.execute(delete_query)
await audit(self,
f"Deleted RSS source ({nickname=}, {url=})",
inter.user.id, database=database
.error()
.send(inter)
)
else:
await (
Followup("Feed Deleted")
.image(source.icon_url)
.fields(url=url)
.trash()
.send(inter)
)
source = await Source.from_url(url)
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
embed.add_field(name="Nickname", value=nickname)
embed.add_field(name="URL", value=url)
embed.set_thumbnail(url=source.icon_url)
await followup(inter, embed=embed)
@feed_group.command(name="list")
@choices(sort=rss_list_sort_choices)
async def list_rss_sources(
self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False
):
"""Provides a with a list of RSS sources available for the current server.
async def list_rss_sources(self, inter: Interaction):
"""Provides a with a list of Feeds available for this server.
Parameters
----------
@ -252,55 +247,32 @@ class FeedCog(commands.Cog):
await inter.response.defer()
# Default to the first choice if not specified.
if isinstance(sort, Choice):
description = "Sort by "
description += "Nickname " if sort.value == 0 else "Date Added "
description += '\U000025BC' if sort_reverse else '\U000025B2'
else:
sort = rss_list_sort_choices[0]
description = ""
match sort.value, sort_reverse:
case 0, False:
order_by = RssSourceModel.nick.asc()
case 0, True:
order_by = RssSourceModel.nick.desc()
case 1, False:
order_by = RssSourceModel.created.desc()
case 1, True:
order_by = RssSourceModel.created.asc()
case _, _:
raise ValueError(f"Unknown sort: {sort}")
async with DatabaseManager() as database:
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
query = select(RssSourceModel).where(whereclause).order_by(order_by)
result = await database.session.execute(query)
rss_sources = result.scalars().all()
rowcount = len(rss_sources)
if not rss_sources:
await followup_error(inter,
title="No Feeds Found",
message="I couldn't find any Feeds for this server."
try:
sources = await self.bot.functions.get_feeds(inter.guild_id)
except NoResultFound:
await (
Followup(
"Feeds Not Found Error",
"There are no available Feeds for this server.\n"
"Add a new feed with `/feed add`."
)
return
output = "\n".join([
f"{i}. **[{rss.nick}]({rss.rss_url})** "
for i, rss in enumerate(rss_sources)
.error()
.send()
)
else:
description = "\n".join([
f"{i}. **[{source.name}]({source.url})**"
for i, source in enumerate(sources)
])
await (
Followup(
f"Available Feeds in {inter.guild.name}",
description
)
.info()
.send(inter)
)
embed = Embed(
title="Saved RSS Feeds",
description=f"{description}\n\n{output}",
colour=Colour.blue()
)
embed.set_footer(text=f"Showing {rowcount} results")
await followup(inter, embed=embed)
# @feed_group.command(name="fetch")
# @rename(max_="max")
@ -431,7 +403,7 @@ class FeedCog(commands.Cog):
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
sources = [
Choice(name=rss.nick, value=rss.id)
Choice(name=rss.nick, value=rss.rss_url)
for rss in result.scalars().all()
]
@ -483,10 +455,10 @@ class FeedCog(commands.Cog):
# )
@feed_group.command(name="assign")
@rename(rss="feed")
@autocomplete(rss=autocomplete_rss_sources)
@rename(url="feed")
@autocomplete(url=autocomplete_rss_sources)
async def include_feed(
self, inter: Interaction, rss: int, channel: TextChannel = None, prevent_spam: bool = True
self, inter: Interaction, url: str, channel: TextChannel = None, prevent_spam: bool = True
):
"""Include a feed within the specified channel.
@ -494,7 +466,7 @@ class FeedCog(commands.Cog):
----------
inter : Interaction
Represents an app command interaction.
rss : int
url : int
The RSS feed to include.
channel : TextChannel
The channel to include the feed in.
@ -504,30 +476,41 @@ class FeedCog(commands.Cog):
channel = channel or inter.channel
async with DatabaseManager() as database:
select_query = select(RssSourceModel).where(and_(
RssSourceModel.id == rss,
RssSourceModel.discord_server_id == inter.guild_id
))
select_result = await database.session.execute(select_query)
rss_source = select_result.scalars().one()
nick, rss_url = rss_source.nick, rss_source.rss_url
insert_query = insert(FeedChannelModel).values(
discord_server_id = inter.guild_id,
discord_channel_id = channel.id,
rss_source_id=rss,
search_name=f"{nick} #{channel.name}"
try:
feed_id, source = await self.bot.functions.assign_feed(
url, channel.name, channel.id, inter.guild_id
)
except IntegrityError:
await (
Followup(
"Duplicate Assigned Feed Error",
f"This Feed has already been assigned to {channel.mention}"
)
.error()
.send(inter)
)
except NoResultFound:
await (
Followup(
"Feed Not Found Error",
"A Feed with these parameters could not be found."
)
.error()
.send(inter)
)
else:
await (
Followup(
"Feed Assigned",
f"I've assigned {channel.mention} to receive content from "
f"[{source.name}]({source.url})."
)
.assign()
.send(inter)
)
insert_result = await database.session.execute(insert_query)
feed_id = insert_result.inserted_primary_key.id
if prevent_spam:
await set_all_articles_as_sent(inter, channel, feed_id, rss_url)
await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}")
await set_all_articles_as_sent(inter, channel, feed_id, url)
@feed_group.command(name="unassign")
@autocomplete(option=autocomplete_existing_feeds)
@ -544,20 +527,41 @@ class FeedCog(commands.Cog):
await inter.response.defer()
async with DatabaseManager() as database:
query = delete(FeedChannelModel).where(and_(
FeedChannelModel.id == option,
FeedChannelModel.discord_server_id == inter.guild_id
))
result = await database.session.execute(query)
if not result.rowcount:
await followup_error(inter,
title="Assigned Feed Not Found",
message=f"I couldn't find any assigned feeds for the option: {option}"
try:
await self.bot.functions.unassign_feed(option, inter.guild_id)
except NoResultFound:
await (
Followup(
"Assigned Feed Not Found",
"The assigned Feed doesn't exist."
)
.error()
.send(inter)
)
return
else:
await (
Followup(
"Unassigned Feed",
"Feed has been unassigned."
)
.trash()
.send(inter)
)
# async with DatabaseManager() as database:
# query = delete(FeedChannelModel).where(and_(
# FeedChannelModel.id == option,
# FeedChannelModel.discord_server_id == inter.guild_id
# ))
# result = await database.session.execute(query)
# if not result.rowcount:
# await followup_error(inter,
# title="Assigned Feed Not Found",
# message=f"I couldn't find any assigned feeds for the option: {option}"
# )
# return
await followup(inter, "I've removed this item (placeholder response)")

View File

@ -6,14 +6,20 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo
import logging
from time import process_time
from feedparser import parse
from sqlalchemy import insert, select, and_
from discord import Interaction, TextChannel
from discord import TextChannel
from discord.ext import commands, tasks
from discord.errors import Forbidden
from sqlalchemy import insert, select, and_
from feedparser import parse
from feed import Source, Article # pylint disable=E0401
from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel # pylint disable=E0401
from feed import Source, Article
from db import (
DatabaseManager,
FeedChannelModel,
RssSourceModel,
SentArticleModel
)
from utils import get_unparsed_feed
log = logging.getLogger(__name__)
@ -68,7 +74,7 @@ class TaskCog(commands.Cog):
channel = self.bot.get_channel(feed.discord_channel_id)
unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url)
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
parsed_feed = parse(unparsed_content)
source = Source.from_parsed(parsed_feed)
articles = source.get_latest_articles(5)

View File

@ -1,22 +1,23 @@
import json
import logging
import async_timeout
from dataclasses import dataclass
from datetime import datetime
from typing import Tuple
import aiohttp
import aiohttp
import validators
from textwrap import shorten
from markdownify import markdownify
from discord import Embed, Colour
from bs4 import BeautifulSoup as bs4
from feedparser import FeedParserDict, parse
from markdownify import markdownify
from sqlalchemy import select, insert, delete, and_
from sqlalchemy.exc import NoResultFound
from textwrap import shorten
from utils import audit
from errors import IllegalFeed
from db import DatabaseManager, RssSourceModel, FeedChannelModel
from utils import get_rss_data, get_unparsed_feed
log = logging.getLogger(__name__)
dumps = lambda _dict: json.dumps(_dict, indent=8)
@ -162,8 +163,8 @@ class Source:
@classmethod
async def from_url(cls, url: str):
unparsed_content = await Functions.get_unparsed_feed(url)
return
unparsed_content = await get_unparsed_feed(url)
return cls.from_parsed(parse(unparsed_content))
def get_latest_articles(self, max: int = 999) -> list[Article]:
"""Returns a list of Article objects.
@ -193,19 +194,26 @@ class Functions:
def __init__(self, bot):
self.bot = bot
@staticmethod
async def fetch(session, url: str) -> str:
async with async_timeout.timeout(20):
async with session.get(url) as response:
return await response.text()
@staticmethod
async def get_unparsed_feed(url: str):
async with aiohttp.ClientSession() as session:
return await self.fetch(session, url) # TODO: work from here
async def validate_feed(self, nickname: str, url: str) -> FeedParserDict:
""""""
"""Validates a feed based on the given nickname and url.
Parameters
----------
nickname : str
Human readable nickname used to refer to the feed.
url : str
URL to fetch content from the feed.
Returns
-------
FeedParserDict
A Parsed Dictionary of the feed.
Raises
------
IllegalFeed
If the feed is invalid.
"""
# Ensure the URL is valid
if not validators.url(url):
@ -215,28 +223,47 @@ class Functions:
if validators.url(nickname):
raise IllegalFeed(
"It looks like the nickname you have entered is a URL.\n" \
f"For security reasons, this is not allowed.\n`{nickname=}`"
"For security reasons, this is not allowed.",
nickname=nickname
)
feed_data, status_code = await get_rss_data(url)
if status_code != 200:
raise IllegalFeed(
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}"
"The URL provided returned an invalid status code:",
url=url, status_code=status_code
)
# Check the contents is actually an RSS feed.
feed = parse(feed_data)
if not feed.version:
raise IllegalFeed(
f"The provided URL '{url}' does not seem to be a valid RSS feed."
"The provided URL does not seem to be a valid RSS feed.",
url=url
)
return feed
async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source:
""""""
"""Create a new Feed, and return it as a Source object.
Parameters
----------
nickname : str
Human readable nickname used to refer to the feed.
url : str
URL to fetch content from the feed.
guild_id : int
Discord Server ID associated with the feed.
Returns
-------
Source
Dataclass containing attributes of the feed.
"""
log.info("Creating new Feed: %s - %s", nickname, guild_id)
parsed_feed = await self.validate_feed(nickname, url)
@ -248,4 +275,110 @@ class Functions:
)
await database.session.execute(query)
log.info("Created Feed: %s - %s", nickname, guild_id)
return Source.from_parsed(parsed_feed)
async def delete_feed(self, url: str, guild_id: int) -> Source:
"""Delete an existing Feed, then return it as a Source object.
Parameters
----------
url : str
URL of the feed, used in the whereclause.
guild_id : int
Discord Server ID of the feed, used in the whereclause.
Returns
-------
Source
Dataclass containing attributes of the feed.
"""
log.info("Deleting Feed: %s - %s", url, guild_id)
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == guild_id,
RssSourceModel.rss_url == url
)
# Select the Feed entry, because an exception is raised if not found.
select_query = select(RssSourceModel).filter(whereclause)
select_result = await database.session.execute(select_query)
select_result.scalars().one()
delete_query = delete(RssSourceModel).filter(whereclause)
await database.session.execute(delete_query)
log.info("Deleted Feed: %s - %s", url, guild_id)
return await Source.from_url(url)
async def get_feeds(self, guild_id: int) -> list[Source]:
"""Returns a list of fetched Feed objects from the database.
Note: a request will be made too all found Feed URLs.
Parameters
----------
guild_id : int
The Discord Server ID, used to filter down the Feed query.
Returns
-------
list[Source]
List of Source objects, resulting from the query.
Raises
------
NoResultFound
Raised if no results are found.
"""
async with DatabaseManager() as database:
whereclause = and_(RssSourceModel.discord_server_id == guild_id)
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
rss_sources = result.scalars().all()
if not rss_sources:
raise NoResultFound
return [await Source.from_url(feed.rss_url) for feed in rss_sources]
async def assign_feed(
self, url: str, channel_name: str, channel_id: int, guild_id: int
) -> tuple[int, Source]:
""""""
async with DatabaseManager() as database:
select_query = select(RssSourceModel).where(and_(
RssSourceModel.rss_url == url,
RssSourceModel.discord_server_id == guild_id
))
select_result = await database.session.execute(select_query)
rss_source = select_result.scalars().one()
insert_query = insert(FeedChannelModel).values(
discord_server_id = guild_id,
discord_channel_id = channel_id,
rss_source_id=rss_source.id,
search_name=f"{rss_source.nick} #{channel_name}"
)
insert_result = await database.session.execute(insert_query)
return insert_result.inserted_primary_key.id, await Source.from_url(url)
async def unassign_feed( self, assigned_feed_id: int, guild_id: int):
""""""
async with DatabaseManager() as database:
query = delete(FeedChannelModel).where(and_(
FeedChannelModel.id == assigned_feed_id,
FeedChannelModel.discord_server_id == guild_id
))
result = await database.session.execute(query)
if not result.rowcount:
raise NoResultFound

View File

@ -2,11 +2,21 @@
import aiohttp
import logging
import async_timeout
from discord import Interaction, Embed, Colour
log = logging.getLogger(__name__)
async def fetch(session, url: str) -> str:
async with async_timeout.timeout(20):
async with session.get(url) as response:
return await response.text()
async def get_unparsed_feed(url: str):
async with aiohttp.ClientSession() as session:
return await fetch(session, url)
async def get_rss_data(url: str):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
@ -30,25 +40,94 @@ async def audit(cog, *args, **kwargs):
await cog.bot.audit(*args, **kwargs)
async def followup_error(inter: Interaction, title: str, message: str, *args, **kwargs):
"""Shorthand for following up on an interaction, except returns an embed styled in
error colours.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
# https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
await inter.followup.send(
*args,
embed=Embed(
class FollowupIcons:
error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"
trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png"
info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png"
added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png"
assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
class Followup:
"""Wrapper for a discord embed to follow up an interaction."""
def __init__(
self,
title: str = None,
description: str = None,
):
self._embed = Embed(
title=title,
description=message,
colour=Colour.red()
),
**kwargs
)
description=description
)
async def send(self, inter: Interaction, message: str = None):
""""""
await inter.followup.send(content=message, embed=self._embed)
def fields(self, inline: bool = False, **fields: dict):
""""""
for key, value in fields.items():
self._embed.add_field(name=key, value=value, inline=inline)
return self
def image(self, url: str):
""""""
self._embed.set_image(url=url)
return self
def error(self):
""""""
self._embed.colour = Colour.red()
self._embed.set_thumbnail(url=FollowupIcons.error)
return self
def success(self):
""""""
self._embed.colour = Colour.green()
self._embed.set_thumbnail(url=FollowupIcons.success)
return self
def info(self):
""""""
self._embed.colour = Colour.blue()
self._embed.set_thumbnail(url=FollowupIcons.info)
return self
def added(self):
""""""
self._embed.colour = Colour.blue()
self._embed.set_thumbnail(url=FollowupIcons.added)
return self
def assign(self):
""""""
self._embed.colour = Colour.blue()
self._embed.set_thumbnail(url=FollowupIcons.assigned)
return self
def trash(self):
""""""
self._embed.colour = Colour.red()
self._embed.set_thumbnail(url=FollowupIcons.trash)
return self
def extract_error_info(error: Exception) -> str:
class_name = error.__class__.__name__