mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(azure.py): use openai client sdk for handling sync+async calling
This commit is contained in:
parent
3285113d2d
commit
bf0f8b824c
7 changed files with 136 additions and 167 deletions
|
@ -4,8 +4,9 @@ from .base import BaseLLM
|
|||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
|
||||
from typing import Callable, Optional
|
||||
from litellm import OpenAIConfig
|
||||
import litellm
|
||||
import litellm, json
|
||||
import httpx
|
||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||
|
||||
class AzureOpenAIError(Exception):
|
||||
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
|
||||
|
@ -73,12 +74,10 @@ class AzureOpenAIConfig(OpenAIConfig):
|
|||
top_p)
|
||||
|
||||
class AzureChatCompletion(BaseLLM):
|
||||
_client_session: Optional[httpx.Client] = None
|
||||
_aclient_session: Optional[httpx.AsyncClient] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
def validate_environment(self, api_key, azure_ad_token):
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
|
@ -110,17 +109,12 @@ class AzureChatCompletion(BaseLLM):
|
|||
self._client_session = self.create_client_session()
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
if headers is None:
|
||||
headers = self.validate_environment(api_key=api_key, azure_ad_token=azure_ad_token)
|
||||
|
||||
if model is None or messages is None:
|
||||
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages")
|
||||
# Ensure api_base ends with a trailing slash
|
||||
if not api_base.endswith('/'):
|
||||
api_base += '/'
|
||||
|
||||
api_base = api_base + f"openai/deployments/{model}/chat/completions?api-version={api_version}"
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params
|
||||
}
|
||||
|
@ -137,41 +131,34 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
|
||||
else:
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token)
|
||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
|
||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
|
||||
else:
|
||||
response = self._client_session.post(
|
||||
url=api_base,
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=litellm.request_timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
|
||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
|
||||
if self._aclient_session is None:
|
||||
self._aclient_session = self.create_aclient_session()
|
||||
client = self._aclient_session
|
||||
async def acompletion(self,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
model: str,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
azure_ad_token: Optional[str]=None, ):
|
||||
try:
|
||||
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise AzureOpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
||||
response = await azure_client.chat.completions.create(**data)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
if isinstance(e,httpx.TimeoutException):
|
||||
raise AzureOpenAIError(status_code=500, message="Request Timeout Error")
|
||||
|
@ -183,74 +170,52 @@ class AzureChatCompletion(BaseLLM):
|
|||
def streaming(self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str
|
||||
model: str,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
):
|
||||
if self._client_session is None:
|
||||
self._client_session = self.create_client_session()
|
||||
with self._client_session.stream(
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
method="POST",
|
||||
timeout=litellm.request_timeout
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise AzureOpenAIError(status_code=response.status_code, message="An error occurred while streaming")
|
||||
|
||||
completion_stream = response.iter_lines()
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
||||
response = azure_client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
async def async_streaming(self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str):
|
||||
if self._aclient_session is None:
|
||||
self._aclient_session = self.create_aclient_session()
|
||||
client = self._aclient_session
|
||||
async with client.stream(
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
method="POST",
|
||||
timeout=litellm.request_timeout
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
model: str,
|
||||
azure_ad_token: Optional[str]=None):
|
||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
||||
response = await azure_client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
input: list,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
azure_ad_token: str,
|
||||
api_version: str,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
optional_params=None,):
|
||||
optional_params=None,
|
||||
azure_ad_token: Optional[str]=None):
|
||||
super().embedding()
|
||||
exception_mapping_worked = False
|
||||
if self._client_session is None:
|
||||
self._client_session = self.create_client_session()
|
||||
try:
|
||||
headers = self.validate_environment(api_key, azure_ad_token=azure_ad_token)
|
||||
# Ensure api_base ends with a trailing slash
|
||||
if not api_base.endswith('/'):
|
||||
api_base += '/'
|
||||
|
||||
api_base = api_base + f"openai/deployments/{model}/embeddings?api-version={api_version}"
|
||||
model = model
|
||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
||||
data = {
|
||||
"model": model,
|
||||
"input": input,
|
||||
|
@ -263,10 +228,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = self._client_session.post(
|
||||
api_base, headers=headers, json=data, timeout=litellm.request_timeout
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = azure_client.embeddings.create(**data) # type: ignore
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
|
@ -275,9 +238,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
original_response=response,
|
||||
)
|
||||
|
||||
if response.status_code!=200:
|
||||
raise AzureOpenAIError(message=response.text, status_code=response.status_code)
|
||||
embedding_response = response.json()
|
||||
embedding_response = json.loads(response.model_dump_json())
|
||||
output_data = []
|
||||
for idx, embedding in enumerate(embedding_response["data"]):
|
||||
output_data.append(
|
||||
|
|
|
@ -36,7 +36,7 @@ def test_sync_response_anyscale():
|
|||
|
||||
# test_sync_response_anyscale()
|
||||
|
||||
def test_async_response():
|
||||
def test_async_response_openai():
|
||||
import asyncio
|
||||
litellm.set_verbose = True
|
||||
async def test_get_response():
|
||||
|
@ -44,13 +44,27 @@ def test_async_response():
|
|||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
# response = await response
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
asyncio.run(test_get_response())
|
||||
test_async_response()
|
||||
|
||||
def test_async_response_azure():
|
||||
import asyncio
|
||||
litellm.set_verbose = True
|
||||
async def test_get_response():
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="azure/chatgpt-v-2", messages=messages)
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
asyncio.run(test_get_response())
|
||||
|
||||
|
||||
def test_async_anyscale_response():
|
||||
import asyncio
|
||||
litellm.set_verbose = True
|
||||
|
@ -73,7 +87,7 @@ def test_get_response_streaming():
|
|||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
||||
response = await acompletion(model="azure/chatgpt-v-2", messages=messages, stream=True)
|
||||
print(type(response))
|
||||
|
||||
import inspect
|
||||
|
|
|
@ -481,7 +481,7 @@ def test_completion_openai_litellm_key():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_completion_openai_litellm_key()
|
||||
# test_completion_openai_litellm_key()
|
||||
|
||||
def test_completion_openrouter1():
|
||||
try:
|
||||
|
@ -562,6 +562,8 @@ def test_completion_azure():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_completion_azure()
|
||||
|
||||
def test_azure_openai_ad_token():
|
||||
# this tests if the azure ad token is set in the request header
|
||||
# the request can fail since azure ad tokens expire after 30 mins, but the header MUST have the azure ad token
|
||||
|
|
|
@ -20,7 +20,7 @@ def test_openai_embedding():
|
|||
# print(f"response: {str(response)}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
test_openai_embedding()
|
||||
# test_openai_embedding()
|
||||
|
||||
def test_openai_azure_embedding_simple():
|
||||
try:
|
||||
|
|
|
@ -1,69 +1,69 @@
|
|||
# import sys, os
|
||||
# import traceback
|
||||
# from dotenv import load_dotenv
|
||||
# import copy
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
import copy
|
||||
|
||||
# load_dotenv()
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import asyncio
|
||||
# from litellm import Router, Timeout
|
||||
load_dotenv()
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import asyncio
|
||||
from litellm import Router, Timeout
|
||||
|
||||
|
||||
# async def call_acompletion(semaphore, router: Router, input_data):
|
||||
# async with semaphore:
|
||||
# try:
|
||||
# # Use asyncio.wait_for to set a timeout for the task
|
||||
# response = await router.acompletion(**input_data)
|
||||
# # Handle the response as needed
|
||||
# return response
|
||||
# except Timeout:
|
||||
# print(f"Task timed out: {input_data}")
|
||||
# return None # You may choose to return something else or raise an exception
|
||||
async def call_acompletion(semaphore, router: Router, input_data):
|
||||
async with semaphore:
|
||||
try:
|
||||
# Use asyncio.wait_for to set a timeout for the task
|
||||
response = await router.acompletion(**input_data)
|
||||
# Handle the response as needed
|
||||
return response
|
||||
except Timeout:
|
||||
print(f"Task timed out: {input_data}")
|
||||
return None # You may choose to return something else or raise an exception
|
||||
|
||||
|
||||
# async def main():
|
||||
# # Initialize the Router
|
||||
# model_list= [{
|
||||
# "model_name": "gpt-3.5-turbo",
|
||||
# "litellm_params": {
|
||||
# "model": "gpt-3.5-turbo",
|
||||
# "api_key": os.getenv("OPENAI_API_KEY"),
|
||||
# },
|
||||
# }, {
|
||||
# "model_name": "gpt-3.5-turbo",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/chatgpt-v-2",
|
||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||
# "api_version": os.getenv("AZURE_API_VERSION")
|
||||
# },
|
||||
# }, {
|
||||
# "model_name": "gpt-3.5-turbo",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/chatgpt-functioncalling",
|
||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||
# "api_version": os.getenv("AZURE_API_VERSION")
|
||||
# },
|
||||
# }]
|
||||
# router = Router(model_list=model_list, num_retries=3, timeout=10)
|
||||
async def main():
|
||||
# Initialize the Router
|
||||
model_list= [{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION")
|
||||
},
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION")
|
||||
},
|
||||
}]
|
||||
router = Router(model_list=model_list, num_retries=3, timeout=10)
|
||||
|
||||
# # Create a semaphore with a capacity of 100
|
||||
# semaphore = asyncio.Semaphore(100)
|
||||
# Create a semaphore with a capacity of 100
|
||||
semaphore = asyncio.Semaphore(100)
|
||||
|
||||
# # List to hold all task references
|
||||
# tasks = []
|
||||
# List to hold all task references
|
||||
tasks = []
|
||||
|
||||
# # Launch 1000 tasks
|
||||
# for _ in range(1000):
|
||||
# task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
|
||||
# tasks.append(task)
|
||||
# Launch 1000 tasks
|
||||
for _ in range(1000):
|
||||
task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
|
||||
tasks.append(task)
|
||||
|
||||
# # Wait for all tasks to complete
|
||||
# responses = await asyncio.gather(*tasks)
|
||||
# # Process responses as needed
|
||||
# print(f"NUMBER OF COMPLETED TASKS: {len(responses)}")
|
||||
# # Run the main function
|
||||
# asyncio.run(main())
|
||||
# Wait for all tasks to complete
|
||||
responses = await asyncio.gather(*tasks)
|
||||
# Process responses as needed
|
||||
print(f"NUMBER OF COMPLETED TASKS: {len(responses)}")
|
||||
# Run the main function
|
||||
asyncio.run(main())
|
||||
|
|
|
@ -374,7 +374,7 @@ def test_completion_azure_stream():
|
|||
print(f"completion_response: {complete_response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_azure_stream()
|
||||
test_completion_azure_stream()
|
||||
|
||||
def test_completion_claude_stream():
|
||||
try:
|
||||
|
|
|
@ -4505,11 +4505,12 @@ class CustomStreamWrapper:
|
|||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
if str_line.choices[0].delta.content is not None:
|
||||
text = str_line.choices[0].delta.content
|
||||
if str_line.choices[0].finish_reason:
|
||||
is_finished = True
|
||||
finish_reason = str_line.choices[0].finish_reason
|
||||
if len(str_line.choices) > 0:
|
||||
if str_line.choices[0].delta.content is not None:
|
||||
text = str_line.choices[0].delta.content
|
||||
if str_line.choices[0].finish_reason:
|
||||
is_finished = True
|
||||
finish_reason = str_line.choices[0].finish_reason
|
||||
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
@ -4642,15 +4643,6 @@ class CustomStreamWrapper:
|
|||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "azure":
|
||||
response_obj = self.handle_azure_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
print_verbose(f"response_obj: {response_obj}")
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
print_verbose(f"len(completion_obj['content']: {len(completion_obj['content'])}")
|
||||
if response_obj["is_finished"]:
|
||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||
print_verbose(f"model_response finish reason 2: {model_response.choices[0].finish_reason}")
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
|
||||
response_obj = self.handle_maritalk_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue