refactor(azure.py): working azure completion calls with openai v1 sdk

This commit is contained in:
Krrish Dholakia 2023-11-11 16:44:39 -08:00
parent d0bd932b3c
commit 39c2597c33
9 changed files with 70 additions and 58 deletions

View file

@ -2,6 +2,7 @@
import threading, requests import threading, requests
from typing import Callable, List, Optional, Dict, Union from typing import Callable, List, Optional, Dict, Union
from litellm.caching import Cache from litellm.caching import Cache
import httpx
input_callback: List[Union[str, Callable]] = [] input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = []
@ -44,7 +45,7 @@ max_budget: float = 0.0 # set the max budget across all providers
_current_cost = 0 # private variable, used if max budget is set _current_cost = 0 # private variable, used if max budget is set
error_logs: Dict = {} error_logs: Dict = {}
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
client_session: Optional[requests.Session] = None client_session: Optional[httpx.Client] = None
model_fallbacks: Optional[List] = None model_fallbacks: Optional[List] = None
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
num_retries: Optional[int] = None num_retries: Optional[int] = None

View file

@ -4,12 +4,14 @@ from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
from typing import Callable, Optional from typing import Callable, Optional
from litellm import OpenAIConfig from litellm import OpenAIConfig
import aiohttp import httpx
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message, request: httpx.Request, response: httpx.Response):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = request
self.response = response
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
@ -64,7 +66,7 @@ class AzureOpenAIConfig(OpenAIConfig):
top_p) top_p)
class AzureChatCompletion(BaseLLM): class AzureChatCompletion(BaseLLM):
_client_session: requests.Session _client_session: httpx.Client
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -126,17 +128,7 @@ class AzureChatCompletion(BaseLLM):
else: else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response) return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
elif "stream" in optional_params and optional_params["stream"] == True: elif "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post( return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
url=api_base,
json=data,
headers=headers,
stream=optional_params["stream"]
)
if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return response.iter_lines()
else: else:
response = self._client_session.post( response = self._client_session.post(
url=api_base, url=api_base,
@ -144,7 +136,7 @@ class AzureChatCompletion(BaseLLM):
headers=headers, headers=headers,
) )
if response.status_code != 200: if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text) raise AzureOpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT ## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response.json(), model_response_object=model_response) return convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
@ -152,39 +144,61 @@ class AzureChatCompletion(BaseLLM):
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
if exception_mapping_worked: raise e
raise e
else:
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse): async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
async with aiohttp.ClientSession() as session: async with httpx.AsyncClient(timeout=600) as client:
async with session.post(api_base, json=data, headers=headers, ssl=False) as response: response = await client.post(api_base, json=data, headers=headers)
response_json = await response.json() response_json = response.json()
if response.status != 200: if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status, message=response.text) raise AzureOpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT ## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
def streaming(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str
):
with self._client_session.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST"
) as response:
if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
completion_stream = response.iter_lines()
streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
yield transformed_chunk
async def async_streaming(self, async def async_streaming(self,
logging_obj, logging_obj,
api_base: str, api_base: str,
data: dict, headers: dict, data: dict,
headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str): model: str):
async with aiohttp.ClientSession() as session: client = httpx.AsyncClient()
async with session.post(api_base, json=data, headers=headers, ssl=False) as response: async with client.stream(
# Check if the request was successful (status code 200) url=f"{api_base}",
if response.status != 200: json=data,
raise AzureOpenAIError(status_code=response.status, message=await response.text()) headers=headers,
method="POST"
# Handle the streamed response ) as response:
stream_wrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) if response.status_code != 200:
async for transformed_chunk in stream_wrapper: raise AzureOpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
yield transformed_chunk
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="azure",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(self, def embedding(self,
model: str, model: str,

View file

@ -249,7 +249,7 @@ class OpenAIChatCompletion(BaseLLM):
response = await client.post(api_base, json=data, headers=headers) response = await client.post(api_base, json=data, headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text) raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT ## RESPONSE OBJECT

View file

@ -447,9 +447,7 @@ def completion(
logging_obj=logging, logging_obj=logging,
acompletion=acompletion acompletion=acompletion
) )
if optional_params.get("stream", False) and acompletion is False:
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
return response
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,

View file

@ -28,13 +28,13 @@ def test_async_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages) response = await acompletion(model="azure/chatgpt-v-2", messages=messages)
print(f"response: {response}") print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_get_response()) asyncio.run(test_get_response())
test_async_response() # test_async_response()
def test_get_response_streaming(): def test_get_response_streaming():
import asyncio import asyncio
@ -42,7 +42,7 @@ def test_get_response_streaming():
user_message = "write a short poem in one sentence" user_message = "write a short poem in one sentence"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True) response = await acompletion(model="azure/chatgpt-v-2", messages=messages, stream=True)
print(type(response)) print(type(response))
import inspect import inspect
@ -65,7 +65,7 @@ def test_get_response_streaming():
asyncio.run(test_async_call()) asyncio.run(test_async_call())
# test_get_response_streaming() test_get_response_streaming()
def test_get_response_non_openai_streaming(): def test_get_response_non_openai_streaming():
import asyncio import asyncio

View file

@ -421,7 +421,7 @@ def test_completion_openai():
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_openai() # test_completion_openai()
def test_completion_text_openai(): def test_completion_text_openai():
try: try:
@ -634,7 +634,7 @@ def test_completion_azure():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_azure() test_completion_azure()
def test_completion_azure2(): def test_completion_azure2():
# test if we can pass api_base, api_version and api_key in compleition() # test if we can pass api_base, api_version and api_key in compleition()
try: try:

View file

@ -59,7 +59,7 @@ def test_context_window(model):
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_window_with_fallbacks(model): def test_context_window_with_fallbacks(model):
ctx_window_fallback_dict = {"command-nightly": "claude-2", "gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k"} ctx_window_fallback_dict = {"command-nightly": "claude-2", "gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k", "azure/chatgpt-v-2": "gpt-3.5-turbo-16k"}
sample_text = "how does a court case get to the Supreme Court?" * 1000 sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
@ -67,8 +67,8 @@ def test_context_window_with_fallbacks(model):
# for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model) # test_context_window(model=model)
test_context_window(model="azure/chatgpt-v-2") # test_context_window(model="azure/chatgpt-v-2")
# test_context_window_with_fallbacks(model="gpt-3.5-turbo") # test_context_window_with_fallbacks(model="azure/chatgpt-v-2")
# Test 2: InvalidAuth Errors # Test 2: InvalidAuth Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def invalid_auth(model): # set the model key to an invalid key, depending on the model def invalid_auth(model): # set the model key to an invalid key, depending on the model
@ -163,7 +163,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
# for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
# invalid_auth(model=model) # invalid_auth(model=model)
# invalid_auth(model="gpt-3.5-turbo") # invalid_auth(model="azure/chatgpt-v-2")
# Test 3: Invalid Request Error # Test 3: Invalid Request Error
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@ -173,7 +173,7 @@ def test_invalid_request_error(model):
with pytest.raises(BadRequestError): with pytest.raises(BadRequestError):
completion(model=model, messages=messages, max_tokens="hello world") completion(model=model, messages=messages, max_tokens="hello world")
# test_invalid_request_error(model="gpt-3.5-turbo") # test_invalid_request_error(model="azure/chatgpt-v-2")
# Test 3: Rate Limit Errors # Test 3: Rate Limit Errors
# def test_model_call(model): # def test_model_call(model):
# try: # try:

View file

@ -372,7 +372,7 @@ def test_completion_azure_stream():
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_azure_stream() test_completion_azure_stream()
def test_completion_claude_stream(): def test_completion_claude_stream():
try: try:

View file

@ -3845,7 +3845,8 @@ def exception_type(
raise AuthenticationError( raise AuthenticationError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message}",
llm_provider="azure", llm_provider="azure",
model=model model=model,
response=original_exception.response
) )
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
exception_mapping_worked = True exception_mapping_worked = True
@ -4225,7 +4226,6 @@ class CustomStreamWrapper:
raise ValueError(f"Unable to parse response. Original response: {chunk}") raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_azure_chunk(self, chunk): def handle_azure_chunk(self, chunk):
chunk = chunk.decode("utf-8")
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
text = "" text = ""
@ -4299,7 +4299,6 @@ class CustomStreamWrapper:
def handle_openai_text_completion_chunk(self, chunk): def handle_openai_text_completion_chunk(self, chunk):
try: try:
# str_line = chunk.decode("utf-8") # Convert bytes to string
str_line = chunk str_line = chunk
text = "" text = ""
is_finished = False is_finished = False