This commit is contained in:
Corban-Lee Jones 2023-12-18 10:21:15 +00:00
commit 76d27c4782
11 changed files with 674 additions and 339 deletions

17
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,17 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: NewsBot",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/src/main.py",
"python": "${workspaceFolder}/venv/bin/python",
"console": "integratedTerminal",
"justMyCode": true
}
]
}

View File

@ -46,9 +46,16 @@ class DiscordBot(commands.Bot):
if path.suffix == ".py":
await self.load_extension(f"extensions.{path.stem}")
async def audit(self, message: str, user_id: int):
async def audit(self, message: str, user_id: int, database: DatabaseManager=None):
message = f"Requesting latest article"
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
if database:
await database.session.execute(query)
return
async with DatabaseManager() as database:
message = f"Requesting latest article"
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
await database.session.execute(query)
await database.session.execute(query)
log.debug("Audit logged")

View File

@ -24,7 +24,7 @@ class DatabaseManager:
"""
def __init__(self, no_commit: bool = False):
database_url = self.get_database_url()
database_url = self.get_database_url() # TODO: This is called every time a connection is established, maybe make it once and reference it?
self.engine = create_async_engine(database_url, future=True)
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
self.session = None

View File

@ -5,7 +5,7 @@ All table classes should be suffixed with `Model`.
from enum import Enum, auto
from sqlalchemy import Column, Integer, String, DateTime, BigInteger
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint, ForeignKey
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
@ -24,7 +24,6 @@ class AuditModel(Base):
discord_user_id = Column(BigInteger, nullable=False)
message = Column(String, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
active = Column(Integer, default=True, nullable=False)
class SentArticleModel(Base):
@ -39,7 +38,7 @@ class SentArticleModel(Base):
discord_channel_id = Column(BigInteger, nullable=False)
discord_server_id = Column(BigInteger, nullable=False)
article_url = Column(String, nullable=False)
active = Column(Integer, default=True, nullable=False)
when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
class RssSourceModel(Base):
@ -50,9 +49,17 @@ class RssSourceModel(Base):
__tablename__ = "rss_source"
id = Column(Integer, primary_key=True, autoincrement=True)
nick = Column(String, nullable=False)
discord_server_id = Column(BigInteger, nullable=False)
rss_url = Column(String, nullable=False)
active = Column(Integer, default=True, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
feed_channels = relationship("FeedChannelModel", cascade="all, delete")
# the nickname must be unique, but only within the same discord server
__table_args__ = (
UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'),
)
class FeedChannelModel(Base):
@ -64,4 +71,6 @@ class FeedChannelModel(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
discord_channel_id = Column(BigInteger, nullable=False)
active = Column(Integer, default=True, nullable=False)
discord_server_id = Column(BigInteger, nullable=False)
search_name = Column(String, nullable=False)
rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False)

106
src/extensions/channels.py Normal file
View File

@ -0,0 +1,106 @@
"""
Extension for the `ChannelCog`.
Loading this file via `commands.Bot.load_extension` will add `ChannelCog` to the bot.
"""
import logging
from sqlalchemy import select, and_
from discord import Interaction, TextChannel
from discord.ext import commands
from discord.app_commands import Group, Choice, autocomplete
from db import DatabaseManager, FeedChannelModel
from utils import followup
log = logging.getLogger(__name__)
class ChannelCog(commands.Cog):
"""
Command cog.
"""
def __init__(self, bot):
super().__init__()
self.bot = bot
@commands.Cog.listener()
async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready")
async def autocomplete_existing_feeds(self, inter: Interaction, current: str):
"""Returns a list of existing RSS + Channel feeds.
Parameters
----------
inter : Interaction
Represents an app command interaction.
current : str
The current text entered for the autocomplete.
"""
async with DatabaseManager() as database:
whereclause = and_(
FeedChannelModel.discord_server_id == inter.guild_id,
FeedChannelModel.search_name.ilike(f"%{current}%") # is this secure from SQL Injection atk ?
)
query = select(FeedChannelModel).where(whereclause)
result = await database.session.execute(query)
feeds = [
Choice(name=feed.search_name, value=feed.id)
for feed in result.scalars().all()
]
return feeds
# All RSS commands belong to this group.
channel_group = Group(
name="channel",
description="Commands for channel assignment.",
guild_only=True # We store guild IDs in the database, so guild only = True
)
channel_group.command(name="include-feed")
async def include_feed(self, inter: Interaction, rss: str, channel: TextChannel):
"""Include a feed within the specified channel.
Parameters
----------
inter : Interaction
Represents an app command interaction.
rss : str
The RSS feed to include.
channel : TextChannel
The channel to include the feed in.
"""
await inter.response.defer()
await followup(inter, "Ping")
channel_group.command(name="exclude-feed")
@autocomplete(option=autocomplete_existing_feeds)
async def exclude_feed(self, inter: Interaction, option: int):
"""Undo command for the `/channel include-feed` command.
Parameters
----------
inter : Interaction
Represents an app command interaction.
option : str
The RSS feed and channel to exclude.
"""
await inter.response.defer()
await followup(inter, "Pong")
async def setup(bot):
"""
Setup function for this extension.
Adds `ChannelCog` to the bot.
"""
cog = ChannelCog(bot)
await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog")

View File

@ -1,184 +0,0 @@
"""
Extension for the `CommandCog`.
Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot.
"""
import logging
import validators
import aiohttp
import textwrap
import feedparser
from markdownify import markdownify
from discord import app_commands, Interaction, Embed
from discord.ext import commands, tasks
from sqlalchemy import insert, select, update, and_, or_
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
from feed import Feeds, get_source
log = logging.getLogger(__name__)
async def get_rss_data(url: str):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
items = await response.text(), response.status
return items
class CommandCog(commands.Cog):
"""
Command cog.
"""
def __init__(self, bot):
super().__init__()
self.bot = bot
@commands.Cog.listener()
async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready")
rss_group = app_commands.Group(
name="rss",
description="Commands for rss sources.",
guild_only=True
)
@rss_group.command(name="add")
async def add_rss_source(self, inter: Interaction, url: str):
await inter.response.defer()
# validate the input
if not validators.url(url):
await inter.followup.send(
"The URL you have entered is malformed or invalid:\n"
f"`{url=}`",
suppress_embeds=True
)
return
feed_data, status_code = await get_rss_data(url)
if status_code != 200:
await inter.followup.send(
f"The URL provided returned an invalid status code:\n"
f"{url=}, {status_code=}",
suppress_embeds=True
)
return
feed = feedparser.parse(feed_data)
if not feed.version:
await inter.followup.send(
f"The provided URL '{url}' does not seem to be a valid RSS feed.",
suppress_embeds=True
)
return
async with DatabaseManager() as database:
query = insert(RssSourceModel).values(
discord_server_id = inter.guild_id,
rss_url = url
)
await database.session.execute(query)
await inter.followup.send("RSS source added")
@rss_group.command(name="remove")
async def remove_rss_source(self, inter: Interaction, number: int | None=None, url: str | None = None):
await inter.response.defer()
def exists(item) -> bool:
"""
Shorthand for `is not None`. Cant just use `if not number` because 0 int will pass.
Ironically with this func & comment the code is longer, but at least I can read it ...
"""
return item is not None
url_exists = exists(url)
num_exists = exists(number)
if (url_exists and num_exists) or (not url_exists and not num_exists):
await inter.followup.send(
"Please only specify either the existing rss number or url, "
"enter at least one of these, but don't enter both."
)
return
if url_exists and not validators.url(url):
await inter.followup.send(
"The URL you have entered is malformed or invalid:\n"
f"`{url=}`",
suppress_embeds=True
)
return
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == url
)
query = update(RssSourceModel).where(whereclause).values(active=False)
result = await database.session.execute(query)
await inter.followup.send(f"I've updated {result.rowcount} rows")
@rss_group.command(name="list")
@app_commands.choices(filter=[
app_commands.Choice(name="Active Only [default]", value=1),
app_commands.Choice(name="Inactive Only", value=0),
app_commands.Choice(name="All", value=2),
])
async def list_rss_sources(self, inter: Interaction, filter: app_commands.Choice[int]):
await inter.response.defer()
if filter.value == 2:
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
else:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.active == filter.value # should result to 0 or 1
)
async with DatabaseManager() as database:
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
rss_sources = result.scalars().all()
embed_fields = [{
"name": f"[{i}]",
"value": f"{rss.rss_url} | {'inactive' if not rss.active else 'active'}"
} for i, rss in enumerate(rss_sources)]
if not embed_fields:
await inter.followup.send("It looks like you have no rss sources.")
return
embed = Embed(
title="RSS Sources",
description="Here are your rss sources:"
)
for field in embed_fields:
embed.add_field(**field, inline=False)
# output = "Your rss sources:\n\n"
# output += "\n".join([f"[{i+1}] {rss.rss_url=} {bool(rss.active)=}" for i, rss in enumerate(rss_sources)])
await inter.followup.send(embed=embed)
async def setup(bot):
"""
Setup function for this extension.
Adds `CommandCog` to the bot.
"""
cog = CommandCog(bot)
await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog")

349
src/extensions/rss.py Normal file
View File

@ -0,0 +1,349 @@
"""
Extension for the `RssCog`.
Loading this file via `commands.Bot.load_extension` will add `RssCog` to the bot.
"""
import logging
import validators
from typing import Tuple
import textwrap
import feedparser
from markdownify import markdownify
from discord import Interaction, Embed, Colour
from discord.ext import commands
from discord.app_commands import Choice, Group, autocomplete, choices
from sqlalchemy import insert, select, and_, delete
from utils import get_rss_data, followup, audit
from feed import get_source, Source
from db import DatabaseManager, SentArticleModel, RssSourceModel
log = logging.getLogger(__name__)
rss_list_sort_choices = [
Choice(name="Nickname", value=0),
Choice(name="Date Added", value=1)
]
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource.
# Run a period task to check this.
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]:
"""Validate a provided RSS source.
Parameters
----------
nickname : str
Nickname of the source. Must not contain URL.
url : str
URL of the source. Must be URL with valid status code and be an RSS feed.
Returns
-------
str or None
String invalid message if invalid, NoneType if valid.
FeedParserDict or None
The feed parsed from the given URL or None if invalid.
"""
# Ensure the URL is valid
if not validators.url(url):
return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
# Check the nickname is not a URL
if validators.url(nickname):
return "It looks like the nickname you have entered is a URL.\n" \
f"For security reasons, this is not allowed.\n`{nickname=}`", None
feed_data, status_code = await get_rss_data(url)
# Check the URL status code is valid
if status_code != 200:
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
# Check the contents is actually an RSS feed.
feed = feedparser.parse(feed_data)
if not feed.version:
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
return None, feed
class RssCog(commands.Cog):
"""
Command cog.
"""
def __init__(self, bot):
super().__init__()
self.bot = bot
@commands.Cog.listener()
async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready")
async def source_autocomplete(self, inter: Interaction, nickname: str):
"""Provides RSS source autocomplete functionality for commands.
Parameters
----------
inter : Interaction
Represents an app command interaction.
nickname : str
_description_
Returns
-------
list of app_commands.Choice
_description_
"""
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.nick.ilike(f"%{nickname}%")
)
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
sources = [
Choice(name=rss.nick, value=rss.rss_url)
for rss in result.scalars().all()
]
return sources
# All RSS commands belong to this group.
rss_group = Group(
name="rss",
description="Commands for rss sources.",
guild_only=True # We store guild IDs in the database, so guild only = True
)
@rss_group.command(name="add")
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
"""Add a new RSS source.
Parameters
----------
inter : Interaction
Represents an app command interaction.
nickname : str
A name used to identify the RSS source.
url : str
The RSS feed URL.
"""
await inter.response.defer()
illegal_message, feed = await validate_rss_source(nickname, url)
if illegal_message:
await followup(inter, illegal_message, suppress_embeds=True)
return
log.debug("RSS feed added")
async with DatabaseManager() as database:
query = insert(RssSourceModel).values(
discord_server_id = inter.guild_id,
rss_url = url,
nick=nickname
)
await database.session.execute(query)
await audit(self,
f"Added RSS source ({nickname=}, {url=})",
inter.user.id, database=database
)
embed = Embed(title="RSS Feed Added", colour=Colour.dark_green())
embed.add_field(name="Nickname", value=nickname)
embed.add_field(name="URL", value=url)
embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href"))
await followup(inter, embed=embed)
@rss_group.command(name="remove")
@autocomplete(url=source_autocomplete)
async def remove_rss_source(self, inter: Interaction, url: str):
"""Delete an existing RSS source.
Parameters
----------
inter : Interaction
Represents an app command interaction.
url : str
The RSS source to be removed. Autocomplete or enter the URL.
"""
await inter.response.defer()
log.debug(f"Attempting to remove RSS source ({url=})")
async with DatabaseManager() as database:
select_result = await database.session.execute(
select(RssSourceModel).filter(
and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == url
)
)
)
rss_source = select_result.scalars().one()
nickname = rss_source.nick
delete_result = await database.session.execute(
delete(RssSourceModel).filter(
and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == url
)
)
)
await audit(self,
f"Added RSS source ({nickname=}, {url=})",
inter.user.id, database=database
)
if not delete_result.rowcount:
await followup(inter, "Couldn't find any RSS sources with this name.")
return
source = get_source(url)
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
embed.add_field(name="Nickname", value=nickname)
embed.add_field(name="URL", value=url)
embed.set_thumbnail(url=source.icon_url)
await followup(inter, embed=embed)
@rss_group.command(name="list")
@choices(sort=rss_list_sort_choices)
async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False):
"""Provides a with a list of RSS sources available for the current server.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.response.defer()
# Default to the first choice if not specified.
if type(sort) is Choice:
description = "Sort by "
description += "Nickname " if sort.value == 0 else "Date Added "
description += '\U000025BC' if sort_reverse else '\U000025B2'
else:
sort = rss_list_sort_choices[0]
description = ""
sort = sort if type(sort) == Choice else rss_list_sort_choices[0]
match sort.value, sort_reverse:
case 0, False:
order_by = RssSourceModel.nick.asc()
case 0, True:
order_by = RssSourceModel.nick.desc()
case 1, False: # NOTE:
order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest
case 1, True: # date first, not the oldest as it would sort otherwise.
order_by = RssSourceModel.created.asc()
case _, _:
raise ValueError("Unknown sort: %s" % sort)
async with DatabaseManager() as database:
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
query = select(RssSourceModel).where(whereclause).order_by(order_by)
result = await database.session.execute(query)
rss_sources = result.scalars().all()
if not rss_sources:
await followup(inter, "It looks like you have no rss sources.")
return
output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)])
embed = Embed(
title="Saved RSS Feeds",
description=f"{description}\n\n{output}",
colour=Colour.lighter_grey()
)
await followup(inter, embed=embed)
@rss_group.command(name="fetch")
@autocomplete(rss=source_autocomplete)
async def fetch_rss(self, inter: Interaction, rss: str, max: int=1):
# """"""
await inter.response.defer()
if max > 5:
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
return
invalid_message, feed = await validate_rss_source("", rss)
if invalid_message:
await followup(inter, invalid_message)
return
source = Source.from_parsed(feed)
articles = source.get_latest_articles(max)
if not articles:
await followup(inter, "Sorry, I couldn't find any articles from this feed.")
return
embeds = []
for article in articles:
md_description = markdownify(article.description, strip=("img",))
article_description = textwrap.shorten(md_description, 4096)
embed = Embed(
title=article.title,
description=article_description,
url=article.url,
timestamp=article.published,
colour=Colour.brand_red()
)
thumbail_url = await article.get_thumbnail_url()
thumbail_url = thumbail_url if validators.url(thumbail_url) else None
embed.set_thumbnail(url=source.icon_url)
embed.set_image(url=thumbail_url)
embed.set_footer(text=article.author)
embed.set_author(
name=source.name,
url=source.url,
)
embeds.append(embed)
async with DatabaseManager() as database:
query = insert(SentArticleModel).values([
{
"discord_server_id": inter.guild_id,
"discord_channel_id": inter.channel_id,
"discord_message_id": inter.id,
"article_url": article.url,
}
for article in articles
])
await database.session.execute(query)
await audit(self, f"User is requesting {max} articles from {source.name}", inter.user.id, database=database)
await followup(inter, embeds=embeds)
async def setup(bot):
"""
Setup function for this extension.
Adds `RssCog` to the bot.
"""
cog = RssCog(bot)
await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog")

35
src/extensions/tasks.py Normal file
View File

@ -0,0 +1,35 @@
"""
Extension for the `TaskCog`.
Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot.
"""
import logging
from discord.ext import commands
log = logging.getLogger(__name__)
class TaskCog(commands.Cog):
"""
Command cog.
"""
def __init__(self, bot):
super().__init__()
self.bot = bot
@commands.Cog.listener()
async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready")
async def setup(bot):
"""
Setup function for this extension.
Adds `TaskCog` to the bot.
"""
cog = TaskCog(bot)
await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog")

View File

@ -1,88 +0,0 @@
"""
Extension for the `test` cog.
Loading this file via `commands.Bot.load_extension` will add the `test` cog to the bot.
"""
import logging
import textwrap
from markdownify import markdownify
from discord import app_commands, Interaction, Embed
from discord.ext import commands, tasks
from sqlalchemy import insert, select
from db import DatabaseManager, AuditModel, SentArticleModel
from feed import Feeds, get_source
log = logging.getLogger(__name__)
class Test(commands.Cog):
"""
News cog.
Delivers embeds of news articles to discord channels.
"""
def __init__(self, bot):
super().__init__()
self.bot = bot
@commands.Cog.listener()
async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready")
@app_commands.command(name="test-latest-article")
# @app_commands.choices(source=[
# app_commands.Choice(name="The Babylon Bee", value=Feeds.THE_BABYLON_BEE),
# app_commands.Choice(name="The Upper Lip", value=Feeds.THE_UPPER_LIP),
# app_commands.Choice(name="BBC News", value=Feeds.BBC_NEWS),
# ])
async def test_bee(self, inter: Interaction, source: Feeds):
await inter.response.defer()
await self.bot.audit("Requesting latest article.", inter.user.id)
source = get_source(source)
article = source.get_latest_article()
md_description = markdownify(article.description, strip=("img",))
article_description = textwrap.shorten(md_description, 4096)
embed = Embed(
title=article.title,
description=article_description,
url=article.url,
timestamp=article.published,
)
embed.set_thumbnail(url=source.icon_url)
embed.set_image(url=await article.get_thumbnail_url())
embed.set_footer(text=article.author)
embed.set_author(
name=source.name,
url=source.url,
)
async with DatabaseManager() as database:
query = insert(SentArticleModel).values(
discord_server_id=inter.guild_id,
discord_channel_id=inter.channel_id,
discord_message_id=inter.id,
article_url=article.url
)
await database.session.execute(query)
await inter.followup.send(embed=embed)
async def setup(bot):
"""
Setup function for this extension.
Adds the `ErrorCog` cog to the bot.
"""
cog = Test(bot)
await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog")

