forked from phoenix/litellm-mirror
Merge pull request #5296 from BerriAI/litellm_azure_json_schema_support
feat(azure.py): support 'json_schema' for older models
This commit is contained in:
commit
02eb6455b2
3 changed files with 101 additions and 31 deletions
|
@ -47,6 +47,10 @@ from ..types.llms.openai import (
|
|||
AsyncAssistantEventHandler,
|
||||
AsyncAssistantStreamManager,
|
||||
AsyncCursorPage,
|
||||
ChatCompletionToolChoiceFunctionParam,
|
||||
ChatCompletionToolChoiceObjectParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
HttpxBinaryResponseContent,
|
||||
MessageData,
|
||||
OpenAICreateThreadParamsMessage,
|
||||
|
@ -204,8 +208,8 @@ class AzureOpenAIConfig:
|
|||
and api_version_day < "01"
|
||||
)
|
||||
):
|
||||
if litellm.drop_params == True or (
|
||||
drop_params is not None and drop_params == True
|
||||
if litellm.drop_params is True or (
|
||||
drop_params is not None and drop_params is True
|
||||
):
|
||||
pass
|
||||
else:
|
||||
|
@ -227,6 +231,41 @@ class AzureOpenAIConfig:
|
|||
)
|
||||
else:
|
||||
optional_params["tool_choice"] = value
|
||||
if param == "response_format" and isinstance(value, dict):
|
||||
json_schema: Optional[dict] = None
|
||||
schema_name: str = ""
|
||||
if "response_schema" in value:
|
||||
json_schema = value["response_schema"]
|
||||
schema_name = "json_tool_call"
|
||||
elif "json_schema" in value:
|
||||
json_schema = value["json_schema"]["schema"]
|
||||
schema_name = value["json_schema"]["name"]
|
||||
"""
|
||||
Follow similar approach to anthropic - translate to a single tool call.
|
||||
|
||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||
- You usually want to provide a single tool
|
||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
"""
|
||||
if json_schema is not None:
|
||||
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolChoiceFunctionParam(
|
||||
name=schema_name
|
||||
),
|
||||
)
|
||||
|
||||
_tool = ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name=schema_name, parameters=json_schema
|
||||
),
|
||||
)
|
||||
|
||||
optional_params["tools"] = [_tool]
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
optional_params["json_mode"] = True
|
||||
elif param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
@ -513,6 +552,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
max_retries = optional_params.pop("max_retries", 2)
|
||||
json_mode: Optional[bool] = optional_params.pop("json_mode", False)
|
||||
|
||||
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
||||
### if so - set the model as part of the base url
|
||||
|
@ -578,6 +618,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
logging_obj=logging_obj,
|
||||
convert_tool_call_to_json_mode=json_mode,
|
||||
)
|
||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||
return self.streaming(
|
||||
|
@ -656,6 +697,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
return convert_to_model_response_object(
|
||||
response_object=stringified_response,
|
||||
model_response_object=model_response,
|
||||
convert_tool_call_to_json_mode=json_mode,
|
||||
)
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
|
@ -677,6 +719,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
convert_tool_call_to_json_mode: Optional[bool] = None,
|
||||
client=None, # this is the AsyncAzureOpenAI
|
||||
):
|
||||
response = None
|
||||
|
@ -742,11 +785,13 @@ class AzureChatCompletion(BaseLLM):
|
|||
original_response=stringified_response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return convert_to_model_response_object(
|
||||
response_object=stringified_response,
|
||||
model_response_object=model_response,
|
||||
hidden_params={"headers": headers},
|
||||
_response_headers=headers,
|
||||
convert_tool_call_to_json_mode=convert_tool_call_to_json_mode,
|
||||
)
|
||||
except AzureOpenAIError as e:
|
||||
## LOGGING
|
||||
|
|
|
@ -2162,37 +2162,44 @@ def test_completion_openai():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_openai_pydantic():
|
||||
@pytest.mark.parametrize("model", ["gpt-4o-2024-08-06", "azure/chatgpt-v-2"])
|
||||
def test_completion_openai_pydantic(model):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
from pydantic import BaseModel
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "List 5 important events in the XIX century"}
|
||||
]
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
name: str
|
||||
date: str
|
||||
participants: list[str]
|
||||
|
||||
print(f"api key: {os.environ['OPENAI_API_KEY']}")
|
||||
litellm.api_key = os.environ["OPENAI_API_KEY"]
|
||||
response = completion(
|
||||
model="gpt-4o-2024-08-06",
|
||||
messages=[{"role": "user", "content": "Hey"}],
|
||||
max_tokens=10,
|
||||
metadata={"hi": "bye"},
|
||||
response_format=CalendarEvent,
|
||||
)
|
||||
class EventsList(BaseModel):
|
||||
events: list[CalendarEvent]
|
||||
|
||||
litellm.enable_json_schema_validation = True
|
||||
for _ in range(3):
|
||||
try:
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
metadata={"hi": "bye"},
|
||||
response_format=EventsList,
|
||||
)
|
||||
break
|
||||
except litellm.JSONSchemaValidationError:
|
||||
print("ERROR OCCURRED! INVALID JSON")
|
||||
|
||||
print("This is the response object\n", response)
|
||||
|
||||
response_str = response["choices"][0]["message"]["content"]
|
||||
response_str_2 = response.choices[0].message.content
|
||||
|
||||
cost = completion_cost(completion_response=response)
|
||||
print("Cost for completion call with gpt-3.5-turbo: ", f"${float(cost):.10f}")
|
||||
assert response_str == response_str_2
|
||||
assert type(response_str) == str
|
||||
assert len(response_str) > 1
|
||||
print(f"response_str: {response_str}")
|
||||
json.loads(response_str) # check valid json is returned
|
||||
|
||||
litellm.api_key = None
|
||||
except Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
|
|
@ -843,13 +843,13 @@ def client(original_function):
|
|||
and str(original_function.__name__)
|
||||
in litellm.cache.supported_call_types
|
||||
):
|
||||
print_verbose(f"Checking Cache")
|
||||
print_verbose("Checking Cache")
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
if cached_result != None:
|
||||
if cached_result is not None:
|
||||
if "detail" in cached_result:
|
||||
# implies an error occurred
|
||||
pass
|
||||
|
@ -5907,6 +5907,9 @@ def convert_to_model_response_object(
|
|||
end_time=None,
|
||||
hidden_params: Optional[dict] = None,
|
||||
_response_headers: Optional[dict] = None,
|
||||
convert_tool_call_to_json_mode: Optional[
|
||||
bool
|
||||
] = None, # used for supporting 'json_schema' on older models
|
||||
):
|
||||
received_args = locals()
|
||||
if _response_headers is not None:
|
||||
|
@ -5945,7 +5948,7 @@ def convert_to_model_response_object(
|
|||
):
|
||||
if response_object is None or model_response_object is None:
|
||||
raise Exception("Error in response object format")
|
||||
if stream == True:
|
||||
if stream is True:
|
||||
# for returning cached responses, we need to yield a generator
|
||||
return convert_to_streaming_response(response_object=response_object)
|
||||
choice_list = []
|
||||
|
@ -5955,16 +5958,31 @@ def convert_to_model_response_object(
|
|||
)
|
||||
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(
|
||||
content=choice["message"].get("content", None),
|
||||
role=choice["message"]["role"] or "assistant",
|
||||
function_call=choice["message"].get("function_call", None),
|
||||
tool_calls=choice["message"].get("tool_calls", None),
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
if finish_reason == None:
|
||||
## HANDLE JSON MODE - anthropic returns single function call]
|
||||
tool_calls = choice["message"].get("tool_calls", None)
|
||||
if (
|
||||
convert_tool_call_to_json_mode
|
||||
and tool_calls is not None
|
||||
and len(tool_calls) == 1
|
||||
):
|
||||
# to support 'json_schema' logic on older models
|
||||
json_mode_content_str: Optional[str] = tool_calls[0][
|
||||
"function"
|
||||
].get("arguments")
|
||||
if json_mode_content_str is not None:
|
||||
message = litellm.Message(content=json_mode_content_str)
|
||||
finish_reason = "stop"
|
||||
else:
|
||||
message = Message(
|
||||
content=choice["message"].get("content", None),
|
||||
role=choice["message"]["role"] or "assistant",
|
||||
function_call=choice["message"].get("function_call", None),
|
||||
tool_calls=choice["message"].get("tool_calls", None),
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
if finish_reason is None:
|
||||
# gpt-4 vision can return 'finish_reason' or 'finish_details'
|
||||
finish_reason = choice.get("finish_details")
|
||||
finish_reason = choice.get("finish_details") or "stop"
|
||||
logprobs = choice.get("logprobs", None)
|
||||
enhancements = choice.get("enhancements", None)
|
||||
choice = Choices(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue