fix(azure.py): use openai client sdk for handling sync+async calling

This commit is contained in:
Krrish Dholakia 2023-11-16 12:08:04 -08:00
parent 3285113d2d
commit bf0f8b824c
7 changed files with 136 additions and 167 deletions

View file

@ -4,8 +4,9 @@ from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
from typing import Callable, Optional from typing import Callable, Optional
from litellm import OpenAIConfig from litellm import OpenAIConfig
import litellm import litellm, json
import httpx import httpx
from openai import AzureOpenAI, AsyncAzureOpenAI
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None): def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
@ -73,8 +74,6 @@ class AzureOpenAIConfig(OpenAIConfig):
top_p) top_p)
class AzureChatCompletion(BaseLLM): class AzureChatCompletion(BaseLLM):
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -110,17 +109,12 @@ class AzureChatCompletion(BaseLLM):
self._client_session = self.create_client_session() self._client_session = self.create_client_session()
exception_mapping_worked = False exception_mapping_worked = False
try: try:
if headers is None:
headers = self.validate_environment(api_key=api_key, azure_ad_token=azure_ad_token)
if model is None or messages is None: if model is None or messages is None:
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages") raise AzureOpenAIError(status_code=422, message=f"Missing model or messages")
# Ensure api_base ends with a trailing slash
if not api_base.endswith('/'):
api_base += '/'
api_base = api_base + f"openai/deployments/{model}/chat/completions?api-version={api_version}"
data = { data = {
"model": model,
"messages": messages, "messages": messages,
**optional_params **optional_params
} }
@ -137,41 +131,34 @@ class AzureChatCompletion(BaseLLM):
) )
if acompletion is True: if acompletion is True:
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
else: else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response) return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token)
elif "stream" in optional_params and optional_params["stream"] == True: elif "stream" in optional_params and optional_params["stream"] == True:
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
else: else:
response = self._client_session.post( azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
url=api_base, response = azure_client.chat.completions.create(**data) # type: ignore
json=data, return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
headers=headers,
timeout=litellm.request_timeout
)
if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
raise e raise e
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse): async def acompletion(self,
if self._aclient_session is None: api_key: str,
self._aclient_session = self.create_aclient_session() api_version: str,
client = self._aclient_session model: str,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
azure_ad_token: Optional[str]=None, ):
try: try:
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout) azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
response_json = response.json() response = await azure_client.chat.completions.create(**data)
if response.status_code != 200: return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
raise AzureOpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
except Exception as e: except Exception as e:
if isinstance(e,httpx.TimeoutException): if isinstance(e,httpx.TimeoutException):
raise AzureOpenAIError(status_code=500, message="Request Timeout Error") raise AzureOpenAIError(status_code=500, message="Request Timeout Error")
@ -183,49 +170,33 @@ class AzureChatCompletion(BaseLLM):
def streaming(self, def streaming(self,
logging_obj, logging_obj,
api_base: str, api_base: str,
api_key: str,
api_version: str,
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str model: str,
azure_ad_token: Optional[str]=None,
): ):
if self._client_session is None: azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
self._client_session = self.create_client_session() response = azure_client.chat.completions.create(**data)
with self._client_session.stream( streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
url=f"{api_base}",
json=data,
headers=headers,
method="POST",
timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message="An error occurred while streaming")
completion_stream = response.iter_lines()
streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
for transformed_chunk in streamwrapper: for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
async def async_streaming(self, async def async_streaming(self,
logging_obj, logging_obj,
api_base: str, api_base: str,
api_key: str,
api_version: str,
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str): model: str,
if self._aclient_session is None: azure_ad_token: Optional[str]=None):
self._aclient_session = self.create_aclient_session() azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
client = self._aclient_session response = await azure_client.chat.completions.create(**data)
async with client.stream( streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
url=f"{api_base}",
json=data,
headers=headers,
method="POST",
timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="azure",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
@ -234,23 +205,17 @@ class AzureChatCompletion(BaseLLM):
input: list, input: list,
api_key: str, api_key: str,
api_base: str, api_base: str,
azure_ad_token: str,
api_version: str, api_version: str,
logging_obj=None, logging_obj=None,
model_response=None, model_response=None,
optional_params=None,): optional_params=None,
azure_ad_token: Optional[str]=None):
super().embedding() super().embedding()
exception_mapping_worked = False exception_mapping_worked = False
if self._client_session is None: if self._client_session is None:
self._client_session = self.create_client_session() self._client_session = self.create_client_session()
try: try:
headers = self.validate_environment(api_key, azure_ad_token=azure_ad_token) azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
# Ensure api_base ends with a trailing slash
if not api_base.endswith('/'):
api_base += '/'
api_base = api_base + f"openai/deployments/{model}/embeddings?api-version={api_version}"
model = model
data = { data = {
"model": model, "model": model,
"input": input, "input": input,
@ -264,9 +229,7 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
response = self._client_session.post( response = azure_client.embeddings.create(**data) # type: ignore
api_base, headers=headers, json=data, timeout=litellm.request_timeout
)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
@ -275,9 +238,7 @@ class AzureChatCompletion(BaseLLM):
original_response=response, original_response=response,
) )
if response.status_code!=200: embedding_response = json.loads(response.model_dump_json())
raise AzureOpenAIError(message=response.text, status_code=response.status_code)
embedding_response = response.json()
output_data = [] output_data = []
for idx, embedding in enumerate(embedding_response["data"]): for idx, embedding in enumerate(embedding_response["data"]):
output_data.append( output_data.append(

View file

@ -36,7 +36,7 @@ def test_sync_response_anyscale():
# test_sync_response_anyscale() # test_sync_response_anyscale()
def test_async_response(): def test_async_response_openai():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True
async def test_get_response(): async def test_get_response():
@ -44,13 +44,27 @@ def test_async_response():
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages) response = await acompletion(model="gpt-3.5-turbo", messages=messages)
# response = await response
print(f"response: {response}") print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_get_response()) asyncio.run(test_get_response())
test_async_response()
def test_async_response_azure():
import asyncio
litellm.set_verbose = True
async def test_get_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="azure/chatgpt-v-2", messages=messages)
print(f"response: {response}")
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_get_response())
def test_async_anyscale_response(): def test_async_anyscale_response():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True
@ -73,7 +87,7 @@ def test_get_response_streaming():
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
litellm.set_verbose = True litellm.set_verbose = True
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True) response = await acompletion(model="azure/chatgpt-v-2", messages=messages, stream=True)
print(type(response)) print(type(response))
import inspect import inspect

