mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge 940810f526
into b82af5b826
This commit is contained in:
commit
5aae062946
5 changed files with 98 additions and 16 deletions
|
@ -1041,11 +1041,11 @@ def convert_to_gemini_tool_call_invoke(
|
||||||
if tool_calls is not None:
|
if tool_calls is not None:
|
||||||
for tool in tool_calls:
|
for tool in tool_calls:
|
||||||
if "function" in tool:
|
if "function" in tool:
|
||||||
gemini_function_call: Optional[
|
gemini_function_call: Optional[VertexFunctionCall] = (
|
||||||
VertexFunctionCall
|
_gemini_tool_call_invoke_helper(
|
||||||
] = _gemini_tool_call_invoke_helper(
|
|
||||||
function_call_params=tool["function"]
|
function_call_params=tool["function"]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if gemini_function_call is not None:
|
if gemini_function_call is not None:
|
||||||
_parts_list.append(
|
_parts_list.append(
|
||||||
VertexPartType(function_call=gemini_function_call)
|
VertexPartType(function_call=gemini_function_call)
|
||||||
|
@ -1139,7 +1139,7 @@ def convert_to_gemini_tool_call_result(
|
||||||
|
|
||||||
|
|
||||||
def convert_to_anthropic_tool_result(
|
def convert_to_anthropic_tool_result(
|
||||||
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage]
|
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage],
|
||||||
) -> AnthropicMessagesToolResultParam:
|
) -> AnthropicMessagesToolResultParam:
|
||||||
"""
|
"""
|
||||||
OpenAI message with a tool result looks like:
|
OpenAI message with a tool result looks like:
|
||||||
|
@ -1449,9 +1449,9 @@ def anthropic_messages_pt( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
if "cache_control" in _content_element:
|
if "cache_control" in _content_element:
|
||||||
_anthropic_content_element[
|
_anthropic_content_element["cache_control"] = (
|
||||||
"cache_control"
|
_content_element["cache_control"]
|
||||||
] = _content_element["cache_control"]
|
)
|
||||||
user_content.append(_anthropic_content_element)
|
user_content.append(_anthropic_content_element)
|
||||||
elif m.get("type", "") == "text":
|
elif m.get("type", "") == "text":
|
||||||
m = cast(ChatCompletionTextObject, m)
|
m = cast(ChatCompletionTextObject, m)
|
||||||
|
@ -1502,9 +1502,9 @@ def anthropic_messages_pt( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
if "cache_control" in _content_element:
|
if "cache_control" in _content_element:
|
||||||
_anthropic_content_text_element[
|
_anthropic_content_text_element["cache_control"] = (
|
||||||
"cache_control"
|
_content_element["cache_control"]
|
||||||
] = _content_element["cache_control"]
|
)
|
||||||
|
|
||||||
user_content.append(_anthropic_content_text_element)
|
user_content.append(_anthropic_content_text_element)
|
||||||
|
|
||||||
|
@ -2244,6 +2244,7 @@ from litellm.types.llms.bedrock import ToolBlock as BedrockToolBlock
|
||||||
from litellm.types.llms.bedrock import (
|
from litellm.types.llms.bedrock import (
|
||||||
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
|
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
|
||||||
)
|
)
|
||||||
|
from litellm.types.llms.bedrock import ToolJsonSchemaBlock as BedrockToolJsonSchemaBlock
|
||||||
from litellm.types.llms.bedrock import ToolResultBlock as BedrockToolResultBlock
|
from litellm.types.llms.bedrock import ToolResultBlock as BedrockToolResultBlock
|
||||||
from litellm.types.llms.bedrock import (
|
from litellm.types.llms.bedrock import (
|
||||||
ToolResultContentBlock as BedrockToolResultContentBlock,
|
ToolResultContentBlock as BedrockToolResultContentBlock,
|
||||||
|
@ -2499,7 +2500,7 @@ def _convert_to_bedrock_tool_call_invoke(
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_bedrock_tool_call_result(
|
def _convert_to_bedrock_tool_call_result(
|
||||||
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage]
|
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage],
|
||||||
) -> BedrockContentBlock:
|
) -> BedrockContentBlock:
|
||||||
"""
|
"""
|
||||||
OpenAI message with a tool result looks like:
|
OpenAI message with a tool result looks like:
|
||||||
|
@ -2672,7 +2673,7 @@ def get_user_message_block_or_continue_message(
|
||||||
def return_assistant_continue_message(
|
def return_assistant_continue_message(
|
||||||
assistant_continue_message: Optional[
|
assistant_continue_message: Optional[
|
||||||
Union[str, ChatCompletionAssistantMessage]
|
Union[str, ChatCompletionAssistantMessage]
|
||||||
] = None
|
] = None,
|
||||||
) -> ChatCompletionAssistantMessage:
|
) -> ChatCompletionAssistantMessage:
|
||||||
if assistant_continue_message and isinstance(assistant_continue_message, str):
|
if assistant_continue_message and isinstance(assistant_continue_message, str):
|
||||||
return ChatCompletionAssistantMessage(
|
return ChatCompletionAssistantMessage(
|
||||||
|
@ -3470,7 +3471,13 @@ def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
|
||||||
for _, value in defs_copy.items():
|
for _, value in defs_copy.items():
|
||||||
unpack_defs(value, defs_copy)
|
unpack_defs(value, defs_copy)
|
||||||
unpack_defs(parameters, defs_copy)
|
unpack_defs(parameters, defs_copy)
|
||||||
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
|
tool_input_schema = BedrockToolInputSchemaBlock(
|
||||||
|
json=BedrockToolJsonSchemaBlock(
|
||||||
|
type=parameters.get("type", ""),
|
||||||
|
properties=parameters.get("properties", {}),
|
||||||
|
required=parameters.get("required", []),
|
||||||
|
)
|
||||||
|
)
|
||||||
tool_spec = BedrockToolSpecBlock(
|
tool_spec = BedrockToolSpecBlock(
|
||||||
inputSchema=tool_input_schema, name=name, description=description
|
inputSchema=tool_input_schema, name=name, description=description
|
||||||
)
|
)
|
||||||
|
|
|
@ -125,8 +125,14 @@ class ConverseResponseBlock(TypedDict):
|
||||||
usage: ConverseTokenUsageBlock
|
usage: ConverseTokenUsageBlock
|
||||||
|
|
||||||
|
|
||||||
|
class ToolJsonSchemaBlock(TypedDict, total=False):
|
||||||
|
type: Literal["object"]
|
||||||
|
properties: dict
|
||||||
|
required: List[str]
|
||||||
|
|
||||||
|
|
||||||
class ToolInputSchemaBlock(TypedDict):
|
class ToolInputSchemaBlock(TypedDict):
|
||||||
json: Optional[dict]
|
json: Optional[ToolJsonSchemaBlock]
|
||||||
|
|
||||||
|
|
||||||
class ToolSpecBlock(TypedDict, total=False):
|
class ToolSpecBlock(TypedDict, total=False):
|
||||||
|
|
|
@ -1055,6 +1055,7 @@ class BaseLLMChatTest(ABC):
|
||||||
def test_function_calling_with_tool_response(self):
|
def test_function_calling_with_tool_response(self):
|
||||||
from litellm.utils import supports_function_calling
|
from litellm.utils import supports_function_calling
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
@ -1074,6 +1075,8 @@ class BaseLLMChatTest(ABC):
|
||||||
"name": "get_weather",
|
"name": "get_weather",
|
||||||
"description": "Get the weather in a city",
|
"description": "Get the weather in a city",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
|
"$id": "https://some/internal/name",
|
||||||
|
"$schema": "https://json-schema.org/draft-07/schema",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"city": {
|
"city": {
|
||||||
|
|
|
@ -1264,6 +1264,50 @@ def test_bedrock_tools_pt_invalid_names():
|
||||||
assert result[1]["toolSpec"]["name"] == "another_invalid_name"
|
assert result[1]["toolSpec"]["name"] == "another_invalid_name"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_tools_transformation_valid_params():
|
||||||
|
from litellm.types.llms.bedrock import ToolJsonSchemaBlock
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "123-invalid@name",
|
||||||
|
"description": "Invalid name test",
|
||||||
|
"parameters": {
|
||||||
|
"$id": "https://some/internal/name",
|
||||||
|
"type": "object",
|
||||||
|
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||||
|
"properties": {
|
||||||
|
"test": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["test"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = _bedrock_tools_pt(tools)
|
||||||
|
|
||||||
|
print("bedrock tools after prompt formatting=", result)
|
||||||
|
# Ensure the keys for properties in the response is a subset of keys in ToolJsonSchemaBlock
|
||||||
|
toolJsonSchema = result[0]["toolSpec"]["inputSchema"]["json"]
|
||||||
|
assert toolJsonSchema is not None
|
||||||
|
print("transformed toolJsonSchema keys=", toolJsonSchema.keys())
|
||||||
|
print("allowed ToolJsonSchemaBlock keys=", ToolJsonSchemaBlock.__annotations__.keys())
|
||||||
|
assert set(toolJsonSchema.keys()).issubset(set(ToolJsonSchemaBlock.__annotations__.keys()))
|
||||||
|
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "toolSpec" in result[0]
|
||||||
|
assert result[0]["toolSpec"]["name"] == "a123_invalid_name"
|
||||||
|
assert result[0]["toolSpec"]["description"] == "Invalid name test"
|
||||||
|
assert "inputSchema" in result[0]["toolSpec"]
|
||||||
|
assert "json" in result[0]["toolSpec"]["inputSchema"]
|
||||||
|
assert result[0]["toolSpec"]["inputSchema"]["json"]["properties"]["test"]["type"] == "string"
|
||||||
|
assert "test" in result[0]["toolSpec"]["inputSchema"]["json"]["required"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_not_found_error():
|
def test_not_found_error():
|
||||||
with pytest.raises(litellm.NotFoundError):
|
with pytest.raises(litellm.NotFoundError):
|
||||||
completion(
|
completion(
|
||||||
|
@ -2226,6 +2270,28 @@ class TestBedrockConverseChatNormal(BaseLLMChatTest):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class TestBedrockConverseNovaTestSuite(BaseLLMChatTest):
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
litellm.add_known_models()
|
||||||
|
return {
|
||||||
|
"model": "bedrock/us.amazon.nova-lite-v1:0",
|
||||||
|
"aws_region_name": "us-east-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||||
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_multilingual_requests(self):
|
||||||
|
"""
|
||||||
|
Bedrock API raises a 400 BadRequest error when the request contains invalid utf-8 sequences.
|
||||||
|
|
||||||
|
Todo: if litellm.modify_params is True ensure it's a valid utf-8 sequence
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestBedrockRerank(BaseLLMRerankTest):
|
class TestBedrockRerank(BaseLLMRerankTest):
|
||||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue