incomplete commit, used to work on other machine

This commit is contained in:
Corban-Lee Jones 2023-12-24 16:40:04 +00:00
parent e95edf08df
commit 94b7972a38
7 changed files with 92 additions and 50 deletions

View File

@ -9,6 +9,7 @@ from discord import Intents
from discord.ext import commands
from sqlalchemy import insert
from feed import Functions
from db import DatabaseManager, AuditModel
log = logging.getLogger(__name__)
@ -18,6 +19,7 @@ class DiscordBot(commands.Bot):
def __init__(self, BASE_DIR: Path):
super().__init__(command_prefix="-", intents=Intents.all())
self.functions = Functions(self)
self.BASE_DIR = BASE_DIR
async def sync_app_commands(self):

4
src/errors.py Normal file
View File

@ -0,0 +1,4 @@
class IllegalFeed(Exception):
pass

View File

@ -14,8 +14,8 @@ from discord.app_commands import Choice, Group, autocomplete, choices, rename
from sqlalchemy import insert, select, and_, delete
from sqlalchemy.exc import NoResultFound
from utils import get_rss_data, followup, audit, followup_error # pylint: disable=E0401
from feed import get_source, get_unparsed_feed, Source # pylint: disable=E0401
from utils import get_rss_data, followup, audit, followup_error, extract_error_info # pylint: disable=E0401
from feed import Source # pylint: disable=E0401
from db import ( # pylint: disable=E0401
DatabaseManager,
SentArticleModel,
@ -23,6 +23,7 @@ from db import ( # pylint: disable=E0401
FeedChannelModel,
AuditModel
)
from errors import IllegalFeed
log = logging.getLogger(__name__)
@ -80,7 +81,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed
return None, feed
async def set_all_articles_as_sent(inter, channel: TextChannel, feed_id: int, rss_url: str):
unparsed_feed = await get_unparsed_feed(rss_url)
unparsed_feed = await self.bot.functions.get_unparsed_feed(rss_url)
source = Source.from_parsed(parse(unparsed_feed))
articles = source.get_latest_articles()
@ -167,30 +168,16 @@ class FeedCog(commands.Cog):
await inter.response.defer()
illegal_message, feed = await validate_rss_source(nickname, url)
if illegal_message:
await followup(inter, illegal_message, suppress_embeds=True)
return
log.debug("RSS feed added")
async with DatabaseManager() as database:
query = insert(RssSourceModel).values(
discord_server_id = inter.guild_id,
rss_url = url,
nick=nickname
)
await database.session.execute(query)
await audit(self,
f"Added RSS source ({nickname=}, {url=})",
inter.user.id, database=database
)
try:
source = self.bot.functions.create_new_feed(nickname, url)
except IllegalFeed as error:
title, desc = extract_error_info(error)
followup_error(inter, title=title, description=desc)
embed = Embed(title="RSS Feed Added", colour=Colour.dark_green())
embed.add_field(name="Nickname", value=nickname)
embed.add_field(name="URL", value=url)
embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href"))
embed.set_thumbnail(url=source.thumb_url)
await followup(inter, embed=embed)
@ -241,7 +228,7 @@ class FeedCog(commands.Cog):
inter.user.id, database=database
)
source = get_source(url) # TODO: replace with async function
source = await Source.from_url(url)
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
embed.add_field(name="Nickname", value=nickname)

View File

@ -12,7 +12,7 @@ from discord import Interaction, TextChannel
from discord.ext import commands, tasks
from discord.errors import Forbidden
from feed import Source, Article, get_unparsed_feed # pylint disable=E0401
from feed import Source, Article # pylint disable=E0401
from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel # pylint disable=E0401
log = logging.getLogger(__name__)
@ -68,7 +68,7 @@ class TaskCog(commands.Cog):
channel = self.bot.get_channel(feed.discord_channel_id)
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url)
parsed_feed = parse(unparsed_content)
source = Source.from_parsed(parsed_feed)
articles = source.get_latest_articles(5)

View File

@ -4,6 +4,7 @@ import logging
import async_timeout
from dataclasses import dataclass
from datetime import datetime
from typing import Tuple
import aiohttp
@ -14,6 +15,9 @@ from discord import Embed, Colour
from bs4 import BeautifulSoup as bs4
from feedparser import FeedParserDict, parse
from utils import audit
from errors import IllegalFeed
log = logging.getLogger(__name__)
dumps = lambda _dict: json.dumps(_dict, indent=8)
@ -156,6 +160,11 @@ class Source:
feed=feed
)
@classmethod
async def from_url(cls, url: str):
unparsed_content = await Functions.get_unparsed_feed(url)
return
def get_latest_articles(self, max: int = 999) -> list[Article]:
"""Returns a list of Article objects.
@ -177,30 +186,66 @@ class Source:
for i, entry in enumerate(self.feed.entries)
if i < max
]
async def fetch(session, url: str) -> str:
async with async_timeout.timeout(20):
async with session.get(url) as response:
return await response.text()
async def get_unparsed_feed(url: str):
async with aiohttp.ClientSession() as session:
return await fetch(session, url)
def get_source(rss_url: str) -> Source:
"""_summary_
class Functions:
Parameters
----------
rss_url : str
_description_
def __init__(self, bot):
self.bot = bot
Returns
-------
Source
_description_
"""
@staticmethod
async def fetch(session, url: str) -> str:
async with async_timeout.timeout(20):
async with session.get(url) as response:
return await response.text()
parsed_feed = parse(rss_url) # TODO: make asyncronous
return Source.from_parsed(parsed_feed)
@staticmethod
async def get_unparsed_feed(url: str):
async with aiohttp.ClientSession() as session:
return await self.fetch(session, url) # TODO: work from here
async def validate_feed(self, nickname: str, url: str) -> FeedParserDict:
""""""
# Ensure the URL is valid
if not validators.url(url):
raise IllegalFeed(f"The URL you have entered is malformed or invalid:\n`{url=}`")
# Check the nickname is not a URL
if validators.url(nickname):
raise IllegalFeed(
"It looks like the nickname you have entered is a URL.\n" \
f"For security reasons, this is not allowed.\n`{nickname=}`"
)
feed_data, status_code = await get_rss_data(url)
if status_code != 200:
raise IllegalFeed(
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}"
)
# Check the contents is actually an RSS feed.
feed = parse(feed_data)
if not feed.version:
raise IllegalFeed(
f"The provided URL '{url}' does not seem to be a valid RSS feed."
)
return feed
async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source:
""""""
parsed_feed = await self.validate_feed(nickname, url)
async with DatabaseManager() as database:
query = insert(RssSourceModel).values(
discord_server_id=guild_id,
rss_url=url,
nick=nickname
)
await database.session.execute(query)
return Source.from_parsed(parsed_feed)

View File

@ -44,5 +44,4 @@ async def main():
await bot.start(token, reconnect=True)
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@ -49,3 +49,8 @@ async def followup_error(inter: Interaction, title: str, message: str, *args, **
),
**kwargs
)
def extract_error_info(error: Exception) -> str:
class_name = error.__class__.__name__
desc = str(error)
return class_name, desc