This commit is contained in:
Corban-Lee Jones 2024-02-07 09:25:32 +00:00
commit 83a889cf61
8 changed files with 475 additions and 198 deletions

View File

@ -1,17 +1,7 @@
# NewsBot
# PYRSS
Bot delivering news articles to discord servers.
An RSS driven Discord bot written in Python.
Plans
Provides user commands for storing RSS feed URLs that can be assigned to any given discord channel.
- Multiple news providers
- Choose how much of each provider should be delivered
- Check for duplicate articles between providers, and only deliver preferred provider article
## Dev Notes:
For the sake of development, the following defintions apply:
- Feed - An RSS feed stored within the database, submitted by a user.
- Assigned Feed - A discord channel set to receive content from a Feed.
Content is shared every 10 minutes as an Embed.

127
src/api.py Normal file
View File

@ -0,0 +1,127 @@
import logging
import aiohttp
log = logging.getLogger(__name__)
class APIException(Exception):
pass
class NotCreatedException(APIException):
pass
class BadStatusException(APIException):
pass
class API:
"""Interactions with the API."""
API_HOST = "http://localhost:8000/"
API_ENDPOINT = API_HOST + "api/"
RSS_FEED_ENDPOINT = API_ENDPOINT + "rssfeed/"
FEED_CHANNEL_ENDPOINT = API_ENDPOINT + "feedchannel/"
def __init__(self, api_token: str, session: aiohttp.ClientSession):
log.debug("API session initialised")
self.session = session
self.token_headers = {"Authorization": f"Token {api_token}"}
async def make_request(self, method: str, url: str, **kwargs) -> dict:
"""Make a request to the given API endpoint.
Args:
method (str): The request method to use, examples: GET, POST, DELETE...
url (str): The API endpoint to request to.
**kwargs: Passed into self.session.request.
Returns:
dict: Dictionary containing status code, json or text.
"""
async with self.session.request(method, url, headers=self.token_headers, **kwargs) as response:
response.raise_for_status()
try:
json = await response.json()
text = None
except aiohttp.ContentTypeError:
json = None
text = await response.text()
status = response.status
return {"json": json, "text": text, "status": status}
async def create_new_rssfeed(self, name: str, url: str, image_url: str, discord_server_id: int) -> dict:
"""Create a new RSS Feed.
Args:
name (str): Name of the RSS Feed.
url (str): URL for the RSS Feed.
image_url (str): URL of the image representation of the RSS Feed.
discord_server_id (int): ID of the discord server behind this item.
Returns:
dict: JSON representation of the newly created RSS Feed.
"""
log.debug("creating rssfeed: %s %s %s %s", name, url, image_url, discord_server_id)
async with self.session.get(image_url) as response:
image_data = await response.read()
# Using formdata to make the image transfer easier.
form = aiohttp.FormData({
"name": name,
"url": url,
"discord_server_id": discord_server_id
})
form.add_field("image", image_data, filename="file.jpg")
data = (await self.make_response("POST", self.RSS_FEED_ENDPOINT, data=form))["json"]
return data
async def get_rssfeed(self, uuid: str) -> dict:
"""Get a particular RSS Feed given it's UUID.
Args:
uuid (str): Identifier of the desired RSS Feed.
Returns:
dict: A JSON representation of the RSS Feed.
"""
log.debug("getting rssfeed: %s", uuid)
endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/"
data = (await self.make_request("GET", endpoint))["json"]
return data
async def get_rssfeed_list(self, **filters) -> tuple[list[dict], int]:
"""Get all RSS Feeds with the associated filters.
Returns:
tuple[list[dict], int] list contains dictionaries of each item, int is total items.
"""
log.debug("getting list of rss feeds with filters: %s", filters)
data = (await self.make_request("GET", self.RSS_FEED_ENDPOINT, params=filters))["json"]
return data["results"], data["count"]
async def delete_rssfeed(self, uuid: str) -> int:
"""Delete a specified RSS Feed.
Args:
uuid (str): Identifier of the RSS Feed to delete.
Returns:
int: Status code of the response.
"""
log.debug("deleting rssfeed: %s", uuid)
endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/"
status = (await self.make_request("DELETE", endpoint))["status"]
return status

View File

@ -5,7 +5,7 @@ The discord bot for the application.
import logging
from pathlib import Path
from discord import Intents
from discord import Intents, Game
from discord.ext import commands
from sqlalchemy import insert
@ -17,11 +17,13 @@ log = logging.getLogger(__name__)
class DiscordBot(commands.Bot):
def __init__(self, BASE_DIR: Path, developing: bool):
super().__init__(command_prefix="-", intents=Intents.all())
self.functions = Functions(self)
def __init__(self, BASE_DIR: Path, developing: bool, api_token: str):
activity = Game("Indev") if developing else None
super().__init__(command_prefix="-", intents=Intents.all(), activity=activity)
self.functions = Functions(self, api_token)
self.BASE_DIR = BASE_DIR
self.developing = developing
self.api_token = api_token
log.info("developing=%s", developing)

View File

@ -6,6 +6,7 @@ Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bo
import logging
from typing import Tuple
import aiohttp
import validators
from feedparser import FeedParserDict, parse
from discord.ext import commands
@ -14,7 +15,8 @@ from discord.app_commands import Choice, Group, autocomplete, choices, rename
from sqlalchemy import insert, select, and_, delete
from sqlalchemy.exc import NoResultFound, IntegrityError
from feed import Source
from api import API
from feed import Source, RSSFeed
from errors import IllegalFeed
from db import (
DatabaseManager,
@ -25,6 +27,7 @@ from db import (
)
from utils import (
Followup,
PaginationView,
get_rss_data,
followup,
audit,
@ -121,6 +124,21 @@ class FeedCog(commands.Cog):
log.info("%s cog is ready", self.__class__.__name__)
async def autocomplete_rssfeed(self, inter: Interaction, name: str) -> list[Choice]:
async with aiohttp.ClientSession() as session:
data, _ = await API(self.bot.api_token, session).get_rssfeed_list(
discord_server_id=inter.guild_id
)
rssfeeds = RSSFeed.from_list(data)
choices = [
Choice(name=item.name, value=item.uuid)
for item in rssfeeds
]
return choices
async def source_autocomplete(self, inter: Interaction, nickname: str):
"""Provides RSS source autocomplete functionality for commands.
@ -154,69 +172,55 @@ class FeedCog(commands.Cog):
# All RSS commands belong to this group.
feed_group = Group(
name="feed",
description="Commands for rss sources.",
description="Commands for RSS sources.",
default_permissions=Permissions.elevated(),
guild_only=True # We store guild IDs in the database, so guild only = True
)
@feed_group.command(name="add")
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
"""Add a new Feed for this server.
@feed_group.command(name="new")
async def add_rssfeed(self, inter: Interaction, name: str, url: str):
"""Add a new RSS Feed for this server.
Parameters
----------
inter : Interaction
Represents an app command interaction.
nickname : str
A name used to identify the Feed.
url : str
The Feed URL.
Args:
inter (Interaction): Represents the discord command interaction.
name (str): A nickname used to refer to this RSS Feed.
url (str): The URL of the RSS Feed.
"""
await inter.response.defer()
try:
source = await self.bot.functions.create_new_feed(nickname, url, inter.guild_id)
except IllegalFeed as error:
title, desc = extract_error_info(error)
await Followup(title, desc).fields(**error.items).error().send(inter)
except IntegrityError as error:
rssfeed = await self.bot.functions.create_new_rssfeed(name, url, inter.guild_id)
except Exception as exc:
await (
Followup(
"Duplicate Feed Error",
"A Feed with the same nickname already exist."
)
.fields(nickname=nickname)
Followup(exc.__class__.__name__, str(exc))
.error()
.send(inter)
)
else:
await (
Followup("Feed Added")
.image(source.icon_url)
.fields(nickname=nickname, url=url)
Followup("New RSS Feed")
.image(rssfeed.image)
.fields(uuid=rssfeed.uuid, name=name, url=url)
.added()
.send(inter)
)
@feed_group.command(name="remove")
@rename(url="option")
@autocomplete(url=source_autocomplete)
async def remove_rss_source(self, inter: Interaction, url: str):
"""Delete an existing Feed from this server.
@feed_group.command(name="delete")
@autocomplete(uuid=autocomplete_rssfeed)
@rename(uuid="rssfeed")
async def delete_rssfeed(self, inter: Interaction, uuid: str):
"""Delete an existing RSS Feed for this server.
Parameters
----------
inter : Interaction
Represents an app command interaction.
url : str
The Feed to be removed. Autocomplete or enter the URL.
Args:
inter (Interaction): Represents the discord command interaction.
uuid (str): The UUID of the
"""
await inter.response.defer()
try:
source = await self.bot.functions.delete_feed(url, inter.guild_id)
rssfeed = await self.bot.functions.delete_rssfeed(uuid)
except NoResultFound:
await (
Followup(
@ -229,47 +233,43 @@ class FeedCog(commands.Cog):
else:
await (
Followup("Feed Deleted")
.image(source.icon_url)
.fields(url=url)
.image(rssfeed.image)
.fields(uuid=rssfeed.uuid, name=rssfeed.name, url=rssfeed.url)
.trash()
.send(inter)
)
@feed_group.command(name="list")
async def list_rss_sources(self, inter: Interaction):
"""Provides a with a list of Feeds available for this server.
async def list_rssfeeds(self, inter: Interaction):
"""Provides a list of all RSS Feeds
Parameters
----------
inter : Interaction
Represents an app command interaction.
Args:
inter (Interaction): Represents the discord command interaction.
"""
await inter.response.defer()
page = 1
pagesize = 10
try:
feeds = await self.bot.functions.get_feeds(inter.guild_id)
except NoResultFound:
def formatdata(index, item):
key = f"{index}. {item.name}"
value = f"[RSS]({item.url}) · [API](http://localhost:8000/api/rssfeed/{item.uuid}/)"
return key, value
async def getdata(page):
data, count = await self.bot.functions.get_rssfeeds(inter.guild_id, page, pagesize)
return data, count
embed = Followup(f"Available RSS Feeds in {inter.guild.name}").info()._embed
pagination = PaginationView(self.bot, inter, embed, getdata, formatdata, pagesize, 1)
await pagination.send()
except Exception as exc:
await (
Followup(
"Feeds Not Found Error",
"There are no available Feeds for this server.\n"
"Add a new feed with `/feed add`."
)
Followup(exc.__class__.__name__, str(exc))
.error()
.send()
)
else:
description = "\n".join([
f"{i}. **[{info[0]}]({info[1]})**" # info = (nick, url)
for i, info in enumerate(feeds)
])
await (
Followup(
f"Available Feeds in {inter.guild.name}",
description
)
.info()
.send(inter)
)

View File

@ -10,12 +10,13 @@ from time import process_time
import aiohttp
from discord import TextChannel
from discord import app_commands
from discord.ext import commands, tasks
from discord.errors import Forbidden
from sqlalchemy import insert, select, and_
from feedparser import parse
from feed import Source, Article
from feed import Source, Article, RSSFeed
from db import (
DatabaseManager,
FeedChannelModel,
@ -23,11 +24,13 @@ from db import (
SentArticleModel
)
from utils import get_unparsed_feed
from api import API
log = logging.getLogger(__name__)
TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES")
# task trigger times : must be of type list
times = [
datetime.time(hour, minute, tzinfo=datetime.timezone.utc)
for hour in range(24)
@ -62,6 +65,12 @@ class TaskCog(commands.Cog):
self.rss_task.cancel()
@app_commands.command(name="debug-trigger-task")
async def debug_trigger_task(self, inter):
await inter.response.defer()
await self.rss_task()
await inter.followup.send("done")
@tasks.loop(time=times)
async def rss_task(self):
"""Automated task responsible for processing rss feeds."""
@ -69,13 +78,24 @@ class TaskCog(commands.Cog):
log.info("Running rss task")
time = process_time()
async with DatabaseManager() as database:
query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel)
result = await database.session.execute(query)
feeds = result.scalars().all()
# async with DatabaseManager() as database:
# query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel)
# result = await database.session.execute(query)
# feeds = result.scalars().all()
# for feed in feeds:
# await self.process_feed(feed, database)
guild_ids = [guild.id for guild in self.bot.guilds]
async with aiohttp.ClientSession() as session:
api = API(self.bot.api_token, session)
data, count = await api.get_rssfeed_list(discord_server_id__in=guild_ids)
rssfeeds = RSSFeed.from_list(data)
for item in rssfeeds:
log.info(item.name)
for feed in feeds:
await self.process_feed(feed, database)
log.info("Finished rss task, time elapsed: %s", process_time() - time)

View File

@ -1,4 +1,5 @@
import ssl
import json
import logging
from dataclasses import dataclass
@ -18,6 +19,7 @@ from textwrap import shorten
from errors import IllegalFeed
from db import DatabaseManager, RssSourceModel, FeedChannelModel
from utils import get_rss_data, get_unparsed_feed
from api import API
log = logging.getLogger(__name__)
dumps = lambda _dict: json.dumps(_dict, indent=8)
@ -140,7 +142,7 @@ class Article:
@dataclass
class Source:
"""Represents an RSS source."""
name: str | None
url: str | None
icon_url: str | None
@ -173,7 +175,9 @@ class Source:
@classmethod
async def from_url(cls, url: str):
unparsed_content = await get_unparsed_feed(url)
return cls.from_parsed(parse(unparsed_content))
source = cls.from_parsed(parse(unparsed_content))
source.url = url
return source
def get_latest_articles(self, max: int = 999) -> list[Article]:
"""Returns a list of Article objects.
@ -198,10 +202,37 @@ class Source:
]
@dataclass
class RSSFeed:
uuid: str
name: str
url: str
image: str
discord_server_id: id
created_at: str
@classmethod
def from_list(cls, data: list) -> list:
result = []
for item in data:
key = "discord_server_id"
item[key] = int(item.get(key))
result.append(cls(**item))
return result
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
class Functions:
def __init__(self, bot):
def __init__(self, bot, api_token: str):
self.bot = bot
self.api_token = api_token
async def validate_feed(self, nickname: str, url: str) -> FeedParserDict:
"""Validates a feed based on the given nickname and url.
@ -254,106 +285,50 @@ class Functions:
return feed
async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source:
"""Create a new Feed, and return it as a Source object.
async def create_new_rssfeed(self, name: str, url: str, guild_id: int) -> RSSFeed:
Parameters
----------
nickname : str
Human readable nickname used to refer to the feed.
url : str
URL to fetch content from the feed.
guild_id : int
Discord Server ID associated with the feed.
Returns
-------
Source
Dataclass containing attributes of the feed.
log.info("Creating new Feed: %s", name)
parsed_feed = await self.validate_feed(name, url)
source = Source.from_parsed(parsed_feed)
async with aiohttp.ClientSession() as session:
data = await API(self.api_token, session).create_new_rssfeed(
name, url, source.icon_url, guild_id
)
return RSSFeed.from_dict(data)
async def delete_rssfeed(self, uuid: str) -> RSSFeed:
log.info("Deleting Feed '%s'", uuid)
async with aiohttp.ClientSession() as session:
api = API(self.api_token, session)
data = await api.get_rssfeed(uuid)
await api.delete_rssfeed(uuid)
return RSSFeed.from_dict(data)
async def get_rssfeeds(self, guild_id: int, page: int, pagesize: int) -> list[RSSFeed]:
"""Get a list of RSS Feeds.
Args:
guild_id (int): The guild_id to filter by.
Returns:
list[RSSFeed]: Resulting list of RSS Feeds
"""
log.info("Creating new Feed: %s - %s", nickname, guild_id)
parsed_feed = await self.validate_feed(nickname, url)
async with DatabaseManager() as database:
query = insert(RssSourceModel).values(
async with aiohttp.ClientSession() as session:
data, count = await API(self.api_token, session).get_rssfeed_list(
discord_server_id=guild_id,
rss_url=url,
nick=nickname
)
await database.session.execute(query)
log.info("Created Feed: %s - %s", nickname, guild_id)
return Source.from_parsed(parsed_feed)
async def delete_feed(self, url: str, guild_id: int) -> Source:
"""Delete an existing Feed, then return it as a Source object.
Parameters
----------
url : str
URL of the feed, used in the whereclause.
guild_id : int
Discord Server ID of the feed, used in the whereclause.
Returns
-------
Source
Dataclass containing attributes of the feed.
"""
log.info("Deleting Feed: %s - %s", url, guild_id)
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == guild_id,
RssSourceModel.rss_url == url
page=page,
page_size=pagesize
)
# Select the Feed entry, because an exception is raised if not found.
select_query = select(RssSourceModel).filter(whereclause)
select_result = await database.session.execute(select_query)
select_result.scalars().one()
delete_query = delete(RssSourceModel).filter(whereclause)
await database.session.execute(delete_query)
log.info("Deleted Feed: %s - %s", url, guild_id)
return await Source.from_url(url)
async def get_feeds(self, guild_id: int) -> list[tuple[str, str]]:
"""Returns a list of fetched Feed objects from the database.
Note: a request will be made too all found Feed UR Ls.
Parameters
----------
guild_id : int
The Discord Server ID, used to filter down the Feed query.
Returns
-------
list[tuple[str, str]]
List of Source objects, resulting from the query.
Raises
------
NoResultFound
Raised if no results are found.
"""
async with DatabaseManager() as database:
whereclause = and_(RssSourceModel.discord_server_id == guild_id)
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
rss_sources = result.scalars().all()
if not rss_sources:
raise NoResultFound
return [(feed.nick, feed.rss_url) for feed in rss_sources]
return RSSFeed.from_list(data), count
async def assign_feed(
self, url: str, channel_name: str, channel_id: int, guild_id: int

View File

@ -9,9 +9,9 @@ from os import getenv
from pathlib import Path
# it's important to load environment variables before
# importing the packages that depend on them.
# importing the modules that depend on them.
from dotenv import load_dotenv
load_dotenv()
load_dotenv(override=True)
from bot import DiscordBot
from logs import LogSetup
@ -26,12 +26,17 @@ async def main():
# Grab the token before anything else, because if there is no token
# available then the bot cannot be started anyways.
token = getenv("BOT_TOKEN")
bot_token = getenv("BOT_TOKEN")
if not bot_token:
raise ValueError("Bot Token is empty")
if not token:
raise ValueError("Token is empty")
# ^ same story for the API token. Without it the API cannot be
# interacted with, so grab it first.
api_token = getenv("API_TOKEN")
if not api_token:
raise ValueError("API Token is empty")
developing = getenv("DEVELOPING") == "True"
developing = getenv("DEVELOPING", "False") == "True"
# Setup logging settings and mute spammy loggers
logsetup = LogSetup(BASE_DIR / "logs/")
@ -41,9 +46,10 @@ async def main():
level=logging.WARNING
)
async with DiscordBot(BASE_DIR, developing=developing) as bot:
async with DiscordBot(BASE_DIR, developing=developing, api_token=api_token) as bot:
await bot.load_extensions()
await bot.start(token, reconnect=True)
await bot.start(bot_token, reconnect=True)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -3,8 +3,11 @@
import aiohttp
import logging
import async_timeout
from typing import Callable
from discord import Interaction, Embed, Colour
from discord import Interaction, Embed, Colour, ButtonStyle, Button
from discord.ui import View, button
from discord.ext.commands import Bot
log = logging.getLogger(__name__)
@ -56,6 +59,153 @@ class FollowupIcons:
assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
class PaginationView(View):
"""A Discord UI View that adds pagination to an embed."""
def __init__(
self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable,
formatdata: Callable, pagesize: int, initpage: int=1
):
"""_summary_
Args:
bot (commands.Bot) The discord bot
inter (Interaction): Represents a discord command interaction.
embed (Embed): The base embed to paginate.
getdata (Callable): A function that provides data, must return Tuple[List[Any], int].
formatdata (Callable): A formatter function that determines how the data is displayed.
pagesize (int): The size of each page.
initpage (int, optional): The inital page. Defaults to 1.
"""
self.bot = bot
self.inter = inter
self.embed = embed
self.getdata = getdata
self.formatdata = formatdata
self.maxpage = None
self.pagesize = pagesize
self.index = initpage
# emoji reference
next_emoji = bot.get_emoji(1204542366602502265)
prev_emoji = bot.get_emoji(1204542365432422470)
self.start_emoji = bot.get_emoji(1204542364073463818)
self.end_emoji = bot.get_emoji(1204542367752003624)
super().__init__(timeout=100)
async def check_user_is_author(self, inter: Interaction) -> bool:
"""Ensure the user is the author of the original command."""
if inter.user == self.inter.user:
return True
await inter.response.defer()
await (
Followup(None, "Only the author can interact with this.")
.error()
.send(inter, ephemeral=True)
)
return False
async def on_timeout(self):
"""Erase the controls on timeout."""
message = await self.inter.original_response()
await message.edit(view=None)
@staticmethod
def calc_total_pages(results: int, max_pagesize: int) -> int:
result = ((results - 1) // max_pagesize) + 1
log.debug("total pages calculated: %s", result)
return result
def calc_dataitem_index(self, dataitem_index: int):
"""Calculates a given index to be relative to the sum of all pages items
Example: dataitem_index = 6
pagesize = 10
if page == 1 then return 6
else return 6 + 10 * (page - 1)"""
if self.index > 1:
dataitem_index += self.pagesize * (self.index - 1)
dataitem_index += 1
return dataitem_index
@button(emoji="◀️", style=ButtonStyle.blurple)
async def backward(self, inter: Interaction, button: Button):
self.index -= 1
await inter.response.defer()
self.inter = inter
await self.navigate()
@button(emoji="▶️", style=ButtonStyle.blurple)
async def forward(self, inter: Interaction, button: Button):
self.index += 1
await inter.response.defer()
self.inter = inter
await self.navigate()
@button(emoji="⏭️", style=ButtonStyle.blurple)
async def start_or_end(self, inter: Interaction, button: Button):
if self.index <= self.maxpage // 2:
self.index = self.maxpage
else:
self.index = 1
await inter.response.defer()
self.inter = inter
await self.navigate()
async def navigate(self):
log.debug("navigating to page: %s", self.index)
self.update_buttons()
paged_embed = await self.create_paged_embed()
await self.inter.edit_original_response(embed=paged_embed, view=self)
async def create_paged_embed(self) -> Embed:
embed = self.embed.copy()
data, total_results = await self.getdata(self.index)
self.maxpage = self.calc_total_pages(total_results, self.pagesize)
for i, item in enumerate(data):
i = self.calc_dataitem_index(i)
key, value = self.formatdata(i, item)
embed.add_field(name=key, value=value, inline=False)
if not total_results:
embed.description = "There are no results"
if self.maxpage > 1:
embed.set_footer(text=f"Page {self.index}/{self.maxpage}")
return embed
def update_buttons(self):
if self.index >= self.maxpage:
self.children[2].emoji = self.start_emoji
else:
self.children[2].emoji = self.end_emoji
self.children[0].disabled = self.index == 1
self.children[1].disabled = self.index == self.maxpage
async def send(self):
embed = await self.create_paged_embed()
if self.maxpage <= 1:
await self.inter.edit_original_response(embed=embed)
return
self.update_buttons()
await self.inter.edit_original_response(embed=embed, view=self)
class Followup:
"""Wrapper for a discord embed to follow up an interaction."""
@ -69,10 +219,10 @@ class Followup:
description=description
)
async def send(self, inter: Interaction, message: str = None):
async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False):
""""""
await inter.followup.send(content=message, embed=self._embed)
await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral)
def fields(self, inline: bool = False, **fields: dict):
""""""
@ -89,6 +239,13 @@ class Followup:
return self
def footer(self, text: str, icon_url: str = None):
""""""
self._embed.set_footer(text=text, icon_url=icon_url)
return self
def error(self):
""""""