forked from phoenix/litellm-mirror
Merge pull request #5296 from BerriAI/litellm_azure_json_schema_support
feat(azure.py): support 'json_schema' for older models
This commit is contained in:
commit
02eb6455b2
3 changed files with 101 additions and 31 deletions
|
@ -47,6 +47,10 @@ from ..types.llms.openai import (
|
||||||
AsyncAssistantEventHandler,
|
AsyncAssistantEventHandler,
|
||||||
AsyncAssistantStreamManager,
|
AsyncAssistantStreamManager,
|
||||||
AsyncCursorPage,
|
AsyncCursorPage,
|
||||||
|
ChatCompletionToolChoiceFunctionParam,
|
||||||
|
ChatCompletionToolChoiceObjectParam,
|
||||||
|
ChatCompletionToolParam,
|
||||||
|
ChatCompletionToolParamFunctionChunk,
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
MessageData,
|
MessageData,
|
||||||
OpenAICreateThreadParamsMessage,
|
OpenAICreateThreadParamsMessage,
|
||||||
|
@ -204,8 +208,8 @@ class AzureOpenAIConfig:
|
||||||
and api_version_day < "01"
|
and api_version_day < "01"
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
if litellm.drop_params == True or (
|
if litellm.drop_params is True or (
|
||||||
drop_params is not None and drop_params == True
|
drop_params is not None and drop_params is True
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
@ -227,6 +231,41 @@ class AzureOpenAIConfig:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optional_params["tool_choice"] = value
|
optional_params["tool_choice"] = value
|
||||||
|
if param == "response_format" and isinstance(value, dict):
|
||||||
|
json_schema: Optional[dict] = None
|
||||||
|
schema_name: str = ""
|
||||||
|
if "response_schema" in value:
|
||||||
|
json_schema = value["response_schema"]
|
||||||
|
schema_name = "json_tool_call"
|
||||||
|
elif "json_schema" in value:
|
||||||
|
json_schema = value["json_schema"]["schema"]
|
||||||
|
schema_name = value["json_schema"]["name"]
|
||||||
|
"""
|
||||||
|
Follow similar approach to anthropic - translate to a single tool call.
|
||||||
|
|
||||||
|
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||||
|
- You usually want to provide a single tool
|
||||||
|
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||||
|
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||||
|
"""
|
||||||
|
if json_schema is not None:
|
||||||
|
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolChoiceFunctionParam(
|
||||||
|
name=schema_name
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
_tool = ChatCompletionToolParam(
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolParamFunctionChunk(
|
||||||
|
name=schema_name, parameters=json_schema
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
optional_params["tools"] = [_tool]
|
||||||
|
optional_params["tool_choice"] = _tool_choice
|
||||||
|
optional_params["json_mode"] = True
|
||||||
elif param in supported_openai_params:
|
elif param in supported_openai_params:
|
||||||
optional_params[param] = value
|
optional_params[param] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
@ -513,6 +552,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
max_retries = optional_params.pop("max_retries", 2)
|
max_retries = optional_params.pop("max_retries", 2)
|
||||||
|
json_mode: Optional[bool] = optional_params.pop("json_mode", False)
|
||||||
|
|
||||||
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
||||||
### if so - set the model as part of the base url
|
### if so - set the model as part of the base url
|
||||||
|
@ -578,6 +618,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
convert_tool_call_to_json_mode=json_mode,
|
||||||
)
|
)
|
||||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||||
return self.streaming(
|
return self.streaming(
|
||||||
|
@ -656,6 +697,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
return convert_to_model_response_object(
|
return convert_to_model_response_object(
|
||||||
response_object=stringified_response,
|
response_object=stringified_response,
|
||||||
model_response_object=model_response,
|
model_response_object=model_response,
|
||||||
|
convert_tool_call_to_json_mode=json_mode,
|
||||||
)
|
)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -677,6 +719,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
|
convert_tool_call_to_json_mode: Optional[bool] = None,
|
||||||
client=None, # this is the AsyncAzureOpenAI
|
client=None, # this is the AsyncAzureOpenAI
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
|
@ -742,11 +785,13 @@ class AzureChatCompletion(BaseLLM):
|
||||||
original_response=stringified_response,
|
original_response=stringified_response,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
|
|
||||||
return convert_to_model_response_object(
|
return convert_to_model_response_object(
|
||||||
response_object=stringified_response,
|
response_object=stringified_response,
|
||||||
model_response_object=model_response,
|
model_response_object=model_response,
|
||||||
hidden_params={"headers": headers},
|
hidden_params={"headers": headers},
|
||||||
_response_headers=headers,
|
_response_headers=headers,
|
||||||
|
convert_tool_call_to_json_mode=convert_tool_call_to_json_mode,
|
||||||
)
|
)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -2162,37 +2162,44 @@ def test_completion_openai():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_openai_pydantic():
|
@pytest.mark.parametrize("model", ["gpt-4o-2024-08-06", "azure/chatgpt-v-2"])
|
||||||
|
def test_completion_openai_pydantic(model):
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "List 5 important events in the XIX century"}
|
||||||
|
]
|
||||||
|
|
||||||
class CalendarEvent(BaseModel):
|
class CalendarEvent(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
date: str
|
date: str
|
||||||
participants: list[str]
|
participants: list[str]
|
||||||
|
|
||||||
print(f"api key: {os.environ['OPENAI_API_KEY']}")
|
class EventsList(BaseModel):
|
||||||
litellm.api_key = os.environ["OPENAI_API_KEY"]
|
events: list[CalendarEvent]
|
||||||
response = completion(
|
|
||||||
model="gpt-4o-2024-08-06",
|
litellm.enable_json_schema_validation = True
|
||||||
messages=[{"role": "user", "content": "Hey"}],
|
for _ in range(3):
|
||||||
max_tokens=10,
|
try:
|
||||||
metadata={"hi": "bye"},
|
response = completion(
|
||||||
response_format=CalendarEvent,
|
model=model,
|
||||||
)
|
messages=messages,
|
||||||
|
metadata={"hi": "bye"},
|
||||||
|
response_format=EventsList,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except litellm.JSONSchemaValidationError:
|
||||||
|
print("ERROR OCCURRED! INVALID JSON")
|
||||||
|
|
||||||
print("This is the response object\n", response)
|
print("This is the response object\n", response)
|
||||||
|
|
||||||
response_str = response["choices"][0]["message"]["content"]
|
response_str = response["choices"][0]["message"]["content"]
|
||||||
response_str_2 = response.choices[0].message.content
|
|
||||||
|
|
||||||
cost = completion_cost(completion_response=response)
|
print(f"response_str: {response_str}")
|
||||||
print("Cost for completion call with gpt-3.5-turbo: ", f"${float(cost):.10f}")
|
json.loads(response_str) # check valid json is returned
|
||||||
assert response_str == response_str_2
|
|
||||||
assert type(response_str) == str
|
|
||||||
assert len(response_str) > 1
|
|
||||||
|
|
||||||
litellm.api_key = None
|
|
||||||
except Timeout as e:
|
except Timeout as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -843,13 +843,13 @@ def client(original_function):
|
||||||
and str(original_function.__name__)
|
and str(original_function.__name__)
|
||||||
in litellm.cache.supported_call_types
|
in litellm.cache.supported_call_types
|
||||||
):
|
):
|
||||||
print_verbose(f"Checking Cache")
|
print_verbose("Checking Cache")
|
||||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||||
kwargs["preset_cache_key"] = (
|
kwargs["preset_cache_key"] = (
|
||||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||||
)
|
)
|
||||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||||
if cached_result != None:
|
if cached_result is not None:
|
||||||
if "detail" in cached_result:
|
if "detail" in cached_result:
|
||||||
# implies an error occurred
|
# implies an error occurred
|
||||||
pass
|
pass
|
||||||
|
@ -5907,6 +5907,9 @@ def convert_to_model_response_object(
|
||||||
end_time=None,
|
end_time=None,
|
||||||
hidden_params: Optional[dict] = None,
|
hidden_params: Optional[dict] = None,
|
||||||
_response_headers: Optional[dict] = None,
|
_response_headers: Optional[dict] = None,
|
||||||
|
convert_tool_call_to_json_mode: Optional[
|
||||||
|
bool
|
||||||
|
] = None, # used for supporting 'json_schema' on older models
|
||||||
):
|
):
|
||||||
received_args = locals()
|
received_args = locals()
|
||||||
if _response_headers is not None:
|
if _response_headers is not None:
|
||||||
|
@ -5945,7 +5948,7 @@ def convert_to_model_response_object(
|
||||||
):
|
):
|
||||||
if response_object is None or model_response_object is None:
|
if response_object is None or model_response_object is None:
|
||||||
raise Exception("Error in response object format")
|
raise Exception("Error in response object format")
|
||||||
if stream == True:
|
if stream is True:
|
||||||
# for returning cached responses, we need to yield a generator
|
# for returning cached responses, we need to yield a generator
|
||||||
return convert_to_streaming_response(response_object=response_object)
|
return convert_to_streaming_response(response_object=response_object)
|
||||||
choice_list = []
|
choice_list = []
|
||||||
|
@ -5955,16 +5958,31 @@ def convert_to_model_response_object(
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, choice in enumerate(response_object["choices"]):
|
for idx, choice in enumerate(response_object["choices"]):
|
||||||
message = Message(
|
## HANDLE JSON MODE - anthropic returns single function call]
|
||||||
content=choice["message"].get("content", None),
|
tool_calls = choice["message"].get("tool_calls", None)
|
||||||
role=choice["message"]["role"] or "assistant",
|
if (
|
||||||
function_call=choice["message"].get("function_call", None),
|
convert_tool_call_to_json_mode
|
||||||
tool_calls=choice["message"].get("tool_calls", None),
|
and tool_calls is not None
|
||||||
)
|
and len(tool_calls) == 1
|
||||||
finish_reason = choice.get("finish_reason", None)
|
):
|
||||||
if finish_reason == None:
|
# to support 'json_schema' logic on older models
|
||||||
|
json_mode_content_str: Optional[str] = tool_calls[0][
|
||||||
|
"function"
|
||||||
|
].get("arguments")
|
||||||
|
if json_mode_content_str is not None:
|
||||||
|
message = litellm.Message(content=json_mode_content_str)
|
||||||
|
finish_reason = "stop"
|
||||||
|
else:
|
||||||
|
message = Message(
|
||||||
|
content=choice["message"].get("content", None),
|
||||||
|
role=choice["message"]["role"] or "assistant",
|
||||||
|
function_call=choice["message"].get("function_call", None),
|
||||||
|
tool_calls=choice["message"].get("tool_calls", None),
|
||||||
|
)
|
||||||
|
finish_reason = choice.get("finish_reason", None)
|
||||||
|
if finish_reason is None:
|
||||||
# gpt-4 vision can return 'finish_reason' or 'finish_details'
|
# gpt-4 vision can return 'finish_reason' or 'finish_details'
|
||||||
finish_reason = choice.get("finish_details")
|
finish_reason = choice.get("finish_details") or "stop"
|
||||||
logprobs = choice.get("logprobs", None)
|
logprobs = choice.get("logprobs", None)
|
||||||
enhancements = choice.get("enhancements", None)
|
enhancements = choice.get("enhancements", None)
|
||||||
choice = Choices(
|
choice = Choices(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue