Merge branch 'main' of https://gitea.corbz.dev/corbz/NewsBot
This commit is contained in:
commit
83a889cf61
18
README.md
18
README.md
@ -1,17 +1,7 @@
|
||||
# NewsBot
|
||||
# PYRSS
|
||||
|
||||
Bot delivering news articles to discord servers.
|
||||
An RSS driven Discord bot written in Python.
|
||||
|
||||
Plans
|
||||
Provides user commands for storing RSS feed URLs that can be assigned to any given discord channel.
|
||||
|
||||
- Multiple news providers
|
||||
- Choose how much of each provider should be delivered
|
||||
- Check for duplicate articles between providers, and only deliver preferred provider article
|
||||
|
||||
|
||||
## Dev Notes:
|
||||
|
||||
For the sake of development, the following defintions apply:
|
||||
|
||||
- Feed - An RSS feed stored within the database, submitted by a user.
|
||||
- Assigned Feed - A discord channel set to receive content from a Feed.
|
||||
Content is shared every 10 minutes as an Embed.
|
127
src/api.py
Normal file
127
src/api.py
Normal file
@ -0,0 +1,127 @@
|
||||
|
||||
import logging
|
||||
|
||||
import aiohttp
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NotCreatedException(APIException):
|
||||
pass
|
||||
|
||||
class BadStatusException(APIException):
|
||||
pass
|
||||
|
||||
|
||||
class API:
|
||||
"""Interactions with the API."""
|
||||
|
||||
API_HOST = "http://localhost:8000/"
|
||||
API_ENDPOINT = API_HOST + "api/"
|
||||
|
||||
RSS_FEED_ENDPOINT = API_ENDPOINT + "rssfeed/"
|
||||
FEED_CHANNEL_ENDPOINT = API_ENDPOINT + "feedchannel/"
|
||||
|
||||
def __init__(self, api_token: str, session: aiohttp.ClientSession):
|
||||
log.debug("API session initialised")
|
||||
self.session = session
|
||||
self.token_headers = {"Authorization": f"Token {api_token}"}
|
||||
|
||||
async def make_request(self, method: str, url: str, **kwargs) -> dict:
|
||||
"""Make a request to the given API endpoint.
|
||||
|
||||
Args:
|
||||
method (str): The request method to use, examples: GET, POST, DELETE...
|
||||
url (str): The API endpoint to request to.
|
||||
**kwargs: Passed into self.session.request.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing status code, json or text.
|
||||
"""
|
||||
|
||||
async with self.session.request(method, url, headers=self.token_headers, **kwargs) as response:
|
||||
response.raise_for_status()
|
||||
try:
|
||||
json = await response.json()
|
||||
text = None
|
||||
except aiohttp.ContentTypeError:
|
||||
json = None
|
||||
text = await response.text()
|
||||
|
||||
status = response.status
|
||||
|
||||
return {"json": json, "text": text, "status": status}
|
||||
|
||||
async def create_new_rssfeed(self, name: str, url: str, image_url: str, discord_server_id: int) -> dict:
|
||||
"""Create a new RSS Feed.
|
||||
|
||||
Args:
|
||||
name (str): Name of the RSS Feed.
|
||||
url (str): URL for the RSS Feed.
|
||||
image_url (str): URL of the image representation of the RSS Feed.
|
||||
discord_server_id (int): ID of the discord server behind this item.
|
||||
|
||||
Returns:
|
||||
dict: JSON representation of the newly created RSS Feed.
|
||||
"""
|
||||
|
||||
log.debug("creating rssfeed: %s %s %s %s", name, url, image_url, discord_server_id)
|
||||
|
||||
async with self.session.get(image_url) as response:
|
||||
image_data = await response.read()
|
||||
|
||||
# Using formdata to make the image transfer easier.
|
||||
form = aiohttp.FormData({
|
||||
"name": name,
|
||||
"url": url,
|
||||
"discord_server_id": discord_server_id
|
||||
})
|
||||
form.add_field("image", image_data, filename="file.jpg")
|
||||
|
||||
data = (await self.make_response("POST", self.RSS_FEED_ENDPOINT, data=form))["json"]
|
||||
return data
|
||||
|
||||
async def get_rssfeed(self, uuid: str) -> dict:
|
||||
"""Get a particular RSS Feed given it's UUID.
|
||||
|
||||
Args:
|
||||
uuid (str): Identifier of the desired RSS Feed.
|
||||
|
||||
Returns:
|
||||
dict: A JSON representation of the RSS Feed.
|
||||
"""
|
||||
|
||||
log.debug("getting rssfeed: %s", uuid)
|
||||
endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/"
|
||||
data = (await self.make_request("GET", endpoint))["json"]
|
||||
return data
|
||||
|
||||
async def get_rssfeed_list(self, **filters) -> tuple[list[dict], int]:
|
||||
"""Get all RSS Feeds with the associated filters.
|
||||
|
||||
Returns:
|
||||
tuple[list[dict], int] list contains dictionaries of each item, int is total items.
|
||||
"""
|
||||
|
||||
log.debug("getting list of rss feeds with filters: %s", filters)
|
||||
data = (await self.make_request("GET", self.RSS_FEED_ENDPOINT, params=filters))["json"]
|
||||
return data["results"], data["count"]
|
||||
|
||||
async def delete_rssfeed(self, uuid: str) -> int:
|
||||
"""Delete a specified RSS Feed.
|
||||
|
||||
Args:
|
||||
uuid (str): Identifier of the RSS Feed to delete.
|
||||
|
||||
Returns:
|
||||
int: Status code of the response.
|
||||
"""
|
||||
|
||||
log.debug("deleting rssfeed: %s", uuid)
|
||||
endpoint = f"{self.RSS_FEED_ENDPOINT}/{uuid}/"
|
||||
status = (await self.make_request("DELETE", endpoint))["status"]
|
||||
return status
|
10
src/bot.py
10
src/bot.py
@ -5,7 +5,7 @@ The discord bot for the application.
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from discord import Intents
|
||||
from discord import Intents, Game
|
||||
from discord.ext import commands
|
||||
from sqlalchemy import insert
|
||||
|
||||
@ -17,11 +17,13 @@ log = logging.getLogger(__name__)
|
||||
|
||||
class DiscordBot(commands.Bot):
|
||||
|
||||
def __init__(self, BASE_DIR: Path, developing: bool):
|
||||
super().__init__(command_prefix="-", intents=Intents.all())
|
||||
self.functions = Functions(self)
|
||||
def __init__(self, BASE_DIR: Path, developing: bool, api_token: str):
|
||||
activity = Game("Indev") if developing else None
|
||||
super().__init__(command_prefix="-", intents=Intents.all(), activity=activity)
|
||||
self.functions = Functions(self, api_token)
|
||||
self.BASE_DIR = BASE_DIR
|
||||
self.developing = developing
|
||||
self.api_token = api_token
|
||||
|
||||
log.info("developing=%s", developing)
|
||||
|
||||
|
@ -6,6 +6,7 @@ Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bo
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import aiohttp
|
||||
import validators
|
||||
from feedparser import FeedParserDict, parse
|
||||
from discord.ext import commands
|
||||
@ -14,7 +15,8 @@ from discord.app_commands import Choice, Group, autocomplete, choices, rename
|
||||
from sqlalchemy import insert, select, and_, delete
|
||||
from sqlalchemy.exc import NoResultFound, IntegrityError
|
||||
|
||||
from feed import Source
|
||||
from api import API
|
||||
from feed import Source, RSSFeed
|
||||
from errors import IllegalFeed
|
||||
from db import (
|
||||
DatabaseManager,
|
||||
@ -25,6 +27,7 @@ from db import (
|
||||
)
|
||||
from utils import (
|
||||
Followup,
|
||||
PaginationView,
|
||||
get_rss_data,
|
||||
followup,
|
||||
audit,
|
||||
@ -121,6 +124,21 @@ class FeedCog(commands.Cog):
|
||||
|
||||
log.info("%s cog is ready", self.__class__.__name__)
|
||||
|
||||
async def autocomplete_rssfeed(self, inter: Interaction, name: str) -> list[Choice]:
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
data, _ = await API(self.bot.api_token, session).get_rssfeed_list(
|
||||
discord_server_id=inter.guild_id
|
||||
)
|
||||
rssfeeds = RSSFeed.from_list(data)
|
||||
|
||||
choices = [
|
||||
Choice(name=item.name, value=item.uuid)
|
||||
for item in rssfeeds
|
||||
]
|
||||
|
||||
return choices
|
||||
|
||||
async def source_autocomplete(self, inter: Interaction, nickname: str):
|
||||
"""Provides RSS source autocomplete functionality for commands.
|
||||
|
||||
@ -154,69 +172,55 @@ class FeedCog(commands.Cog):
|
||||
# All RSS commands belong to this group.
|
||||
feed_group = Group(
|
||||
name="feed",
|
||||
description="Commands for rss sources.",
|
||||
description="Commands for RSS sources.",
|
||||
default_permissions=Permissions.elevated(),
|
||||
guild_only=True # We store guild IDs in the database, so guild only = True
|
||||
)
|
||||
|
||||
@feed_group.command(name="add")
|
||||
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
|
||||
"""Add a new Feed for this server.
|
||||
@feed_group.command(name="new")
|
||||
async def add_rssfeed(self, inter: Interaction, name: str, url: str):
|
||||
"""Add a new RSS Feed for this server.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
nickname : str
|
||||
A name used to identify the Feed.
|
||||
url : str
|
||||
The Feed URL.
|
||||
Args:
|
||||
inter (Interaction): Represents the discord command interaction.
|
||||
name (str): A nickname used to refer to this RSS Feed.
|
||||
url (str): The URL of the RSS Feed.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
try:
|
||||
source = await self.bot.functions.create_new_feed(nickname, url, inter.guild_id)
|
||||
except IllegalFeed as error:
|
||||
title, desc = extract_error_info(error)
|
||||
await Followup(title, desc).fields(**error.items).error().send(inter)
|
||||
except IntegrityError as error:
|
||||
rssfeed = await self.bot.functions.create_new_rssfeed(name, url, inter.guild_id)
|
||||
except Exception as exc:
|
||||
await (
|
||||
Followup(
|
||||
"Duplicate Feed Error",
|
||||
"A Feed with the same nickname already exist."
|
||||
)
|
||||
.fields(nickname=nickname)
|
||||
Followup(exc.__class__.__name__, str(exc))
|
||||
.error()
|
||||
.send(inter)
|
||||
)
|
||||
else:
|
||||
await (
|
||||
Followup("Feed Added")
|
||||
.image(source.icon_url)
|
||||
.fields(nickname=nickname, url=url)
|
||||
Followup("New RSS Feed")
|
||||
.image(rssfeed.image)
|
||||
.fields(uuid=rssfeed.uuid, name=name, url=url)
|
||||
.added()
|
||||
.send(inter)
|
||||
)
|
||||
|
||||
@feed_group.command(name="remove")
|
||||
@rename(url="option")
|
||||
@autocomplete(url=source_autocomplete)
|
||||
async def remove_rss_source(self, inter: Interaction, url: str):
|
||||
"""Delete an existing Feed from this server.
|
||||
@feed_group.command(name="delete")
|
||||
@autocomplete(uuid=autocomplete_rssfeed)
|
||||
@rename(uuid="rssfeed")
|
||||
async def delete_rssfeed(self, inter: Interaction, uuid: str):
|
||||
"""Delete an existing RSS Feed for this server.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
url : str
|
||||
The Feed to be removed. Autocomplete or enter the URL.
|
||||
Args:
|
||||
inter (Interaction): Represents the discord command interaction.
|
||||
uuid (str): The UUID of the
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
try:
|
||||
source = await self.bot.functions.delete_feed(url, inter.guild_id)
|
||||
rssfeed = await self.bot.functions.delete_rssfeed(uuid)
|
||||
except NoResultFound:
|
||||
await (
|
||||
Followup(
|
||||
@ -229,47 +233,43 @@ class FeedCog(commands.Cog):
|
||||
else:
|
||||
await (
|
||||
Followup("Feed Deleted")
|
||||
.image(source.icon_url)
|
||||
.fields(url=url)
|
||||
.image(rssfeed.image)
|
||||
.fields(uuid=rssfeed.uuid, name=rssfeed.name, url=rssfeed.url)
|
||||
.trash()
|
||||
.send(inter)
|
||||
)
|
||||
|
||||
@feed_group.command(name="list")
|
||||
async def list_rss_sources(self, inter: Interaction):
|
||||
"""Provides a with a list of Feeds available for this server.
|
||||
async def list_rssfeeds(self, inter: Interaction):
|
||||
"""Provides a list of all RSS Feeds
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
Args:
|
||||
inter (Interaction): Represents the discord command interaction.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
page = 1
|
||||
pagesize = 10
|
||||
|
||||
try:
|
||||
feeds = await self.bot.functions.get_feeds(inter.guild_id)
|
||||
except NoResultFound:
|
||||
def formatdata(index, item):
|
||||
key = f"{index}. {item.name}"
|
||||
value = f"[RSS]({item.url}) · [API](http://localhost:8000/api/rssfeed/{item.uuid}/)"
|
||||
return key, value
|
||||
|
||||
async def getdata(page):
|
||||
data, count = await self.bot.functions.get_rssfeeds(inter.guild_id, page, pagesize)
|
||||
return data, count
|
||||
|
||||
embed = Followup(f"Available RSS Feeds in {inter.guild.name}").info()._embed
|
||||
pagination = PaginationView(self.bot, inter, embed, getdata, formatdata, pagesize, 1)
|
||||
await pagination.send()
|
||||
|
||||
except Exception as exc:
|
||||
await (
|
||||
Followup(
|
||||
"Feeds Not Found Error",
|
||||
"There are no available Feeds for this server.\n"
|
||||
"Add a new feed with `/feed add`."
|
||||
)
|
||||
Followup(exc.__class__.__name__, str(exc))
|
||||
.error()
|
||||
.send()
|
||||
)
|
||||
else:
|
||||
description = "\n".join([
|
||||
f"{i}. **[{info[0]}]({info[1]})**" # info = (nick, url)
|
||||
for i, info in enumerate(feeds)
|
||||
])
|
||||
await (
|
||||
Followup(
|
||||
f"Available Feeds in {inter.guild.name}",
|
||||
description
|
||||
)
|
||||
.info()
|
||||
.send(inter)
|
||||
)
|
||||
|
||||
|
@ -10,12 +10,13 @@ from time import process_time
|
||||
|
||||
import aiohttp
|
||||
from discord import TextChannel
|
||||
from discord import app_commands
|
||||
from discord.ext import commands, tasks
|
||||
from discord.errors import Forbidden
|
||||
from sqlalchemy import insert, select, and_
|
||||
from feedparser import parse
|
||||
|
||||
from feed import Source, Article
|
||||
from feed import Source, Article, RSSFeed
|
||||
from db import (
|
||||
DatabaseManager,
|
||||
FeedChannelModel,
|
||||
@ -23,11 +24,13 @@ from db import (
|
||||
SentArticleModel
|
||||
)
|
||||
from utils import get_unparsed_feed
|
||||
from api import API
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES")
|
||||
|
||||
# task trigger times : must be of type list
|
||||
times = [
|
||||
datetime.time(hour, minute, tzinfo=datetime.timezone.utc)
|
||||
for hour in range(24)
|
||||
@ -62,6 +65,12 @@ class TaskCog(commands.Cog):
|
||||
|
||||
self.rss_task.cancel()
|
||||
|
||||
@app_commands.command(name="debug-trigger-task")
|
||||
async def debug_trigger_task(self, inter):
|
||||
await inter.response.defer()
|
||||
await self.rss_task()
|
||||
await inter.followup.send("done")
|
||||
|
||||
@tasks.loop(time=times)
|
||||
async def rss_task(self):
|
||||
"""Automated task responsible for processing rss feeds."""
|
||||
@ -69,13 +78,24 @@ class TaskCog(commands.Cog):
|
||||
log.info("Running rss task")
|
||||
time = process_time()
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel)
|
||||
result = await database.session.execute(query)
|
||||
feeds = result.scalars().all()
|
||||
# async with DatabaseManager() as database:
|
||||
# query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel)
|
||||
# result = await database.session.execute(query)
|
||||
# feeds = result.scalars().all()
|
||||
|
||||
# for feed in feeds:
|
||||
# await self.process_feed(feed, database)
|
||||
|
||||
guild_ids = [guild.id for guild in self.bot.guilds]
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
api = API(self.bot.api_token, session)
|
||||
data, count = await api.get_rssfeed_list(discord_server_id__in=guild_ids)
|
||||
rssfeeds = RSSFeed.from_list(data)
|
||||
for item in rssfeeds:
|
||||
log.info(item.name)
|
||||
|
||||
|
||||
for feed in feeds:
|
||||
await self.process_feed(feed, database)
|
||||
|
||||
log.info("Finished rss task, time elapsed: %s", process_time() - time)
|
||||
|
||||
|
167
src/feed.py
167
src/feed.py
@ -1,4 +1,5 @@
|
||||
|
||||
import ssl
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
@ -18,6 +19,7 @@ from textwrap import shorten
|
||||
from errors import IllegalFeed
|
||||
from db import DatabaseManager, RssSourceModel, FeedChannelModel
|
||||
from utils import get_rss_data, get_unparsed_feed
|
||||
from api import API
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
dumps = lambda _dict: json.dumps(_dict, indent=8)
|
||||
@ -140,7 +142,7 @@ class Article:
|
||||
@dataclass
|
||||
class Source:
|
||||
"""Represents an RSS source."""
|
||||
|
||||
|
||||
name: str | None
|
||||
url: str | None
|
||||
icon_url: str | None
|
||||
@ -173,7 +175,9 @@ class Source:
|
||||
@classmethod
|
||||
async def from_url(cls, url: str):
|
||||
unparsed_content = await get_unparsed_feed(url)
|
||||
return cls.from_parsed(parse(unparsed_content))
|
||||
source = cls.from_parsed(parse(unparsed_content))
|
||||
source.url = url
|
||||
return source
|
||||
|
||||
def get_latest_articles(self, max: int = 999) -> list[Article]:
|
||||
"""Returns a list of Article objects.
|
||||
@ -198,10 +202,37 @@ class Source:
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RSSFeed:
|
||||
|
||||
uuid: str
|
||||
name: str
|
||||
url: str
|
||||
image: str
|
||||
discord_server_id: id
|
||||
created_at: str
|
||||
|
||||
@classmethod
|
||||
def from_list(cls, data: list) -> list:
|
||||
result = []
|
||||
|
||||
for item in data:
|
||||
key = "discord_server_id"
|
||||
item[key] = int(item.get(key))
|
||||
result.append(cls(**item))
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class Functions:
|
||||
|
||||
def __init__(self, bot):
|
||||
def __init__(self, bot, api_token: str):
|
||||
self.bot = bot
|
||||
self.api_token = api_token
|
||||
|
||||
async def validate_feed(self, nickname: str, url: str) -> FeedParserDict:
|
||||
"""Validates a feed based on the given nickname and url.
|
||||
@ -254,106 +285,50 @@ class Functions:
|
||||
|
||||
return feed
|
||||
|
||||
async def create_new_feed(self, nickname: str, url: str, guild_id: int) -> Source:
|
||||
"""Create a new Feed, and return it as a Source object.
|
||||
async def create_new_rssfeed(self, name: str, url: str, guild_id: int) -> RSSFeed:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nickname : str
|
||||
Human readable nickname used to refer to the feed.
|
||||
url : str
|
||||
URL to fetch content from the feed.
|
||||
guild_id : int
|
||||
Discord Server ID associated with the feed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Source
|
||||
Dataclass containing attributes of the feed.
|
||||
log.info("Creating new Feed: %s", name)
|
||||
|
||||
parsed_feed = await self.validate_feed(name, url)
|
||||
source = Source.from_parsed(parsed_feed)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
data = await API(self.api_token, session).create_new_rssfeed(
|
||||
name, url, source.icon_url, guild_id
|
||||
)
|
||||
|
||||
return RSSFeed.from_dict(data)
|
||||
|
||||
async def delete_rssfeed(self, uuid: str) -> RSSFeed:
|
||||
|
||||
log.info("Deleting Feed '%s'", uuid)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
api = API(self.api_token, session)
|
||||
data = await api.get_rssfeed(uuid)
|
||||
await api.delete_rssfeed(uuid)
|
||||
|
||||
return RSSFeed.from_dict(data)
|
||||
|
||||
async def get_rssfeeds(self, guild_id: int, page: int, pagesize: int) -> list[RSSFeed]:
|
||||
"""Get a list of RSS Feeds.
|
||||
|
||||
Args:
|
||||
guild_id (int): The guild_id to filter by.
|
||||
|
||||
Returns:
|
||||
list[RSSFeed]: Resulting list of RSS Feeds
|
||||
"""
|
||||
|
||||
log.info("Creating new Feed: %s - %s", nickname, guild_id)
|
||||
|
||||
parsed_feed = await self.validate_feed(nickname, url)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = insert(RssSourceModel).values(
|
||||
async with aiohttp.ClientSession() as session:
|
||||
data, count = await API(self.api_token, session).get_rssfeed_list(
|
||||
discord_server_id=guild_id,
|
||||
rss_url=url,
|
||||
nick=nickname
|
||||
)
|
||||
await database.session.execute(query)
|
||||
|
||||
log.info("Created Feed: %s - %s", nickname, guild_id)
|
||||
|
||||
return Source.from_parsed(parsed_feed)
|
||||
|
||||
async def delete_feed(self, url: str, guild_id: int) -> Source:
|
||||
"""Delete an existing Feed, then return it as a Source object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
URL of the feed, used in the whereclause.
|
||||
guild_id : int
|
||||
Discord Server ID of the feed, used in the whereclause.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Source
|
||||
Dataclass containing attributes of the feed.
|
||||
"""
|
||||
|
||||
log.info("Deleting Feed: %s - %s", url, guild_id)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
RssSourceModel.discord_server_id == guild_id,
|
||||
RssSourceModel.rss_url == url
|
||||
page=page,
|
||||
page_size=pagesize
|
||||
)
|
||||
|
||||
# Select the Feed entry, because an exception is raised if not found.
|
||||
select_query = select(RssSourceModel).filter(whereclause)
|
||||
select_result = await database.session.execute(select_query)
|
||||
select_result.scalars().one()
|
||||
|
||||
delete_query = delete(RssSourceModel).filter(whereclause)
|
||||
await database.session.execute(delete_query)
|
||||
|
||||
log.info("Deleted Feed: %s - %s", url, guild_id)
|
||||
|
||||
return await Source.from_url(url)
|
||||
|
||||
async def get_feeds(self, guild_id: int) -> list[tuple[str, str]]:
|
||||
"""Returns a list of fetched Feed objects from the database.
|
||||
Note: a request will be made too all found Feed UR Ls.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
guild_id : int
|
||||
The Discord Server ID, used to filter down the Feed query.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[tuple[str, str]]
|
||||
List of Source objects, resulting from the query.
|
||||
|
||||
Raises
|
||||
------
|
||||
NoResultFound
|
||||
Raised if no results are found.
|
||||
"""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(RssSourceModel.discord_server_id == guild_id)
|
||||
query = select(RssSourceModel).where(whereclause)
|
||||
result = await database.session.execute(query)
|
||||
rss_sources = result.scalars().all()
|
||||
|
||||
if not rss_sources:
|
||||
raise NoResultFound
|
||||
|
||||
return [(feed.nick, feed.rss_url) for feed in rss_sources]
|
||||
return RSSFeed.from_list(data), count
|
||||
|
||||
async def assign_feed(
|
||||
self, url: str, channel_name: str, channel_id: int, guild_id: int
|
||||
|
22
src/main.py
22
src/main.py
@ -9,9 +9,9 @@ from os import getenv
|
||||
from pathlib import Path
|
||||
|
||||
# it's important to load environment variables before
|
||||
# importing the packages that depend on them.
|
||||
# importing the modules that depend on them.
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
load_dotenv(override=True)
|
||||
|
||||
from bot import DiscordBot
|
||||
from logs import LogSetup
|
||||
@ -26,12 +26,17 @@ async def main():
|
||||
|
||||
# Grab the token before anything else, because if there is no token
|
||||
# available then the bot cannot be started anyways.
|
||||
token = getenv("BOT_TOKEN")
|
||||
bot_token = getenv("BOT_TOKEN")
|
||||
if not bot_token:
|
||||
raise ValueError("Bot Token is empty")
|
||||
|
||||
if not token:
|
||||
raise ValueError("Token is empty")
|
||||
# ^ same story for the API token. Without it the API cannot be
|
||||
# interacted with, so grab it first.
|
||||
api_token = getenv("API_TOKEN")
|
||||
if not api_token:
|
||||
raise ValueError("API Token is empty")
|
||||
|
||||
developing = getenv("DEVELOPING") == "True"
|
||||
developing = getenv("DEVELOPING", "False") == "True"
|
||||
|
||||
# Setup logging settings and mute spammy loggers
|
||||
logsetup = LogSetup(BASE_DIR / "logs/")
|
||||
@ -41,9 +46,10 @@ async def main():
|
||||
level=logging.WARNING
|
||||
)
|
||||
|
||||
async with DiscordBot(BASE_DIR, developing=developing) as bot:
|
||||
|
||||
async with DiscordBot(BASE_DIR, developing=developing, api_token=api_token) as bot:
|
||||
await bot.load_extensions()
|
||||
await bot.start(token, reconnect=True)
|
||||
await bot.start(bot_token, reconnect=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
163
src/utils.py
163
src/utils.py
@ -3,8 +3,11 @@
|
||||
import aiohttp
|
||||
import logging
|
||||
import async_timeout
|
||||
from typing import Callable
|
||||
|
||||
from discord import Interaction, Embed, Colour
|
||||
from discord import Interaction, Embed, Colour, ButtonStyle, Button
|
||||
from discord.ui import View, button
|
||||
from discord.ext.commands import Bot
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -56,6 +59,153 @@ class FollowupIcons:
|
||||
assigned = "https://img.icons8.com/fluency-systems-filled/48/4598DA/hashtag-large.png"
|
||||
|
||||
|
||||
class PaginationView(View):
|
||||
"""A Discord UI View that adds pagination to an embed."""
|
||||
|
||||
def __init__(
|
||||
self, bot: Bot, inter: Interaction, embed: Embed, getdata: Callable,
|
||||
formatdata: Callable, pagesize: int, initpage: int=1
|
||||
):
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
bot (commands.Bot) The discord bot
|
||||
inter (Interaction): Represents a discord command interaction.
|
||||
embed (Embed): The base embed to paginate.
|
||||
getdata (Callable): A function that provides data, must return Tuple[List[Any], int].
|
||||
formatdata (Callable): A formatter function that determines how the data is displayed.
|
||||
pagesize (int): The size of each page.
|
||||
initpage (int, optional): The inital page. Defaults to 1.
|
||||
"""
|
||||
|
||||
self.bot = bot
|
||||
self.inter = inter
|
||||
self.embed = embed
|
||||
self.getdata = getdata
|
||||
self.formatdata = formatdata
|
||||
self.maxpage = None
|
||||
self.pagesize = pagesize
|
||||
self.index = initpage
|
||||
|
||||
# emoji reference
|
||||
next_emoji = bot.get_emoji(1204542366602502265)
|
||||
prev_emoji = bot.get_emoji(1204542365432422470)
|
||||
self.start_emoji = bot.get_emoji(1204542364073463818)
|
||||
self.end_emoji = bot.get_emoji(1204542367752003624)
|
||||
|
||||
super().__init__(timeout=100)
|
||||
|
||||
async def check_user_is_author(self, inter: Interaction) -> bool:
|
||||
"""Ensure the user is the author of the original command."""
|
||||
|
||||
if inter.user == self.inter.user:
|
||||
return True
|
||||
|
||||
await inter.response.defer()
|
||||
await (
|
||||
Followup(None, "Only the author can interact with this.")
|
||||
.error()
|
||||
.send(inter, ephemeral=True)
|
||||
)
|
||||
return False
|
||||
|
||||
async def on_timeout(self):
|
||||
"""Erase the controls on timeout."""
|
||||
|
||||
message = await self.inter.original_response()
|
||||
await message.edit(view=None)
|
||||
|
||||
@staticmethod
|
||||
def calc_total_pages(results: int, max_pagesize: int) -> int:
|
||||
result = ((results - 1) // max_pagesize) + 1
|
||||
log.debug("total pages calculated: %s", result)
|
||||
return result
|
||||
|
||||
def calc_dataitem_index(self, dataitem_index: int):
|
||||
"""Calculates a given index to be relative to the sum of all pages items
|
||||
|
||||
Example: dataitem_index = 6
|
||||
pagesize = 10
|
||||
if page == 1 then return 6
|
||||
else return 6 + 10 * (page - 1)"""
|
||||
|
||||
if self.index > 1:
|
||||
dataitem_index += self.pagesize * (self.index - 1)
|
||||
|
||||
dataitem_index += 1
|
||||
return dataitem_index
|
||||
|
||||
@button(emoji="◀️", style=ButtonStyle.blurple)
|
||||
async def backward(self, inter: Interaction, button: Button):
|
||||
self.index -= 1
|
||||
await inter.response.defer()
|
||||
self.inter = inter
|
||||
await self.navigate()
|
||||
|
||||
@button(emoji="▶️", style=ButtonStyle.blurple)
|
||||
async def forward(self, inter: Interaction, button: Button):
|
||||
self.index += 1
|
||||
await inter.response.defer()
|
||||
self.inter = inter
|
||||
await self.navigate()
|
||||
|
||||
@button(emoji="⏭️", style=ButtonStyle.blurple)
|
||||
async def start_or_end(self, inter: Interaction, button: Button):
|
||||
if self.index <= self.maxpage // 2:
|
||||
self.index = self.maxpage
|
||||
else:
|
||||
self.index = 1
|
||||
|
||||
await inter.response.defer()
|
||||
self.inter = inter
|
||||
await self.navigate()
|
||||
|
||||
async def navigate(self):
|
||||
log.debug("navigating to page: %s", self.index)
|
||||
|
||||
self.update_buttons()
|
||||
paged_embed = await self.create_paged_embed()
|
||||
await self.inter.edit_original_response(embed=paged_embed, view=self)
|
||||
|
||||
async def create_paged_embed(self) -> Embed:
|
||||
embed = self.embed.copy()
|
||||
data, total_results = await self.getdata(self.index)
|
||||
self.maxpage = self.calc_total_pages(total_results, self.pagesize)
|
||||
|
||||
for i, item in enumerate(data):
|
||||
i = self.calc_dataitem_index(i)
|
||||
key, value = self.formatdata(i, item)
|
||||
embed.add_field(name=key, value=value, inline=False)
|
||||
|
||||
if not total_results:
|
||||
embed.description = "There are no results"
|
||||
|
||||
if self.maxpage > 1:
|
||||
embed.set_footer(text=f"Page {self.index}/{self.maxpage}")
|
||||
|
||||
return embed
|
||||
|
||||
def update_buttons(self):
|
||||
if self.index >= self.maxpage:
|
||||
self.children[2].emoji = self.start_emoji
|
||||
else:
|
||||
self.children[2].emoji = self.end_emoji
|
||||
|
||||
self.children[0].disabled = self.index == 1
|
||||
self.children[1].disabled = self.index == self.maxpage
|
||||
|
||||
async def send(self):
|
||||
embed = await self.create_paged_embed()
|
||||
|
||||
if self.maxpage <= 1:
|
||||
await self.inter.edit_original_response(embed=embed)
|
||||
return
|
||||
|
||||
self.update_buttons()
|
||||
await self.inter.edit_original_response(embed=embed, view=self)
|
||||
|
||||
|
||||
|
||||
class Followup:
|
||||
"""Wrapper for a discord embed to follow up an interaction."""
|
||||
|
||||
@ -69,10 +219,10 @@ class Followup:
|
||||
description=description
|
||||
)
|
||||
|
||||
async def send(self, inter: Interaction, message: str = None):
|
||||
async def send(self, inter: Interaction, message: str = None, ephemeral: bool = False):
|
||||
""""""
|
||||
|
||||
await inter.followup.send(content=message, embed=self._embed)
|
||||
await inter.followup.send(content=message, embed=self._embed, ephemeral=ephemeral)
|
||||
|
||||
def fields(self, inline: bool = False, **fields: dict):
|
||||
""""""
|
||||
@ -89,6 +239,13 @@ class Followup:
|
||||
|
||||
return self
|
||||
|
||||
def footer(self, text: str, icon_url: str = None):
|
||||
""""""
|
||||
|
||||
self._embed.set_footer(text=text, icon_url=icon_url)
|
||||
|
||||
return self
|
||||
|
||||
def error(self):
|
||||
""""""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user