feat: global client for sync + async calls (openai + Azure only)

This commit is contained in:
Krrish Dholakia 2023-11-16 14:44:06 -08:00
parent 5fd4376802
commit 51bf637656
4 changed files with 22 additions and 14 deletions

View file

@ -203,7 +203,7 @@ class OpenAIChatCompletion(BaseLLM):
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key)
else:
openai_client = OpenAI(api_key=api_key, base_url=api_base)
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session)
response = openai_client.chat.completions.create(**data) # type: ignore
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e:
@ -238,7 +238,7 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str]=None):
response = None
try:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base)
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session)
response = await openai_aclient.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e:
@ -254,7 +254,7 @@ class OpenAIChatCompletion(BaseLLM):
api_key: Optional[str]=None,
api_base: Optional[str]=None
):
openai_client = OpenAI(api_key=api_key, base_url=api_base)
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session)
response = openai_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
@ -266,7 +266,7 @@ class OpenAIChatCompletion(BaseLLM):
model: str,
api_key: Optional[str]=None,
api_base: Optional[str]=None):
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base)
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session)
response = await openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
@ -283,7 +283,7 @@ class OpenAIChatCompletion(BaseLLM):
super().embedding()
exception_mapping_worked = False
try:
openai_client = OpenAI(api_key=api_key, base_url=api_base)
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session)
model = model
data = {
"model": model,