View file

@ -481,7 +481,7 @@ def test_completion_openai_litellm_key():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_openai_litellm_key() # test_completion_openai_litellm_key()
def test_completion_openrouter1(): def test_completion_openrouter1():
try: try:
@ -562,6 +562,8 @@ def test_completion_azure():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_azure()
def test_azure_openai_ad_token(): def test_azure_openai_ad_token():
# this tests if the azure ad token is set in the request header # this tests if the azure ad token is set in the request header
# the request can fail since azure ad tokens expire after 30 mins, but the header MUST have the azure ad token # the request can fail since azure ad tokens expire after 30 mins, but the header MUST have the azure ad token

View file

@ -20,7 +20,7 @@ def test_openai_embedding():
# print(f"response: {str(response)}") # print(f"response: {str(response)}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_openai_embedding() # test_openai_embedding()
def test_openai_azure_embedding_simple(): def test_openai_azure_embedding_simple():
try: try:

View file

@ -1,69 +1,69 @@
# import sys, os import sys, os
# import traceback import traceback
# from dotenv import load_dotenv from dotenv import load_dotenv
# import copy import copy
# load_dotenv() load_dotenv()
# sys.path.insert( sys.path.insert(
# 0, os.path.abspath("../..") 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
# import asyncio import asyncio
# from litellm import Router, Timeout from litellm import Router, Timeout
# async def call_acompletion(semaphore, router: Router, input_data): async def call_acompletion(semaphore, router: Router, input_data):
# async with semaphore: async with semaphore:
# try: try:
# # Use asyncio.wait_for to set a timeout for the task # Use asyncio.wait_for to set a timeout for the task
# response = await router.acompletion(**input_data) response = await router.acompletion(**input_data)
# # Handle the response as needed # Handle the response as needed
# return response return response
# except Timeout: except Timeout:
# print(f"Task timed out: {input_data}") print(f"Task timed out: {input_data}")
# return None # You may choose to return something else or raise an exception return None # You may choose to return something else or raise an exception
# async def main(): async def main():
# # Initialize the Router # Initialize the Router
# model_list= [{ model_list= [{
# "model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
# "litellm_params": { "litellm_params": {
# "model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
# "api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
# }, },
# }, { }, {
# "model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
# "litellm_params": { "litellm_params": {
# "model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
# "api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
# "api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
# "api_version": os.getenv("AZURE_API_VERSION") "api_version": os.getenv("AZURE_API_VERSION")
# }, },
# }, { }, {
# "model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
# "litellm_params": { "litellm_params": {
# "model": "azure/chatgpt-functioncalling", "model": "azure/chatgpt-functioncalling",
# "api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
# "api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
# "api_version": os.getenv("AZURE_API_VERSION") "api_version": os.getenv("AZURE_API_VERSION")
# }, },
# }] }]
# router = Router(model_list=model_list, num_retries=3, timeout=10) router = Router(model_list=model_list, num_retries=3, timeout=10)
# # Create a semaphore with a capacity of 100 # Create a semaphore with a capacity of 100
# semaphore = asyncio.Semaphore(100) semaphore = asyncio.Semaphore(100)
# # List to hold all task references # List to hold all task references
# tasks = [] tasks = []
# # Launch 1000 tasks # Launch 1000 tasks
# for _ in range(1000): for _ in range(1000):
# task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]})) task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
# tasks.append(task) tasks.append(task)
# # Wait for all tasks to complete # Wait for all tasks to complete
# responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
# # Process responses as needed # Process responses as needed
# print(f"NUMBER OF COMPLETED TASKS: {len(responses)}") print(f"NUMBER OF COMPLETED TASKS: {len(responses)}")
# # Run the main function # Run the main function
# asyncio.run(main()) asyncio.run(main())

View file

@ -374,7 +374,7 @@ def test_completion_azure_stream():
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_azure_stream() test_completion_azure_stream()
def test_completion_claude_stream(): def test_completion_claude_stream():
try: try:

View file

@ -4505,6 +4505,7 @@ class CustomStreamWrapper:
text = "" text = ""
is_finished = False is_finished = False
finish_reason = None finish_reason = None
if len(str_line.choices) > 0:
if str_line.choices[0].delta.content is not None: if str_line.choices[0].delta.content is not None:
text = str_line.choices[0].delta.content text = str_line.choices[0].delta.content
if str_line.choices[0].finish_reason: if str_line.choices[0].finish_reason:
@ -4642,15 +4643,6 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "azure":
response_obj = self.handle_azure_chunk(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"response_obj: {response_obj}")
print_verbose(f"completion obj content: {completion_obj['content']}")
print_verbose(f"len(completion_obj['content']: {len(completion_obj['content'])}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
print_verbose(f"model_response finish reason 2: {model_response.choices[0].finish_reason}")
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk": elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
response_obj = self.handle_maritalk_chunk(chunk) response_obj = self.handle_maritalk_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]