Reusable aiohttp session
This commit is contained in:
parent
1d8ad24569
commit
df15ad07a5
@ -28,7 +28,7 @@ class AuditModel(Base):
|
|||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
discord_user_id = Column(BigInteger, nullable=False)
|
discord_user_id = Column(BigInteger, nullable=False)
|
||||||
discord_server_id = Column(BigInteger, nullable=False)
|
# discord_server_id = Column(BigInteger, nullable=False) # TODO: this doesnt exist, integrate it.
|
||||||
message = Column(String, nullable=False)
|
message = Column(String, nullable=False)
|
||||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
||||||
|
|
||||||
|
@ -248,7 +248,7 @@ class FeedCog(commands.Cog):
|
|||||||
await inter.response.defer()
|
await inter.response.defer()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sources = await self.bot.functions.get_feeds(inter.guild_id)
|
feeds = await self.bot.functions.get_feeds(inter.guild_id)
|
||||||
except NoResultFound:
|
except NoResultFound:
|
||||||
await (
|
await (
|
||||||
Followup(
|
Followup(
|
||||||
@ -261,8 +261,8 @@ class FeedCog(commands.Cog):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
description = "\n".join([
|
description = "\n".join([
|
||||||
f"{i}. **[{source.name}]({source.url})**"
|
f"{i}. **[{info[0]}]({info[1]})**" # info = (nick, url)
|
||||||
for i, source in enumerate(sources)
|
for i, info in enumerate(feeds)
|
||||||
])
|
])
|
||||||
await (
|
await (
|
||||||
Followup(
|
Followup(
|
||||||
@ -273,7 +273,6 @@ class FeedCog(commands.Cog):
|
|||||||
.send(inter)
|
.send(inter)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# @feed_group.command(name="fetch")
|
# @feed_group.command(name="fetch")
|
||||||
# @rename(max_="max")
|
# @rename(max_="max")
|
||||||
# @autocomplete(rss=source_autocomplete)
|
# @autocomplete(rss=source_autocomplete)
|
||||||
@ -308,7 +307,8 @@ class FeedCog(commands.Cog):
|
|||||||
await followup(inter, "Sorry, I couldn't find any articles from this feed.")
|
await followup(inter, "Sorry, I couldn't find any articles from this feed.")
|
||||||
return
|
return
|
||||||
|
|
||||||
embeds = [await article.to_embed() for article in articles]
|
async with aiohttp.ClientSession() as session:
|
||||||
|
embeds = [await article.to_embed(session) for article in articles]
|
||||||
|
|
||||||
async with DatabaseManager() as database:
|
async with DatabaseManager() as database:
|
||||||
query = insert(SentArticleModel).values([
|
query = insert(SentArticleModel).values([
|
||||||
@ -630,7 +630,7 @@ class FeedCog(commands.Cog):
|
|||||||
return
|
return
|
||||||
|
|
||||||
output = "\n".join([
|
output = "\n".join([
|
||||||
f"{i}. <#{feed.discord_channel_id}> · "
|
f"{i}. <#{feed.discord_channel_id}> · " # TODO: add icon indicating inaccessible channel, if is the case.
|
||||||
f"[{feed.rss_source.nick}]({feed.rss_source.rss_url})"
|
f"[{feed.rss_source.nick}]({feed.rss_source.rss_url})"
|
||||||
for i, feed in enumerate(feed_channels)
|
for i, feed in enumerate(feed_channels)
|
||||||
])
|
])
|
||||||
|
@ -8,6 +8,7 @@ import datetime
|
|||||||
from os import getenv
|
from os import getenv
|
||||||
from time import process_time
|
from time import process_time
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
from discord import TextChannel
|
from discord import TextChannel
|
||||||
from discord.ext import commands, tasks
|
from discord.ext import commands, tasks
|
||||||
from discord.errors import Forbidden
|
from discord.errors import Forbidden
|
||||||
@ -61,7 +62,7 @@ class TaskCog(commands.Cog):
|
|||||||
|
|
||||||
self.rss_task.cancel()
|
self.rss_task.cancel()
|
||||||
|
|
||||||
@tasks.loop(minutes=10)
|
@tasks.loop(time=times)
|
||||||
async def rss_task(self):
|
async def rss_task(self):
|
||||||
"""Automated task responsible for processing rss feeds."""
|
"""Automated task responsible for processing rss feeds."""
|
||||||
|
|
||||||
@ -95,20 +96,23 @@ class TaskCog(commands.Cog):
|
|||||||
|
|
||||||
# TODO: integrate the `validate_feed` code into here, also do on list command and show errors.
|
# TODO: integrate the `validate_feed` code into here, also do on list command and show errors.
|
||||||
|
|
||||||
unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url)
|
async with aiohttp.ClientSession() as session:
|
||||||
parsed_feed = parse(unparsed_content)
|
|
||||||
source = Source.from_parsed(parsed_feed)
|
|
||||||
articles = source.get_latest_articles(5)
|
|
||||||
|
|
||||||
if not articles:
|
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
|
||||||
log.info("No articles to process")
|
parsed_feed = parse(unparsed_content)
|
||||||
return
|
source = Source.from_parsed(parsed_feed)
|
||||||
|
articles = source.get_latest_articles(5)
|
||||||
|
|
||||||
for article in articles:
|
if not articles:
|
||||||
await self.process_article(feed.id, article, channel, database)
|
log.info("No articles to process for %s in ", feed.rss_source.nick, feed.discord_server_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
for article in articles:
|
||||||
|
await self.process_article(feed.id, article, channel, database, session)
|
||||||
|
|
||||||
async def process_article(
|
async def process_article(
|
||||||
self, feed_id: int, article: Article, channel: TextChannel, database: DatabaseManager
|
self, feed_id: int, article: Article, channel: TextChannel, database: DatabaseManager,
|
||||||
|
session: aiohttp.ClientSession
|
||||||
):
|
):
|
||||||
"""Process the passed article. Will send the embed to a channel if all is valid.
|
"""Process the passed article. Will send the embed to a channel if all is valid.
|
||||||
|
|
||||||
@ -136,11 +140,11 @@ class TaskCog(commands.Cog):
|
|||||||
log.debug("Article already processed: %s", article.url)
|
log.debug("Article already processed: %s", article.url)
|
||||||
return
|
return
|
||||||
|
|
||||||
embed = await article.to_embed()
|
embed = await article.to_embed(session)
|
||||||
try:
|
try:
|
||||||
await channel.send(embed=embed)
|
await channel.send(embed=embed)
|
||||||
except Forbidden:
|
except Forbidden as error: # TODO: find some way of informing the user about this error.
|
||||||
log.error("Can't send article to channel: %s · %s", channel.name, channel.id)
|
log.error("Can't send article to channel: %s · %s · %s", channel.name, channel.id, error)
|
||||||
return
|
return
|
||||||
|
|
||||||
query = insert(SentArticleModel).values(
|
query = insert(SentArticleModel).values(
|
||||||
|
31
src/feed.py
31
src/feed.py
@ -63,9 +63,14 @@ class Article:
|
|||||||
source=source
|
source=source
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_thumbnail_url(self) -> str | None:
|
async def get_thumbnail_url(self, session: aiohttp.ClientSession) -> str | None:
|
||||||
"""Returns the thumbnail URL for an article.
|
"""Returns the thumbnail URL for an article.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
session : aiohttp.ClientSession
|
||||||
|
A client session used to get the thumbnail.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
str or None
|
str or None
|
||||||
@ -75,11 +80,10 @@ class Article:
|
|||||||
log.debug("Fetching thumbnail for article: %s", self)
|
log.debug("Fetching thumbnail for article: %s", self)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with session.get(self.url, timeout=15) as response:
|
||||||
async with session.get(self.url) as response:
|
html = await response.text()
|
||||||
html = await response.text()
|
|
||||||
except aiohttp.InvalidURL as error:
|
except aiohttp.InvalidURL as error:
|
||||||
log.error(error)
|
log.error("invalid thumbnail url: %s", error)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
soup = bs4(html, "html.parser")
|
soup = bs4(html, "html.parser")
|
||||||
@ -90,9 +94,14 @@ class Article:
|
|||||||
image_content = image_element.get("content")
|
image_content = image_element.get("content")
|
||||||
return image_content if validators.url(image_content) else None
|
return image_content if validators.url(image_content) else None
|
||||||
|
|
||||||
async def to_embed(self) -> Embed:
|
async def to_embed(self, session: aiohttp.ClientSession) -> Embed:
|
||||||
"""Creates and returns a Discord Embed object from the article.
|
"""Creates and returns a Discord Embed object from the article.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
session : aiohttp.ClientSession
|
||||||
|
A client session used to get additional article data.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Embed
|
Embed
|
||||||
@ -110,7 +119,7 @@ class Article:
|
|||||||
embed_url = self.url if validators.url(self.url) else None
|
embed_url = self.url if validators.url(self.url) else None
|
||||||
author_url = self.source.url if validators.url(self.source.url) else None
|
author_url = self.source.url if validators.url(self.source.url) else None
|
||||||
icon_url = self.source.icon_url if validators.url(self.source.icon_url) else None
|
icon_url = self.source.icon_url if validators.url(self.source.icon_url) else None
|
||||||
thumb_url = await self.get_thumbnail_url() # validation done inside func
|
thumb_url = await self.get_thumbnail_url(session) # validation done inside func
|
||||||
|
|
||||||
embed = Embed(
|
embed = Embed(
|
||||||
title=title,
|
title=title,
|
||||||
@ -315,9 +324,9 @@ class Functions:
|
|||||||
|
|
||||||
return await Source.from_url(url)
|
return await Source.from_url(url)
|
||||||
|
|
||||||
async def get_feeds(self, guild_id: int) -> list[Source]:
|
async def get_feeds(self, guild_id: int) -> list[tuple[str, str]]:
|
||||||
"""Returns a list of fetched Feed objects from the database.
|
"""Returns a list of fetched Feed objects from the database.
|
||||||
Note: a request will be made too all found Feed URLs.
|
Note: a request will be made too all found Feed UR Ls.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -326,7 +335,7 @@ class Functions:
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
list[Source]
|
list[tuple[str, str]]
|
||||||
List of Source objects, resulting from the query.
|
List of Source objects, resulting from the query.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
@ -344,7 +353,7 @@ class Functions:
|
|||||||
if not rss_sources:
|
if not rss_sources:
|
||||||
raise NoResultFound
|
raise NoResultFound
|
||||||
|
|
||||||
return [await Source.from_url(feed.rss_url) for feed in rss_sources]
|
return [(feed.nick, feed.rss_url) for feed in rss_sources]
|
||||||
|
|
||||||
async def assign_feed(
|
async def assign_feed(
|
||||||
self, url: str, channel_name: str, channel_id: int, guild_id: int
|
self, url: str, channel_name: str, channel_id: int, guild_id: int
|
||||||
|
@ -13,7 +13,10 @@ async def fetch(session, url: str) -> str:
|
|||||||
async with session.get(url) as response:
|
async with session.get(url) as response:
|
||||||
return await response.text()
|
return await response.text()
|
||||||
|
|
||||||
async def get_unparsed_feed(url: str):
|
async def get_unparsed_feed(url: str, session: aiohttp.ClientSession=None):
|
||||||
|
if session is not None:
|
||||||
|
return await fetch(session, url)
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
return await fetch(session, url)
|
return await fetch(session, url)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user