fix cache openai client for embeddings, text, speech

This commit is contained in:
Ishaan Jaff 2024-05-31 21:35:03 -07:00
parent cedeb10a08
commit 1c16904566

View file

@ -899,16 +899,14 @@ class OpenAIChatCompletion(BaseLLM):
): ):
response = None response = None
try: try:
if client is None: openai_aclient = self._get_openai_client(
openai_aclient = AsyncOpenAI( is_async=True,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.aclient_session, timeout=timeout,
timeout=timeout, max_retries=max_retries,
max_retries=max_retries, client=client,
) )
else:
openai_aclient = client
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump() stringified_response = response.model_dump()
## LOGGING ## LOGGING
@ -956,19 +954,18 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data, "api_base": api_base}, additional_args={"complete_input_dict": data, "api_base": api_base},
) )
if aembedding == True: if aembedding is True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response return response
if client is None:
openai_client = OpenAI( openai_client = self._get_openai_client(
api_key=api_key, is_async=False,
base_url=api_base, api_key=api_key,
http_client=litellm.client_session, api_base=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
) client=client,
else: )
openai_client = client
## COMPLETION CALL ## COMPLETION CALL
response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore
@ -1004,16 +1001,16 @@ class OpenAIChatCompletion(BaseLLM):
): ):
response = None response = None
try: try:
if client is None:
openai_aclient = AsyncOpenAI( openai_aclient = self._get_openai_client(
api_key=api_key, is_async=True,
base_url=api_base, api_key=api_key,
http_client=litellm.aclient_session, api_base=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
) client=client,
else: )
openai_aclient = client
response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump() stringified_response = response.model_dump()
## LOGGING ## LOGGING
@ -1058,16 +1055,14 @@ class OpenAIChatCompletion(BaseLLM):
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response return response
if client is None: openai_client = self._get_openai_client(
openai_client = OpenAI( is_async=False,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.client_session, timeout=timeout,
timeout=timeout, max_retries=max_retries,
max_retries=max_retries, client=client,
) )
else:
openai_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -1132,7 +1127,7 @@ class OpenAIChatCompletion(BaseLLM):
atranscription: bool = False, atranscription: bool = False,
): ):
data = {"model": model, "file": audio_file, **optional_params} data = {"model": model, "file": audio_file, **optional_params}
if atranscription == True: if atranscription is True:
return self.async_audio_transcriptions( return self.async_audio_transcriptions(
audio_file=audio_file, audio_file=audio_file,
data=data, data=data,
@ -1144,16 +1139,14 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=max_retries, max_retries=max_retries,
logging_obj=logging_obj, logging_obj=logging_obj,
) )
if client is None:
openai_client = OpenAI( openai_client = self._get_openai_client(
api_key=api_key, is_async=False,
base_url=api_base, api_key=api_key,
http_client=litellm.client_session, api_base=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
) )
else:
openai_client = client
response = openai_client.audio.transcriptions.create( response = openai_client.audio.transcriptions.create(
**data, timeout=timeout # type: ignore **data, timeout=timeout # type: ignore
) )
@ -1183,16 +1176,15 @@ class OpenAIChatCompletion(BaseLLM):
logging_obj=None, logging_obj=None,
): ):
try: try:
if client is None: openai_aclient = self._get_openai_client(
openai_aclient = AsyncOpenAI( is_async=True,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
http_client=litellm.aclient_session, timeout=timeout,
timeout=timeout, max_retries=max_retries,
max_retries=max_retries, client=client,
) )
else:
openai_aclient = client
response = await openai_aclient.audio.transcriptions.create( response = await openai_aclient.audio.transcriptions.create(
**data, timeout=timeout **data, timeout=timeout
) # type: ignore ) # type: ignore
@ -1231,7 +1223,7 @@ class OpenAIChatCompletion(BaseLLM):
client=None, client=None,
) -> HttpxBinaryResponseContent: ) -> HttpxBinaryResponseContent:
if aspeech is not None and aspeech == True: if aspeech is not None and aspeech is True:
return self.async_audio_speech( return self.async_audio_speech(
model=model, model=model,
input=input, input=input,
@ -1246,18 +1238,14 @@ class OpenAIChatCompletion(BaseLLM):
client=client, client=client,
) # type: ignore ) # type: ignore
if client is None: openai_client = self._get_openai_client(
openai_client = OpenAI( is_async=False,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
organization=organization, timeout=timeout,
project=project, max_retries=max_retries,
http_client=litellm.client_session, client=client,
timeout=timeout, )
max_retries=max_retries,
)
else:
openai_client = client
response = openai_client.audio.speech.create( response = openai_client.audio.speech.create(
model=model, model=model,
@ -1282,18 +1270,14 @@ class OpenAIChatCompletion(BaseLLM):
client=None, client=None,
) -> HttpxBinaryResponseContent: ) -> HttpxBinaryResponseContent:
if client is None: openai_client = self._get_openai_client(
openai_client = AsyncOpenAI( is_async=True,
api_key=api_key, api_key=api_key,
base_url=api_base, api_base=api_base,
organization=organization, timeout=timeout,
project=project, max_retries=max_retries,
http_client=litellm.aclient_session, client=client,
timeout=timeout, )
max_retries=max_retries,
)
else:
openai_client = client
response = await openai_client.audio.speech.create( response = await openai_client.audio.speech.create(
model=model, model=model,