Merge pull request #3956 from BerriAI/litellm_cache_openai_clients

[FEAT] Perf improvements - litellm.completion / litellm.acompletion - Cache OpenAI client
This commit is contained in:
Ishaan Jaff 2024-06-01 09:46:42 -07:00 committed by GitHub
commit d83b4a00d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,6 +6,7 @@ from typing import (
Literal, Literal,
Iterable, Iterable,
) )
import hashlib
from typing_extensions import override, overload from typing_extensions import override, overload
from pydantic import BaseModel from pydantic import BaseModel
import types, time, json, traceback import types, time, json, traceback
@ -504,6 +505,62 @@ class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def _get_openai_client(
self,
is_async: bool,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = None,
organization: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):
if client is None:
if not isinstance(max_retries, int):
raise OpenAIError(
status_code=422,
message="max retries must be an int. Passed in value: {}".format(
max_retries
),
)
# Creating a new OpenAI Client
# check in memory cache before creating a new one
# Convert the API key to bytes
hashed_api_key = None
if api_key is not None:
hash_object = hashlib.sha256(api_key.encode())
# Hexadecimal representation of the hash
hashed_api_key = hash_object.hexdigest()
_cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization}"
if _cache_key in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key]
if is_async:
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
_new_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
litellm.in_memory_llm_clients_cache[_cache_key] = _new_client
return _new_client
else:
return client
def completion( def completion(
self, self,
model_response: ModelResponse, model_response: ModelResponse,
@ -610,17 +667,16 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError( raise OpenAIError(
status_code=422, message="max retries must be an int" status_code=422, message="max retries must be an int"
) )
if client is None:
openai_client = OpenAI( openai_client = self._get_openai_client(
api_key=api_key, is_async=False,
base_url=api_base, api_key=api_key,
http_client=litellm.client_session, api_base=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
) client=client,
else: )
openai_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -700,17 +756,15 @@ class OpenAIChatCompletion(BaseLLM):
): ):
response = None response = None
try: try:
if client is None: openai_aclient = self._get_openai_client(
openai_aclient = AsyncOpenAI( is_async=True,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.aclient_session, timeout=timeout,
timeout=timeout, max_retries=max_retries,
max_retries=max_retries, organization=organization,
organization=organization, client=client,
) )
else:
openai_aclient = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -754,17 +808,15 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=None, max_retries=None,
headers=None, headers=None,
): ):
if client is None: openai_client = self._get_openai_client(
openai_client = OpenAI( is_async=False,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.client_session, timeout=timeout,
timeout=timeout, max_retries=max_retries,
max_retries=max_retries, organization=organization,
organization=organization, client=client,
) )
else:
openai_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data["messages"], input=data["messages"],
@ -801,17 +853,15 @@ class OpenAIChatCompletion(BaseLLM):
): ):
response = None response = None
try: try:
if client is None: openai_aclient = self._get_openai_client(
openai_aclient = AsyncOpenAI( is_async=True,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.aclient_session, timeout=timeout,
timeout=timeout, max_retries=max_retries,
max_retries=max_retries, organization=organization,
organization=organization, client=client,
) )
else:
openai_aclient = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data["messages"], input=data["messages"],
@ -865,16 +915,14 @@ class OpenAIChatCompletion(BaseLLM):
): ):
response = None response = None
try: try:
if client is None: openai_aclient = self._get_openai_client(
openai_aclient = AsyncOpenAI( is_async=True,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.aclient_session, timeout=timeout,
timeout=timeout, max_retries=max_retries,
max_retries=max_retries, client=client,
) )
else:
openai_aclient = client
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump() stringified_response = response.model_dump()
## LOGGING ## LOGGING
@ -922,19 +970,18 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data, "api_base": api_base}, additional_args={"complete_input_dict": data, "api_base": api_base},
) )
if aembedding == True: if aembedding is True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response return response
if client is None:
openai_client = OpenAI( openai_client = self._get_openai_client(
api_key=api_key, is_async=False,
base_url=api_base, api_key=api_key,
http_client=litellm.client_session, api_base=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
) client=client,
else: )
openai_client = client
## COMPLETION CALL ## COMPLETION CALL
response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore
@ -970,16 +1017,16 @@ class OpenAIChatCompletion(BaseLLM):
): ):
response = None response = None
try: try:
if client is None:
openai_aclient = AsyncOpenAI( openai_aclient = self._get_openai_client(
api_key=api_key, is_async=True,
base_url=api_base, api_key=api_key,
http_client=litellm.aclient_session, api_base=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
) client=client,
else: )
openai_aclient = client
response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump() stringified_response = response.model_dump()
## LOGGING ## LOGGING
@ -1024,16 +1071,14 @@ class OpenAIChatCompletion(BaseLLM):
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response return response
if client is None: openai_client = self._get_openai_client(
openai_client = OpenAI( is_async=False,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.client_session, timeout=timeout,
timeout=timeout, max_retries=max_retries,
max_retries=max_retries, client=client,
) )
else:
openai_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -1098,7 +1143,7 @@ class OpenAIChatCompletion(BaseLLM):
atranscription: bool = False, atranscription: bool = False,
): ):
data = {"model": model, "file": audio_file, **optional_params} data = {"model": model, "file": audio_file, **optional_params}
if atranscription == True: if atranscription is True:
return self.async_audio_transcriptions( return self.async_audio_transcriptions(
audio_file=audio_file, audio_file=audio_file,
data=data, data=data,
@ -1110,16 +1155,14 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=max_retries, max_retries=max_retries,
logging_obj=logging_obj, logging_obj=logging_obj,
) )
if client is None:
openai_client = OpenAI( openai_client = self._get_openai_client(
api_key=api_key, is_async=False,
base_url=api_base, api_key=api_key,
http_client=litellm.client_session, api_base=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
) )
else:
openai_client = client
response = openai_client.audio.transcriptions.create( response = openai_client.audio.transcriptions.create(
**data, timeout=timeout # type: ignore **data, timeout=timeout # type: ignore
) )
@ -1149,16 +1192,15 @@ class OpenAIChatCompletion(BaseLLM):
logging_obj=None, logging_obj=None,
): ):
try: try:
if client is None: openai_aclient = self._get_openai_client(
openai_aclient = AsyncOpenAI( is_async=True,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.aclient_session, timeout=timeout,
timeout=timeout, max_retries=max_retries,
max_retries=max_retries, client=client,
) )
else:
openai_aclient = client
response = await openai_aclient.audio.transcriptions.create( response = await openai_aclient.audio.transcriptions.create(
**data, timeout=timeout **data, timeout=timeout
) # type: ignore ) # type: ignore
@ -1197,7 +1239,7 @@ class OpenAIChatCompletion(BaseLLM):
client=None, client=None,
) -> HttpxBinaryResponseContent: ) -> HttpxBinaryResponseContent:
if aspeech is not None and aspeech == True: if aspeech is not None and aspeech is True:
return self.async_audio_speech( return self.async_audio_speech(
model=model, model=model,
input=input, input=input,
@ -1212,18 +1254,14 @@ class OpenAIChatCompletion(BaseLLM):
client=client, client=client,
) # type: ignore ) # type: ignore
if client is None: openai_client = self._get_openai_client(
openai_client = OpenAI( is_async=False,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
organization=organization, timeout=timeout,
project=project, max_retries=max_retries,
http_client=litellm.client_session, client=client,
timeout=timeout, )
max_retries=max_retries,
)
else:
openai_client = client
response = openai_client.audio.speech.create( response = openai_client.audio.speech.create(
model=model, model=model,
@ -1248,18 +1286,14 @@ class OpenAIChatCompletion(BaseLLM):
client=None, client=None,
) -> HttpxBinaryResponseContent: ) -> HttpxBinaryResponseContent:
if client is None: openai_client = self._get_openai_client(
openai_client = AsyncOpenAI( is_async=True,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
organization=organization, timeout=timeout,
project=project, max_retries=max_retries,
http_client=litellm.aclient_session, client=client,
timeout=timeout, )
max_retries=max_retries,
)
else:
openai_client = client
response = await openai_client.audio.speech.create( response = await openai_client.audio.speech.create(
model=model, model=model,