diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 0dd23e13d..fbb3a6583 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -4,8 +4,9 @@ from .base import BaseLLM from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object from typing import Callable, Optional from litellm import OpenAIConfig -import litellm +import litellm, json import httpx +from openai import AzureOpenAI, AsyncAzureOpenAI class AzureOpenAIError(Exception): def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None): @@ -73,12 +74,10 @@ class AzureOpenAIConfig(OpenAIConfig): top_p) class AzureChatCompletion(BaseLLM): - _client_session: Optional[httpx.Client] = None - _aclient_session: Optional[httpx.AsyncClient] = None def __init__(self) -> None: super().__init__() - + def validate_environment(self, api_key, azure_ad_token): headers = { "content-type": "application/json", @@ -110,17 +109,12 @@ class AzureChatCompletion(BaseLLM): self._client_session = self.create_client_session() exception_mapping_worked = False try: - if headers is None: - headers = self.validate_environment(api_key=api_key, azure_ad_token=azure_ad_token) if model is None or messages is None: raise AzureOpenAIError(status_code=422, message=f"Missing model or messages") - # Ensure api_base ends with a trailing slash - if not api_base.endswith('/'): - api_base += '/' - - api_base = api_base + f"openai/deployments/{model}/chat/completions?api-version={api_version}" + data = { + "model": model, "messages": messages, **optional_params } @@ -137,41 +131,34 @@ class AzureChatCompletion(BaseLLM): ) if acompletion is True: if optional_params.get("stream", False): - return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) + return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token) else: - return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response) + return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token) elif "stream" in optional_params and optional_params["stream"] == True: - return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) + return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token) else: - response = self._client_session.post( - url=api_base, - json=data, - headers=headers, - timeout=litellm.request_timeout - ) - if response.status_code != 200: - raise AzureOpenAIError(status_code=response.status_code, message=response.text) - - ## RESPONSE OBJECT - return convert_to_model_response_object(response_object=response.json(), model_response_object=model_response) + azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token) + response = azure_client.chat.completions.create(**data) # type: ignore + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) except AzureOpenAIError as e: exception_mapping_worked = True raise e except Exception as e: raise e - async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse): - if self._aclient_session is None: - self._aclient_session = self.create_aclient_session() - client = self._aclient_session + async def acompletion(self, + api_key: str, + api_version: str, + model: str, + api_base: str, + data: dict, + headers: dict, + model_response: ModelResponse, + azure_ad_token: Optional[str]=None, ): try: - response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout) - response_json = response.json() - if response.status_code != 200: - raise AzureOpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response) - - ## RESPONSE OBJECT - return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) + azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token) + response = await azure_client.chat.completions.create(**data) + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) except Exception as e: if isinstance(e,httpx.TimeoutException): raise AzureOpenAIError(status_code=500, message="Request Timeout Error") @@ -183,74 +170,52 @@ class AzureChatCompletion(BaseLLM): def streaming(self, logging_obj, api_base: str, + api_key: str, + api_version: str, data: dict, headers: dict, model_response: ModelResponse, - model: str + model: str, + azure_ad_token: Optional[str]=None, ): - if self._client_session is None: - self._client_session = self.create_client_session() - with self._client_session.stream( - url=f"{api_base}", - json=data, - headers=headers, - method="POST", - timeout=litellm.request_timeout - ) as response: - if response.status_code != 200: - raise AzureOpenAIError(status_code=response.status_code, message="An error occurred while streaming") - - completion_stream = response.iter_lines() - streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="azure",logging_obj=logging_obj) - for transformed_chunk in streamwrapper: - yield transformed_chunk + azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token) + response = azure_client.chat.completions.create(**data) + streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) + for transformed_chunk in streamwrapper: + yield transformed_chunk async def async_streaming(self, logging_obj, api_base: str, + api_key: str, + api_version: str, data: dict, headers: dict, model_response: ModelResponse, - model: str): - if self._aclient_session is None: - self._aclient_session = self.create_aclient_session() - client = self._aclient_session - async with client.stream( - url=f"{api_base}", - json=data, - headers=headers, - method="POST", - timeout=litellm.request_timeout - ) as response: - if response.status_code != 200: - raise AzureOpenAIError(status_code=response.status_code, message=response.text) - - streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="azure",logging_obj=logging_obj) - async for transformed_chunk in streamwrapper: - yield transformed_chunk + model: str, + azure_ad_token: Optional[str]=None): + azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token) + response = await azure_client.chat.completions.create(**data) + streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) + async for transformed_chunk in streamwrapper: + yield transformed_chunk def embedding(self, model: str, input: list, api_key: str, api_base: str, - azure_ad_token: str, api_version: str, logging_obj=None, model_response=None, - optional_params=None,): + optional_params=None, + azure_ad_token: Optional[str]=None): super().embedding() exception_mapping_worked = False if self._client_session is None: self._client_session = self.create_client_session() try: - headers = self.validate_environment(api_key, azure_ad_token=azure_ad_token) - # Ensure api_base ends with a trailing slash - if not api_base.endswith('/'): - api_base += '/' - - api_base = api_base + f"openai/deployments/{model}/embeddings?api-version={api_version}" - model = model + azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token) data = { "model": model, "input": input, @@ -263,10 +228,8 @@ class AzureChatCompletion(BaseLLM): api_key=api_key, additional_args={"complete_input_dict": data}, ) - ## COMPLETION CALL - response = self._client_session.post( - api_base, headers=headers, json=data, timeout=litellm.request_timeout - ) + ## COMPLETION CALL + response = azure_client.embeddings.create(**data) # type: ignore ## LOGGING logging_obj.post_call( input=input, @@ -275,9 +238,7 @@ class AzureChatCompletion(BaseLLM): original_response=response, ) - if response.status_code!=200: - raise AzureOpenAIError(message=response.text, status_code=response.status_code) - embedding_response = response.json() + embedding_response = json.loads(response.model_dump_json()) output_data = [] for idx, embedding in enumerate(embedding_response["data"]): output_data.append( diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index 8bd0b2daa..09e7aba3c 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -36,7 +36,7 @@ def test_sync_response_anyscale(): # test_sync_response_anyscale() -def test_async_response(): +def test_async_response_openai(): import asyncio litellm.set_verbose = True async def test_get_response(): @@ -44,13 +44,27 @@ def test_async_response(): messages = [{"content": user_message, "role": "user"}] try: response = await acompletion(model="gpt-3.5-turbo", messages=messages) - # response = await response print(f"response: {response}") except Exception as e: pytest.fail(f"An exception occurred: {e}") asyncio.run(test_get_response()) -test_async_response() + +def test_async_response_azure(): + import asyncio + litellm.set_verbose = True + async def test_get_response(): + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + try: + response = await acompletion(model="azure/chatgpt-v-2", messages=messages) + print(f"response: {response}") + except Exception as e: + pytest.fail(f"An exception occurred: {e}") + + asyncio.run(test_get_response()) + + def test_async_anyscale_response(): import asyncio litellm.set_verbose = True @@ -73,7 +87,7 @@ def test_get_response_streaming(): messages = [{"content": user_message, "role": "user"}] try: litellm.set_verbose = True - response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True) + response = await acompletion(model="azure/chatgpt-v-2", messages=messages, stream=True) print(type(response)) import inspect diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 98d116d0d..eb7f05065 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -481,7 +481,7 @@ def test_completion_openai_litellm_key(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_openai_litellm_key() +# test_completion_openai_litellm_key() def test_completion_openrouter1(): try: @@ -562,6 +562,8 @@ def test_completion_azure(): except Exception as e: pytest.fail(f"Error occurred: {e}") +test_completion_azure() + def test_azure_openai_ad_token(): # this tests if the azure ad token is set in the request header # the request can fail since azure ad tokens expire after 30 mins, but the header MUST have the azure ad token diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index d9f1c9f67..d533b7c42 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -20,7 +20,7 @@ def test_openai_embedding(): # print(f"response: {str(response)}") except Exception as e: pytest.fail(f"Error occurred: {e}") -test_openai_embedding() +# test_openai_embedding() def test_openai_azure_embedding_simple(): try: diff --git a/litellm/tests/test_loadtest_router.py b/litellm/tests/test_loadtest_router.py index da031be69..325164515 100644 --- a/litellm/tests/test_loadtest_router.py +++ b/litellm/tests/test_loadtest_router.py @@ -1,69 +1,69 @@ -# import sys, os -# import traceback -# from dotenv import load_dotenv -# import copy +import sys, os +import traceback +from dotenv import load_dotenv +import copy -# load_dotenv() -# sys.path.insert( -# 0, os.path.abspath("../..") -# ) # Adds the parent directory to the system path -# import asyncio -# from litellm import Router, Timeout +load_dotenv() +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +from litellm import Router, Timeout -# async def call_acompletion(semaphore, router: Router, input_data): -# async with semaphore: -# try: -# # Use asyncio.wait_for to set a timeout for the task -# response = await router.acompletion(**input_data) -# # Handle the response as needed -# return response -# except Timeout: -# print(f"Task timed out: {input_data}") -# return None # You may choose to return something else or raise an exception +async def call_acompletion(semaphore, router: Router, input_data): + async with semaphore: + try: + # Use asyncio.wait_for to set a timeout for the task + response = await router.acompletion(**input_data) + # Handle the response as needed + return response + except Timeout: + print(f"Task timed out: {input_data}") + return None # You may choose to return something else or raise an exception -# async def main(): -# # Initialize the Router -# model_list= [{ -# "model_name": "gpt-3.5-turbo", -# "litellm_params": { -# "model": "gpt-3.5-turbo", -# "api_key": os.getenv("OPENAI_API_KEY"), -# }, -# }, { -# "model_name": "gpt-3.5-turbo", -# "litellm_params": { -# "model": "azure/chatgpt-v-2", -# "api_key": os.getenv("AZURE_API_KEY"), -# "api_base": os.getenv("AZURE_API_BASE"), -# "api_version": os.getenv("AZURE_API_VERSION") -# }, -# }, { -# "model_name": "gpt-3.5-turbo", -# "litellm_params": { -# "model": "azure/chatgpt-functioncalling", -# "api_key": os.getenv("AZURE_API_KEY"), -# "api_base": os.getenv("AZURE_API_BASE"), -# "api_version": os.getenv("AZURE_API_VERSION") -# }, -# }] -# router = Router(model_list=model_list, num_retries=3, timeout=10) +async def main(): + # Initialize the Router + model_list= [{ + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION") + }, + }, { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-functioncalling", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION") + }, + }] + router = Router(model_list=model_list, num_retries=3, timeout=10) -# # Create a semaphore with a capacity of 100 -# semaphore = asyncio.Semaphore(100) + # Create a semaphore with a capacity of 100 + semaphore = asyncio.Semaphore(100) -# # List to hold all task references -# tasks = [] + # List to hold all task references + tasks = [] -# # Launch 1000 tasks -# for _ in range(1000): -# task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]})) -# tasks.append(task) + # Launch 1000 tasks + for _ in range(1000): + task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]})) + tasks.append(task) -# # Wait for all tasks to complete -# responses = await asyncio.gather(*tasks) -# # Process responses as needed -# print(f"NUMBER OF COMPLETED TASKS: {len(responses)}") -# # Run the main function -# asyncio.run(main()) + # Wait for all tasks to complete + responses = await asyncio.gather(*tasks) + # Process responses as needed + print(f"NUMBER OF COMPLETED TASKS: {len(responses)}") +# Run the main function +asyncio.run(main()) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index a7f777268..d77c4e797 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -374,7 +374,7 @@ def test_completion_azure_stream(): print(f"completion_response: {complete_response}") except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_azure_stream() +test_completion_azure_stream() def test_completion_claude_stream(): try: diff --git a/litellm/utils.py b/litellm/utils.py index a9f23656e..2a01eadbf 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4505,11 +4505,12 @@ class CustomStreamWrapper: text = "" is_finished = False finish_reason = None - if str_line.choices[0].delta.content is not None: - text = str_line.choices[0].delta.content - if str_line.choices[0].finish_reason: - is_finished = True - finish_reason = str_line.choices[0].finish_reason + if len(str_line.choices) > 0: + if str_line.choices[0].delta.content is not None: + text = str_line.choices[0].delta.content + if str_line.choices[0].finish_reason: + is_finished = True + finish_reason = str_line.choices[0].finish_reason return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} except Exception as e: traceback.print_exc() @@ -4642,15 +4643,6 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj["finish_reason"] - elif self.custom_llm_provider and self.custom_llm_provider == "azure": - response_obj = self.handle_azure_chunk(chunk) - completion_obj["content"] = response_obj["text"] - print_verbose(f"response_obj: {response_obj}") - print_verbose(f"completion obj content: {completion_obj['content']}") - print_verbose(f"len(completion_obj['content']: {len(completion_obj['content'])}") - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] - print_verbose(f"model_response finish reason 2: {model_response.choices[0].finish_reason}") elif self.custom_llm_provider and self.custom_llm_provider == "maritalk": response_obj = self.handle_maritalk_chunk(chunk) completion_obj["content"] = response_obj["text"]