improvement: _get_session() is now an async generator that automatically closes session

This commit is contained in:
2025-08-14 08:22:57 +03:00
parent 8c9095a88f
commit 465ea9bc63
+18 -40
View File
@@ -1,14 +1,21 @@
from pathlib import Path from pathlib import Path
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from aiohttp import ClientSession, ClientResponse, FormData, ClientTimeout from aiohttp import ClientSession, ClientResponse, FormData, ClientTimeout
BASE_URL = "https://vikingfile.com/api/" BASE_URL = "https://vikingfile.com/api/"
def _get_session(session: ClientSession | None) -> tuple[ClientSession, bool]: @asynccontextmanager
async def _get_session(session: ClientSession | None) -> AsyncGenerator[ClientSession]:
close_session = session is None close_session = session is None
session = session or ClientSession() session = session or ClientSession()
return session, close_session try:
yield session
finally:
if close_session:
await session.close()
async def get_upload_url(size: int, session: ClientSession = None, timeout: int = 10) -> ClientResponse: async def get_upload_url(size: int, session: ClientSession = None, timeout: int = 10) -> ClientResponse:
@@ -23,21 +30,18 @@ async def get_upload_url(size: int, session: ClientSession = None, timeout: int
Returns: Returns:
aiohttp.ClientResponse: The response from the server containing the upload server URL. aiohttp.ClientResponse: The response from the server containing the upload server URL.
""" """
session, close_session = _get_session(session)
request_data = { request_data = {
"size": size, "size": size,
} }
async with _get_session(session) as session:
response = await session.post( response = await session.post(
url=BASE_URL + "get-upload-url", url=BASE_URL + "get-upload-url",
data=request_data, data=request_data,
timeout=timeout timeout=timeout
) )
if close_session:
await session.close()
return response return response
@@ -59,7 +63,6 @@ async def complete_upload(key: str, upload_id: str, parts: list[dict], filename:
Returns: Returns:
aiohttp.ClientResponse: Server response. aiohttp.ClientResponse: Server response.
""" """
session, close_session = _get_session(session)
request_data = { request_data = {
"key": key, "key": key,
@@ -73,15 +76,13 @@ async def complete_upload(key: str, upload_id: str, parts: list[dict], filename:
request_data[f"parts[{idx}][PartNumber]"] = str(part["PartNumber"]) request_data[f"parts[{idx}][PartNumber]"] = str(part["PartNumber"])
request_data[f"parts[{idx}][ETag]"] = part["ETag"] request_data[f"parts[{idx}][ETag]"] = part["ETag"]
async with _get_session(session) as session:
response = await session.post( response = await session.post(
url=BASE_URL + "complete-upload", url=BASE_URL + "complete-upload",
data=request_data, data=request_data,
timeout=timeout timeout=timeout
) )
if close_session:
await session.close()
return response return response
@@ -97,16 +98,12 @@ async def get_upload_server(session: ClientSession = None, timeout: int = 10) ->
aiohttp.ClientResponse: The response from the server containing the upload server URL. aiohttp.ClientResponse: The response from the server containing the upload server URL.
""" """
session, close_session = _get_session(session) async with _get_session(session) as session:
response = await session.get( response = await session.get(
url=BASE_URL + "get-server", url=BASE_URL + "get-server",
timeout=timeout timeout=timeout
) )
if close_session:
await session.close()
return response return response
@@ -131,22 +128,18 @@ async def upload_file_legacy(upload_url: str, filepath: str, user: str = "", pat
filepath = Path(filepath).resolve() filepath = Path(filepath).resolve()
assert filepath.exists(), f"File {filepath} doesn't exist!" assert filepath.exists(), f"File {filepath} doesn't exist!"
session, close_session = _get_session(session)
data = FormData() data = FormData()
data.add_field("user", user) data.add_field("user", user)
data.add_field("path", path) data.add_field("path", path)
data.add_field("file", open(filepath, "rb"), filename=filepath.name) data.add_field("file", open(filepath, "rb"), filename=filepath.name)
async with _get_session(session) as session:
response = await session.post( response = await session.post(
url=upload_url, url=upload_url,
data=data, data=data,
timeout=ClientTimeout(connect=timeout, sock_read=None, sock_connect=None) timeout=ClientTimeout(connect=timeout, sock_read=None, sock_connect=None)
) )
if close_session:
await session.close()
return response return response
@@ -167,7 +160,6 @@ async def upload_remote_file(upload_server: str, link: str, user: str = "", file
Returns: Returns:
aiohttp.ClientResponse: Server response. aiohttp.ClientResponse: Server response.
""" """
session, close_session = _get_session(session)
request_data = { request_data = {
"link": link, "link": link,
@@ -176,15 +168,13 @@ async def upload_remote_file(upload_server: str, link: str, user: str = "", file
"path": path, "path": path,
} }
async with _get_session(session) as session:
response = await session.post( response = await session.post(
url=upload_server, url=upload_server,
data=request_data, data=request_data,
timeout=timeout timeout=timeout
) )
if close_session:
await session.close()
return response return response
@@ -201,22 +191,19 @@ async def delete_file(file_hash: str, user: str, session: ClientSession = None,
Returns: Returns:
aiohttp.ClientResponse: Server response. aiohttp.ClientResponse: Server response.
""" """
session, close_session = _get_session(session)
request_data = { request_data = {
"hash": file_hash, "hash": file_hash,
"user": user, "user": user,
} }
async with _get_session(session) as session:
response = await session.post( response = await session.post(
url=BASE_URL + "delete-file", url=BASE_URL + "delete-file",
data=request_data, data=request_data,
timeout=timeout timeout=timeout
) )
if close_session:
await session.close()
return response return response
@@ -235,7 +222,6 @@ async def rename_file(file_hash: str, user: str, filename: str, session: ClientS
Returns: Returns:
aiohttp.ClientResponse: Server response. aiohttp.ClientResponse: Server response.
""" """
session, close_session = _get_session(session)
request_data = { request_data = {
"hash": file_hash, "hash": file_hash,
@@ -243,15 +229,13 @@ async def rename_file(file_hash: str, user: str, filename: str, session: ClientS
"filename": filename, "filename": filename,
} }
async with _get_session(session) as session:
response = await session.post( response = await session.post(
url=BASE_URL + "rename-file", url=BASE_URL + "rename-file",
data=request_data, data=request_data,
timeout=timeout timeout=timeout
) )
if close_session:
await session.close()
return response return response
@@ -267,21 +251,18 @@ async def check_file(file_hash: str, session: ClientSession = None, timeout: int
Returns: Returns:
aiohttp.ClientResponse: Server response. aiohttp.ClientResponse: Server response.
""" """
session, close_session = _get_session(session)
request_data = { request_data = {
"hash": file_hash, "hash": file_hash,
} }
async with _get_session(session) as session:
response = await session.post( response = await session.post(
url=BASE_URL + "check-file", url=BASE_URL + "check-file",
data=request_data, data=request_data,
timeout=timeout timeout=timeout
) )
if close_session:
await session.close()
return response return response
@@ -303,7 +284,6 @@ async def list_files(user: str, page: int, path: str = "", session: ClientSessio
Returns: Returns:
aiohttp.ClientResponse: Server response. aiohttp.ClientResponse: Server response.
""" """
session, close_session = _get_session(session)
if page <= 0: if page <= 0:
raise ValueError("Page must be positive number") raise ValueError("Page must be positive number")
@@ -314,13 +294,11 @@ async def list_files(user: str, page: int, path: str = "", session: ClientSessio
"path": path "path": path
} }
async with _get_session(session) as session:
response = await session.post( response = await session.post(
url=BASE_URL + "list-files", url=BASE_URL + "list-files",
data=request_data, data=request_data,
timeout=timeout timeout=timeout
) )
if close_session:
await session.close()
return response return response