feat - cache openai clients

This commit is contained in:
Ishaan Jaff 2024-05-31 21:22:06 -07:00
parent b8df5d1a01
commit 6feeff1f31

View file

@ -27,6 +27,7 @@ from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI
from ..types.llms.openai import * from ..types.llms.openai import *
import openai import openai
from functools import lru_cache
class OpenAIError(Exception): class OpenAIError(Exception):
@ -504,6 +505,46 @@ class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@lru_cache(maxsize=10)
def _get_openai_client(
self,
is_async: bool,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = None,
organization: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):
if client is None:
if not isinstance(max_retries, int):
raise OpenAIError(
status_code=422,
message="max retries must be an int. Passed in value: {}".format(
max_retries
),
)
if is_async:
return AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
return OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
return client
def completion( def completion(
self, self,
model_response: ModelResponse, model_response: ModelResponse,
@ -610,17 +651,16 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError( raise OpenAIError(
status_code=422, message="max retries must be an int" status_code=422, message="max retries must be an int"
) )
if client is None:
openai_client = OpenAI( openai_client = self._get_openai_client(
is_async=False,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.client_session,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
client=client,
) )
else:
openai_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -700,17 +740,15 @@ class OpenAIChatCompletion(BaseLLM):
): ):
response = None response = None
try: try:
if client is None: openai_aclient = self._get_openai_client(
openai_aclient = AsyncOpenAI( is_async=True,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.aclient_session,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
client=client,
) )
else:
openai_aclient = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -754,17 +792,15 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=None, max_retries=None,
headers=None, headers=None,
): ):
if client is None: openai_client = self._get_openai_client(
openai_client = OpenAI( is_async=False,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.client_session,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
client=client,
) )
else:
openai_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data["messages"], input=data["messages"],
@ -801,17 +837,16 @@ class OpenAIChatCompletion(BaseLLM):
): ):
response = None response = None
try: try:
if client is None: openai_aclient = self._get_openai_client(
openai_aclient = AsyncOpenAI( is_async=True,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.aclient_session,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
client=client,
is_async=True,
) )
else:
openai_aclient = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data["messages"], input=data["messages"],