fix(azure.py-+-proxy_server.py): fix function calling response object + support router on proxy

This commit is contained in:
Krrish Dholakia 2023-11-15 13:15:09 -08:00
parent 24cc1f620d
commit e5929f2f7e
8 changed files with 54 additions and 59 deletions

1
.gitignore vendored
View file

@ -14,3 +14,4 @@ litellm/proxy/api_log.json
.idea/ .idea/
router_config.yaml router_config.yaml
litellm_server/config.yaml litellm_server/config.yaml
litellm/proxy/_secret_config.yaml

View file

@ -1,8 +1,6 @@
import Image from '@theme/IdealImage'; import Image from '@theme/IdealImage';
# Reliability - Fallbacks, Azure Deployments, etc. # Manage Multiple Deployments
## Manage Multiple Deployments
Use this if you're trying to load-balance across multiple deployments (e.g. Azure/OpenAI). Use this if you're trying to load-balance across multiple deployments (e.g. Azure/OpenAI).

View file

@ -110,12 +110,13 @@ class APIError(APIError): # type: ignore
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made
class APIConnectionError(APIConnectionError): # type: ignore class APIConnectionError(APIConnectionError): # type: ignore
def __init__(self, message, llm_provider, model): def __init__(self, message, llm_provider, model, request: httpx.Request):
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
super().__init__( super().__init__(
self.message message=self.message,
request=request
) )
class OpenAIError(OpenAIError): # type: ignore class OpenAIError(OpenAIError): # type: ignore

View file

@ -195,7 +195,7 @@ class AzureChatCompletion(BaseLLM):
method="POST" method="POST"
) as response: ) as response:
if response.status_code != 200: if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text) raise AzureOpenAIError(status_code=response.status_code, message="An error occurred while streaming")
completion_stream = response.iter_lines() completion_stream = response.iter_lines()
streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="azure",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="azure",logging_obj=logging_obj)

View file

@ -2,7 +2,7 @@ import os
import json import json
from enum import Enum from enum import Enum
import requests import requests
import time import time, httpx
from typing import Callable, Any from typing import Callable, Any
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -11,6 +11,8 @@ class VLLMError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="http://0.0.0.0:8000")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs

View file

@ -453,25 +453,18 @@ def litellm_completion(*args, **kwargs):
kwargs["max_tokens"] = user_max_tokens kwargs["max_tokens"] = user_max_tokens
if user_api_base: if user_api_base:
kwargs["api_base"] = user_api_base kwargs["api_base"] = user_api_base
## CHECK CONFIG ## ## ROUTE TO CORRECT ENDPOINT ##
if llm_model_list != None: router_model_names = [m["model_name"] for m in llm_model_list]
llm_models = [m["model_name"] for m in llm_model_list] if llm_router is not None and kwargs["model"] in router_model_names: # model in router model list
if kwargs["model"] in llm_models: if call_type == "chat_completion":
for m in llm_model_list: response = llm_router.completion(*args, **kwargs)
if kwargs["model"] == m["model_name"]: # if user has specified a config, this will use the config elif call_type == "text_completion":
for key, value in m["litellm_params"].items(): response = llm_router.text_completion(*args, **kwargs)
kwargs[key] = value else:
break if call_type == "chat_completion":
else: response = litellm.completion(*args, **kwargs)
print_verbose("user sent model not in config, using default config model") elif call_type == "text_completion":
default_model = llm_model_list[0] response = litellm.text_completion(*args, **kwargs)
litellm_params = default_model.get('litellm_params', None)
for key, value in litellm_params.items():
kwargs[key] = value
if call_type == "chat_completion":
response = litellm.completion(*args, **kwargs)
elif call_type == "text_completion":
response = litellm.text_completion(*args, **kwargs)
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream') return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response return response

View file

@ -579,36 +579,34 @@ def test_completion_openai_with_more_optional_params():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_openai_with_more_optional_params() # test_completion_openai_with_more_optional_params()
# def test_completion_openai_azure_with_functions(): def test_completion_openai_azure_with_functions():
# function1 = [ function1 = [
# { {
# "name": "get_current_weather", "name": "get_current_weather",
# "description": "Get the current weather in a given location", "description": "Get the current weather in a given location",
# "parameters": { "parameters": {
# "type": "object", "type": "object",
# "properties": { "properties": {
# "location": { "location": {
# "type": "string", "type": "string",
# "description": "The city and state, e.g. San Francisco, CA", "description": "The city and state, e.g. San Francisco, CA",
# }, },
# "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
# }, },
# "required": ["location"], "required": ["location"],
# }, },
# } }
# ] ]
# try: try:
# response = completion( messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
# model="azure/chatgpt-functioncalling", messages=messages, stream=True response = completion(
# ) model="azure/chatgpt-functioncalling", messages=messages, functions=function1
# # Add any assertions here to check the response )
# print(response) # Add any assertions here to check the response
# for chunk in response: print(response)
# print(chunk) except Exception as e:
# print(chunk["choices"][0]["finish_reason"]) pytest.fail(f"Error occurred: {e}")
# except Exception as e: test_completion_openai_azure_with_functions()
# pytest.fail(f"Error occurred: {e}")
# test_completion_openai_azure_with_functions()
def test_completion_azure(): def test_completion_azure():

View file

@ -2896,7 +2896,7 @@ def convert_to_model_response_object(response_object: Optional[dict]=None, model
raise Exception("Error in response object format") raise Exception("Error in response object format")
choice_list=[] choice_list=[]
for idx, choice in enumerate(response_object["choices"]): for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["message"]["content"], role=choice["message"]["role"], function_call=choice["message"].get("function_call", None)) message = Message(content=choice["message"].get("content", None), role=choice["message"]["role"], function_call=choice["message"].get("function_call", None))
finish_reason = choice.get("finish_reason", None) finish_reason = choice.get("finish_reason", None)
if finish_reason == None: if finish_reason == None:
# gpt-4 vision can return 'finish_reason' or 'finish_details' # gpt-4 vision can return 'finish_reason' or 'finish_details'
@ -4018,7 +4018,8 @@ def exception_type(
raise APIConnectionError( raise APIConnectionError(
message=f"VLLMException - {original_exception.message}", message=f"VLLMException - {original_exception.message}",
llm_provider="vllm", llm_provider="vllm",
model=model model=model,
request=original_exception.request
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
if "This model's maximum context length is" in error_str: if "This model's maximum context length is" in error_str:
@ -4093,7 +4094,8 @@ def exception_type(
raise APIConnectionError( raise APIConnectionError(
message=f"{str(original_exception)}", message=f"{str(original_exception)}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model model=model,
request=original_exception.request
) )
except Exception as e: except Exception as e:
# LOGGING # LOGGING