Reusable aiohttp session

This commit is contained in:
Corban-Lee Jones 2024-02-07 09:21:13 +00:00
parent 1d8ad24569
commit df15ad07a5
5 changed files with 49 additions and 33 deletions

View File

@ -28,7 +28,7 @@ class AuditModel(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
discord_user_id = Column(BigInteger, nullable=False)
discord_server_id = Column(BigInteger, nullable=False)
# discord_server_id = Column(BigInteger, nullable=False) # TODO: this doesnt exist, integrate it.
message = Column(String, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102

View File

@ -248,7 +248,7 @@ class FeedCog(commands.Cog):
await inter.response.defer()
try:
sources = await self.bot.functions.get_feeds(inter.guild_id)
feeds = await self.bot.functions.get_feeds(inter.guild_id)
except NoResultFound:
await (
Followup(
@ -261,8 +261,8 @@ class FeedCog(commands.Cog):
)
else:
description = "\n".join([
f"{i}. **[{source.name}]({source.url})**"
for i, source in enumerate(sources)
f"{i}. **[{info[0]}]({info[1]})**" # info = (nick, url)
for i, info in enumerate(feeds)
])
await (
Followup(
@ -273,7 +273,6 @@ class FeedCog(commands.Cog):
.send(inter)
)
# @feed_group.command(name="fetch")
# @rename(max_="max")
# @autocomplete(rss=source_autocomplete)
@ -308,7 +307,8 @@ class FeedCog(commands.Cog):
await followup(inter, "Sorry, I couldn't find any articles from this feed.")
return
embeds = [await article.to_embed() for article in articles]
async with aiohttp.ClientSession() as session:
embeds = [await article.to_embed(session) for article in articles]
async with DatabaseManager() as database:
query = insert(SentArticleModel).values([
@ -630,7 +630,7 @@ class FeedCog(commands.Cog):
return
output = "\n".join([
f"{i}. <#{feed.discord_channel_id}> · "
f"{i}. <#{feed.discord_channel_id}> · " # TODO: add icon indicating inaccessible channel, if is the case.
f"[{feed.rss_source.nick}]({feed.rss_source.rss_url})"
for i, feed in enumerate(feed_channels)
])

View File

@ -8,6 +8,7 @@ import datetime
from os import getenv
from time import process_time
import aiohttp
from discord import TextChannel
from discord.ext import commands, tasks
from discord.errors import Forbidden
@ -61,7 +62,7 @@ class TaskCog(commands.Cog):
self.rss_task.cancel()
@tasks.loop(minutes=10)
@tasks.loop(time=times)
async def rss_task(self):
"""Automated task responsible for processing rss feeds."""
@ -95,20 +96,23 @@ class TaskCog(commands.Cog):
# TODO: integrate the `validate_feed` code into here, also do on list command and show errors.
unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url)
parsed_feed = parse(unparsed_content)
source = Source.from_parsed(parsed_feed)
articles = source.get_latest_articles(5)
async with aiohttp.ClientSession() as session:
if not articles:
log.info("No articles to process")
return
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
parsed_feed = parse(unparsed_content)
source = Source.from_parsed(parsed_feed)
articles = source.get_latest_articles(5)
for article in articles:
await self.process_article(feed.id, article, channel, database)
if not articles:
log.info("No articles to process for %s in ", feed.rss_source.nick, feed.discord_server_id)
return
for article in articles:
await self.process_article(feed.id, article, channel, database, session)
async def process_article(
self, feed_id: int, article: Article, channel: TextChannel, database: DatabaseManager
self, feed_id: int, article: Article, channel: TextChannel, database: DatabaseManager,
session: aiohttp.ClientSession
):
"""Process the passed article. Will send the embed to a channel if all is valid.
@ -136,11 +140,11 @@ class TaskCog(commands.Cog):
log.debug("Article already processed: %s", article.url)
return
embed = await article.to_embed()
embed = await article.to_embed(session)
try:
await channel.send(embed=embed)
except Forbidden:
log.error("Can't send article to channel: %s · %s", channel.name, channel.id)
except Forbidden as error: # TODO: find some way of informing the user about this error.
log.error("Can't send article to channel: %s · %s · %s", channel.name, channel.id, error)
return
query = insert(SentArticleModel).values(

View File

@ -63,9 +63,14 @@ class Article:
source=source
)
async def get_thumbnail_url(self) -> str | None:
async def get_thumbnail_url(self, session: aiohttp.ClientSession) -> str | None:
"""Returns the thumbnail URL for an article.
Parameters
----------
session : aiohttp.ClientSession
A client session used to get the thumbnail.
Returns
-------
str or None
@ -75,11 +80,10 @@ class Article:
log.debug("Fetching thumbnail for article: %s", self)
try:
async with aiohttp.ClientSession() as session:
async with session.get(self.url) as response:
html = await response.text()
async with session.get(self.url, timeout=15) as response:
html = await response.text()
except aiohttp.InvalidURL as error:
log.error(error)
log.error("invalid thumbnail url: %s", error)
return None
soup = bs4(html, "html.parser")
@ -90,9 +94,14 @@ class Article:
image_content = image_element.get("content")
return image_content if validators.url(image_content) else None
async def to_embed(self) -> Embed:
async def to_embed(self, session: aiohttp.ClientSession) -> Embed:
"""Creates and returns a Discord Embed object from the article.
Parameters
----------
session : aiohttp.ClientSession
A client session used to get additional article data.
Returns
-------
Embed
@ -110,7 +119,7 @@ class Article:
embed_url = self.url if validators.url(self.url) else None
author_url = self.source.url if validators.url(self.source.url) else None
icon_url = self.source.icon_url if validators.url(self.source.icon_url) else None
thumb_url = await self.get_thumbnail_url() # validation done inside func
thumb_url = await self.get_thumbnail_url(session) # validation done inside func
embed = Embed(
title=title,
@ -315,9 +324,9 @@ class Functions:
return await Source.from_url(url)
async def get_feeds(self, guild_id: int) -> list[Source]:
async def get_feeds(self, guild_id: int) -> list[tuple[str, str]]:
"""Returns a list of fetched Feed objects from the database.
Note: a request will be made too all found Feed URLs.
Note: a request will be made too all found Feed UR Ls.
Parameters
----------
@ -326,7 +335,7 @@ class Functions:
Returns
-------
list[Source]
list[tuple[str, str]]
List of Source objects, resulting from the query.
Raises
@ -344,7 +353,7 @@ class Functions:
if not rss_sources:
raise NoResultFound
return [await Source.from_url(feed.rss_url) for feed in rss_sources]
return [(feed.nick, feed.rss_url) for feed in rss_sources]
async def assign_feed(
self, url: str, channel_name: str, channel_id: int, guild_id: int

View File

@ -13,7 +13,10 @@ async def fetch(session, url: str) -> str:
async with session.get(url) as response:
return await response.text()
async def get_unparsed_feed(url: str):
async def get_unparsed_feed(url: str, session: aiohttp.ClientSession=None):
if session is not None:
return await fetch(session, url)
async with aiohttp.ClientSession() as session:
return await fetch(session, url)