mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #3956 from BerriAI/litellm_cache_openai_clients
[FEAT] Perf improvements - litellm.completion / litellm.acompletion - Cache OpenAI client
This commit is contained in:
commit
d83b4a00d3
1 changed files with 165 additions and 131 deletions
|
@ -6,6 +6,7 @@ from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
Iterable,
|
Iterable,
|
||||||
)
|
)
|
||||||
|
import hashlib
|
||||||
from typing_extensions import override, overload
|
from typing_extensions import override, overload
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import types, time, json, traceback
|
import types, time, json, traceback
|
||||||
|
@ -504,6 +505,62 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def _get_openai_client(
|
||||||
|
self,
|
||||||
|
is_async: bool,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||||
|
max_retries: Optional[int] = None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
||||||
|
):
|
||||||
|
if client is None:
|
||||||
|
if not isinstance(max_retries, int):
|
||||||
|
raise OpenAIError(
|
||||||
|
status_code=422,
|
||||||
|
message="max retries must be an int. Passed in value: {}".format(
|
||||||
|
max_retries
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Creating a new OpenAI Client
|
||||||
|
# check in memory cache before creating a new one
|
||||||
|
# Convert the API key to bytes
|
||||||
|
hashed_api_key = None
|
||||||
|
if api_key is not None:
|
||||||
|
hash_object = hashlib.sha256(api_key.encode())
|
||||||
|
# Hexadecimal representation of the hash
|
||||||
|
hashed_api_key = hash_object.hexdigest()
|
||||||
|
|
||||||
|
_cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization}"
|
||||||
|
|
||||||
|
if _cache_key in litellm.in_memory_llm_clients_cache:
|
||||||
|
return litellm.in_memory_llm_clients_cache[_cache_key]
|
||||||
|
if is_async:
|
||||||
|
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
http_client=litellm.aclient_session,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_new_client = OpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
http_client=litellm.client_session,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
)
|
||||||
|
|
||||||
|
litellm.in_memory_llm_clients_cache[_cache_key] = _new_client
|
||||||
|
return _new_client
|
||||||
|
|
||||||
|
else:
|
||||||
|
return client
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
|
@ -610,17 +667,16 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
status_code=422, message="max retries must be an int"
|
status_code=422, message="max retries must be an int"
|
||||||
)
|
)
|
||||||
if client is None:
|
|
||||||
openai_client = OpenAI(
|
openai_client = self._get_openai_client(
|
||||||
api_key=api_key,
|
is_async=False,
|
||||||
base_url=api_base,
|
api_key=api_key,
|
||||||
http_client=litellm.client_session,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
)
|
client=client,
|
||||||
else:
|
)
|
||||||
openai_client = client
|
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -700,17 +756,15 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
if client is None:
|
openai_aclient = self._get_openai_client(
|
||||||
openai_aclient = AsyncOpenAI(
|
is_async=True,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
api_base=api_base,
|
||||||
http_client=litellm.aclient_session,
|
timeout=timeout,
|
||||||
timeout=timeout,
|
max_retries=max_retries,
|
||||||
max_retries=max_retries,
|
organization=organization,
|
||||||
organization=organization,
|
client=client,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
openai_aclient = client
|
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -754,17 +808,15 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
headers=None,
|
headers=None,
|
||||||
):
|
):
|
||||||
if client is None:
|
openai_client = self._get_openai_client(
|
||||||
openai_client = OpenAI(
|
is_async=False,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
api_base=api_base,
|
||||||
http_client=litellm.client_session,
|
timeout=timeout,
|
||||||
timeout=timeout,
|
max_retries=max_retries,
|
||||||
max_retries=max_retries,
|
organization=organization,
|
||||||
organization=organization,
|
client=client,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
openai_client = client
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["messages"],
|
input=data["messages"],
|
||||||
|
@ -801,17 +853,15 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
if client is None:
|
openai_aclient = self._get_openai_client(
|
||||||
openai_aclient = AsyncOpenAI(
|
is_async=True,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
api_base=api_base,
|
||||||
http_client=litellm.aclient_session,
|
timeout=timeout,
|
||||||
timeout=timeout,
|
max_retries=max_retries,
|
||||||
max_retries=max_retries,
|
organization=organization,
|
||||||
organization=organization,
|
client=client,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
openai_aclient = client
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["messages"],
|
input=data["messages"],
|
||||||
|
@ -865,16 +915,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
if client is None:
|
openai_aclient = self._get_openai_client(
|
||||||
openai_aclient = AsyncOpenAI(
|
is_async=True,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
api_base=api_base,
|
||||||
http_client=litellm.aclient_session,
|
timeout=timeout,
|
||||||
timeout=timeout,
|
max_retries=max_retries,
|
||||||
max_retries=max_retries,
|
client=client,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
openai_aclient = client
|
|
||||||
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
|
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -922,19 +970,18 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
additional_args={"complete_input_dict": data, "api_base": api_base},
|
additional_args={"complete_input_dict": data, "api_base": api_base},
|
||||||
)
|
)
|
||||||
|
|
||||||
if aembedding == True:
|
if aembedding is True:
|
||||||
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||||
return response
|
return response
|
||||||
if client is None:
|
|
||||||
openai_client = OpenAI(
|
openai_client = self._get_openai_client(
|
||||||
api_key=api_key,
|
is_async=False,
|
||||||
base_url=api_base,
|
api_key=api_key,
|
||||||
http_client=litellm.client_session,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
)
|
client=client,
|
||||||
else:
|
)
|
||||||
openai_client = client
|
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore
|
response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore
|
||||||
|
@ -970,16 +1017,16 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
if client is None:
|
|
||||||
openai_aclient = AsyncOpenAI(
|
openai_aclient = self._get_openai_client(
|
||||||
api_key=api_key,
|
is_async=True,
|
||||||
base_url=api_base,
|
api_key=api_key,
|
||||||
http_client=litellm.aclient_session,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
)
|
client=client,
|
||||||
else:
|
)
|
||||||
openai_aclient = client
|
|
||||||
response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore
|
response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -1024,16 +1071,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||||
return response
|
return response
|
||||||
|
|
||||||
if client is None:
|
openai_client = self._get_openai_client(
|
||||||
openai_client = OpenAI(
|
is_async=False,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
api_base=api_base,
|
||||||
http_client=litellm.client_session,
|
timeout=timeout,
|
||||||
timeout=timeout,
|
max_retries=max_retries,
|
||||||
max_retries=max_retries,
|
client=client,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
openai_client = client
|
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -1098,7 +1143,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
atranscription: bool = False,
|
atranscription: bool = False,
|
||||||
):
|
):
|
||||||
data = {"model": model, "file": audio_file, **optional_params}
|
data = {"model": model, "file": audio_file, **optional_params}
|
||||||
if atranscription == True:
|
if atranscription is True:
|
||||||
return self.async_audio_transcriptions(
|
return self.async_audio_transcriptions(
|
||||||
audio_file=audio_file,
|
audio_file=audio_file,
|
||||||
data=data,
|
data=data,
|
||||||
|
@ -1110,16 +1155,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
if client is None:
|
|
||||||
openai_client = OpenAI(
|
openai_client = self._get_openai_client(
|
||||||
api_key=api_key,
|
is_async=False,
|
||||||
base_url=api_base,
|
api_key=api_key,
|
||||||
http_client=litellm.client_session,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
openai_client = client
|
|
||||||
response = openai_client.audio.transcriptions.create(
|
response = openai_client.audio.transcriptions.create(
|
||||||
**data, timeout=timeout # type: ignore
|
**data, timeout=timeout # type: ignore
|
||||||
)
|
)
|
||||||
|
@ -1149,16 +1192,15 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if client is None:
|
openai_aclient = self._get_openai_client(
|
||||||
openai_aclient = AsyncOpenAI(
|
is_async=True,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
api_base=api_base,
|
||||||
http_client=litellm.aclient_session,
|
timeout=timeout,
|
||||||
timeout=timeout,
|
max_retries=max_retries,
|
||||||
max_retries=max_retries,
|
client=client,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
openai_aclient = client
|
|
||||||
response = await openai_aclient.audio.transcriptions.create(
|
response = await openai_aclient.audio.transcriptions.create(
|
||||||
**data, timeout=timeout
|
**data, timeout=timeout
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
@ -1197,7 +1239,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
) -> HttpxBinaryResponseContent:
|
) -> HttpxBinaryResponseContent:
|
||||||
|
|
||||||
if aspeech is not None and aspeech == True:
|
if aspeech is not None and aspeech is True:
|
||||||
return self.async_audio_speech(
|
return self.async_audio_speech(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -1212,18 +1254,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
client=client,
|
client=client,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
if client is None:
|
openai_client = self._get_openai_client(
|
||||||
openai_client = OpenAI(
|
is_async=False,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
api_base=api_base,
|
||||||
organization=organization,
|
timeout=timeout,
|
||||||
project=project,
|
max_retries=max_retries,
|
||||||
http_client=litellm.client_session,
|
client=client,
|
||||||
timeout=timeout,
|
)
|
||||||
max_retries=max_retries,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
openai_client = client
|
|
||||||
|
|
||||||
response = openai_client.audio.speech.create(
|
response = openai_client.audio.speech.create(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1248,18 +1286,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
) -> HttpxBinaryResponseContent:
|
) -> HttpxBinaryResponseContent:
|
||||||
|
|
||||||
if client is None:
|
openai_client = self._get_openai_client(
|
||||||
openai_client = AsyncOpenAI(
|
is_async=True,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
api_base=api_base,
|
||||||
organization=organization,
|
timeout=timeout,
|
||||||
project=project,
|
max_retries=max_retries,
|
||||||
http_client=litellm.aclient_session,
|
client=client,
|
||||||
timeout=timeout,
|
)
|
||||||
max_retries=max_retries,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
openai_client = client
|
|
||||||
|
|
||||||
response = await openai_client.audio.speech.create(
|
response = await openai_client.audio.speech.create(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue