Litellm dev 04 05 2025 p2 (#9774)

* test: move test to just checking async

* fix(transformation.py): handle function call with no schema

* fix(utils.py): handle pydantic base model in message tool calls

Fix https://github.com/BerriAI/litellm/issues/9321

* fix(vertex_and_google_ai_studio.py): handle tools=[]

Fixes https://github.com/BerriAI/litellm/issues/9080

* test: remove max token restriction

* test: fix basic test

* fix(get_supported_openai_params.py): fix check

* fix(converse_transformation.py): support fake streaming for meta.llama3-3-70b-instruct-v1:0

* fix: fix test

* fix: parse out empty dictionary on dbrx streaming + tool calls

* fix(handle-'strict'-param-when-calling-fireworks-ai): fireworks ai does not support 'strict' param

* fix: fix ruff check

'

* fix: handle no strict in function

* fix: revert bedrock change - handle in separate PR
This commit is contained in:
Krish Dholakia 2025-04-07 21:02:52 -07:00 committed by GitHub
parent d8f47fc9e5
commit fcf17d114f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 214 additions and 11 deletions

View file

@ -30,6 +30,7 @@ from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
ChatCompletionUserMessage,
OpenAIChatCompletionToolParam,
OpenAIMessageContentListBlock,
)
from litellm.types.utils import ModelResponse, PromptTokensDetailsWrapper, Usage
@ -211,6 +212,23 @@ class AmazonConverseConfig(BaseConfig):
)
return _tool
def _apply_tool_call_transformation(
self,
tools: List[OpenAIChatCompletionToolParam],
model: str,
non_default_params: dict,
optional_params: dict,
):
optional_params = self._add_tools_to_optional_params(
optional_params=optional_params, tools=tools
)
if (
"meta.llama3-3-70b-instruct-v1:0" in model
and non_default_params.get("stream", False) is True
):
optional_params["fake_stream"] = True
def map_openai_params(
self,
non_default_params: dict,
@ -286,8 +304,11 @@ class AmazonConverseConfig(BaseConfig):
if param == "top_p":
optional_params["topP"] = value
if param == "tools" and isinstance(value, list):
optional_params = self._add_tools_to_optional_params(
optional_params=optional_params, tools=value
self._apply_tool_call_transformation(
tools=cast(List[OpenAIChatCompletionToolParam], value),
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
)
if param == "tool_choice":
_tool_choice_value = self.map_tool_choice_values(

View file

@ -27,7 +27,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
strip_name_from_messages,
)
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.types.llms.anthropic import AnthropicMessagesTool
from litellm.types.llms.anthropic import AllAnthropicToolsValues
from litellm.types.llms.databricks import (
AllDatabricksContentValues,
DatabricksChoice,
@ -160,7 +160,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
]
def convert_anthropic_tool_to_databricks_tool(
self, tool: Optional[AnthropicMessagesTool]
self, tool: Optional[AllAnthropicToolsValues]
) -> Optional[DatabricksTool]:
if tool is None:
return None
@ -173,6 +173,19 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
),
)
def _map_openai_to_dbrx_tool(self, model: str, tools: List) -> List[DatabricksTool]:
# if not claude, send as is
if "claude" not in model:
return tools
# if claude, convert to anthropic tool and then to databricks tool
anthropic_tools = self._map_tools(tools=tools)
databricks_tools = [
cast(DatabricksTool, self.convert_anthropic_tool_to_databricks_tool(tool))
for tool in anthropic_tools
]
return databricks_tools
def map_response_format_to_databricks_tool(
self,
model: str,
@ -202,6 +215,10 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
mapped_params = super().map_openai_params(
non_default_params, optional_params, model, drop_params
)
if "tools" in mapped_params:
mapped_params["tools"] = self._map_openai_to_dbrx_tool(
model=model, tools=mapped_params["tools"]
)
if (
"max_completion_tokens" in non_default_params
and replace_max_completion_tokens_with_max_tokens
@ -499,7 +516,10 @@ class DatabricksChatResponseIterator(BaseModelResponseIterator):
message.content = ""
choice["delta"]["content"] = message.content
choice["delta"]["tool_calls"] = None
elif tool_calls:
for _tc in tool_calls:
if _tc.get("function", {}).get("arguments") == "{}":
_tc["function"]["arguments"] = "" # avoid invalid json
# extract the content str
content_str = DatabricksConfig.extract_content_str(
choice["delta"].get("content")

View file

@ -2,7 +2,11 @@ from typing import List, Literal, Optional, Tuple, Union, cast
import litellm
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionImageObject,
OpenAIChatCompletionToolParam,
)
from litellm.types.utils import ProviderSpecificModelInfo
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
@ -150,6 +154,14 @@ class FireworksAIConfig(OpenAIGPTConfig):
] = f"{content['image_url']['url']}#transform=inline"
return content
def _transform_tools(
self, tools: List[OpenAIChatCompletionToolParam]
) -> List[OpenAIChatCompletionToolParam]:
for tool in tools:
if tool.get("type") == "function":
tool["function"].pop("strict", None)
return tools
def _transform_messages_helper(
self, messages: List[AllMessageValues], model: str, litellm_params: dict
) -> List[AllMessageValues]:
@ -196,6 +208,9 @@ class FireworksAIConfig(OpenAIGPTConfig):
messages = self._transform_messages_helper(
messages=messages, model=model, litellm_params=litellm_params
)
if "tools" in optional_params and optional_params["tools"] is not None:
tools = self._transform_tools(tools=optional_params["tools"])
optional_params["tools"] = tools
return super().transform_request(
model=model,
messages=messages,

View file

@ -374,7 +374,11 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
optional_params["responseLogprobs"] = value
elif param == "top_logprobs":
optional_params["logprobs"] = value
elif (param == "tools" or param == "functions") and isinstance(value, list):
elif (
(param == "tools" or param == "functions")
and isinstance(value, list)
and value
):
optional_params["tools"] = self._map_function(value=value)
optional_params["litellm_param_is_function_call"] = (
True if param == "functions" else False

View file

@ -695,6 +695,7 @@ class ChatCompletionToolParamFunctionChunk(TypedDict, total=False):
name: Required[str]
description: str
parameters: dict
strict: bool
class OpenAIChatCompletionToolParam(TypedDict):

View file

@ -6112,6 +6112,8 @@ def validate_and_fix_openai_messages(messages: List):
for message in messages:
if not message.get("role"):
message["role"] = "assistant"
if message.get("tool_calls"):
message["tool_calls"] = jsonify_tools(tools=message["tool_calls"])
return validate_chat_completion_messages(messages=messages)
@ -6705,3 +6707,20 @@ def return_raw_request(endpoint: CallTypes, kwargs: dict) -> RawRequestTypedDict
return RawRequestTypedDict(
error=received_exception,
)
def jsonify_tools(tools: List[Any]) -> List[Dict]:
"""
Fixes https://github.com/BerriAI/litellm/issues/9321
Where user passes in a pydantic base model
"""
new_tools: List[Dict] = []
for tool in tools:
if isinstance(tool, BaseModel):
tool = tool.model_dump(exclude_none=True)
elif isinstance(tool, dict):
tool = tool.copy()
if isinstance(tool, dict):
new_tools.append(tool)
return new_tools

View file

@ -804,6 +804,35 @@ class BaseLLMChatTest(ABC):
url = f"data:application/pdf;base64,{encoded_file}"
return url
def test_empty_tools(self):
"""
Related Issue: https://github.com/BerriAI/litellm/issues/9080
"""
try:
from litellm import completion, ModelResponse
litellm.set_verbose = True
litellm._turn_on_debug()
from litellm.utils import supports_function_calling
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
base_completion_call_args = self.get_base_completion_call_args()
if not supports_function_calling(base_completion_call_args["model"], None):
print("Model does not support function calling")
pytest.skip("Model does not support function calling")
response = completion(**base_completion_call_args, messages=[{"role": "user", "content": "Hello, how are you?"}], tools=[]) # just make sure call doesn't fail
print("response: ", response)
assert response is not None
except litellm.InternalServerError:
pytest.skip("Model is overloaded")
except litellm.RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_basic_tool_calling(self):
try:
@ -1003,6 +1032,101 @@ class BaseLLMChatTest(ABC):
elif input_type == "audio_url":
assert test_file_id in json.dumps(raw_request), "Audio URL not sent to gemini"
def test_function_calling_with_tool_response(self):
from litellm.utils import supports_function_calling
from litellm import completion
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
base_completion_call_args = self.get_base_completion_call_args()
if not supports_function_calling(base_completion_call_args["model"], None):
print("Model does not support function calling")
pytest.skip("Model does not support function calling")
def get_weather(city: str):
return f"City: {city}, Weather: Sunny with 34 degree Celcius"
TOOLS = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for",
}
},
"required": ["city"],
"additionalProperties": False,
},
"strict": True,
},
}
]
messages = [{ "content": "How is the weather in Mumbai?","role": "user"}]
response, iteration = "", 0
while True:
if response:
break
# Create a streaming response with tool calling enabled
stream = completion(
**base_completion_call_args,
messages=messages,
tools=TOOLS,
stream=True,
)
final_tool_calls = {}
for chunk in stream:
delta = chunk.choices[0].delta
print(delta)
if delta.content:
response += delta.content
elif delta.tool_calls:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
else:
final_tool_calls[
index
].function.arguments += tool_call.function.arguments
if final_tool_calls:
for tool_call in final_tool_calls.values():
if tool_call.function.name == "get_weather":
city = json.loads(tool_call.function.arguments)["city"]
tool_response = get_weather(city)
messages.append(
{
"role": "assistant",
"tool_calls": [tool_call],
"content": None,
}
)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_response,
}
)
iteration += 1
if iteration > 2:
print("Something went wrong!")
break
print(response)
class BaseOSeriesModelsTest(ABC): # test across azure/openai
@abstractmethod
def get_base_completion_call_args(self):

View file

View file

@ -3368,7 +3368,6 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
messages=[{"role": "user", "content": "Hey! how's it going?"}],
temperature=0.2,
max_tokens=200,
stop=["stop sequence"],
)
assert isinstance(response, litellm.ModelResponse)
@ -3380,7 +3379,6 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
messages=[{"role": "user", "content": "Hey! how's it going?"}],
temperature=0.2,
max_tokens=100,
stop=["stop sequence"],
)
assert isinstance(response, litellm.ModelResponse)

View file

@ -562,8 +562,9 @@ def test_groq_parallel_function_call():
@pytest.mark.parametrize(
"model",
[
"anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-haiku-20240307",
# "anthropic.claude-3-sonnet-20240229-v1:0",
# "claude-3-haiku-20240307",
"databricks/databricks-claude-3-7-sonnet"
],
)
def test_anthropic_function_call_with_no_schema(model):