Merge pull request #5296 from BerriAI/litellm_azure_json_schema_support

feat(azure.py): support 'json_schema' for older models
This commit is contained in:
Krish Dholakia 2024-08-20 11:41:38 -07:00 committed by GitHub
commit 02eb6455b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 31 deletions

View file

@ -47,6 +47,10 @@ from ..types.llms.openai import (
AsyncAssistantEventHandler,
AsyncAssistantStreamManager,
AsyncCursorPage,
ChatCompletionToolChoiceFunctionParam,
ChatCompletionToolChoiceObjectParam,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
HttpxBinaryResponseContent,
MessageData,
OpenAICreateThreadParamsMessage,
@ -204,8 +208,8 @@ class AzureOpenAIConfig:
and api_version_day < "01"
)
):
if litellm.drop_params == True or (
drop_params is not None and drop_params == True
if litellm.drop_params is True or (
drop_params is not None and drop_params is True
):
pass
else:
@ -227,6 +231,41 @@ class AzureOpenAIConfig:
)
else:
optional_params["tool_choice"] = value
if param == "response_format" and isinstance(value, dict):
json_schema: Optional[dict] = None
schema_name: str = ""
if "response_schema" in value:
json_schema = value["response_schema"]
schema_name = "json_tool_call"
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
schema_name = value["json_schema"]["name"]
"""
Follow similar approach to anthropic - translate to a single tool call.
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
if json_schema is not None:
_tool_choice = ChatCompletionToolChoiceObjectParam(
type="function",
function=ChatCompletionToolChoiceFunctionParam(
name=schema_name
),
)
_tool = ChatCompletionToolParam(
type="function",
function=ChatCompletionToolParamFunctionChunk(
name=schema_name, parameters=json_schema
),
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
elif param in supported_openai_params:
optional_params[param] = value
return optional_params
@ -513,6 +552,7 @@ class AzureChatCompletion(BaseLLM):
)
max_retries = optional_params.pop("max_retries", 2)
json_mode: Optional[bool] = optional_params.pop("json_mode", False)
### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url
@ -578,6 +618,7 @@ class AzureChatCompletion(BaseLLM):
timeout=timeout,
client=client,
logging_obj=logging_obj,
convert_tool_call_to_json_mode=json_mode,
)
elif "stream" in optional_params and optional_params["stream"] == True:
return self.streaming(
@ -656,6 +697,7 @@ class AzureChatCompletion(BaseLLM):
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
convert_tool_call_to_json_mode=json_mode,
)
except AzureOpenAIError as e:
exception_mapping_worked = True
@ -677,6 +719,7 @@ class AzureChatCompletion(BaseLLM):
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
azure_ad_token: Optional[str] = None,
convert_tool_call_to_json_mode: Optional[bool] = None,
client=None, # this is the AsyncAzureOpenAI
):
response = None
@ -742,11 +785,13 @@ class AzureChatCompletion(BaseLLM):
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
hidden_params={"headers": headers},
_response_headers=headers,
convert_tool_call_to_json_mode=convert_tool_call_to_json_mode,
)
except AzureOpenAIError as e:
## LOGGING

View file

@ -2162,37 +2162,44 @@ def test_completion_openai():
pytest.fail(f"Error occurred: {e}")
def test_completion_openai_pydantic():
@pytest.mark.parametrize("model", ["gpt-4o-2024-08-06", "azure/chatgpt-v-2"])
def test_completion_openai_pydantic(model):
try:
litellm.set_verbose = True
from pydantic import BaseModel
messages = [
{"role": "user", "content": "List 5 important events in the XIX century"}
]
class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]
print(f"api key: {os.environ['OPENAI_API_KEY']}")
litellm.api_key = os.environ["OPENAI_API_KEY"]
response = completion(
model="gpt-4o-2024-08-06",
messages=[{"role": "user", "content": "Hey"}],
max_tokens=10,
metadata={"hi": "bye"},
response_format=CalendarEvent,
)
class EventsList(BaseModel):
events: list[CalendarEvent]
litellm.enable_json_schema_validation = True
for _ in range(3):
try:
response = completion(
model=model,
messages=messages,
metadata={"hi": "bye"},
response_format=EventsList,
)
break
except litellm.JSONSchemaValidationError:
print("ERROR OCCURRED! INVALID JSON")
print("This is the response object\n", response)
response_str = response["choices"][0]["message"]["content"]
response_str_2 = response.choices[0].message.content
cost = completion_cost(completion_response=response)
print("Cost for completion call with gpt-3.5-turbo: ", f"${float(cost):.10f}")
assert response_str == response_str_2
assert type(response_str) == str
assert len(response_str) > 1
print(f"response_str: {response_str}")
json.loads(response_str) # check valid json is returned
litellm.api_key = None
except Timeout as e:
pass
except Exception as e:

View file

@ -843,13 +843,13 @@ def client(original_function):
and str(original_function.__name__)
in litellm.cache.supported_call_types
):
print_verbose(f"Checking Cache")
print_verbose("Checking Cache")
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None:
if cached_result is not None:
if "detail" in cached_result:
# implies an error occurred
pass
@ -5907,6 +5907,9 @@ def convert_to_model_response_object(
end_time=None,
hidden_params: Optional[dict] = None,
_response_headers: Optional[dict] = None,
convert_tool_call_to_json_mode: Optional[
bool
] = None, # used for supporting 'json_schema' on older models
):
received_args = locals()
if _response_headers is not None:
@ -5945,7 +5948,7 @@ def convert_to_model_response_object(
):
if response_object is None or model_response_object is None:
raise Exception("Error in response object format")
if stream == True:
if stream is True:
# for returning cached responses, we need to yield a generator
return convert_to_streaming_response(response_object=response_object)
choice_list = []
@ -5955,16 +5958,31 @@ def convert_to_model_response_object(
)
for idx, choice in enumerate(response_object["choices"]):
message = Message(
content=choice["message"].get("content", None),
role=choice["message"]["role"] or "assistant",
function_call=choice["message"].get("function_call", None),
tool_calls=choice["message"].get("tool_calls", None),
)
finish_reason = choice.get("finish_reason", None)
if finish_reason == None:
## HANDLE JSON MODE - anthropic returns single function call]
tool_calls = choice["message"].get("tool_calls", None)
if (
convert_tool_call_to_json_mode
and tool_calls is not None
and len(tool_calls) == 1
):
# to support 'json_schema' logic on older models
json_mode_content_str: Optional[str] = tool_calls[0][
"function"
].get("arguments")
if json_mode_content_str is not None:
message = litellm.Message(content=json_mode_content_str)
finish_reason = "stop"
else:
message = Message(
content=choice["message"].get("content", None),
role=choice["message"]["role"] or "assistant",
function_call=choice["message"].get("function_call", None),
tool_calls=choice["message"].get("tool_calls", None),
)
finish_reason = choice.get("finish_reason", None)
if finish_reason is None:
# gpt-4 vision can return 'finish_reason' or 'finish_details'
finish_reason = choice.get("finish_details")
finish_reason = choice.get("finish_details") or "stop"
logprobs = choice.get("logprobs", None)
enhancements = choice.get("enhancements", None)
choice = Choices(