From 8c9095a88feffdd62a337586d0a67da900b50dcd Mon Sep 17 00:00:00 2001 From: Alexander Tarasov Date: Thu, 14 Aug 2025 08:13:00 +0300 Subject: [PATCH] small change: add _get_session() function to api.py to reduce repeating code --- viking_file/api.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/viking_file/api.py b/viking_file/api.py index 69413dd..95fbc06 100644 --- a/viking_file/api.py +++ b/viking_file/api.py @@ -5,6 +5,12 @@ from aiohttp import ClientSession, ClientResponse, FormData, ClientTimeout BASE_URL = "https://vikingfile.com/api/" +def _get_session(session: ClientSession | None) -> tuple[ClientSession, bool]: + close_session = session is None + session = session or ClientSession() + return session, close_session + + async def get_upload_url(size: int, session: ClientSession = None, timeout: int = 10) -> ClientResponse: """ Get the URL of the upload server. @@ -17,8 +23,7 @@ async def get_upload_url(size: int, session: ClientSession = None, timeout: int Returns: aiohttp.ClientResponse: The response from the server containing the upload server URL. """ - close_session = session is None - session = session or ClientSession() + session, close_session = _get_session(session) request_data = { "size": size, @@ -54,8 +59,7 @@ async def complete_upload(key: str, upload_id: str, parts: list[dict], filename: Returns: aiohttp.ClientResponse: Server response. """ - close_session = session is None - session = session or ClientSession() + session, close_session = _get_session(session) request_data = { "key": key, @@ -93,8 +97,7 @@ async def get_upload_server(session: ClientSession = None, timeout: int = 10) -> aiohttp.ClientResponse: The response from the server containing the upload server URL. """ - close_session = session is None - session = session or ClientSession() + session, close_session = _get_session(session) response = await session.get( url=BASE_URL + "get-server", @@ -128,8 +131,7 @@ async def upload_file_legacy(upload_url: str, filepath: str, user: str = "", pat filepath = Path(filepath).resolve() assert filepath.exists(), f"File {filepath} doesn't exist!" - close_session = session is None - session = session or ClientSession() + session, close_session = _get_session(session) data = FormData() data.add_field("user", user) @@ -165,8 +167,7 @@ async def upload_remote_file(upload_server: str, link: str, user: str = "", file Returns: aiohttp.ClientResponse: Server response. """ - close_session = session is None - session = session or ClientSession() + session, close_session = _get_session(session) request_data = { "link": link, @@ -200,8 +201,7 @@ async def delete_file(file_hash: str, user: str, session: ClientSession = None, Returns: aiohttp.ClientResponse: Server response. """ - close_session = session is None - session = session or ClientSession() + session, close_session = _get_session(session) request_data = { "hash": file_hash, @@ -235,8 +235,7 @@ async def rename_file(file_hash: str, user: str, filename: str, session: ClientS Returns: aiohttp.ClientResponse: Server response. """ - close_session = session is None - session = session or ClientSession() + session, close_session = _get_session(session) request_data = { "hash": file_hash, @@ -268,8 +267,7 @@ async def check_file(file_hash: str, session: ClientSession = None, timeout: int Returns: aiohttp.ClientResponse: Server response. """ - close_session = session is None - session = session or ClientSession() + session, close_session = _get_session(session) request_data = { "hash": file_hash, @@ -305,8 +303,7 @@ async def list_files(user: str, page: int, path: str = "", session: ClientSessio Returns: aiohttp.ClientResponse: Server response. """ - close_session = session is None - session = session or ClientSession() + session, close_session = _get_session(session) if page <= 0: raise ValueError("Page must be positive number")