mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
add test_function_calling_with_tool_response to base llm tests
This commit is contained in:
parent
a743b6fc1f
commit
aa1e13f65b
4 changed files with 62 additions and 15 deletions
|
@ -15,6 +15,12 @@ from litellm import verbose_logger
|
|||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, get_async_httpx_client
|
||||
from litellm.types.llms.anthropic import *
|
||||
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
|
||||
from litellm.types.llms.bedrock import ToolBlock as BedrockToolBlock
|
||||
from litellm.types.llms.bedrock import (
|
||||
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
|
||||
)
|
||||
from litellm.types.llms.bedrock import ToolJsonSchemaBlock as BedrockToolJsonSchemaBlock
|
||||
from litellm.types.llms.bedrock import ToolSpecBlock as BedrockToolSpecBlock
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.llms.ollama import OllamaVisionModelObject
|
||||
from litellm.types.llms.openai import (
|
||||
|
@ -1041,10 +1047,10 @@ def convert_to_gemini_tool_call_invoke(
|
|||
if tool_calls is not None:
|
||||
for tool in tool_calls:
|
||||
if "function" in tool:
|
||||
gemini_function_call: Optional[
|
||||
VertexFunctionCall
|
||||
] = _gemini_tool_call_invoke_helper(
|
||||
function_call_params=tool["function"]
|
||||
gemini_function_call: Optional[VertexFunctionCall] = (
|
||||
_gemini_tool_call_invoke_helper(
|
||||
function_call_params=tool["function"]
|
||||
)
|
||||
)
|
||||
if gemini_function_call is not None:
|
||||
_parts_list.append(
|
||||
|
@ -1139,7 +1145,7 @@ def convert_to_gemini_tool_call_result(
|
|||
|
||||
|
||||
def convert_to_anthropic_tool_result(
|
||||
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage]
|
||||
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage],
|
||||
) -> AnthropicMessagesToolResultParam:
|
||||
"""
|
||||
OpenAI message with a tool result looks like:
|
||||
|
@ -1449,9 +1455,9 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
)
|
||||
|
||||
if "cache_control" in _content_element:
|
||||
_anthropic_content_element[
|
||||
"cache_control"
|
||||
] = _content_element["cache_control"]
|
||||
_anthropic_content_element["cache_control"] = (
|
||||
_content_element["cache_control"]
|
||||
)
|
||||
user_content.append(_anthropic_content_element)
|
||||
elif m.get("type", "") == "text":
|
||||
m = cast(ChatCompletionTextObject, m)
|
||||
|
@ -1502,9 +1508,9 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
)
|
||||
|
||||
if "cache_control" in _content_element:
|
||||
_anthropic_content_text_element[
|
||||
"cache_control"
|
||||
] = _content_element["cache_control"]
|
||||
_anthropic_content_text_element["cache_control"] = (
|
||||
_content_element["cache_control"]
|
||||
)
|
||||
|
||||
user_content.append(_anthropic_content_text_element)
|
||||
|
||||
|
@ -2491,7 +2497,7 @@ def _convert_to_bedrock_tool_call_invoke(
|
|||
|
||||
|
||||
def _convert_to_bedrock_tool_call_result(
|
||||
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage]
|
||||
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage],
|
||||
) -> BedrockContentBlock:
|
||||
"""
|
||||
OpenAI message with a tool result looks like:
|
||||
|
@ -2664,7 +2670,7 @@ def get_user_message_block_or_continue_message(
|
|||
def return_assistant_continue_message(
|
||||
assistant_continue_message: Optional[
|
||||
Union[str, ChatCompletionAssistantMessage]
|
||||
] = None
|
||||
] = None,
|
||||
) -> ChatCompletionAssistantMessage:
|
||||
if assistant_continue_message and isinstance(assistant_continue_message, str):
|
||||
return ChatCompletionAssistantMessage(
|
||||
|
@ -3462,7 +3468,13 @@ def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
|
|||
for _, value in defs_copy.items():
|
||||
unpack_defs(value, defs_copy)
|
||||
unpack_defs(parameters, defs_copy)
|
||||
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
|
||||
tool_input_schema = BedrockToolInputSchemaBlock(
|
||||
json=BedrockToolJsonSchemaBlock(
|
||||
type=parameters.get("type", ""),
|
||||
properties=parameters.get("properties", {}),
|
||||
required=parameters.get("required", []),
|
||||
)
|
||||
)
|
||||
tool_spec = BedrockToolSpecBlock(
|
||||
inputSchema=tool_input_schema, name=name, description=description
|
||||
)
|
||||
|
|
|
@ -125,8 +125,19 @@ class ConverseResponseBlock(TypedDict):
|
|||
usage: ConverseTokenUsageBlock
|
||||
|
||||
|
||||
class ToolJsonArgsBlock(TypedDict, total=False):
|
||||
type: str
|
||||
description: str
|
||||
|
||||
|
||||
class ToolJsonSchemaBlock(TypedDict, total=False):
|
||||
type: Literal["object"]
|
||||
properties: dict
|
||||
required: List[str]
|
||||
|
||||
|
||||
class ToolInputSchemaBlock(TypedDict):
|
||||
json: Optional[dict]
|
||||
json: Optional[ToolJsonSchemaBlock]
|
||||
|
||||
|
||||
class ToolSpecBlock(TypedDict, total=False):
|
||||
|
|
|
@ -1037,6 +1037,7 @@ class BaseLLMChatTest(ABC):
|
|||
def test_function_calling_with_tool_response(self):
|
||||
from litellm.utils import supports_function_calling
|
||||
from litellm import completion
|
||||
litellm._turn_on_debug()
|
||||
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
@ -1056,6 +1057,7 @@ class BaseLLMChatTest(ABC):
|
|||
"name": "get_weather",
|
||||
"description": "Get the weather in a city",
|
||||
"parameters": {
|
||||
"$id": "https://some/internal/name",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
|
|
|
@ -2226,6 +2226,28 @@ class TestBedrockConverseChatNormal(BaseLLMChatTest):
|
|||
"""
|
||||
pass
|
||||
|
||||
class TestBedrockConverseNovaTestSuite(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
litellm.add_known_models()
|
||||
return {
|
||||
"model": "bedrock/us.amazon.nova-lite-v1:0",
|
||||
"aws_region_name": "us-east-1",
|
||||
}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
||||
def test_multilingual_requests(self):
|
||||
"""
|
||||
Bedrock API raises a 400 BadRequest error when the request contains invalid utf-8 sequences.
|
||||
|
||||
Todo: if litellm.modify_params is True ensure it's a valid utf-8 sequence
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TestBedrockRerank(BaseLLMRerankTest):
|
||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue