diff --git a/viking_file/clients/client.py b/viking_file/clients/client.py index 034986f..b2607d3 100644 --- a/viking_file/clients/client.py +++ b/viking_file/clients/client.py @@ -14,8 +14,20 @@ class VikingClient(AsyncVikingClient): session = ClientSession(loop=self._loop) super().__init__(user_hash, api_timeout, session) + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def close(self): + if not self._session.closed: + self._loop.run_until_complete(self._session.close()) + self._loop.close() + def _cleanup(self): - self._loop.run_until_complete(self._session.close()) + if not self._session.closed: + self._loop.run_until_complete(self._session.close()) self._loop.close() def get_max_pages(self, path: str = "") -> int: diff --git a/viking_file/clients/client_async.py b/viking_file/clients/client_async.py index 4feddf3..0bd1b43 100644 --- a/viking_file/clients/client_async.py +++ b/viking_file/clients/client_async.py @@ -38,6 +38,24 @@ class AsyncVikingClient: atexit.register(self._cleanup) + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if self._close_session: + await self._session.close() + + async def close(self): + """ + Closes the client _session. + + This method is a coroutine and should be used with the `await` keyword. + + This method is idempotent and can be called multiple times without issue. + """ + if not self._session.closed: + await self._session.close() + def _cleanup(self): """ Cleans up the client _session when exiting. @@ -46,7 +64,7 @@ class AsyncVikingClient: when the program exits. It closes the client _session if it was created automatically by the client. """ - if self._close_session: + if self._close_session and not self._session.closed: loop = asyncio.new_event_loop() loop.run_until_complete(self._session.close())