forked from phoenix/litellm-mirror
feat(vertex_ai_anthropic.py): support response_schema for vertex ai anthropic calls
allows passing response_schema for anthropic calls. supports schema validation.
This commit is contained in:
parent
f8bdfe7cc3
commit
f2401d6d5e
6 changed files with 189 additions and 48 deletions
|
@ -13,7 +13,12 @@ def validate_schema(schema: dict, response: str):
|
||||||
|
|
||||||
from litellm import JSONSchemaValidationError
|
from litellm import JSONSchemaValidationError
|
||||||
|
|
||||||
response_dict = json.loads(response)
|
try:
|
||||||
|
response_dict = json.loads(response)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise JSONSchemaValidationError(
|
||||||
|
model="", llm_provider="", raw_response=response, schema=response
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
validate(response_dict, schema=schema)
|
validate(response_dict, schema=schema)
|
||||||
|
|
|
@ -16,6 +16,7 @@ from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
_get_async_httpx_client,
|
_get_async_httpx_client,
|
||||||
_get_httpx_client,
|
_get_httpx_client,
|
||||||
)
|
)
|
||||||
|
@ -538,7 +539,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def process_response(
|
def _process_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
response: Union[requests.Response, httpx.Response],
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
@ -551,6 +552,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
messages: List,
|
messages: List,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
encoding,
|
encoding,
|
||||||
|
json_mode: bool,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -574,27 +576,40 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text_content = ""
|
text_content = ""
|
||||||
tool_calls = []
|
tool_calls: List[ChatCompletionToolCallChunk] = []
|
||||||
for content in completion_response["content"]:
|
for idx, content in enumerate(completion_response["content"]):
|
||||||
if content["type"] == "text":
|
if content["type"] == "text":
|
||||||
text_content += content["text"]
|
text_content += content["text"]
|
||||||
## TOOL CALLING
|
## TOOL CALLING
|
||||||
elif content["type"] == "tool_use":
|
elif content["type"] == "tool_use":
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
{
|
ChatCompletionToolCallChunk(
|
||||||
"id": content["id"],
|
id=content["id"],
|
||||||
"type": "function",
|
type="function",
|
||||||
"function": {
|
function=ChatCompletionToolCallFunctionChunk(
|
||||||
"name": content["name"],
|
name=content["name"],
|
||||||
"arguments": json.dumps(content["input"]),
|
arguments=json.dumps(content["input"]),
|
||||||
},
|
),
|
||||||
}
|
index=idx,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
_message = litellm.Message(
|
_message = litellm.Message(
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
content=text_content or None,
|
content=text_content or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## HANDLE JSON MODE - anthropic returns single function call
|
||||||
|
if json_mode and len(tool_calls) == 1:
|
||||||
|
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
|
||||||
|
"arguments"
|
||||||
|
)
|
||||||
|
if json_mode_content_str is not None:
|
||||||
|
args = json.loads(json_mode_content_str)
|
||||||
|
values: Optional[dict] = args.get("values")
|
||||||
|
if values is not None:
|
||||||
|
_message = litellm.Message(content=json.dumps(values))
|
||||||
|
completion_response["stop_reason"] = "stop"
|
||||||
model_response.choices[0].message = _message # type: ignore
|
model_response.choices[0].message = _message # type: ignore
|
||||||
model_response._hidden_params["original_response"] = completion_response[
|
model_response._hidden_params["original_response"] = completion_response[
|
||||||
"content"
|
"content"
|
||||||
|
@ -687,9 +702,11 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
_is_function_call,
|
_is_function_call,
|
||||||
data: dict,
|
data: dict,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
json_mode: bool,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
|
client=None,
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
async_handler = _get_async_httpx_client()
|
async_handler = _get_async_httpx_client()
|
||||||
|
|
||||||
|
@ -705,7 +722,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return self.process_response(
|
return self._process_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
@ -717,6 +734,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -731,10 +749,12 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
acompletion=None,
|
acompletion=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
|
client=None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key, headers, model)
|
headers = validate_environment(api_key, headers, model)
|
||||||
_is_function_call = False
|
_is_function_call = False
|
||||||
|
@ -787,14 +807,18 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
|
|
||||||
anthropic_tools = []
|
anthropic_tools = []
|
||||||
for tool in optional_params["tools"]:
|
for tool in optional_params["tools"]:
|
||||||
new_tool = tool["function"]
|
if "input_schema" in tool: # assume in anthropic format
|
||||||
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
anthropic_tools.append(tool)
|
||||||
anthropic_tools.append(new_tool)
|
else: # assume openai tool call
|
||||||
|
new_tool = tool["function"]
|
||||||
|
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
||||||
|
anthropic_tools.append(new_tool)
|
||||||
|
|
||||||
optional_params["tools"] = anthropic_tools
|
optional_params["tools"] = anthropic_tools
|
||||||
|
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
|
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
|
||||||
|
json_mode: bool = optional_params.pop("json_mode", False)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
@ -815,7 +839,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||||
if acompletion == True:
|
if acompletion is True:
|
||||||
if (
|
if (
|
||||||
stream is True
|
stream is True
|
||||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||||
|
@ -857,15 +881,21 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
client=client,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
|
client = HTTPHandler(timeout=timeout) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client
|
||||||
if (
|
if (
|
||||||
stream is True
|
stream is True
|
||||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||||
print_verbose("makes anthropic streaming POST request")
|
print_verbose("makes anthropic streaming POST request")
|
||||||
data["stream"] = stream
|
data["stream"] = stream
|
||||||
response = requests.post(
|
response = client.post(
|
||||||
api_base,
|
api_base,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
|
@ -889,15 +919,13 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
return streaming_response
|
return streaming_response
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = requests.post(
|
response = client.post(api_base, headers=headers, data=json.dumps(data))
|
||||||
api_base, headers=headers, data=json.dumps(data)
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise AnthropicError(
|
raise AnthropicError(
|
||||||
status_code=response.status_code, message=response.text
|
status_code=response.status_code, message=response.text
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.process_response(
|
return self._process_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
@ -909,6 +937,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def embedding(self):
|
def embedding(self):
|
||||||
|
|
|
@ -7,7 +7,7 @@ import time
|
||||||
import types
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
@ -15,7 +15,14 @@ import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
|
from litellm.types.llms.anthropic import (
|
||||||
|
AnthropicMessagesTool,
|
||||||
|
AnthropicMessagesToolChoice,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
ChatCompletionToolParam,
|
||||||
|
ChatCompletionToolParamFunctionChunk,
|
||||||
|
)
|
||||||
from litellm.types.utils import ResponseFormatChunk
|
from litellm.types.utils import ResponseFormatChunk
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
|
@ -142,7 +149,27 @@ class VertexAIAnthropicConfig:
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
if param == "response_format" and "response_schema" in value:
|
if param == "response_format" and "response_schema" in value:
|
||||||
optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore
|
"""
|
||||||
|
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||||
|
- You usually want to provide a single tool
|
||||||
|
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||||
|
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||||
|
"""
|
||||||
|
_tool_choice = None
|
||||||
|
_tool_choice = {"name": "json_tool_call", "type": "tool"}
|
||||||
|
|
||||||
|
_tool = AnthropicMessagesTool(
|
||||||
|
name="json_tool_call",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"values": value["response_schema"]}, # type: ignore
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
optional_params["tools"] = [_tool]
|
||||||
|
optional_params["tool_choice"] = _tool_choice
|
||||||
|
optional_params["json_mode"] = True
|
||||||
|
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
@ -222,6 +249,7 @@ def completion(
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
custom_prompt_dict: dict,
|
custom_prompt_dict: dict,
|
||||||
headers: Optional[dict],
|
headers: Optional[dict],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
vertex_project=None,
|
vertex_project=None,
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
vertex_credentials=None,
|
vertex_credentials=None,
|
||||||
|
@ -301,6 +329,8 @@ def completion(
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
headers=vertex_headers,
|
headers=vertex_headers,
|
||||||
|
client=client,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1528,6 +1528,8 @@ def completion(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
if optional_params.get("stream", False) or acompletion == True:
|
if optional_params.get("stream", False) or acompletion == True:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -2046,7 +2048,10 @@ def completion(
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model_response = vertex_ai.completion(
|
model_response = vertex_ai.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -1,5 +1,13 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: llama-3
|
- model_name: bad-azure-model
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-4
|
model: azure/chatgpt-v-2
|
||||||
request_timeout: 1
|
azure_ad_token: ""
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
|
||||||
|
- model_name: good-openai-model
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
fallbacks: [{"bad-azure-model": ["good-openai-model"]}]
|
||||||
|
|
|
@ -1128,6 +1128,39 @@ def vertex_httpx_mock_post_valid_response(*args, **kwargs):
|
||||||
return mock_response
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def vertex_httpx_mock_post_valid_response_anthropic(*args, **kwargs):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"id": "msg_vrtx_013Wki5RFQXAspL7rmxRFjZg",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_vrtx_01YMnYZrToPPfcmY2myP2gEB",
|
||||||
|
"name": "json_tool_call",
|
||||||
|
"input": {
|
||||||
|
"values": [
|
||||||
|
{"recipe_name": "Chocolate Chip Cookies"},
|
||||||
|
{"recipe_name": "Oatmeal Raisin Cookies"},
|
||||||
|
{"recipe_name": "Peanut Butter Cookies"},
|
||||||
|
{"recipe_name": "Snickerdoodle Cookies"},
|
||||||
|
{"recipe_name": "Sugar Cookies"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stop_reason": "tool_use",
|
||||||
|
"stop_sequence": None,
|
||||||
|
"usage": {"input_tokens": 368, "output_tokens": 118},
|
||||||
|
}
|
||||||
|
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
|
@ -1183,11 +1216,29 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
||||||
return mock_response
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def vertex_httpx_mock_post_invalid_schema_response_anthropic(*args, **kwargs):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"id": "msg_vrtx_013Wki5RFQXAspL7rmxRFjZg",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"content": [{"text": "Hi! My name is Claude.", "type": "text"}],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
"usage": {"input_tokens": 368, "output_tokens": 118},
|
||||||
|
}
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, vertex_location, supports_response_schema",
|
"model, vertex_location, supports_response_schema",
|
||||||
[
|
[
|
||||||
("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True),
|
("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True),
|
||||||
("vertex_ai_beta/gemini-1.5-flash", "us-central1", False),
|
("vertex_ai_beta/gemini-1.5-flash", "us-central1", False),
|
||||||
|
("vertex_ai/claude-3-5-sonnet@20240620", "us-east5", False),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -1231,12 +1282,21 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
|
||||||
|
|
||||||
httpx_response = MagicMock()
|
httpx_response = MagicMock()
|
||||||
if invalid_response is True:
|
if invalid_response is True:
|
||||||
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
|
if "claude" in model:
|
||||||
|
httpx_response.side_effect = (
|
||||||
|
vertex_httpx_mock_post_invalid_schema_response_anthropic
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
|
||||||
else:
|
else:
|
||||||
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
|
if "claude" in model:
|
||||||
|
httpx_response.side_effect = vertex_httpx_mock_post_valid_response_anthropic
|
||||||
|
else:
|
||||||
|
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
|
||||||
with patch.object(client, "post", new=httpx_response) as mock_call:
|
with patch.object(client, "post", new=httpx_response) as mock_call:
|
||||||
|
print("SENDING CLIENT POST={}".format(client.post))
|
||||||
try:
|
try:
|
||||||
_ = completion(
|
resp = completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_format={
|
response_format={
|
||||||
|
@ -1247,30 +1307,34 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
|
||||||
vertex_location=vertex_location,
|
vertex_location=vertex_location,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
print("Received={}".format(resp))
|
||||||
if invalid_response is True and enforce_validation is True:
|
if invalid_response is True and enforce_validation is True:
|
||||||
pytest.fail("Expected this to fail")
|
pytest.fail("Expected this to fail")
|
||||||
except litellm.JSONSchemaValidationError as e:
|
except litellm.JSONSchemaValidationError as e:
|
||||||
if invalid_response is False and "claude-3" not in model:
|
if invalid_response is False:
|
||||||
pytest.fail("Expected this to pass. Got={}".format(e))
|
pytest.fail("Expected this to pass. Got={}".format(e))
|
||||||
|
|
||||||
mock_call.assert_called_once()
|
mock_call.assert_called_once()
|
||||||
print(mock_call.call_args.kwargs)
|
if "claude" not in model:
|
||||||
print(mock_call.call_args.kwargs["json"]["generationConfig"])
|
print(mock_call.call_args.kwargs)
|
||||||
|
print(mock_call.call_args.kwargs["json"]["generationConfig"])
|
||||||
|
|
||||||
if supports_response_schema:
|
if supports_response_schema:
|
||||||
assert (
|
assert (
|
||||||
"response_schema"
|
"response_schema"
|
||||||
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
"response_schema"
|
"response_schema"
|
||||||
not in mock_call.call_args.kwargs["json"]["generationConfig"]
|
not in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
"Use this JSON schema:"
|
"Use this JSON schema:"
|
||||||
in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"]
|
in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1][
|
||||||
)
|
"text"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue