Merge pull request #5296 from BerriAI/litellm_azure_json_schema_support

feat(azure.py): support 'json_schema' for older models
This commit is contained in:
Krish Dholakia 2024-08-20 11:41:38 -07:00 committed by GitHub
commit 02eb6455b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 31 deletions

View file

@ -47,6 +47,10 @@ from ..types.llms.openai import (
AsyncAssistantEventHandler,
AsyncAssistantStreamManager,
AsyncCursorPage,
ChatCompletionToolChoiceFunctionParam,
ChatCompletionToolChoiceObjectParam,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
HttpxBinaryResponseContent,
MessageData,
OpenAICreateThreadParamsMessage,
@ -204,8 +208,8 @@ class AzureOpenAIConfig:
and api_version_day < "01"
)
):
if litellm.drop_params == True or (
drop_params is not None and drop_params == True
if litellm.drop_params is True or (
drop_params is not None and drop_params is True
):
pass
else:
@ -227,6 +231,41 @@ class AzureOpenAIConfig:
)
else:
optional_params["tool_choice"] = value
if param == "response_format" and isinstance(value, dict):
json_schema: Optional[dict] = None
schema_name: str = ""
if "response_schema" in value:
json_schema = value["response_schema"]
schema_name = "json_tool_call"
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
schema_name = value["json_schema"]["name"]
"""
Follow similar approach to anthropic - translate to a single tool call.
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
if json_schema is not None:
_tool_choice = ChatCompletionToolChoiceObjectParam(
type="function",
function=ChatCompletionToolChoiceFunctionParam(
name=schema_name
),
)
_tool = ChatCompletionToolParam(
type="function",
function=ChatCompletionToolParamFunctionChunk(
name=schema_name, parameters=json_schema
),
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
elif param in supported_openai_params:
optional_params[param] = value
return optional_params
@ -513,6 +552,7 @@ class AzureChatCompletion(BaseLLM):
)
max_retries = optional_params.pop("max_retries", 2)
json_mode: Optional[bool] = optional_params.pop("json_mode", False)
### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url
@ -578,6 +618,7 @@ class AzureChatCompletion(BaseLLM):
timeout=timeout,
client=client,
logging_obj=logging_obj,
convert_tool_call_to_json_mode=json_mode,
)
elif "stream" in optional_params and optional_params["stream"] == True:
return self.streaming(
@ -656,6 +697,7 @@ class AzureChatCompletion(BaseLLM):
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
convert_tool_call_to_json_mode=json_mode,
)
except AzureOpenAIError as e:
exception_mapping_worked = True
@ -677,6 +719,7 @@ class AzureChatCompletion(BaseLLM):
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
azure_ad_token: Optional[str] = None,
convert_tool_call_to_json_mode: Optional[bool] = None,
client=None, # this is the AsyncAzureOpenAI
):
response = None
@ -742,11 +785,13 @@ class AzureChatCompletion(BaseLLM):
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
hidden_params={"headers": headers},
_response_headers=headers,
convert_tool_call_to_json_mode=convert_tool_call_to_json_mode,
)
except AzureOpenAIError as e:
## LOGGING