Created Followup and move functions to feeds.py
This commit is contained in:
parent
94b7972a38
commit
e8d13ae26b
@ -4,8 +4,7 @@ All table classes should be suffixed with `Model`.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship, declarative_base
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column,
|
Column,
|
||||||
Integer,
|
Integer,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
|
|
||||||
class IllegalFeed(Exception):
|
class IllegalFeed(Exception):
|
||||||
pass
|
def __init__(self, message: str, **items):
|
||||||
|
super().__init__(message)
|
||||||
|
self.items = items
|
||||||
|
@ -12,18 +12,25 @@ from discord.ext import commands
|
|||||||
from discord import Interaction, Embed, Colour, TextChannel, Permissions
|
from discord import Interaction, Embed, Colour, TextChannel, Permissions
|
||||||
from discord.app_commands import Choice, Group, autocomplete, choices, rename
|
from discord.app_commands import Choice, Group, autocomplete, choices, rename
|
||||||
from sqlalchemy import insert, select, and_, delete
|
from sqlalchemy import insert, select, and_, delete
|
||||||
from sqlalchemy.exc import NoResultFound
|
from sqlalchemy.exc import NoResultFound, IntegrityError
|
||||||
|
|
||||||
from utils import get_rss_data, followup, audit, followup_error, extract_error_info # pylint: disable=E0401
|
from feed import Source
|
||||||
from feed import Source # pylint: disable=E0401
|
from errors import IllegalFeed
|
||||||
from db import ( # pylint: disable=E0401
|
from db import (
|
||||||
DatabaseManager,
|
DatabaseManager,
|
||||||
SentArticleModel,
|
SentArticleModel,
|
||||||
RssSourceModel,
|
RssSourceModel,
|
||||||
FeedChannelModel,
|
FeedChannelModel,
|
||||||
AuditModel
|
AuditModel
|
||||||
)
|
)
|
||||||
from errors import IllegalFeed
|
from utils import (
|
||||||
|
Followup,
|
||||||
|
get_rss_data,
|
||||||
|
followup,
|
||||||
|
audit,
|
||||||
|
extract_error_info,
|
||||||
|
get_unparsed_feed
|
||||||
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -81,7 +88,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed
|
|||||||
return None, feed
|
return None, feed
|
||||||
|
|
||||||
async def set_all_articles_as_sent(inter, channel: TextChannel, feed_id: int, rss_url: str):
|
async def set_all_articles_as_sent(inter, channel: TextChannel, feed_id: int, rss_url: str):
|
||||||
unparsed_feed = await self.bot.functions.get_unparsed_feed(rss_url)
|
unparsed_feed = await get_unparsed_feed(rss_url)
|
||||||
source = Source.from_parsed(parse(unparsed_feed))
|
source = Source.from_parsed(parse(unparsed_feed))
|
||||||
articles = source.get_latest_articles()
|
articles = source.get_latest_articles()
|
||||||
|
|
||||||
@ -148,101 +155,89 @@ class FeedCog(commands.Cog):
|
|||||||
feed_group = Group(
|
feed_group = Group(
|
||||||
name="feed",
|
name="feed",
|
||||||
description="Commands for rss sources.",
|
description="Commands for rss sources.",
|
||||||
guild_only=True, # We store guild IDs in the database, so guild only = True
|
default_permissions=Permissions.elevated(),
|
||||||
default_permissions=Permissions.elevated()
|
guild_only=True # We store guild IDs in the database, so guild only = True
|
||||||
)
|
)
|
||||||
|
|
||||||
@feed_group.command(name="add")
|
@feed_group.command(name="add")
|
||||||
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
|
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
|
||||||
"""Add a new RSS source.
|
"""Add a new Feed for this server.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
inter : Interaction
|
inter : Interaction
|
||||||
Represents an app command interaction.
|
Represents an app command interaction.
|
||||||
nickname : str
|
nickname : str
|
||||||
A name used to identify the RSS source.
|
A name used to identify the Feed.
|
||||||
url : str
|
url : str
|
||||||
The RSS feed URL.
|
The Feed URL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await inter.response.defer()
|
await inter.response.defer()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
source = self.bot.functions.create_new_feed(nickname, url)
|
source = await self.bot.functions.create_new_feed(nickname, url, inter.guild_id)
|
||||||
except IllegalFeed as error:
|
except IllegalFeed as error:
|
||||||
title, desc = extract_error_info(error)
|
title, desc = extract_error_info(error)
|
||||||
followup_error(inter, title=title, description=desc)
|
await Followup(title, desc).fields(**error.items).error().send(inter)
|
||||||
|
except IntegrityError as error:
|
||||||
embed = Embed(title="RSS Feed Added", colour=Colour.dark_green())
|
await (
|
||||||
embed.add_field(name="Nickname", value=nickname)
|
Followup(
|
||||||
embed.add_field(name="URL", value=url)
|
"Duplicate Feed Error",
|
||||||
embed.set_thumbnail(url=source.thumb_url)
|
"A Feed with the same nickname already exist."
|
||||||
|
)
|
||||||
await followup(inter, embed=embed)
|
.fields(nickname=nickname)
|
||||||
|
.error()
|
||||||
|
.send(inter)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await (
|
||||||
|
Followup("Feed Added")
|
||||||
|
.image(source.icon_url)
|
||||||
|
.fields(nickname=nickname, url=url)
|
||||||
|
.added()
|
||||||
|
.send(inter)
|
||||||
|
)
|
||||||
|
|
||||||
@feed_group.command(name="remove")
|
@feed_group.command(name="remove")
|
||||||
@rename(url="option")
|
@rename(url="option")
|
||||||
@autocomplete(url=source_autocomplete)
|
@autocomplete(url=source_autocomplete)
|
||||||
async def remove_rss_source(self, inter: Interaction, url: str):
|
async def remove_rss_source(self, inter: Interaction, url: str):
|
||||||
"""Delete an existing RSS source.
|
"""Delete an existing Feed from this server.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
inter : Interaction
|
inter : Interaction
|
||||||
Represents an app command interaction.
|
Represents an app command interaction.
|
||||||
url : str
|
url : str
|
||||||
The RSS source to be removed. Autocomplete or enter the URL.
|
The Feed to be removed. Autocomplete or enter the URL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await inter.response.defer()
|
await inter.response.defer()
|
||||||
|
|
||||||
log.debug("Attempting to remove RSS source (url=%s)", url)
|
try:
|
||||||
|
source = await self.bot.functions.delete_feed(url, inter.guild_id)
|
||||||
async with DatabaseManager() as database:
|
except NoResultFound:
|
||||||
whereclause = and_(
|
await (
|
||||||
RssSourceModel.discord_server_id == inter.guild_id,
|
Followup(
|
||||||
RssSourceModel.rss_url == url
|
"Feed Not Found Error",
|
||||||
)
|
"A Feed with these parameters could not be found."
|
||||||
|
|
||||||
# We will select the item first, so we can reference it's nickname later.
|
|
||||||
select_query = select(RssSourceModel).filter(whereclause)
|
|
||||||
select_result = await database.session.execute(select_query)
|
|
||||||
|
|
||||||
try:
|
|
||||||
rss_source = select_result.scalars().one()
|
|
||||||
except NoResultFound:
|
|
||||||
await followup_error(inter,
|
|
||||||
title="Error Deleting Feed",
|
|
||||||
message=f"I couldn't find anything for `{url}`"
|
|
||||||
)
|
)
|
||||||
return
|
.error()
|
||||||
|
.send(inter)
|
||||||
nickname = rss_source.nick
|
)
|
||||||
|
else:
|
||||||
delete_query = delete(RssSourceModel).filter(whereclause)
|
await (
|
||||||
delete_result = await database.session.execute(delete_query)
|
Followup("Feed Deleted")
|
||||||
|
.image(source.icon_url)
|
||||||
await audit(self,
|
.fields(url=url)
|
||||||
f"Deleted RSS source ({nickname=}, {url=})",
|
.trash()
|
||||||
inter.user.id, database=database
|
.send(inter)
|
||||||
)
|
)
|
||||||
|
|
||||||
source = await Source.from_url(url)
|
|
||||||
|
|
||||||
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
|
|
||||||
embed.add_field(name="Nickname", value=nickname)
|
|
||||||
embed.add_field(name="URL", value=url)
|
|
||||||
embed.set_thumbnail(url=source.icon_url)
|
|
||||||
|
|
||||||
await followup(inter, embed=embed)
|
|
||||||
|
|
||||||
@feed_group.command(name="list")
|
@feed_group.command(name="list")
|
||||||
@choices(sort=rss_list_sort_choices)
|
async def list_rss_sources(self, inter: Interaction):
|
||||||
async def list_rss_sources(
|
"""Provides a with a list of Feeds available for this server.
|
||||||
self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False
|
|
||||||
):
|
|
||||||
"""Provides a with a list of RSS sources available for the current server.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -252,55 +247,32 @@ class FeedCog(commands.Cog):
|
|||||||
|
|
||||||
await inter.response.defer()
|
await inter.response.defer()
|
||||||
|
|
||||||
# Default to the first choice if not specified.
|
try:
|
||||||
if isinstance(sort, Choice):
|
sources = await self.bot.functions.get_feeds(inter.guild_id)
|
||||||
description = "Sort by "
|
except NoResultFound:
|
||||||
description += "Nickname " if sort.value == 0 else "Date Added "
|
await (
|
||||||
description += '\U000025BC' if sort_reverse else '\U000025B2'
|
Followup(
|
||||||
else:
|
"Feeds Not Found Error",
|
||||||
sort = rss_list_sort_choices[0]
|
"There are no available Feeds for this server.\n"
|
||||||
description = ""
|
"Add a new feed with `/feed add`."
|
||||||
|
|
||||||
match sort.value, sort_reverse:
|
|
||||||
case 0, False:
|
|
||||||
order_by = RssSourceModel.nick.asc()
|
|
||||||
case 0, True:
|
|
||||||
order_by = RssSourceModel.nick.desc()
|
|
||||||
case 1, False:
|
|
||||||
order_by = RssSourceModel.created.desc()
|
|
||||||
case 1, True:
|
|
||||||
order_by = RssSourceModel.created.asc()
|
|
||||||
case _, _:
|
|
||||||
raise ValueError(f"Unknown sort: {sort}")
|
|
||||||
|
|
||||||
async with DatabaseManager() as database:
|
|
||||||
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
|
||||||
query = select(RssSourceModel).where(whereclause).order_by(order_by)
|
|
||||||
result = await database.session.execute(query)
|
|
||||||
|
|
||||||
rss_sources = result.scalars().all()
|
|
||||||
rowcount = len(rss_sources)
|
|
||||||
|
|
||||||
if not rss_sources:
|
|
||||||
await followup_error(inter,
|
|
||||||
title="No Feeds Found",
|
|
||||||
message="I couldn't find any Feeds for this server."
|
|
||||||
)
|
)
|
||||||
return
|
.error()
|
||||||
|
.send()
|
||||||
output = "\n".join([
|
)
|
||||||
f"{i}. **[{rss.nick}]({rss.rss_url})** "
|
else:
|
||||||
for i, rss in enumerate(rss_sources)
|
description = "\n".join([
|
||||||
|
f"{i}. **[{source.name}]({source.url})**"
|
||||||
|
for i, source in enumerate(sources)
|
||||||
])
|
])
|
||||||
|
await (
|
||||||
|
Followup(
|
||||||
|
f"Available Feeds in {inter.guild.name}",
|
||||||
|
description
|
||||||
|
)
|
||||||
|
.info()
|
||||||
|
.send(inter)
|
||||||
|
)
|
||||||
|
|
||||||
embed = Embed(
|
|
||||||
title="Saved RSS Feeds",
|
|
||||||
description=f"{description}\n\n{output}",
|
|
||||||
colour=Colour.blue()
|
|
||||||
)
|
|
||||||
embed.set_footer(text=f"Showing {rowcount} results")
|
|
||||||
|
|
||||||
await followup(inter, embed=embed)
|
|
||||||
|
|
||||||
# @feed_group.command(name="fetch")
|
# @feed_group.command(name="fetch")
|
||||||
# @rename(max_="max")
|
# @rename(max_="max")
|
||||||
@ -431,7 +403,7 @@ class FeedCog(commands.Cog):
|
|||||||
query = select(RssSourceModel).where(whereclause)
|
query = select(RssSourceModel).where(whereclause)
|
||||||
result = await database.session.execute(query)
|
result = await database.session.execute(query)
|
||||||
sources = [
|
sources = [
|
||||||
Choice(name=rss.nick, value=rss.id)
|
Choice(name=rss.nick, value=rss.rss_url)
|
||||||
for rss in result.scalars().all()
|
for rss in result.scalars().all()
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -483,10 +455,10 @@ class FeedCog(commands.Cog):
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
@feed_group.command(name="assign")
|
@feed_group.command(name="assign")
|
||||||
@rename(rss="feed")
|
@rename(url="feed")
|
||||||
@autocomplete(rss=autocomplete_rss_sources)
|
@autocomplete(url=autocomplete_rss_sources)
|
||||||
async def include_feed(
|
async def include_feed(
|
||||||
self, inter: Interaction, rss: int, channel: TextChannel = None, prevent_spam: bool = True
|
self, inter: Interaction, url: str, channel: TextChannel = None, prevent_spam: bool = True
|
||||||
):
|
):
|
||||||
"""Include a feed within the specified channel.
|
"""Include a feed within the specified channel.
|
||||||
|
|
||||||
@ -494,7 +466,7 @@ class FeedCog(commands.Cog):
|
|||||||
----------
|
----------
|
||||||
inter : Interaction
|
inter : Interaction
|
||||||
Represents an app command interaction.
|
Represents an app command interaction.
|
||||||
rss : int
|
url : int
|
||||||
The RSS feed to include.
|
The RSS feed to include.
|
||||||
channel : TextChannel
|
channel : TextChannel
|
||||||
The channel to include the feed in.
|
The channel to include the feed in.
|
||||||
@ -504,30 +476,41 @@ class FeedCog(commands.Cog):
|
|||||||
|
|
||||||
channel = channel or inter.channel
|
channel = channel or inter.channel
|
||||||
|
|
||||||
async with DatabaseManager() as database:
|
try:
|
||||||
select_query = select(RssSourceModel).where(and_(
|
feed_id, source = await self.bot.functions.assign_feed(
|
||||||
RssSourceModel.id == rss,
|
url, channel.name, channel.id, inter.guild_id
|
||||||
RssSourceModel.discord_server_id == inter.guild_id
|
)
|
||||||
))
|
except IntegrityError:
|
||||||
|
await (
|
||||||
select_result = await database.session.execute(select_query)
|
Followup(
|
||||||
rss_source = select_result.scalars().one()
|
"Duplicate Assigned Feed Error",
|
||||||
nick, rss_url = rss_source.nick, rss_source.rss_url
|
f"This Feed has already been assigned to {channel.mention}"
|
||||||
|
)
|
||||||
insert_query = insert(FeedChannelModel).values(
|
.error()
|
||||||
discord_server_id = inter.guild_id,
|
.send(inter)
|
||||||
discord_channel_id = channel.id,
|
)
|
||||||
rss_source_id=rss,
|
except NoResultFound:
|
||||||
search_name=f"{nick} #{channel.name}"
|
await (
|
||||||
|
Followup(
|
||||||
|
"Feed Not Found Error",
|
||||||
|
"A Feed with these parameters could not be found."
|
||||||
|
)
|
||||||
|
.error()
|
||||||
|
.send(inter)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await (
|
||||||
|
Followup(
|
||||||
|
"Feed Assigned",
|
||||||
|
f"I've assigned {channel.mention} to receive content from "
|
||||||
|
f"[{source.name}]({source.url})."
|
||||||
|
)
|
||||||
|
.assign()
|
||||||
|
.send(inter)
|
||||||
)
|
)
|
||||||
|
|
||||||
insert_result = await database.session.execute(insert_query)
|
|
||||||
feed_id = insert_result.inserted_primary_key.id
|
|
||||||
|
|
||||||
if prevent_spam:
|
if prevent_spam:
|
||||||
await set_all_articles_as_sent(inter, channel, feed_id, rss_url)
|
await set_all_articles_as_sent(inter, channel, feed_id, url)
|
||||||
|
|
||||||
await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}")
|
|
||||||
|
|
||||||
@feed_group.command(name="unassign")
|
@feed_group.command(name="unassign")
|
||||||
@autocomplete(option=autocomplete_existing_feeds)
|
@autocomplete(option=autocomplete_existing_feeds)
|
||||||
@ -544,20 +527,41 @@ class FeedCog(commands.Cog):
|
|||||||
|
|
||||||
await inter.response.defer()
|
await inter.response.defer()
|
||||||
|
|
||||||
async with DatabaseManager() as database:
|
try:
|
||||||
query = delete(FeedChannelModel).where(and_(
|
await self.bot.functions.unassign_feed(option, inter.guild_id)
|
||||||
FeedChannelModel.id == option,
|
except NoResultFound:
|
||||||
FeedChannelModel.discord_server_id == inter.guild_id
|
await (
|
||||||
))
|
Followup(
|
||||||
|
"Assigned Feed Not Found",
|
||||||
result = await database.session.execute(query)
|
"The assigned Feed doesn't exist."
|
||||||
|
)
|
||||||
if not result.rowcount:
|
.error()
|
||||||
await followup_error(inter,
|
.send(inter)
|
||||||
title="Assigned Feed Not Found",
|
|
||||||
message=f"I couldn't find any assigned feeds for the option: {option}"
|
|
||||||
)
|
)
|
||||||
return
|
else:
|
||||||
|
await (
|
||||||
|
Followup(
|
||||||
|
"Unassigned Feed",
|
||||||
|
"Feed has been unassigned."
|
||||||
|
)
|
||||||
|
.trash()
|
||||||
|
.send(inter)
|
||||||
|
)
|
||||||
|
|
||||||
|
# async with DatabaseManager() as database:
|
||||||
|
# query = delete(FeedChannelModel).where(and_(
|
||||||
|
# FeedChannelModel.id == option,
|
||||||
|
# FeedChannelModel.discord_server_id == inter.guild_id
|
||||||
|
# ))
|
||||||
|
|
||||||
|
# result = await database.session.execute(query)
|
||||||
|
|
||||||
|
# if not result.rowcount:
|
||||||
|
# await followup_error(inter,
|
||||||
|
# title="Assigned Feed Not Found",
|
||||||
|
# message=f"I couldn't find any assigned feeds for the option: {option}"
|
||||||
|
# )
|
||||||
|
# return
|
||||||
|
|
||||||
await followup(inter, "I've removed this item (placeholder response)")
|
await followup(inter, "I've removed this item (placeholder response)")
|
||||||
|
|
||||||
|
@ -6,14 +6,20 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo
|
|||||||
import logging
|
import logging
|
||||||
from time import process_time
|
from time import process_time
|
||||||
|
|
||||||
from feedparser import parse
|
from discord import TextChannel
|
||||||
from sqlalchemy import insert, select, and_
|
|
||||||
from discord import Interaction, TextChannel
|
|
||||||
from discord.ext import commands, tasks
|
from discord.ext import commands, tasks
|
||||||
from discord.errors import Forbidden
|
from discord.errors import Forbidden
|
||||||
|
from sqlalchemy import insert, select, and_
|
||||||
|
from feedparser import parse
|
||||||
|
|
||||||
from feed import Source, Article # pylint disable=E0401
|
from feed import Source, Article
|
||||||
from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel # pylint disable=E0401
|
from db import (
|
||||||
|
DatabaseManager,
|
||||||
|
FeedChannelModel,
|
||||||
|
RssSourceModel,
|
||||||
|
SentArticleModel
|
||||||
|
)
|
||||||
|
from utils import get_unparsed_feed
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -68,7 +74,7 @@ class TaskCog(commands.Cog):
|
|||||||
|
|
||||||
channel = self.bot.get_channel(feed.discord_channel_id)
|
channel = self.bot.get_channel(feed.discord_channel_id)
|
||||||
|
|
||||||
unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url)
|
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
|
||||||
parsed_feed = parse(unparsed_content)
|
parsed_feed = parse(unparsed_content)
|
||||||
source = Source.from_parsed(parsed_feed)
|
source = Source.from_parsed(parsed_feed)
|
||||||
articles = source.get_latest_articles(5)
|
articles = source.get_latest_articles(5)
|
||||||
|
183
src/feed.py
183
src/feed.py
@ -1,22 +1,23 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import async_timeout
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import validators
|
import validators
|
||||||
from textwrap import shorten
|
|
||||||
from markdownify import markdownify
|
|
||||||
from discord import Embed, Colour
|
from discord import Embed, Colour
|
||||||
from bs4 import BeautifulSoup as bs4
|
from bs4 import BeautifulSoup as bs4
|
||||||
from feedparser import FeedParserDict, parse
|
from feedparser import FeedParserDict, parse
|
||||||
|
from markdownify import markdownify
|
||||||
|
from sqlalchemy import select, insert, delete, and_
|
||||||
|
from sqlalchemy.exc import NoResultFound
|
||||||
|
from textwrap import shorten
|
||||||
|
|
||||||
from utils import audit
|
|
||||||
from errors import IllegalFeed
|
from errors import IllegalFeed
|
||||||
|
from db import DatabaseManager, RssSourceModel, FeedChannelModel
|
||||||
|
from utils import get_rss_data, get_unparsed_feed
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
dumps = lambda _dict: json.dumps(_dict, indent=8)
|
dumps = lambda _dict: json.dumps(_dict, indent=8)
|
||||||
@ -162,8 +163,8 @@ class Source:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_url(cls, url: str):
|
async def from_url(cls, url: str):
|
||||||
unparsed_content = await Functions.get_unparsed_feed(url)
|
unparsed_content = await get_unparsed_feed(url)
|
||||||
return
|
return cls.from_parsed(parse(unparsed_content))
|
||||||
|
|
||||||
def get_latest_articles(self, max: int = 999) -> list[Article]:
|
def get_latest_articles(self, max: int = 999) -> list[Article]:
|
||||||
"""Returns a list of Article objects.
|
"""Returns a list of Article objects.
|
||||||
@ -193,19 +194,26 @@ class Functions:
|
|||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def fetch(session, url: str) -> str:
|
|
||||||
async with async_timeout.timeout(20):
|
|
||||||
async with session.get(url) as response:
|
|
||||||
return await response.text()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_unparsed_feed(url: str):
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
return await self.fetch(session, url) # TODO: work from here
|
|
||||||
|
|
||||||
async def validate_feed(self, nickname: str, url: str) -> FeedParserDict:
|
async def validate_feed(self, nickname: str, url: str) -> FeedParserDict:
|
||||||
""""""
|
"""Validates a feed based on the given nickname and url.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
nickname : str
|
||||||
|
Human readable nickname used to refer to the feed.
|
||||||
|
url : str
|
||||||
|
URL to fetch content from the feed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
FeedParserDict
|
||||||
|
A Parsed Dictionary of the feed.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
IllegalFeed
|
||||||
|
If the feed is invalid.
|
||||||
|
"""
|
||||||
|
|
||||||
# Ensure the URL is valid
|
# Ensure the URL is valid
|
||||||
if not validators.url(url):
|
if not validators.url(url):
|
||||||
@ -215,28 +223,47 @@ class Functions:
|
|||||||
if validators.url(nickname):
|
if validators.url(nickname):
|
||||||
raise IllegalFeed(
|
raise IllegalFeed(
|
||||||
"It looks like the nickname you have entered is a URL.\n" \
|
"It looks like the nickname you have entered is a URL.\n" \
|
||||||
f"For security reasons, this is not allowed.\n`{nickname=}`"
|
"For security reasons, this is not allowed.",
|
||||||
|
nickname=nickname
|
||||||
)
|
)
|
||||||
|
|
||||||
feed_data, status_code = await get_rss_data(url)
|
feed_data, status_code = await get_rss_data(url)
|
||||||
|
|
||||||
if status_code != 200:
|
if status_code != 200:
|
||||||
raise IllegalFeed(
|
raise IllegalFeed(
|
||||||
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}"
|
"The URL provided returned an invalid status code:",
|
||||||
|
url=url, status_code=status_code
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check the contents is actually an RSS feed.
|
# Check the contents is actually an RSS feed.
|
||||||
feed = parse(feed_data)
|
feed = parse(feed_data)
|
||||||
if not feed.version:
|
if not feed.version:
|
||||||
raise IllegalFeed(
|
raise IllegalFeed(
|
||||||
f"The provided URL '{url}' does not seem to be a valid RSS feed."
|
"The provided URL does not seem to be a valid RSS feed.",
|
||||||
|
url=url
|
||||||
)
|
)
|
||||||
|
|
||||||
return feed
|
return feed
|
||||||
|
|
||||||
|
|
||||||
async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source:
|
async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source:
|
||||||
""""""
|
"""Create a new Feed, and return it as a Source object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
nickname : str
|
||||||
|
Human readable nickname used to refer to the feed.
|
||||||
|
url : str
|
||||||
|
URL to fetch content from the feed.
|
||||||
|
guild_id : int
|
||||||
|
Discord Server ID associated with the feed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Source
|
||||||
|
Dataclass containing attributes of the feed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
log.info("Creating new Feed: %s - %s", nickname, guild_id)
|
||||||
|
|
||||||
parsed_feed = await self.validate_feed(nickname, url)
|
parsed_feed = await self.validate_feed(nickname, url)
|
||||||
|
|
||||||
@ -248,4 +275,110 @@ class Functions:
|
|||||||
)
|
)
|
||||||
await database.session.execute(query)
|
await database.session.execute(query)
|
||||||
|
|
||||||
|
log.info("Created Feed: %s - %s", nickname, guild_id)
|
||||||
|
|
||||||
return Source.from_parsed(parsed_feed)
|
return Source.from_parsed(parsed_feed)
|
||||||
|
|
||||||
|
async def delete_feed(self, url: str, guild_id: int) -> Source:
|
||||||
|
"""Delete an existing Feed, then return it as a Source object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
url : str
|
||||||
|
URL of the feed, used in the whereclause.
|
||||||
|
guild_id : int
|
||||||
|
Discord Server ID of the feed, used in the whereclause.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Source
|
||||||
|
Dataclass containing attributes of the feed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
log.info("Deleting Feed: %s - %s", url, guild_id)
|
||||||
|
|
||||||
|
async with DatabaseManager() as database:
|
||||||
|
whereclause = and_(
|
||||||
|
RssSourceModel.discord_server_id == guild_id,
|
||||||
|
RssSourceModel.rss_url == url
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select the Feed entry, because an exception is raised if not found.
|
||||||
|
select_query = select(RssSourceModel).filter(whereclause)
|
||||||
|
select_result = await database.session.execute(select_query)
|
||||||
|
select_result.scalars().one()
|
||||||
|
|
||||||
|
delete_query = delete(RssSourceModel).filter(whereclause)
|
||||||
|
await database.session.execute(delete_query)
|
||||||
|
|
||||||
|
log.info("Deleted Feed: %s - %s", url, guild_id)
|
||||||
|
|
||||||
|
return await Source.from_url(url)
|
||||||
|
|
||||||
|
async def get_feeds(self, guild_id: int) -> list[Source]:
|
||||||
|
"""Returns a list of fetched Feed objects from the database.
|
||||||
|
Note: a request will be made too all found Feed URLs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
guild_id : int
|
||||||
|
The Discord Server ID, used to filter down the Feed query.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[Source]
|
||||||
|
List of Source objects, resulting from the query.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
NoResultFound
|
||||||
|
Raised if no results are found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with DatabaseManager() as database:
|
||||||
|
whereclause = and_(RssSourceModel.discord_server_id == guild_id)
|
||||||
|
query = select(RssSourceModel).where(whereclause)
|
||||||
|
result = await database.session.execute(query)
|
||||||
|
rss_sources = result.scalars().all()
|
||||||
|
|
||||||
|
if not rss_sources:
|
||||||
|
raise NoResultFound
|
||||||
|
|
||||||
|
return [await Source.from_url(feed.rss_url) for feed in rss_sources]
|
||||||
|
|
||||||
|
async def assign_feed(
|
||||||
|
self, url: str, channel_name: str, channel_id: int, guild_id: int
|
||||||
|
) -> tuple[int, Source]:
|
||||||
|
""""""
|
||||||
|
|
||||||
|
async with DatabaseManager() as database:
|
||||||
|
select_query = select(RssSourceModel).where(and_(
|
||||||
|
RssSourceModel.rss_url == url,
|
||||||
|
RssSourceModel.discord_server_id == guild_id
|
||||||
|
))
|
||||||
|
select_result = await database.session.execute(select_query)
|
||||||
|
|
||||||
|
rss_source = select_result.scalars().one()
|
||||||
|
|
||||||
|
insert_query = insert(FeedChannelModel).values(
|
||||||
|
discord_server_id = guild_id,
|
||||||
|
discord_channel_id = channel_id,
|
||||||
|
rss_source_id=rss_source.id,
|
||||||
|
search_name=f"{rss_source.nick} #{channel_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
insert_result = await database.session.execute(insert_query)
|
||||||
|
return insert_result.inserted_primary_key.id, await Source.from_url(url)
|
||||||
|
|
||||||
|
async def unassign_feed( self, assigned_feed_id: int, guild_id: int):
|
||||||
|
""""""
|
||||||
|
|
||||||
|
async with DatabaseManager() as database:
|
||||||
|
query = delete(FeedChannelModel).where(and_(
|
||||||
|
FeedChannelModel.id == assigned_feed_id,
|
||||||
|
FeedChannelModel.discord_server_id == guild_id
|
||||||
|
))
|
||||||
|
|
||||||
|
result = await database.session.execute(query)
|
||||||
|
if not result.rowcount:
|
||||||
|
raise NoResultFound
|
||||||
|
111
src/utils.py
111
src/utils.py
@ -2,11 +2,21 @@
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import logging
|
import logging
|
||||||
|
import async_timeout
|
||||||
|
|
||||||
from discord import Interaction, Embed, Colour
|
from discord import Interaction, Embed, Colour
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def fetch(session, url: str) -> str:
|
||||||
|
async with async_timeout.timeout(20):
|
||||||
|
async with session.get(url) as response:
|
||||||
|
return await response.text()
|
||||||
|
|
||||||
|
async def get_unparsed_feed(url: str):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
return await fetch(session, url)
|
||||||
|
|
||||||
async def get_rss_data(url: str):
|
async def get_rss_data(url: str):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(url) as response:
|
async with session.get(url) as response:
|
||||||
@ -30,25 +40,94 @@ async def audit(cog, *args, **kwargs):
|
|||||||
|
|
||||||
await cog.bot.audit(*args, **kwargs)
|
await cog.bot.audit(*args, **kwargs)
|
||||||
|
|
||||||
async def followup_error(inter: Interaction, title: str, message: str, *args, **kwargs):
|
|
||||||
"""Shorthand for following up on an interaction, except returns an embed styled in
|
|
||||||
error colours.
|
|
||||||
|
|
||||||
Parameters
|
# https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
|
||||||
----------
|
|
||||||
inter : Interaction
|
|
||||||
Represents an app command interaction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
await inter.followup.send(
|
|
||||||
*args,
|
class FollowupIcons:
|
||||||
embed=Embed(
|
error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
|
||||||
|
success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"
|
||||||
|
trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png"
|
||||||
|
info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png"
|
||||||
|
added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png"
|
||||||
|
assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
|
||||||
|
|
||||||
|
|
||||||
|
class Followup:
|
||||||
|
"""Wrapper for a discord embed to follow up an interaction."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
title: str = None,
|
||||||
|
description: str = None,
|
||||||
|
):
|
||||||
|
self._embed = Embed(
|
||||||
title=title,
|
title=title,
|
||||||
description=message,
|
description=description
|
||||||
colour=Colour.red()
|
)
|
||||||
),
|
|
||||||
**kwargs
|
async def send(self, inter: Interaction, message: str = None):
|
||||||
)
|
""""""
|
||||||
|
|
||||||
|
await inter.followup.send(content=message, embed=self._embed)
|
||||||
|
|
||||||
|
def fields(self, inline: bool = False, **fields: dict):
|
||||||
|
""""""
|
||||||
|
|
||||||
|
for key, value in fields.items():
|
||||||
|
self._embed.add_field(name=key, value=value, inline=inline)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def image(self, url: str):
|
||||||
|
""""""
|
||||||
|
|
||||||
|
self._embed.set_image(url=url)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def error(self):
|
||||||
|
""""""
|
||||||
|
|
||||||
|
self._embed.colour = Colour.red()
|
||||||
|
self._embed.set_thumbnail(url=FollowupIcons.error)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def success(self):
|
||||||
|
""""""
|
||||||
|
|
||||||
|
self._embed.colour = Colour.green()
|
||||||
|
self._embed.set_thumbnail(url=FollowupIcons.success)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def info(self):
|
||||||
|
""""""
|
||||||
|
|
||||||
|
self._embed.colour = Colour.blue()
|
||||||
|
self._embed.set_thumbnail(url=FollowupIcons.info)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def added(self):
|
||||||
|
""""""
|
||||||
|
|
||||||
|
self._embed.colour = Colour.blue()
|
||||||
|
self._embed.set_thumbnail(url=FollowupIcons.added)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def assign(self):
|
||||||
|
""""""
|
||||||
|
|
||||||
|
self._embed.colour = Colour.blue()
|
||||||
|
self._embed.set_thumbnail(url=FollowupIcons.assigned)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def trash(self):
|
||||||
|
""""""
|
||||||
|
|
||||||
|
self._embed.colour = Colour.red()
|
||||||
|
self._embed.set_thumbnail(url=FollowupIcons.trash)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
def extract_error_info(error: Exception) -> str:
|
def extract_error_info(error: Exception) -> str:
|
||||||
class_name = error.__class__.__name__
|
class_name = error.__class__.__name__
|
||||||
|
Loading…
x
Reference in New Issue
Block a user