mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat: global client for sync + async calls (openai + Azure only)
This commit is contained in:
parent
5fd4376802
commit
51bf637656
4 changed files with 22 additions and 14 deletions
|
@ -105,8 +105,6 @@ class AzureChatCompletion(BaseLLM):
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
headers: Optional[dict]=None):
|
headers: Optional[dict]=None):
|
||||||
super().completion()
|
super().completion()
|
||||||
if self._client_session is None:
|
|
||||||
self._client_session = self.create_client_session()
|
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
@ -137,7 +135,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
|
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
|
||||||
else:
|
else:
|
||||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session)
|
||||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
|
@ -155,7 +153,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
azure_ad_token: Optional[str]=None, ):
|
azure_ad_token: Optional[str]=None, ):
|
||||||
try:
|
try:
|
||||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session)
|
||||||
response = await azure_client.chat.completions.create(**data)
|
response = await azure_client.chat.completions.create(**data)
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -175,7 +173,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
model: str,
|
model: str,
|
||||||
azure_ad_token: Optional[str]=None,
|
azure_ad_token: Optional[str]=None,
|
||||||
):
|
):
|
||||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session)
|
||||||
response = azure_client.chat.completions.create(**data)
|
response = azure_client.chat.completions.create(**data)
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||||
for transformed_chunk in streamwrapper:
|
for transformed_chunk in streamwrapper:
|
||||||
|
@ -189,7 +187,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
azure_ad_token: Optional[str]=None):
|
azure_ad_token: Optional[str]=None):
|
||||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session)
|
||||||
response = await azure_client.chat.completions.create(**data)
|
response = await azure_client.chat.completions.create(**data)
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||||
async for transformed_chunk in streamwrapper:
|
async for transformed_chunk in streamwrapper:
|
||||||
|
@ -210,7 +208,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if self._client_session is None:
|
if self._client_session is None:
|
||||||
self._client_session = self.create_client_session()
|
self._client_session = self.create_client_session()
|
||||||
try:
|
try:
|
||||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session)
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"input": input,
|
"input": input,
|
||||||
|
|
|
@ -203,7 +203,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
elif optional_params.get("stream", False):
|
elif optional_params.get("stream", False):
|
||||||
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key)
|
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key)
|
||||||
else:
|
else:
|
||||||
openai_client = OpenAI(api_key=api_key, base_url=api_base)
|
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session)
|
||||||
response = openai_client.chat.completions.create(**data) # type: ignore
|
response = openai_client.chat.completions.create(**data) # type: ignore
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -238,7 +238,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
api_base: Optional[str]=None):
|
api_base: Optional[str]=None):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session)
|
||||||
response = await openai_aclient.chat.completions.create(**data)
|
response = await openai_aclient.chat.completions.create(**data)
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -254,7 +254,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
api_key: Optional[str]=None,
|
api_key: Optional[str]=None,
|
||||||
api_base: Optional[str]=None
|
api_base: Optional[str]=None
|
||||||
):
|
):
|
||||||
openai_client = OpenAI(api_key=api_key, base_url=api_base)
|
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session)
|
||||||
response = openai_client.chat.completions.create(**data)
|
response = openai_client.chat.completions.create(**data)
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||||
for transformed_chunk in streamwrapper:
|
for transformed_chunk in streamwrapper:
|
||||||
|
@ -266,7 +266,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model: str,
|
model: str,
|
||||||
api_key: Optional[str]=None,
|
api_key: Optional[str]=None,
|
||||||
api_base: Optional[str]=None):
|
api_base: Optional[str]=None):
|
||||||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session)
|
||||||
response = await openai_aclient.chat.completions.create(**data)
|
response = await openai_aclient.chat.completions.create(**data)
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||||
async for transformed_chunk in streamwrapper:
|
async for transformed_chunk in streamwrapper:
|
||||||
|
@ -283,7 +283,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
super().embedding()
|
super().embedding()
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
try:
|
try:
|
||||||
openai_client = OpenAI(api_key=api_key, base_url=api_base)
|
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session)
|
||||||
model = model
|
model = model
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
|
@ -16,7 +16,7 @@ from litellm import completion, acompletion, acreate
|
||||||
litellm.num_retries = 3
|
litellm.num_retries = 3
|
||||||
|
|
||||||
def test_sync_response():
|
def test_sync_response():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = False
|
||||||
user_message = "Hello, how are you?"
|
user_message = "Hello, how are you?"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
try:
|
try:
|
||||||
|
@ -146,4 +146,4 @@ def test_get_response_non_openai_streaming():
|
||||||
return response
|
return response
|
||||||
asyncio.run(test_async_call())
|
asyncio.run(test_async_call())
|
||||||
|
|
||||||
test_get_response_non_openai_streaming()
|
# test_get_response_non_openai_streaming()
|
|
@ -1054,6 +1054,16 @@ def client(original_function):
|
||||||
try:
|
try:
|
||||||
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
||||||
function_id = kwargs["id"] if "id" in kwargs else None
|
function_id = kwargs["id"] if "id" in kwargs else None
|
||||||
|
if litellm.client_session is None:
|
||||||
|
litellm.client_session = httpx.Client(
|
||||||
|
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
|
||||||
|
timeout = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
if litellm.aclient_session is None:
|
||||||
|
litellm.aclient_session = httpx.AsyncClient(
|
||||||
|
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
|
||||||
|
timeout = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
if litellm.use_client or ("use_client" in kwargs and kwargs["use_client"] == True):
|
if litellm.use_client or ("use_client" in kwargs and kwargs["use_client"] == True):
|
||||||
print_verbose(f"litedebugger initialized")
|
print_verbose(f"litedebugger initialized")
|
||||||
if "lite_debugger" not in litellm.input_callback:
|
if "lite_debugger" not in litellm.input_callback:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue