diff --git a/viking_file/api.py b/viking_file/api.py index 95fbc06..7717d0d 100644 --- a/viking_file/api.py +++ b/viking_file/api.py @@ -1,14 +1,21 @@ from pathlib import Path +from contextlib import asynccontextmanager +from typing import AsyncGenerator from aiohttp import ClientSession, ClientResponse, FormData, ClientTimeout BASE_URL = "https://vikingfile.com/api/" -def _get_session(session: ClientSession | None) -> tuple[ClientSession, bool]: +@asynccontextmanager +async def _get_session(session: ClientSession | None) -> AsyncGenerator[ClientSession]: close_session = session is None session = session or ClientSession() - return session, close_session + try: + yield session + finally: + if close_session: + await session.close() async def get_upload_url(size: int, session: ClientSession = None, timeout: int = 10) -> ClientResponse: @@ -23,20 +30,17 @@ async def get_upload_url(size: int, session: ClientSession = None, timeout: int Returns: aiohttp.ClientResponse: The response from the server containing the upload server URL. """ - session, close_session = _get_session(session) request_data = { "size": size, } - response = await session.post( - url=BASE_URL + "get-upload-url", - data=request_data, - timeout=timeout - ) - - if close_session: - await session.close() + async with _get_session(session) as session: + response = await session.post( + url=BASE_URL + "get-upload-url", + data=request_data, + timeout=timeout + ) return response @@ -59,7 +63,6 @@ async def complete_upload(key: str, upload_id: str, parts: list[dict], filename: Returns: aiohttp.ClientResponse: Server response. """ - session, close_session = _get_session(session) request_data = { "key": key, @@ -73,14 +76,12 @@ async def complete_upload(key: str, upload_id: str, parts: list[dict], filename: request_data[f"parts[{idx}][PartNumber]"] = str(part["PartNumber"]) request_data[f"parts[{idx}][ETag]"] = part["ETag"] - response = await session.post( - url=BASE_URL + "complete-upload", - data=request_data, - timeout=timeout - ) - - if close_session: - await session.close() + async with _get_session(session) as session: + response = await session.post( + url=BASE_URL + "complete-upload", + data=request_data, + timeout=timeout + ) return response @@ -97,15 +98,11 @@ async def get_upload_server(session: ClientSession = None, timeout: int = 10) -> aiohttp.ClientResponse: The response from the server containing the upload server URL. """ - session, close_session = _get_session(session) - - response = await session.get( - url=BASE_URL + "get-server", - timeout=timeout - ) - - if close_session: - await session.close() + async with _get_session(session) as session: + response = await session.get( + url=BASE_URL + "get-server", + timeout=timeout + ) return response @@ -131,21 +128,17 @@ async def upload_file_legacy(upload_url: str, filepath: str, user: str = "", pat filepath = Path(filepath).resolve() assert filepath.exists(), f"File {filepath} doesn't exist!" - session, close_session = _get_session(session) - data = FormData() data.add_field("user", user) data.add_field("path", path) data.add_field("file", open(filepath, "rb"), filename=filepath.name) - response = await session.post( - url=upload_url, - data=data, - timeout=ClientTimeout(connect=timeout, sock_read=None, sock_connect=None) - ) - - if close_session: - await session.close() + async with _get_session(session) as session: + response = await session.post( + url=upload_url, + data=data, + timeout=ClientTimeout(connect=timeout, sock_read=None, sock_connect=None) + ) return response @@ -167,7 +160,6 @@ async def upload_remote_file(upload_server: str, link: str, user: str = "", file Returns: aiohttp.ClientResponse: Server response. """ - session, close_session = _get_session(session) request_data = { "link": link, @@ -176,14 +168,12 @@ async def upload_remote_file(upload_server: str, link: str, user: str = "", file "path": path, } - response = await session.post( - url=upload_server, - data=request_data, - timeout=timeout - ) - - if close_session: - await session.close() + async with _get_session(session) as session: + response = await session.post( + url=upload_server, + data=request_data, + timeout=timeout + ) return response @@ -201,21 +191,18 @@ async def delete_file(file_hash: str, user: str, session: ClientSession = None, Returns: aiohttp.ClientResponse: Server response. """ - session, close_session = _get_session(session) request_data = { "hash": file_hash, "user": user, } - response = await session.post( - url=BASE_URL + "delete-file", - data=request_data, - timeout=timeout - ) - - if close_session: - await session.close() + async with _get_session(session) as session: + response = await session.post( + url=BASE_URL + "delete-file", + data=request_data, + timeout=timeout + ) return response @@ -235,7 +222,6 @@ async def rename_file(file_hash: str, user: str, filename: str, session: ClientS Returns: aiohttp.ClientResponse: Server response. """ - session, close_session = _get_session(session) request_data = { "hash": file_hash, @@ -243,14 +229,12 @@ async def rename_file(file_hash: str, user: str, filename: str, session: ClientS "filename": filename, } - response = await session.post( - url=BASE_URL + "rename-file", - data=request_data, - timeout=timeout - ) - - if close_session: - await session.close() + async with _get_session(session) as session: + response = await session.post( + url=BASE_URL + "rename-file", + data=request_data, + timeout=timeout + ) return response @@ -267,20 +251,17 @@ async def check_file(file_hash: str, session: ClientSession = None, timeout: int Returns: aiohttp.ClientResponse: Server response. """ - session, close_session = _get_session(session) request_data = { "hash": file_hash, } - response = await session.post( - url=BASE_URL + "check-file", - data=request_data, - timeout=timeout - ) - - if close_session: - await session.close() + async with _get_session(session) as session: + response = await session.post( + url=BASE_URL + "check-file", + data=request_data, + timeout=timeout + ) return response @@ -303,7 +284,6 @@ async def list_files(user: str, page: int, path: str = "", session: ClientSessio Returns: aiohttp.ClientResponse: Server response. """ - session, close_session = _get_session(session) if page <= 0: raise ValueError("Page must be positive number") @@ -314,13 +294,11 @@ async def list_files(user: str, page: int, path: str = "", session: ClientSessio "path": path } - response = await session.post( - url=BASE_URL + "list-files", - data=request_data, - timeout=timeout - ) - - if close_session: - await session.close() + async with _get_session(session) as session: + response = await session.post( + url=BASE_URL + "list-files", + data=request_data, + timeout=timeout + ) return response