From 6feeff1f31d91b48c80a2666271dc49a2edafc64 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 31 May 2024 21:22:06 -0700 Subject: [PATCH 1/5] feat - cache openai clients --- litellm/llms/openai.py | 123 ++++++++++++++++++++++++++--------------- 1 file changed, 79 insertions(+), 44 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 1c65333591..761bbe7c8f 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -27,6 +27,7 @@ from .prompt_templates.factory import prompt_factory, custom_prompt from openai import OpenAI, AsyncOpenAI from ..types.llms.openai import * import openai +from functools import lru_cache class OpenAIError(Exception): @@ -504,6 +505,46 @@ class OpenAIChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() + @lru_cache(maxsize=10) + def _get_openai_client( + self, + is_async: bool, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), + max_retries: Optional[int] = None, + organization: Optional[str] = None, + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ): + if client is None: + if not isinstance(max_retries, int): + raise OpenAIError( + status_code=422, + message="max retries must be an int. Passed in value: {}".format( + max_retries + ), + ) + if is_async: + return AsyncOpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.aclient_session, + timeout=timeout, + max_retries=max_retries, + organization=organization, + ) + else: + return OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, + organization=organization, + ) + else: + return client + def completion( self, model_response: ModelResponse, @@ -610,17 +651,16 @@ class OpenAIChatCompletion(BaseLLM): raise OpenAIError( status_code=422, message="max retries must be an int" ) - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_client = client + + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) ## LOGGING logging_obj.pre_call( @@ -700,17 +740,15 @@ class OpenAIChatCompletion(BaseLLM): ): response = None try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_aclient = client + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) ## LOGGING logging_obj.pre_call( @@ -754,17 +792,15 @@ class OpenAIChatCompletion(BaseLLM): max_retries=None, headers=None, ): - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_client = client + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) ## LOGGING logging_obj.pre_call( input=data["messages"], @@ -801,17 +837,16 @@ class OpenAIChatCompletion(BaseLLM): ): response = None try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_aclient = client + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + is_async=True, + ) ## LOGGING logging_obj.pre_call( input=data["messages"], From cedeb10a089a5257e3c38f47cc84ab98ce3f870d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 31 May 2024 21:24:14 -0700 Subject: [PATCH 2/5] fix - linting error --- litellm/llms/openai.py | 1 - 1 file changed, 1 deletion(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 761bbe7c8f..76568455ed 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -845,7 +845,6 @@ class OpenAIChatCompletion(BaseLLM): max_retries=max_retries, organization=organization, client=client, - is_async=True, ) ## LOGGING logging_obj.pre_call( From 1c16904566d62c39dfda5d11904c827746b8c3d8 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 31 May 2024 21:35:03 -0700 Subject: [PATCH 3/5] fix cache openai client for embeddings, text, speech --- litellm/llms/openai.py | 158 ++++++++++++++++++----------------------- 1 file changed, 71 insertions(+), 87 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 76568455ed..c8850734fd 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -899,16 +899,14 @@ class OpenAIChatCompletion(BaseLLM): ): response = None try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_aclient = client + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore stringified_response = response.model_dump() ## LOGGING @@ -956,19 +954,18 @@ class OpenAIChatCompletion(BaseLLM): additional_args={"complete_input_dict": data, "api_base": api_base}, ) - if aembedding == True: + if aembedding is True: response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore return response - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_client = client + + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) ## COMPLETION CALL response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore @@ -1004,16 +1001,16 @@ class OpenAIChatCompletion(BaseLLM): ): response = None try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_aclient = client + + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore stringified_response = response.model_dump() ## LOGGING @@ -1058,16 +1055,14 @@ class OpenAIChatCompletion(BaseLLM): response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore return response - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_client = client + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) ## LOGGING logging_obj.pre_call( @@ -1132,7 +1127,7 @@ class OpenAIChatCompletion(BaseLLM): atranscription: bool = False, ): data = {"model": model, "file": audio_file, **optional_params} - if atranscription == True: + if atranscription is True: return self.async_audio_transcriptions( audio_file=audio_file, data=data, @@ -1144,16 +1139,14 @@ class OpenAIChatCompletion(BaseLLM): max_retries=max_retries, logging_obj=logging_obj, ) - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_client = client + + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + ) response = openai_client.audio.transcriptions.create( **data, timeout=timeout # type: ignore ) @@ -1183,16 +1176,15 @@ class OpenAIChatCompletion(BaseLLM): logging_obj=None, ): try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_aclient = client + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + response = await openai_aclient.audio.transcriptions.create( **data, timeout=timeout ) # type: ignore @@ -1231,7 +1223,7 @@ class OpenAIChatCompletion(BaseLLM): client=None, ) -> HttpxBinaryResponseContent: - if aspeech is not None and aspeech == True: + if aspeech is not None and aspeech is True: return self.async_audio_speech( model=model, input=input, @@ -1246,18 +1238,14 @@ class OpenAIChatCompletion(BaseLLM): client=client, ) # type: ignore - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - organization=organization, - project=project, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_client = client + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) response = openai_client.audio.speech.create( model=model, @@ -1282,18 +1270,14 @@ class OpenAIChatCompletion(BaseLLM): client=None, ) -> HttpxBinaryResponseContent: - if client is None: - openai_client = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - organization=organization, - project=project, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_client = client + openai_client = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) response = await openai_client.audio.speech.create( model=model, From 47337c172ebb6e48980f8c76c999812c37c75b2b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 1 Jun 2024 08:58:22 -0700 Subject: [PATCH 4/5] fix - in memory client cache --- litellm/llms/openai.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index c8850734fd..c560cd3b9d 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -505,7 +505,6 @@ class OpenAIChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() - @lru_cache(maxsize=10) def _get_openai_client( self, is_async: bool, @@ -524,8 +523,14 @@ class OpenAIChatCompletion(BaseLLM): max_retries ), ) + # Creating a new OpenAI Client + # check in memory cache before doing so + _cache_key = f"api_key={api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization}" + + if _cache_key in litellm.in_memory_llm_clients_cache: + return litellm.in_memory_llm_clients_cache[_cache_key] if is_async: - return AsyncOpenAI( + _new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI( api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, @@ -534,7 +539,7 @@ class OpenAIChatCompletion(BaseLLM): organization=organization, ) else: - return OpenAI( + _new_client = OpenAI( api_key=api_key, base_url=api_base, http_client=litellm.client_session, @@ -542,6 +547,10 @@ class OpenAIChatCompletion(BaseLLM): max_retries=max_retries, organization=organization, ) + + litellm.in_memory_llm_clients_cache[_cache_key] = _new_client + return _new_client + else: return client From 47dd52c5666dc6fcbbbd5bd47993507e24acc9f2 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 1 Jun 2024 09:24:16 -0700 Subject: [PATCH 5/5] fix used hashed api key --- litellm/llms/openai.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index c560cd3b9d..c72b404c37 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -6,6 +6,7 @@ from typing import ( Literal, Iterable, ) +import hashlib from typing_extensions import override, overload from pydantic import BaseModel import types, time, json, traceback @@ -27,7 +28,6 @@ from .prompt_templates.factory import prompt_factory, custom_prompt from openai import OpenAI, AsyncOpenAI from ..types.llms.openai import * import openai -from functools import lru_cache class OpenAIError(Exception): @@ -524,8 +524,15 @@ class OpenAIChatCompletion(BaseLLM): ), ) # Creating a new OpenAI Client - # check in memory cache before doing so - _cache_key = f"api_key={api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization}" + # check in memory cache before creating a new one + # Convert the API key to bytes + hashed_api_key = None + if api_key is not None: + hash_object = hashlib.sha256(api_key.encode()) + # Hexadecimal representation of the hash + hashed_api_key = hash_object.hexdigest() + + _cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization}" if _cache_key in litellm.in_memory_llm_clients_cache: return litellm.in_memory_llm_clients_cache[_cache_key]