Created Followup and move functions to feeds.py
This commit is contained in:
parent
94b7972a38
commit
e8d13ae26b
@ -4,8 +4,7 @@ All table classes should be suffixed with `Model`.
|
||||
"""
|
||||
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
|
@ -1,4 +1,5 @@
|
||||
|
||||
class IllegalFeed(Exception):
|
||||
pass
|
||||
|
||||
def __init__(self, message: str, **items):
|
||||
super().__init__(message)
|
||||
self.items = items
|
||||
|
@ -12,18 +12,25 @@ from discord.ext import commands
|
||||
from discord import Interaction, Embed, Colour, TextChannel, Permissions
|
||||
from discord.app_commands import Choice, Group, autocomplete, choices, rename
|
||||
from sqlalchemy import insert, select, and_, delete
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.exc import NoResultFound, IntegrityError
|
||||
|
||||
from utils import get_rss_data, followup, audit, followup_error, extract_error_info # pylint: disable=E0401
|
||||
from feed import Source # pylint: disable=E0401
|
||||
from db import ( # pylint: disable=E0401
|
||||
from feed import Source
|
||||
from errors import IllegalFeed
|
||||
from db import (
|
||||
DatabaseManager,
|
||||
SentArticleModel,
|
||||
RssSourceModel,
|
||||
FeedChannelModel,
|
||||
AuditModel
|
||||
)
|
||||
from errors import IllegalFeed
|
||||
from utils import (
|
||||
Followup,
|
||||
get_rss_data,
|
||||
followup,
|
||||
audit,
|
||||
extract_error_info,
|
||||
get_unparsed_feed
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -81,7 +88,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed
|
||||
return None, feed
|
||||
|
||||
async def set_all_articles_as_sent(inter, channel: TextChannel, feed_id: int, rss_url: str):
|
||||
unparsed_feed = await self.bot.functions.get_unparsed_feed(rss_url)
|
||||
unparsed_feed = await get_unparsed_feed(rss_url)
|
||||
source = Source.from_parsed(parse(unparsed_feed))
|
||||
articles = source.get_latest_articles()
|
||||
|
||||
@ -148,101 +155,89 @@ class FeedCog(commands.Cog):
|
||||
feed_group = Group(
|
||||
name="feed",
|
||||
description="Commands for rss sources.",
|
||||
guild_only=True, # We store guild IDs in the database, so guild only = True
|
||||
default_permissions=Permissions.elevated()
|
||||
default_permissions=Permissions.elevated(),
|
||||
guild_only=True # We store guild IDs in the database, so guild only = True
|
||||
)
|
||||
|
||||
@feed_group.command(name="add")
|
||||
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
|
||||
"""Add a new RSS source.
|
||||
"""Add a new Feed for this server.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
nickname : str
|
||||
A name used to identify the RSS source.
|
||||
A name used to identify the Feed.
|
||||
url : str
|
||||
The RSS feed URL.
|
||||
The Feed URL.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
try:
|
||||
source = self.bot.functions.create_new_feed(nickname, url)
|
||||
source = await self.bot.functions.create_new_feed(nickname, url, inter.guild_id)
|
||||
except IllegalFeed as error:
|
||||
title, desc = extract_error_info(error)
|
||||
followup_error(inter, title=title, description=desc)
|
||||
|
||||
embed = Embed(title="RSS Feed Added", colour=Colour.dark_green())
|
||||
embed.add_field(name="Nickname", value=nickname)
|
||||
embed.add_field(name="URL", value=url)
|
||||
embed.set_thumbnail(url=source.thumb_url)
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
await Followup(title, desc).fields(**error.items).error().send(inter)
|
||||
except IntegrityError as error:
|
||||
await (
|
||||
Followup(
|
||||
"Duplicate Feed Error",
|
||||
"A Feed with the same nickname already exist."
|
||||
)
|
||||
.fields(nickname=nickname)
|
||||
.error()
|
||||
.send(inter)
|
||||
)
|
||||
else:
|
||||
await (
|
||||
Followup("Feed Added")
|
||||
.image(source.icon_url)
|
||||
.fields(nickname=nickname, url=url)
|
||||
.added()
|
||||
.send(inter)
|
||||
)
|
||||
|
||||
@feed_group.command(name="remove")
|
||||
@rename(url="option")
|
||||
@autocomplete(url=source_autocomplete)
|
||||
async def remove_rss_source(self, inter: Interaction, url: str):
|
||||
"""Delete an existing RSS source.
|
||||
"""Delete an existing Feed from this server.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
url : str
|
||||
The RSS source to be removed. Autocomplete or enter the URL.
|
||||
The Feed to be removed. Autocomplete or enter the URL.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
log.debug("Attempting to remove RSS source (url=%s)", url)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.rss_url == url
|
||||
)
|
||||
|
||||
# We will select the item first, so we can reference it's nickname later.
|
||||
select_query = select(RssSourceModel).filter(whereclause)
|
||||
select_result = await database.session.execute(select_query)
|
||||
|
||||
try:
|
||||
rss_source = select_result.scalars().one()
|
||||
except NoResultFound:
|
||||
await followup_error(inter,
|
||||
title="Error Deleting Feed",
|
||||
message=f"I couldn't find anything for `{url}`"
|
||||
try:
|
||||
source = await self.bot.functions.delete_feed(url, inter.guild_id)
|
||||
except NoResultFound:
|
||||
await (
|
||||
Followup(
|
||||
"Feed Not Found Error",
|
||||
"A Feed with these parameters could not be found."
|
||||
)
|
||||
return
|
||||
|
||||
nickname = rss_source.nick
|
||||
|
||||
delete_query = delete(RssSourceModel).filter(whereclause)
|
||||
delete_result = await database.session.execute(delete_query)
|
||||
|
||||
await audit(self,
|
||||
f"Deleted RSS source ({nickname=}, {url=})",
|
||||
inter.user.id, database=database
|
||||
.error()
|
||||
.send(inter)
|
||||
)
|
||||
else:
|
||||
await (
|
||||
Followup("Feed Deleted")
|
||||
.image(source.icon_url)
|
||||
.fields(url=url)
|
||||
.trash()
|
||||
.send(inter)
|
||||
)
|
||||
|
||||
source = await Source.from_url(url)
|
||||
|
||||
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
|
||||
embed.add_field(name="Nickname", value=nickname)
|
||||
embed.add_field(name="URL", value=url)
|
||||
embed.set_thumbnail(url=source.icon_url)
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@feed_group.command(name="list")
|
||||
@choices(sort=rss_list_sort_choices)
|
||||
async def list_rss_sources(
|
||||
self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False
|
||||
):
|
||||
"""Provides a with a list of RSS sources available for the current server.
|
||||
async def list_rss_sources(self, inter: Interaction):
|
||||
"""Provides a with a list of Feeds available for this server.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -252,55 +247,32 @@ class FeedCog(commands.Cog):
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
# Default to the first choice if not specified.
|
||||
if isinstance(sort, Choice):
|
||||
description = "Sort by "
|
||||
description += "Nickname " if sort.value == 0 else "Date Added "
|
||||
description += '\U000025BC' if sort_reverse else '\U000025B2'
|
||||
else:
|
||||
sort = rss_list_sort_choices[0]
|
||||
description = ""
|
||||
|
||||
match sort.value, sort_reverse:
|
||||
case 0, False:
|
||||
order_by = RssSourceModel.nick.asc()
|
||||
case 0, True:
|
||||
order_by = RssSourceModel.nick.desc()
|
||||
case 1, False:
|
||||
order_by = RssSourceModel.created.desc()
|
||||
case 1, True:
|
||||
order_by = RssSourceModel.created.asc()
|
||||
case _, _:
|
||||
raise ValueError(f"Unknown sort: {sort}")
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
||||
query = select(RssSourceModel).where(whereclause).order_by(order_by)
|
||||
result = await database.session.execute(query)
|
||||
|
||||
rss_sources = result.scalars().all()
|
||||
rowcount = len(rss_sources)
|
||||
|
||||
if not rss_sources:
|
||||
await followup_error(inter,
|
||||
title="No Feeds Found",
|
||||
message="I couldn't find any Feeds for this server."
|
||||
try:
|
||||
sources = await self.bot.functions.get_feeds(inter.guild_id)
|
||||
except NoResultFound:
|
||||
await (
|
||||
Followup(
|
||||
"Feeds Not Found Error",
|
||||
"There are no available Feeds for this server.\n"
|
||||
"Add a new feed with `/feed add`."
|
||||
)
|
||||
return
|
||||
|
||||
output = "\n".join([
|
||||
f"{i}. **[{rss.nick}]({rss.rss_url})** "
|
||||
for i, rss in enumerate(rss_sources)
|
||||
.error()
|
||||
.send()
|
||||
)
|
||||
else:
|
||||
description = "\n".join([
|
||||
f"{i}. **[{source.name}]({source.url})**"
|
||||
for i, source in enumerate(sources)
|
||||
])
|
||||
await (
|
||||
Followup(
|
||||
f"Available Feeds in {inter.guild.name}",
|
||||
description
|
||||
)
|
||||
.info()
|
||||
.send(inter)
|
||||
)
|
||||
|
||||
embed = Embed(
|
||||
title="Saved RSS Feeds",
|
||||
description=f"{description}\n\n{output}",
|
||||
colour=Colour.blue()
|
||||
)
|
||||
embed.set_footer(text=f"Showing {rowcount} results")
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
# @feed_group.command(name="fetch")
|
||||
# @rename(max_="max")
|
||||
@ -431,7 +403,7 @@ class FeedCog(commands.Cog):
|
||||
query = select(RssSourceModel).where(whereclause)
|
||||
result = await database.session.execute(query)
|
||||
sources = [
|
||||
Choice(name=rss.nick, value=rss.id)
|
||||
Choice(name=rss.nick, value=rss.rss_url)
|
||||
for rss in result.scalars().all()
|
||||
]
|
||||
|
||||
@ -483,10 +455,10 @@ class FeedCog(commands.Cog):
|
||||
# )
|
||||
|
||||
@feed_group.command(name="assign")
|
||||
@rename(rss="feed")
|
||||
@autocomplete(rss=autocomplete_rss_sources)
|
||||
@rename(url="feed")
|
||||
@autocomplete(url=autocomplete_rss_sources)
|
||||
async def include_feed(
|
||||
self, inter: Interaction, rss: int, channel: TextChannel = None, prevent_spam: bool = True
|
||||
self, inter: Interaction, url: str, channel: TextChannel = None, prevent_spam: bool = True
|
||||
):
|
||||
"""Include a feed within the specified channel.
|
||||
|
||||
@ -494,7 +466,7 @@ class FeedCog(commands.Cog):
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
rss : int
|
||||
url : int
|
||||
The RSS feed to include.
|
||||
channel : TextChannel
|
||||
The channel to include the feed in.
|
||||
@ -504,30 +476,41 @@ class FeedCog(commands.Cog):
|
||||
|
||||
channel = channel or inter.channel
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
select_query = select(RssSourceModel).where(and_(
|
||||
RssSourceModel.id == rss,
|
||||
RssSourceModel.discord_server_id == inter.guild_id
|
||||
))
|
||||
|
||||
select_result = await database.session.execute(select_query)
|
||||
rss_source = select_result.scalars().one()
|
||||
nick, rss_url = rss_source.nick, rss_source.rss_url
|
||||
|
||||
insert_query = insert(FeedChannelModel).values(
|
||||
discord_server_id = inter.guild_id,
|
||||
discord_channel_id = channel.id,
|
||||
rss_source_id=rss,
|
||||
search_name=f"{nick} #{channel.name}"
|
||||
try:
|
||||
feed_id, source = await self.bot.functions.assign_feed(
|
||||
url, channel.name, channel.id, inter.guild_id
|
||||
)
|
||||
except IntegrityError:
|
||||
await (
|
||||
Followup(
|
||||
"Duplicate Assigned Feed Error",
|
||||
f"This Feed has already been assigned to {channel.mention}"
|
||||
)
|
||||
.error()
|
||||
.send(inter)
|
||||
)
|
||||
except NoResultFound:
|
||||
await (
|
||||
Followup(
|
||||
"Feed Not Found Error",
|
||||
"A Feed with these parameters could not be found."
|
||||
)
|
||||
.error()
|
||||
.send(inter)
|
||||
)
|
||||
else:
|
||||
await (
|
||||
Followup(
|
||||
"Feed Assigned",
|
||||
f"I've assigned {channel.mention} to receive content from "
|
||||
f"[{source.name}]({source.url})."
|
||||
)
|
||||
.assign()
|
||||
.send(inter)
|
||||
)
|
||||
|
||||
insert_result = await database.session.execute(insert_query)
|
||||
feed_id = insert_result.inserted_primary_key.id
|
||||
|
||||
if prevent_spam:
|
||||
await set_all_articles_as_sent(inter, channel, feed_id, rss_url)
|
||||
|
||||
await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}")
|
||||
await set_all_articles_as_sent(inter, channel, feed_id, url)
|
||||
|
||||
@feed_group.command(name="unassign")
|
||||
@autocomplete(option=autocomplete_existing_feeds)
|
||||
@ -544,20 +527,41 @@ class FeedCog(commands.Cog):
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = delete(FeedChannelModel).where(and_(
|
||||
FeedChannelModel.id == option,
|
||||
FeedChannelModel.discord_server_id == inter.guild_id
|
||||
))
|
||||
|
||||
result = await database.session.execute(query)
|
||||
|
||||
if not result.rowcount:
|
||||
await followup_error(inter,
|
||||
title="Assigned Feed Not Found",
|
||||
message=f"I couldn't find any assigned feeds for the option: {option}"
|
||||
try:
|
||||
await self.bot.functions.unassign_feed(option, inter.guild_id)
|
||||
except NoResultFound:
|
||||
await (
|
||||
Followup(
|
||||
"Assigned Feed Not Found",
|
||||
"The assigned Feed doesn't exist."
|
||||
)
|
||||
.error()
|
||||
.send(inter)
|
||||
)
|
||||
return
|
||||
else:
|
||||
await (
|
||||
Followup(
|
||||
"Unassigned Feed",
|
||||
"Feed has been unassigned."
|
||||
)
|
||||
.trash()
|
||||
.send(inter)
|
||||
)
|
||||
|
||||
# async with DatabaseManager() as database:
|
||||
# query = delete(FeedChannelModel).where(and_(
|
||||
# FeedChannelModel.id == option,
|
||||
# FeedChannelModel.discord_server_id == inter.guild_id
|
||||
# ))
|
||||
|
||||
# result = await database.session.execute(query)
|
||||
|
||||
# if not result.rowcount:
|
||||
# await followup_error(inter,
|
||||
# title="Assigned Feed Not Found",
|
||||
# message=f"I couldn't find any assigned feeds for the option: {option}"
|
||||
# )
|
||||
# return
|
||||
|
||||
await followup(inter, "I've removed this item (placeholder response)")
|
||||
|
||||
|
@ -6,14 +6,20 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo
|
||||
import logging
|
||||
from time import process_time
|
||||
|
||||
from feedparser import parse
|
||||
from sqlalchemy import insert, select, and_
|
||||
from discord import Interaction, TextChannel
|
||||
from discord import TextChannel
|
||||
from discord.ext import commands, tasks
|
||||
from discord.errors import Forbidden
|
||||
from sqlalchemy import insert, select, and_
|
||||
from feedparser import parse
|
||||
|
||||
from feed import Source, Article # pylint disable=E0401
|
||||
from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel # pylint disable=E0401
|
||||
from feed import Source, Article
|
||||
from db import (
|
||||
DatabaseManager,
|
||||
FeedChannelModel,
|
||||
RssSourceModel,
|
||||
SentArticleModel
|
||||
)
|
||||
from utils import get_unparsed_feed
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -68,7 +74,7 @@ class TaskCog(commands.Cog):
|
||||
|
||||
channel = self.bot.get_channel(feed.discord_channel_id)
|
||||
|
||||
unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url)
|
||||
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
|
||||
parsed_feed = parse(unparsed_content)
|
||||
source = Source.from_parsed(parsed_feed)
|
||||
articles = source.get_latest_articles(5)
|
||||
|
183
src/feed.py
183
src/feed.py
@ -1,22 +1,23 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import async_timeout
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
import aiohttp
|
||||
import validators
|
||||
from textwrap import shorten
|
||||
from markdownify import markdownify
|
||||
from discord import Embed, Colour
|
||||
from bs4 import BeautifulSoup as bs4
|
||||
from feedparser import FeedParserDict, parse
|
||||
from markdownify import markdownify
|
||||
from sqlalchemy import select, insert, delete, and_
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from textwrap import shorten
|
||||
|
||||
from utils import audit
|
||||
from errors import IllegalFeed
|
||||
from db import DatabaseManager, RssSourceModel, FeedChannelModel
|
||||
from utils import get_rss_data, get_unparsed_feed
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
dumps = lambda _dict: json.dumps(_dict, indent=8)
|
||||
@ -162,8 +163,8 @@ class Source:
|
||||
|
||||
@classmethod
|
||||
async def from_url(cls, url: str):
|
||||
unparsed_content = await Functions.get_unparsed_feed(url)
|
||||
return
|
||||
unparsed_content = await get_unparsed_feed(url)
|
||||
return cls.from_parsed(parse(unparsed_content))
|
||||
|
||||
def get_latest_articles(self, max: int = 999) -> list[Article]:
|
||||
"""Returns a list of Article objects.
|
||||
@ -193,19 +194,26 @@ class Functions:
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
|
||||
@staticmethod
|
||||
async def fetch(session, url: str) -> str:
|
||||
async with async_timeout.timeout(20):
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
|
||||
@staticmethod
|
||||
async def get_unparsed_feed(url: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
return await self.fetch(session, url) # TODO: work from here
|
||||
|
||||
async def validate_feed(self, nickname: str, url: str) -> FeedParserDict:
|
||||
""""""
|
||||
"""Validates a feed based on the given nickname and url.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nickname : str
|
||||
Human readable nickname used to refer to the feed.
|
||||
url : str
|
||||
URL to fetch content from the feed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FeedParserDict
|
||||
A Parsed Dictionary of the feed.
|
||||
|
||||
Raises
|
||||
------
|
||||
IllegalFeed
|
||||
If the feed is invalid.
|
||||
"""
|
||||
|
||||
# Ensure the URL is valid
|
||||
if not validators.url(url):
|
||||
@ -215,28 +223,47 @@ class Functions:
|
||||
if validators.url(nickname):
|
||||
raise IllegalFeed(
|
||||
"It looks like the nickname you have entered is a URL.\n" \
|
||||
f"For security reasons, this is not allowed.\n`{nickname=}`"
|
||||
"For security reasons, this is not allowed.",
|
||||
nickname=nickname
|
||||
)
|
||||
|
||||
feed_data, status_code = await get_rss_data(url)
|
||||
|
||||
if status_code != 200:
|
||||
raise IllegalFeed(
|
||||
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}"
|
||||
"The URL provided returned an invalid status code:",
|
||||
url=url, status_code=status_code
|
||||
)
|
||||
|
||||
# Check the contents is actually an RSS feed.
|
||||
feed = parse(feed_data)
|
||||
if not feed.version:
|
||||
raise IllegalFeed(
|
||||
f"The provided URL '{url}' does not seem to be a valid RSS feed."
|
||||
"The provided URL does not seem to be a valid RSS feed.",
|
||||
url=url
|
||||
)
|
||||
|
||||
return feed
|
||||
|
||||
|
||||
async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source:
|
||||
""""""
|
||||
"""Create a new Feed, and return it as a Source object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nickname : str
|
||||
Human readable nickname used to refer to the feed.
|
||||
url : str
|
||||
URL to fetch content from the feed.
|
||||
guild_id : int
|
||||
Discord Server ID associated with the feed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Source
|
||||
Dataclass containing attributes of the feed.
|
||||
"""
|
||||
|
||||
log.info("Creating new Feed: %s - %s", nickname, guild_id)
|
||||
|
||||
parsed_feed = await self.validate_feed(nickname, url)
|
||||
|
||||
@ -248,4 +275,110 @@ class Functions:
|
||||
)
|
||||
await database.session.execute(query)
|
||||
|
||||
log.info("Created Feed: %s - %s", nickname, guild_id)
|
||||
|
||||
return Source.from_parsed(parsed_feed)
|
||||
|
||||
async def delete_feed(self, url: str, guild_id: int) -> Source:
|
||||
"""Delete an existing Feed, then return it as a Source object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
URL of the feed, used in the whereclause.
|
||||
guild_id : int
|
||||
Discord Server ID of the feed, used in the whereclause.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Source
|
||||
Dataclass containing attributes of the feed.
|
||||
"""
|
||||
|
||||
log.info("Deleting Feed: %s - %s", url, guild_id)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
RssSourceModel.discord_server_id == guild_id,
|
||||
RssSourceModel.rss_url == url
|
||||
)
|
||||
|
||||
# Select the Feed entry, because an exception is raised if not found.
|
||||
select_query = select(RssSourceModel).filter(whereclause)
|
||||
select_result = await database.session.execute(select_query)
|
||||
select_result.scalars().one()
|
||||
|
||||
delete_query = delete(RssSourceModel).filter(whereclause)
|
||||
await database.session.execute(delete_query)
|
||||
|
||||
log.info("Deleted Feed: %s - %s", url, guild_id)
|
||||
|
||||
return await Source.from_url(url)
|
||||
|
||||
async def get_feeds(self, guild_id: int) -> list[Source]:
|
||||
"""Returns a list of fetched Feed objects from the database.
|
||||
Note: a request will be made too all found Feed URLs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
guild_id : int
|
||||
The Discord Server ID, used to filter down the Feed query.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Source]
|
||||
List of Source objects, resulting from the query.
|
||||
|
||||
Raises
|
||||
------
|
||||
NoResultFound
|
||||
Raised if no results are found.
|
||||
"""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(RssSourceModel.discord_server_id == guild_id)
|
||||
query = select(RssSourceModel).where(whereclause)
|
||||
result = await database.session.execute(query)
|
||||
rss_sources = result.scalars().all()
|
||||
|
||||
if not rss_sources:
|
||||
raise NoResultFound
|
||||
|
||||
return [await Source.from_url(feed.rss_url) for feed in rss_sources]
|
||||
|
||||
async def assign_feed(
|
||||
self, url: str, channel_name: str, channel_id: int, guild_id: int
|
||||
) -> tuple[int, Source]:
|
||||
""""""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
select_query = select(RssSourceModel).where(and_(
|
||||
RssSourceModel.rss_url == url,
|
||||
RssSourceModel.discord_server_id == guild_id
|
||||
))
|
||||
select_result = await database.session.execute(select_query)
|
||||
|
||||
rss_source = select_result.scalars().one()
|
||||
|
||||
insert_query = insert(FeedChannelModel).values(
|
||||
discord_server_id = guild_id,
|
||||
discord_channel_id = channel_id,
|
||||
rss_source_id=rss_source.id,
|
||||
search_name=f"{rss_source.nick} #{channel_name}"
|
||||
)
|
||||
|
||||
insert_result = await database.session.execute(insert_query)
|
||||
return insert_result.inserted_primary_key.id, await Source.from_url(url)
|
||||
|
||||
async def unassign_feed( self, assigned_feed_id: int, guild_id: int):
|
||||
""""""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = delete(FeedChannelModel).where(and_(
|
||||
FeedChannelModel.id == assigned_feed_id,
|
||||
FeedChannelModel.discord_server_id == guild_id
|
||||
))
|
||||
|
||||
result = await database.session.execute(query)
|
||||
if not result.rowcount:
|
||||
raise NoResultFound
|
||||
|
111
src/utils.py
111
src/utils.py
@ -2,11 +2,21 @@
|
||||
|
||||
import aiohttp
|
||||
import logging
|
||||
import async_timeout
|
||||
|
||||
from discord import Interaction, Embed, Colour
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
async def fetch(session, url: str) -> str:
|
||||
async with async_timeout.timeout(20):
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
|
||||
async def get_unparsed_feed(url: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
return await fetch(session, url)
|
||||
|
||||
async def get_rss_data(url: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
@ -30,25 +40,94 @@ async def audit(cog, *args, **kwargs):
|
||||
|
||||
await cog.bot.audit(*args, **kwargs)
|
||||
|
||||
async def followup_error(inter: Interaction, title: str, message: str, *args, **kwargs):
|
||||
"""Shorthand for following up on an interaction, except returns an embed styled in
|
||||
error colours.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
"""
|
||||
# https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
|
||||
|
||||
await inter.followup.send(
|
||||
*args,
|
||||
embed=Embed(
|
||||
|
||||
class FollowupIcons:
|
||||
error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
|
||||
success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"
|
||||
trash = "https://img.icons8.com/fluency-systems-filled/48/DC573C/trash.png"
|
||||
info = "https://img.icons8.com/fluency-systems-filled/48/4598DA/info.png"
|
||||
added = "https://img.icons8.com/fluency-systems-filled/48/4598DA/plus.png"
|
||||
assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
|
||||
|
||||
|
||||
class Followup:
|
||||
"""Wrapper for a discord embed to follow up an interaction."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
title: str = None,
|
||||
description: str = None,
|
||||
):
|
||||
self._embed = Embed(
|
||||
title=title,
|
||||
description=message,
|
||||
colour=Colour.red()
|
||||
),
|
||||
**kwargs
|
||||
)
|
||||
description=description
|
||||
)
|
||||
|
||||
async def send(self, inter: Interaction, message: str = None):
|
||||
""""""
|
||||
|
||||
await inter.followup.send(content=message, embed=self._embed)
|
||||
|
||||
def fields(self, inline: bool = False, **fields: dict):
|
||||
""""""
|
||||
|
||||
for key, value in fields.items():
|
||||
self._embed.add_field(name=key, value=value, inline=inline)
|
||||
|
||||
return self
|
||||
|
||||
def image(self, url: str):
|
||||
""""""
|
||||
|
||||
self._embed.set_image(url=url)
|
||||
|
||||
return self
|
||||
|
||||
def error(self):
|
||||
""""""
|
||||
|
||||
self._embed.colour = Colour.red()
|
||||
self._embed.set_thumbnail(url=FollowupIcons.error)
|
||||
return self
|
||||
|
||||
def success(self):
|
||||
""""""
|
||||
|
||||
self._embed.colour = Colour.green()
|
||||
self._embed.set_thumbnail(url=FollowupIcons.success)
|
||||
return self
|
||||
|
||||
def info(self):
|
||||
""""""
|
||||
|
||||
self._embed.colour = Colour.blue()
|
||||
self._embed.set_thumbnail(url=FollowupIcons.info)
|
||||
return self
|
||||
|
||||
def added(self):
|
||||
""""""
|
||||
|
||||
self._embed.colour = Colour.blue()
|
||||
self._embed.set_thumbnail(url=FollowupIcons.added)
|
||||
return self
|
||||
|
||||
def assign(self):
|
||||
""""""
|
||||
|
||||
self._embed.colour = Colour.blue()
|
||||
self._embed.set_thumbnail(url=FollowupIcons.assigned)
|
||||
return self
|
||||
|
||||
def trash(self):
|
||||
""""""
|
||||
|
||||
self._embed.colour = Colour.red()
|
||||
self._embed.set_thumbnail(url=FollowupIcons.trash)
|
||||
return self
|
||||
|
||||
|
||||
def extract_error_info(error: Exception) -> str:
|
||||
class_name = error.__class__.__name__
|
||||
|
Loading…
x
Reference in New Issue
Block a user