incomplete commit, used to work on other machine
This commit is contained in:
parent
e95edf08df
commit
94b7972a38
@ -9,6 +9,7 @@ from discord import Intents
|
||||
from discord.ext import commands
|
||||
from sqlalchemy import insert
|
||||
|
||||
from feed import Functions
|
||||
from db import DatabaseManager, AuditModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -18,6 +19,7 @@ class DiscordBot(commands.Bot):
|
||||
|
||||
def __init__(self, BASE_DIR: Path):
|
||||
super().__init__(command_prefix="-", intents=Intents.all())
|
||||
self.functions = Functions(self)
|
||||
self.BASE_DIR = BASE_DIR
|
||||
|
||||
async def sync_app_commands(self):
|
||||
|
4
src/errors.py
Normal file
4
src/errors.py
Normal file
@ -0,0 +1,4 @@
|
||||
|
||||
class IllegalFeed(Exception):
|
||||
pass
|
||||
|
@ -14,8 +14,8 @@ from discord.app_commands import Choice, Group, autocomplete, choices, rename
|
||||
from sqlalchemy import insert, select, and_, delete
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
|
||||
from utils import get_rss_data, followup, audit, followup_error # pylint: disable=E0401
|
||||
from feed import get_source, get_unparsed_feed, Source # pylint: disable=E0401
|
||||
from utils import get_rss_data, followup, audit, followup_error, extract_error_info # pylint: disable=E0401
|
||||
from feed import Source # pylint: disable=E0401
|
||||
from db import ( # pylint: disable=E0401
|
||||
DatabaseManager,
|
||||
SentArticleModel,
|
||||
@ -23,6 +23,7 @@ from db import ( # pylint: disable=E0401
|
||||
FeedChannelModel,
|
||||
AuditModel
|
||||
)
|
||||
from errors import IllegalFeed
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -80,7 +81,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed
|
||||
return None, feed
|
||||
|
||||
async def set_all_articles_as_sent(inter, channel: TextChannel, feed_id: int, rss_url: str):
|
||||
unparsed_feed = await get_unparsed_feed(rss_url)
|
||||
unparsed_feed = await self.bot.functions.get_unparsed_feed(rss_url)
|
||||
source = Source.from_parsed(parse(unparsed_feed))
|
||||
articles = source.get_latest_articles()
|
||||
|
||||
@ -167,30 +168,16 @@ class FeedCog(commands.Cog):
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
illegal_message, feed = await validate_rss_source(nickname, url)
|
||||
if illegal_message:
|
||||
await followup(inter, illegal_message, suppress_embeds=True)
|
||||
return
|
||||
|
||||
log.debug("RSS feed added")
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = insert(RssSourceModel).values(
|
||||
discord_server_id = inter.guild_id,
|
||||
rss_url = url,
|
||||
nick=nickname
|
||||
)
|
||||
await database.session.execute(query)
|
||||
|
||||
await audit(self,
|
||||
f"Added RSS source ({nickname=}, {url=})",
|
||||
inter.user.id, database=database
|
||||
)
|
||||
try:
|
||||
source = self.bot.functions.create_new_feed(nickname, url)
|
||||
except IllegalFeed as error:
|
||||
title, desc = extract_error_info(error)
|
||||
followup_error(inter, title=title, description=desc)
|
||||
|
||||
embed = Embed(title="RSS Feed Added", colour=Colour.dark_green())
|
||||
embed.add_field(name="Nickname", value=nickname)
|
||||
embed.add_field(name="URL", value=url)
|
||||
embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href"))
|
||||
embed.set_thumbnail(url=source.thumb_url)
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@ -241,7 +228,7 @@ class FeedCog(commands.Cog):
|
||||
inter.user.id, database=database
|
||||
)
|
||||
|
||||
source = get_source(url) # TODO: replace with async function
|
||||
source = await Source.from_url(url)
|
||||
|
||||
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
|
||||
embed.add_field(name="Nickname", value=nickname)
|
||||
|
@ -12,7 +12,7 @@ from discord import Interaction, TextChannel
|
||||
from discord.ext import commands, tasks
|
||||
from discord.errors import Forbidden
|
||||
|
||||
from feed import Source, Article, get_unparsed_feed # pylint disable=E0401
|
||||
from feed import Source, Article # pylint disable=E0401
|
||||
from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel # pylint disable=E0401
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -68,7 +68,7 @@ class TaskCog(commands.Cog):
|
||||
|
||||
channel = self.bot.get_channel(feed.discord_channel_id)
|
||||
|
||||
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
|
||||
unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url)
|
||||
parsed_feed = parse(unparsed_content)
|
||||
source = Source.from_parsed(parsed_feed)
|
||||
articles = source.get_latest_articles(5)
|
||||
|
89
src/feed.py
89
src/feed.py
@ -4,6 +4,7 @@ import logging
|
||||
import async_timeout
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
@ -14,6 +15,9 @@ from discord import Embed, Colour
|
||||
from bs4 import BeautifulSoup as bs4
|
||||
from feedparser import FeedParserDict, parse
|
||||
|
||||
from utils import audit
|
||||
from errors import IllegalFeed
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
dumps = lambda _dict: json.dumps(_dict, indent=8)
|
||||
|
||||
@ -156,6 +160,11 @@ class Source:
|
||||
feed=feed
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def from_url(cls, url: str):
|
||||
unparsed_content = await Functions.get_unparsed_feed(url)
|
||||
return
|
||||
|
||||
def get_latest_articles(self, max: int = 999) -> list[Article]:
|
||||
"""Returns a list of Article objects.
|
||||
|
||||
@ -177,30 +186,66 @@ class Source:
|
||||
for i, entry in enumerate(self.feed.entries)
|
||||
if i < max
|
||||
]
|
||||
|
||||
async def fetch(session, url: str) -> str:
|
||||
async with async_timeout.timeout(20):
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
|
||||
async def get_unparsed_feed(url: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
return await fetch(session, url)
|
||||
|
||||
|
||||
def get_source(rss_url: str) -> Source:
|
||||
"""_summary_
|
||||
class Functions:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rss_url : str
|
||||
_description_
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
|
||||
Returns
|
||||
-------
|
||||
Source
|
||||
_description_
|
||||
"""
|
||||
@staticmethod
|
||||
async def fetch(session, url: str) -> str:
|
||||
async with async_timeout.timeout(20):
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
|
||||
parsed_feed = parse(rss_url) # TODO: make asyncronous
|
||||
return Source.from_parsed(parsed_feed)
|
||||
@staticmethod
|
||||
async def get_unparsed_feed(url: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
return await self.fetch(session, url) # TODO: work from here
|
||||
|
||||
async def validate_feed(self, nickname: str, url: str) -> FeedParserDict:
|
||||
""""""
|
||||
|
||||
# Ensure the URL is valid
|
||||
if not validators.url(url):
|
||||
raise IllegalFeed(f"The URL you have entered is malformed or invalid:\n`{url=}`")
|
||||
|
||||
# Check the nickname is not a URL
|
||||
if validators.url(nickname):
|
||||
raise IllegalFeed(
|
||||
"It looks like the nickname you have entered is a URL.\n" \
|
||||
f"For security reasons, this is not allowed.\n`{nickname=}`"
|
||||
)
|
||||
|
||||
feed_data, status_code = await get_rss_data(url)
|
||||
|
||||
if status_code != 200:
|
||||
raise IllegalFeed(
|
||||
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}"
|
||||
)
|
||||
|
||||
# Check the contents is actually an RSS feed.
|
||||
feed = parse(feed_data)
|
||||
if not feed.version:
|
||||
raise IllegalFeed(
|
||||
f"The provided URL '{url}' does not seem to be a valid RSS feed."
|
||||
)
|
||||
|
||||
return feed
|
||||
|
||||
|
||||
async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source:
|
||||
""""""
|
||||
|
||||
parsed_feed = await self.validate_feed(nickname, url)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = insert(RssSourceModel).values(
|
||||
discord_server_id=guild_id,
|
||||
rss_url=url,
|
||||
nick=nickname
|
||||
)
|
||||
await database.session.execute(query)
|
||||
|
||||
return Source.from_parsed(parsed_feed)
|
||||
|
@ -44,5 +44,4 @@ async def main():
|
||||
await bot.start(token, reconnect=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
@ -49,3 +49,8 @@ async def followup_error(inter: Interaction, title: str, message: str, *args, **
|
||||
),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def extract_error_info(error: Exception) -> str:
|
||||
class_name = error.__class__.__name__
|
||||
desc = str(error)
|
||||
return class_name, desc
|
||||
|
Loading…
x
Reference in New Issue
Block a user