Merge pull request #5296 from BerriAI/litellm_azure_json_schema_support

feat(azure.py): support 'json_schema' for older models
This commit is contained in:
Krish Dholakia 2024-08-20 11:41:38 -07:00 committed by GitHub
commit 02eb6455b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 31 deletions

View file

@ -47,6 +47,10 @@ from ..types.llms.openai import (
AsyncAssistantEventHandler, AsyncAssistantEventHandler,
AsyncAssistantStreamManager, AsyncAssistantStreamManager,
AsyncCursorPage, AsyncCursorPage,
ChatCompletionToolChoiceFunctionParam,
ChatCompletionToolChoiceObjectParam,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
HttpxBinaryResponseContent, HttpxBinaryResponseContent,
MessageData, MessageData,
OpenAICreateThreadParamsMessage, OpenAICreateThreadParamsMessage,
@ -204,8 +208,8 @@ class AzureOpenAIConfig:
and api_version_day < "01" and api_version_day < "01"
) )
): ):
if litellm.drop_params == True or ( if litellm.drop_params is True or (
drop_params is not None and drop_params == True drop_params is not None and drop_params is True
): ):
pass pass
else: else:
@ -227,6 +231,41 @@ class AzureOpenAIConfig:
) )
else: else:
optional_params["tool_choice"] = value optional_params["tool_choice"] = value
if param == "response_format" and isinstance(value, dict):
json_schema: Optional[dict] = None
schema_name: str = ""
if "response_schema" in value:
json_schema = value["response_schema"]
schema_name = "json_tool_call"
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
schema_name = value["json_schema"]["name"]
"""
Follow similar approach to anthropic - translate to a single tool call.
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
if json_schema is not None:
_tool_choice = ChatCompletionToolChoiceObjectParam(
type="function",
function=ChatCompletionToolChoiceFunctionParam(
name=schema_name
),
)
_tool = ChatCompletionToolParam(
type="function",
function=ChatCompletionToolParamFunctionChunk(
name=schema_name, parameters=json_schema
),
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
elif param in supported_openai_params: elif param in supported_openai_params:
optional_params[param] = value optional_params[param] = value
return optional_params return optional_params
@ -513,6 +552,7 @@ class AzureChatCompletion(BaseLLM):
) )
max_retries = optional_params.pop("max_retries", 2) max_retries = optional_params.pop("max_retries", 2)
json_mode: Optional[bool] = optional_params.pop("json_mode", False)
### CHECK IF CLOUDFLARE AI GATEWAY ### ### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url ### if so - set the model as part of the base url
@ -578,6 +618,7 @@ class AzureChatCompletion(BaseLLM):
timeout=timeout, timeout=timeout,
client=client, client=client,
logging_obj=logging_obj, logging_obj=logging_obj,
convert_tool_call_to_json_mode=json_mode,
) )
elif "stream" in optional_params and optional_params["stream"] == True: elif "stream" in optional_params and optional_params["stream"] == True:
return self.streaming( return self.streaming(
@ -656,6 +697,7 @@ class AzureChatCompletion(BaseLLM):
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=stringified_response, response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
convert_tool_call_to_json_mode=json_mode,
) )
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
@ -677,6 +719,7 @@ class AzureChatCompletion(BaseLLM):
model_response: ModelResponse, model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
convert_tool_call_to_json_mode: Optional[bool] = None,
client=None, # this is the AsyncAzureOpenAI client=None, # this is the AsyncAzureOpenAI
): ):
response = None response = None
@ -742,11 +785,13 @@ class AzureChatCompletion(BaseLLM):
original_response=stringified_response, original_response=stringified_response,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=stringified_response, response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
hidden_params={"headers": headers}, hidden_params={"headers": headers},
_response_headers=headers, _response_headers=headers,
convert_tool_call_to_json_mode=convert_tool_call_to_json_mode,
) )
except AzureOpenAIError as e: except AzureOpenAIError as e:
## LOGGING ## LOGGING

View file

@ -2162,37 +2162,44 @@ def test_completion_openai():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_openai_pydantic(): @pytest.mark.parametrize("model", ["gpt-4o-2024-08-06", "azure/chatgpt-v-2"])
def test_completion_openai_pydantic(model):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
from pydantic import BaseModel from pydantic import BaseModel
messages = [
{"role": "user", "content": "List 5 important events in the XIX century"}
]
class CalendarEvent(BaseModel): class CalendarEvent(BaseModel):
name: str name: str
date: str date: str
participants: list[str] participants: list[str]
print(f"api key: {os.environ['OPENAI_API_KEY']}") class EventsList(BaseModel):
litellm.api_key = os.environ["OPENAI_API_KEY"] events: list[CalendarEvent]
response = completion(
model="gpt-4o-2024-08-06", litellm.enable_json_schema_validation = True
messages=[{"role": "user", "content": "Hey"}], for _ in range(3):
max_tokens=10, try:
metadata={"hi": "bye"}, response = completion(
response_format=CalendarEvent, model=model,
) messages=messages,
metadata={"hi": "bye"},
response_format=EventsList,
)
break
except litellm.JSONSchemaValidationError:
print("ERROR OCCURRED! INVALID JSON")
print("This is the response object\n", response) print("This is the response object\n", response)
response_str = response["choices"][0]["message"]["content"] response_str = response["choices"][0]["message"]["content"]
response_str_2 = response.choices[0].message.content
cost = completion_cost(completion_response=response) print(f"response_str: {response_str}")
print("Cost for completion call with gpt-3.5-turbo: ", f"${float(cost):.10f}") json.loads(response_str) # check valid json is returned
assert response_str == response_str_2
assert type(response_str) == str
assert len(response_str) > 1
litellm.api_key = None
except Timeout as e: except Timeout as e:
pass pass
except Exception as e: except Exception as e:

View file

@ -843,13 +843,13 @@ def client(original_function):
and str(original_function.__name__) and str(original_function.__name__)
in litellm.cache.supported_call_types in litellm.cache.supported_call_types
): ):
print_verbose(f"Checking Cache") print_verbose("Checking Cache")
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = ( kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key preset_cache_key # for streaming calls, we need to pass the preset_cache_key
) )
cached_result = litellm.cache.get_cache(*args, **kwargs) cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None: if cached_result is not None:
if "detail" in cached_result: if "detail" in cached_result:
# implies an error occurred # implies an error occurred
pass pass
@ -5907,6 +5907,9 @@ def convert_to_model_response_object(
end_time=None, end_time=None,
hidden_params: Optional[dict] = None, hidden_params: Optional[dict] = None,
_response_headers: Optional[dict] = None, _response_headers: Optional[dict] = None,
convert_tool_call_to_json_mode: Optional[
bool
] = None, # used for supporting 'json_schema' on older models
): ):
received_args = locals() received_args = locals()
if _response_headers is not None: if _response_headers is not None:
@ -5945,7 +5948,7 @@ def convert_to_model_response_object(
): ):
if response_object is None or model_response_object is None: if response_object is None or model_response_object is None:
raise Exception("Error in response object format") raise Exception("Error in response object format")
if stream == True: if stream is True:
# for returning cached responses, we need to yield a generator # for returning cached responses, we need to yield a generator
return convert_to_streaming_response(response_object=response_object) return convert_to_streaming_response(response_object=response_object)
choice_list = [] choice_list = []
@ -5955,16 +5958,31 @@ def convert_to_model_response_object(
) )
for idx, choice in enumerate(response_object["choices"]): for idx, choice in enumerate(response_object["choices"]):
message = Message( ## HANDLE JSON MODE - anthropic returns single function call]
content=choice["message"].get("content", None), tool_calls = choice["message"].get("tool_calls", None)
role=choice["message"]["role"] or "assistant", if (
function_call=choice["message"].get("function_call", None), convert_tool_call_to_json_mode
tool_calls=choice["message"].get("tool_calls", None), and tool_calls is not None
) and len(tool_calls) == 1
finish_reason = choice.get("finish_reason", None) ):
if finish_reason == None: # to support 'json_schema' logic on older models
json_mode_content_str: Optional[str] = tool_calls[0][
"function"
].get("arguments")
if json_mode_content_str is not None:
message = litellm.Message(content=json_mode_content_str)
finish_reason = "stop"
else:
message = Message(
content=choice["message"].get("content", None),
role=choice["message"]["role"] or "assistant",
function_call=choice["message"].get("function_call", None),
tool_calls=choice["message"].get("tool_calls", None),
)
finish_reason = choice.get("finish_reason", None)
if finish_reason is None:
# gpt-4 vision can return 'finish_reason' or 'finish_details' # gpt-4 vision can return 'finish_reason' or 'finish_details'
finish_reason = choice.get("finish_details") finish_reason = choice.get("finish_details") or "stop"
logprobs = choice.get("logprobs", None) logprobs = choice.get("logprobs", None)
enhancements = choice.get("enhancements", None) enhancements = choice.get("enhancements", None)
choice = Choices( choice = Choices(