View File

@ -1,7 +1,9 @@
"""
"""
import json
import logging
from enum import Enum
from dataclasses import dataclass
from datetime import datetime
@ -10,85 +12,136 @@ from bs4 import BeautifulSoup as bs4
from feedparser import FeedParserDict, parse
log = logging.getLogger(__name__)
class Feeds(Enum):
THE_UPPER_LIP = "https://theupperlip.co.uk/rss"
THE_BABYLON_BEE= "https://babylonbee.com/feed"
BBC_NEWS = "https://feeds.bbci.co.uk/news/rss.xml"
@dataclass
class Source:
name: str
url: str
icon_url: str
feed: FeedParserDict
@classmethod
def from_parsed(cls, feed:FeedParserDict):
# print(json.dumps(feed, indent=8))
return cls(
name=feed.channel.title,
url=feed.channel.link,
icon_url=feed.feed.image.href,
feed=feed
)
def get_latest_article(self):
return Article.from_parsed(self.feed)
dumps = lambda _dict: json.dumps(_dict, indent=8)
@dataclass
class Article:
title: str
description: str
url: str
published: datetime
"""Represents a news article, or entry from an RSS feed."""
title: str | None
description: str | None
url: str | None
published: datetime | None
author: str | None
@classmethod
def from_parsed(cls, feed:FeedParserDict):
entry = feed.entries[0]
# log.debug(json.dumps(entry, indent=8))
def from_entry(cls, entry:FeedParserDict):
"""Create an Article from an RSS feed entry.
Parameters
----------
entry : FeedParserDict
An entry pulled from a complete FeedParserDict object.
Returns
-------
Article
The Article created from the feed entry.
"""
log.debug("Creating Article from entry: %s", dumps(entry))
published_parsed = entry.get("published_parsed")
published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None
return cls(
title=entry.title,
description=entry.description,
url=entry.link,
published=datetime(*entry.published_parsed[0:-2]),
author = entry.get("author", None)
title=entry.get("title"),
description=entry.get("description"),
url=entry.get("link"),
published=published,
author = entry.get("author")
)
async def get_thumbnail_url(self):
async def get_thumbnail_url(self) -> str | None:
"""Returns the thumbnail URL for an article.
Returns
-------
str or None
The thumbnail URL, or None if not found.
"""
"""
log.debug("Fetching thumbnail for article: %s", self)
async with aiohttp.ClientSession() as session:
async with session.get(self.url) as response:
html = await response.text()
# Parse the thumbnail for the news story
soup = bs4(html, "html.parser")
image_element = soup.select_one("meta[property='og:image']")
return image_element.get("content") if image_element else None
def get_source(feed: Feeds) -> Source:
@dataclass
class Source:
"""Represents an RSS source."""
name: str | None
url: str | None
icon_url: str | None
feed: FeedParserDict
@classmethod
def from_parsed(cls, feed:FeedParserDict):
"""Returns a Source object from a parsed feed.
Parameters
----------
feed : FeedParserDict
The feed used to create the Source.
Returns
-------
Source
The Source object
"""
log.debug("Creating Source from feed: %s", dumps(feed))
return cls(
name=feed.get("channel", {}).get("title"),
url=feed.get("channel", {}).get("link"),
icon_url=feed.get("feed", {}).get("image", {}).get("href"),
feed=feed
)
def get_latest_articles(self, max: int) -> list[Article]:
"""Returns a list of Article objects.
Parameters
----------
max : int
The maximum number of articles to return.
Returns
-------
list of Article
A list of Article objects.
"""
log.debug("Fetching latest articles from %s, max=%s", self, max)
return [
Article.from_entry(entry)
for i, entry in enumerate(self.feed.entries)
if i < max
]
def get_source(rss_url: str) -> Source:
"""_summary_
Parameters
----------
rss_url : str
_description_
Returns
-------
Source
_description_
"""
"""
parsed_feed = parse("https://gitea.corbz.dev/corbz/BBC-News-Bot/rss/branch/main/src/extensions/news.py")
parsed_feed = parse(rss_url) # TODO: make asyncronous
return Source.from_parsed(parsed_feed)
def get_test():
parsed = parse(Feeds.THE_UPPER_LIP.value)
print(json.dumps(parsed, indent=4))
return parsed

31
src/utils.py Normal file
View File

@ -0,0 +1,31 @@
"""A collection of utility functions that can be used in various places."""
import aiohttp
import logging
from discord import Interaction
log = logging.getLogger(__name__)
async def get_rss_data(url: str):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
items = await response.text(), response.status
return items
async def followup(inter: Interaction, *args, **kwargs):
"""Shorthand for following up on an interaction.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.followup.send(*args, **kwargs)
async def audit(cog, *args, **kwargs):
"""Shorthand for auditing an interaction."""
await cog.bot.audit(*args, **kwargs)