Reusable aiohttp session
This commit is contained in:
parent
1d8ad24569
commit
df15ad07a5
@ -28,7 +28,7 @@ class AuditModel(Base):
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
discord_user_id = Column(BigInteger, nullable=False)
|
||||
discord_server_id = Column(BigInteger, nullable=False)
|
||||
# discord_server_id = Column(BigInteger, nullable=False) # TODO: this doesnt exist, integrate it.
|
||||
message = Column(String, nullable=False)
|
||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
||||
|
||||
|
@ -248,7 +248,7 @@ class FeedCog(commands.Cog):
|
||||
await inter.response.defer()
|
||||
|
||||
try:
|
||||
sources = await self.bot.functions.get_feeds(inter.guild_id)
|
||||
feeds = await self.bot.functions.get_feeds(inter.guild_id)
|
||||
except NoResultFound:
|
||||
await (
|
||||
Followup(
|
||||
@ -261,8 +261,8 @@ class FeedCog(commands.Cog):
|
||||
)
|
||||
else:
|
||||
description = "\n".join([
|
||||
f"{i}. **[{source.name}]({source.url})**"
|
||||
for i, source in enumerate(sources)
|
||||
f"{i}. **[{info[0]}]({info[1]})**" # info = (nick, url)
|
||||
for i, info in enumerate(feeds)
|
||||
])
|
||||
await (
|
||||
Followup(
|
||||
@ -273,7 +273,6 @@ class FeedCog(commands.Cog):
|
||||
.send(inter)
|
||||
)
|
||||
|
||||
|
||||
# @feed_group.command(name="fetch")
|
||||
# @rename(max_="max")
|
||||
# @autocomplete(rss=source_autocomplete)
|
||||
@ -308,7 +307,8 @@ class FeedCog(commands.Cog):
|
||||
await followup(inter, "Sorry, I couldn't find any articles from this feed.")
|
||||
return
|
||||
|
||||
embeds = [await article.to_embed() for article in articles]
|
||||
async with aiohttp.ClientSession() as session:
|
||||
embeds = [await article.to_embed(session) for article in articles]
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = insert(SentArticleModel).values([
|
||||
@ -630,7 +630,7 @@ class FeedCog(commands.Cog):
|
||||
return
|
||||
|
||||
output = "\n".join([
|
||||
f"{i}. <#{feed.discord_channel_id}> · "
|
||||
f"{i}. <#{feed.discord_channel_id}> · " # TODO: add icon indicating inaccessible channel, if is the case.
|
||||
f"[{feed.rss_source.nick}]({feed.rss_source.rss_url})"
|
||||
for i, feed in enumerate(feed_channels)
|
||||
])
|
||||
|
@ -8,6 +8,7 @@ import datetime
|
||||
from os import getenv
|
||||
from time import process_time
|
||||
|
||||
import aiohttp
|
||||
from discord import TextChannel
|
||||
from discord.ext import commands, tasks
|
||||
from discord.errors import Forbidden
|
||||
@ -61,7 +62,7 @@ class TaskCog(commands.Cog):
|
||||
|
||||
self.rss_task.cancel()
|
||||
|
||||
@tasks.loop(minutes=10)
|
||||
@tasks.loop(time=times)
|
||||
async def rss_task(self):
|
||||
"""Automated task responsible for processing rss feeds."""
|
||||
|
||||
@ -95,20 +96,23 @@ class TaskCog(commands.Cog):
|
||||
|
||||
# TODO: integrate the `validate_feed` code into here, also do on list command and show errors.
|
||||
|
||||
unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url)
|
||||
parsed_feed = parse(unparsed_content)
|
||||
source = Source.from_parsed(parsed_feed)
|
||||
articles = source.get_latest_articles(5)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
||||
if not articles:
|
||||
log.info("No articles to process")
|
||||
return
|
||||
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
|
||||
parsed_feed = parse(unparsed_content)
|
||||
source = Source.from_parsed(parsed_feed)
|
||||
articles = source.get_latest_articles(5)
|
||||
|
||||
for article in articles:
|
||||
await self.process_article(feed.id, article, channel, database)
|
||||
if not articles:
|
||||
log.info("No articles to process for %s in ", feed.rss_source.nick, feed.discord_server_id)
|
||||
return
|
||||
|
||||
for article in articles:
|
||||
await self.process_article(feed.id, article, channel, database, session)
|
||||
|
||||
async def process_article(
|
||||
self, feed_id: int, article: Article, channel: TextChannel, database: DatabaseManager
|
||||
self, feed_id: int, article: Article, channel: TextChannel, database: DatabaseManager,
|
||||
session: aiohttp.ClientSession
|
||||
):
|
||||
"""Process the passed article. Will send the embed to a channel if all is valid.
|
||||
|
||||
@ -136,11 +140,11 @@ class TaskCog(commands.Cog):
|
||||
log.debug("Article already processed: %s", article.url)
|
||||
return
|
||||
|
||||
embed = await article.to_embed()
|
||||
embed = await article.to_embed(session)
|
||||
try:
|
||||
await channel.send(embed=embed)
|
||||
except Forbidden:
|
||||
log.error("Can't send article to channel: %s · %s", channel.name, channel.id)
|
||||
except Forbidden as error: # TODO: find some way of informing the user about this error.
|
||||
log.error("Can't send article to channel: %s · %s · %s", channel.name, channel.id, error)
|
||||
return
|
||||
|
||||
query = insert(SentArticleModel).values(
|
||||
|
31
src/feed.py
31
src/feed.py
@ -63,9 +63,14 @@ class Article:
|
||||
source=source
|
||||
)
|
||||
|
||||
async def get_thumbnail_url(self) -> str | None:
|
||||
async def get_thumbnail_url(self, session: aiohttp.ClientSession) -> str | None:
|
||||
"""Returns the thumbnail URL for an article.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session : aiohttp.ClientSession
|
||||
A client session used to get the thumbnail.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
@ -75,11 +80,10 @@ class Article:
|
||||
log.debug("Fetching thumbnail for article: %s", self)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self.url) as response:
|
||||
html = await response.text()
|
||||
async with session.get(self.url, timeout=15) as response:
|
||||
html = await response.text()
|
||||
except aiohttp.InvalidURL as error:
|
||||
log.error(error)
|
||||
log.error("invalid thumbnail url: %s", error)
|
||||
return None
|
||||
|
||||
soup = bs4(html, "html.parser")
|
||||
@ -90,9 +94,14 @@ class Article:
|
||||
image_content = image_element.get("content")
|
||||
return image_content if validators.url(image_content) else None
|
||||
|
||||
async def to_embed(self) -> Embed:
|
||||
async def to_embed(self, session: aiohttp.ClientSession) -> Embed:
|
||||
"""Creates and returns a Discord Embed object from the article.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session : aiohttp.ClientSession
|
||||
A client session used to get additional article data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Embed
|
||||
@ -110,7 +119,7 @@ class Article:
|
||||
embed_url = self.url if validators.url(self.url) else None
|
||||
author_url = self.source.url if validators.url(self.source.url) else None
|
||||
icon_url = self.source.icon_url if validators.url(self.source.icon_url) else None
|
||||
thumb_url = await self.get_thumbnail_url() # validation done inside func
|
||||
thumb_url = await self.get_thumbnail_url(session) # validation done inside func
|
||||
|
||||
embed = Embed(
|
||||
title=title,
|
||||
@ -315,9 +324,9 @@ class Functions:
|
||||
|
||||
return await Source.from_url(url)
|
||||
|
||||
async def get_feeds(self, guild_id: int) -> list[Source]:
|
||||
async def get_feeds(self, guild_id: int) -> list[tuple[str, str]]:
|
||||
"""Returns a list of fetched Feed objects from the database.
|
||||
Note: a request will be made too all found Feed URLs.
|
||||
Note: a request will be made too all found Feed UR Ls.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -326,7 +335,7 @@ class Functions:
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Source]
|
||||
list[tuple[str, str]]
|
||||
List of Source objects, resulting from the query.
|
||||
|
||||
Raises
|
||||
@ -344,7 +353,7 @@ class Functions:
|
||||
if not rss_sources:
|
||||
raise NoResultFound
|
||||
|
||||
return [await Source.from_url(feed.rss_url) for feed in rss_sources]
|
||||
return [(feed.nick, feed.rss_url) for feed in rss_sources]
|
||||
|
||||
async def assign_feed(
|
||||
self, url: str, channel_name: str, channel_id: int, guild_id: int
|
||||
|
@ -13,7 +13,10 @@ async def fetch(session, url: str) -> str:
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
|
||||
async def get_unparsed_feed(url: str):
|
||||
async def get_unparsed_feed(url: str, session: aiohttp.ClientSession=None):
|
||||
if session is not None:
|
||||
return await fetch(session, url)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
return await fetch(session, url)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user