From 6feeff1f31d91b48c80a2666271dc49a2edafc64 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 31 May 2024 21:22:06 -0700 Subject: [PATCH] feat - cache openai clients --- litellm/llms/openai.py | 123 ++++++++++++++++++++++++++--------------- 1 file changed, 79 insertions(+), 44 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 1c65333591..761bbe7c8f 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -27,6 +27,7 @@ from .prompt_templates.factory import prompt_factory, custom_prompt from openai import OpenAI, AsyncOpenAI from ..types.llms.openai import * import openai +from functools import lru_cache class OpenAIError(Exception): @@ -504,6 +505,46 @@ class OpenAIChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() + @lru_cache(maxsize=10) + def _get_openai_client( + self, + is_async: bool, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), + max_retries: Optional[int] = None, + organization: Optional[str] = None, + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ): + if client is None: + if not isinstance(max_retries, int): + raise OpenAIError( + status_code=422, + message="max retries must be an int. Passed in value: {}".format( + max_retries + ), + ) + if is_async: + return AsyncOpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.aclient_session, + timeout=timeout, + max_retries=max_retries, + organization=organization, + ) + else: + return OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, + organization=organization, + ) + else: + return client + def completion( self, model_response: ModelResponse, @@ -610,17 +651,16 @@ class OpenAIChatCompletion(BaseLLM): raise OpenAIError( status_code=422, message="max retries must be an int" ) - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_client = client + + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) ## LOGGING logging_obj.pre_call( @@ -700,17 +740,15 @@ class OpenAIChatCompletion(BaseLLM): ): response = None try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_aclient = client + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) ## LOGGING logging_obj.pre_call( @@ -754,17 +792,15 @@ class OpenAIChatCompletion(BaseLLM): max_retries=None, headers=None, ): - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_client = client + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) ## LOGGING logging_obj.pre_call( input=data["messages"], @@ -801,17 +837,16 @@ class OpenAIChatCompletion(BaseLLM): ): response = None try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_aclient = client + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + is_async=True, + ) ## LOGGING logging_obj.pre_call( input=data["messages"],