diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 1c65333591..c72b404c37 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -6,6 +6,7 @@ from typing import ( Literal, Iterable, ) +import hashlib from typing_extensions import override, overload from pydantic import BaseModel import types, time, json, traceback @@ -504,6 +505,62 @@ class OpenAIChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() + def _get_openai_client( + self, + is_async: bool, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), + max_retries: Optional[int] = None, + organization: Optional[str] = None, + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ): + if client is None: + if not isinstance(max_retries, int): + raise OpenAIError( + status_code=422, + message="max retries must be an int. Passed in value: {}".format( + max_retries + ), + ) + # Creating a new OpenAI Client + # check in memory cache before creating a new one + # Convert the API key to bytes + hashed_api_key = None + if api_key is not None: + hash_object = hashlib.sha256(api_key.encode()) + # Hexadecimal representation of the hash + hashed_api_key = hash_object.hexdigest() + + _cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization}" + + if _cache_key in litellm.in_memory_llm_clients_cache: + return litellm.in_memory_llm_clients_cache[_cache_key] + if is_async: + _new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.aclient_session, + timeout=timeout, + max_retries=max_retries, + organization=organization, + ) + else: + _new_client = OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, + organization=organization, + ) + + litellm.in_memory_llm_clients_cache[_cache_key] = _new_client + return _new_client + + else: + return client + def completion( self, model_response: ModelResponse, @@ -610,17 +667,16 @@ class OpenAIChatCompletion(BaseLLM): raise OpenAIError( status_code=422, message="max retries must be an int" ) - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_client = client + + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) ## LOGGING logging_obj.pre_call( @@ -700,17 +756,15 @@ class OpenAIChatCompletion(BaseLLM): ): response = None try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_aclient = client + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) ## LOGGING logging_obj.pre_call( @@ -754,17 +808,15 @@ class OpenAIChatCompletion(BaseLLM): max_retries=None, headers=None, ): - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_client = client + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) ## LOGGING logging_obj.pre_call( input=data["messages"], @@ -801,17 +853,15 @@ class OpenAIChatCompletion(BaseLLM): ): response = None try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_aclient = client + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) ## LOGGING logging_obj.pre_call( input=data["messages"], @@ -865,16 +915,14 @@ class OpenAIChatCompletion(BaseLLM): ): response = None try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_aclient = client + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore stringified_response = response.model_dump() ## LOGGING @@ -922,19 +970,18 @@ class OpenAIChatCompletion(BaseLLM): additional_args={"complete_input_dict": data, "api_base": api_base}, ) - if aembedding == True: + if aembedding is True: response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore return response - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_client = client + + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) ## COMPLETION CALL response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore @@ -970,16 +1017,16 @@ class OpenAIChatCompletion(BaseLLM): ): response = None try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_aclient = client + + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore stringified_response = response.model_dump() ## LOGGING @@ -1024,16 +1071,14 @@ class OpenAIChatCompletion(BaseLLM): response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore return response - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_client = client + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) ## LOGGING logging_obj.pre_call( @@ -1098,7 +1143,7 @@ class OpenAIChatCompletion(BaseLLM): atranscription: bool = False, ): data = {"model": model, "file": audio_file, **optional_params} - if atranscription == True: + if atranscription is True: return self.async_audio_transcriptions( audio_file=audio_file, data=data, @@ -1110,16 +1155,14 @@ class OpenAIChatCompletion(BaseLLM): max_retries=max_retries, logging_obj=logging_obj, ) - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_client = client + + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + ) response = openai_client.audio.transcriptions.create( **data, timeout=timeout # type: ignore ) @@ -1149,16 +1192,15 @@ class OpenAIChatCompletion(BaseLLM): logging_obj=None, ): try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_aclient = client + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + response = await openai_aclient.audio.transcriptions.create( **data, timeout=timeout ) # type: ignore @@ -1197,7 +1239,7 @@ class OpenAIChatCompletion(BaseLLM): client=None, ) -> HttpxBinaryResponseContent: - if aspeech is not None and aspeech == True: + if aspeech is not None and aspeech is True: return self.async_audio_speech( model=model, input=input, @@ -1212,18 +1254,14 @@ class OpenAIChatCompletion(BaseLLM): client=client, ) # type: ignore - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - organization=organization, - project=project, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_client = client + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) response = openai_client.audio.speech.create( model=model, @@ -1248,18 +1286,14 @@ class OpenAIChatCompletion(BaseLLM): client=None, ) -> HttpxBinaryResponseContent: - if client is None: - openai_client = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - organization=organization, - project=project, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - ) - else: - openai_client = client + openai_client = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) response = await openai_client.audio.speech.create( model=model,