mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Litellm dev 04 05 2025 p2 (#9774)
* test: move test to just checking async * fix(transformation.py): handle function call with no schema * fix(utils.py): handle pydantic base model in message tool calls Fix https://github.com/BerriAI/litellm/issues/9321 * fix(vertex_and_google_ai_studio.py): handle tools=[] Fixes https://github.com/BerriAI/litellm/issues/9080 * test: remove max token restriction * test: fix basic test * fix(get_supported_openai_params.py): fix check * fix(converse_transformation.py): support fake streaming for meta.llama3-3-70b-instruct-v1:0 * fix: fix test * fix: parse out empty dictionary on dbrx streaming + tool calls * fix(handle-'strict'-param-when-calling-fireworks-ai): fireworks ai does not support 'strict' param * fix: fix ruff check ' * fix: handle no strict in function * fix: revert bedrock change - handle in separate PR
This commit is contained in:
parent
d8f47fc9e5
commit
fcf17d114f
10 changed files with 214 additions and 11 deletions
|
@ -30,6 +30,7 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
ChatCompletionUserMessage,
|
||||
OpenAIChatCompletionToolParam,
|
||||
OpenAIMessageContentListBlock,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse, PromptTokensDetailsWrapper, Usage
|
||||
|
@ -211,6 +212,23 @@ class AmazonConverseConfig(BaseConfig):
|
|||
)
|
||||
return _tool
|
||||
|
||||
def _apply_tool_call_transformation(
|
||||
self,
|
||||
tools: List[OpenAIChatCompletionToolParam],
|
||||
model: str,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
):
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=tools
|
||||
)
|
||||
|
||||
if (
|
||||
"meta.llama3-3-70b-instruct-v1:0" in model
|
||||
and non_default_params.get("stream", False) is True
|
||||
):
|
||||
optional_params["fake_stream"] = True
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
|
@ -286,8 +304,11 @@ class AmazonConverseConfig(BaseConfig):
|
|||
if param == "top_p":
|
||||
optional_params["topP"] = value
|
||||
if param == "tools" and isinstance(value, list):
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=value
|
||||
self._apply_tool_call_transformation(
|
||||
tools=cast(List[OpenAIChatCompletionToolParam], value),
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
if param == "tool_choice":
|
||||
_tool_choice_value = self.map_tool_choice_values(
|
||||
|
|
|
@ -27,7 +27,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
|||
strip_name_from_messages,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.types.llms.anthropic import AnthropicMessagesTool
|
||||
from litellm.types.llms.anthropic import AllAnthropicToolsValues
|
||||
from litellm.types.llms.databricks import (
|
||||
AllDatabricksContentValues,
|
||||
DatabricksChoice,
|
||||
|
@ -160,7 +160,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
|||
]
|
||||
|
||||
def convert_anthropic_tool_to_databricks_tool(
|
||||
self, tool: Optional[AnthropicMessagesTool]
|
||||
self, tool: Optional[AllAnthropicToolsValues]
|
||||
) -> Optional[DatabricksTool]:
|
||||
if tool is None:
|
||||
return None
|
||||
|
@ -173,6 +173,19 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
|||
),
|
||||
)
|
||||
|
||||
def _map_openai_to_dbrx_tool(self, model: str, tools: List) -> List[DatabricksTool]:
|
||||
# if not claude, send as is
|
||||
if "claude" not in model:
|
||||
return tools
|
||||
|
||||
# if claude, convert to anthropic tool and then to databricks tool
|
||||
anthropic_tools = self._map_tools(tools=tools)
|
||||
databricks_tools = [
|
||||
cast(DatabricksTool, self.convert_anthropic_tool_to_databricks_tool(tool))
|
||||
for tool in anthropic_tools
|
||||
]
|
||||
return databricks_tools
|
||||
|
||||
def map_response_format_to_databricks_tool(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -202,6 +215,10 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
|||
mapped_params = super().map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
if "tools" in mapped_params:
|
||||
mapped_params["tools"] = self._map_openai_to_dbrx_tool(
|
||||
model=model, tools=mapped_params["tools"]
|
||||
)
|
||||
if (
|
||||
"max_completion_tokens" in non_default_params
|
||||
and replace_max_completion_tokens_with_max_tokens
|
||||
|
@ -499,7 +516,10 @@ class DatabricksChatResponseIterator(BaseModelResponseIterator):
|
|||
message.content = ""
|
||||
choice["delta"]["content"] = message.content
|
||||
choice["delta"]["tool_calls"] = None
|
||||
|
||||
elif tool_calls:
|
||||
for _tc in tool_calls:
|
||||
if _tc.get("function", {}).get("arguments") == "{}":
|
||||
_tc["function"]["arguments"] = "" # avoid invalid json
|
||||
# extract the content str
|
||||
content_str = DatabricksConfig.extract_content_str(
|
||||
choice["delta"].get("content")
|
||||
|
|
|
@ -2,7 +2,11 @@ from typing import List, Literal, Optional, Tuple, Union, cast
|
|||
|
||||
import litellm
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionImageObject,
|
||||
OpenAIChatCompletionToolParam,
|
||||
)
|
||||
from litellm.types.utils import ProviderSpecificModelInfo
|
||||
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
@ -150,6 +154,14 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
|||
] = f"{content['image_url']['url']}#transform=inline"
|
||||
return content
|
||||
|
||||
def _transform_tools(
|
||||
self, tools: List[OpenAIChatCompletionToolParam]
|
||||
) -> List[OpenAIChatCompletionToolParam]:
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
tool["function"].pop("strict", None)
|
||||
return tools
|
||||
|
||||
def _transform_messages_helper(
|
||||
self, messages: List[AllMessageValues], model: str, litellm_params: dict
|
||||
) -> List[AllMessageValues]:
|
||||
|
@ -196,6 +208,9 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
|||
messages = self._transform_messages_helper(
|
||||
messages=messages, model=model, litellm_params=litellm_params
|
||||
)
|
||||
if "tools" in optional_params and optional_params["tools"] is not None:
|
||||
tools = self._transform_tools(tools=optional_params["tools"])
|
||||
optional_params["tools"] = tools
|
||||
return super().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
|
@ -374,7 +374,11 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
optional_params["responseLogprobs"] = value
|
||||
elif param == "top_logprobs":
|
||||
optional_params["logprobs"] = value
|
||||
elif (param == "tools" or param == "functions") and isinstance(value, list):
|
||||
elif (
|
||||
(param == "tools" or param == "functions")
|
||||
and isinstance(value, list)
|
||||
and value
|
||||
):
|
||||
optional_params["tools"] = self._map_function(value=value)
|
||||
optional_params["litellm_param_is_function_call"] = (
|
||||
True if param == "functions" else False
|
||||
|
|
|
@ -695,6 +695,7 @@ class ChatCompletionToolParamFunctionChunk(TypedDict, total=False):
|
|||
name: Required[str]
|
||||
description: str
|
||||
parameters: dict
|
||||
strict: bool
|
||||
|
||||
|
||||
class OpenAIChatCompletionToolParam(TypedDict):
|
||||
|
|
|
@ -6112,6 +6112,8 @@ def validate_and_fix_openai_messages(messages: List):
|
|||
for message in messages:
|
||||
if not message.get("role"):
|
||||
message["role"] = "assistant"
|
||||
if message.get("tool_calls"):
|
||||
message["tool_calls"] = jsonify_tools(tools=message["tool_calls"])
|
||||
return validate_chat_completion_messages(messages=messages)
|
||||
|
||||
|
||||
|
@ -6705,3 +6707,20 @@ def return_raw_request(endpoint: CallTypes, kwargs: dict) -> RawRequestTypedDict
|
|||
return RawRequestTypedDict(
|
||||
error=received_exception,
|
||||
)
|
||||
|
||||
|
||||
def jsonify_tools(tools: List[Any]) -> List[Dict]:
|
||||
"""
|
||||
Fixes https://github.com/BerriAI/litellm/issues/9321
|
||||
|
||||
Where user passes in a pydantic base model
|
||||
"""
|
||||
new_tools: List[Dict] = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, BaseModel):
|
||||
tool = tool.model_dump(exclude_none=True)
|
||||
elif isinstance(tool, dict):
|
||||
tool = tool.copy()
|
||||
if isinstance(tool, dict):
|
||||
new_tools.append(tool)
|
||||
return new_tools
|
||||
|
|
|
@ -804,6 +804,35 @@ class BaseLLMChatTest(ABC):
|
|||
url = f"data:application/pdf;base64,{encoded_file}"
|
||||
|
||||
return url
|
||||
|
||||
def test_empty_tools(self):
|
||||
"""
|
||||
Related Issue: https://github.com/BerriAI/litellm/issues/9080
|
||||
"""
|
||||
try:
|
||||
from litellm import completion, ModelResponse
|
||||
|
||||
litellm.set_verbose = True
|
||||
litellm._turn_on_debug()
|
||||
from litellm.utils import supports_function_calling
|
||||
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
if not supports_function_calling(base_completion_call_args["model"], None):
|
||||
print("Model does not support function calling")
|
||||
pytest.skip("Model does not support function calling")
|
||||
|
||||
response = completion(**base_completion_call_args, messages=[{"role": "user", "content": "Hello, how are you?"}], tools=[]) # just make sure call doesn't fail
|
||||
print("response: ", response)
|
||||
assert response is not None
|
||||
except litellm.InternalServerError:
|
||||
pytest.skip("Model is overloaded")
|
||||
except litellm.RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
def test_basic_tool_calling(self):
|
||||
try:
|
||||
|
@ -1003,6 +1032,101 @@ class BaseLLMChatTest(ABC):
|
|||
elif input_type == "audio_url":
|
||||
assert test_file_id in json.dumps(raw_request), "Audio URL not sent to gemini"
|
||||
|
||||
|
||||
def test_function_calling_with_tool_response(self):
|
||||
from litellm.utils import supports_function_calling
|
||||
from litellm import completion
|
||||
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
if not supports_function_calling(base_completion_call_args["model"], None):
|
||||
print("Model does not support function calling")
|
||||
pytest.skip("Model does not support function calling")
|
||||
|
||||
def get_weather(city: str):
|
||||
return f"City: {city}, Weather: Sunny with 34 degree Celcius"
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the weather for",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
messages = [{ "content": "How is the weather in Mumbai?","role": "user"}]
|
||||
response, iteration = "", 0
|
||||
while True:
|
||||
if response:
|
||||
break
|
||||
# Create a streaming response with tool calling enabled
|
||||
stream = completion(
|
||||
**base_completion_call_args,
|
||||
messages=messages,
|
||||
tools=TOOLS,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
final_tool_calls = {}
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
print(delta)
|
||||
if delta.content:
|
||||
response += delta.content
|
||||
elif delta.tool_calls:
|
||||
for tool_call in chunk.choices[0].delta.tool_calls or []:
|
||||
index = tool_call.index
|
||||
if index not in final_tool_calls:
|
||||
final_tool_calls[index] = tool_call
|
||||
else:
|
||||
final_tool_calls[
|
||||
index
|
||||
].function.arguments += tool_call.function.arguments
|
||||
if final_tool_calls:
|
||||
for tool_call in final_tool_calls.values():
|
||||
if tool_call.function.name == "get_weather":
|
||||
city = json.loads(tool_call.function.arguments)["city"]
|
||||
tool_response = get_weather(city)
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [tool_call],
|
||||
"content": None,
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_response,
|
||||
}
|
||||
)
|
||||
iteration += 1
|
||||
if iteration > 2:
|
||||
print("Something went wrong!")
|
||||
break
|
||||
|
||||
print(response)
|
||||
|
||||
|
||||
|
||||
class BaseOSeriesModelsTest(ABC): # test across azure/openai
|
||||
@abstractmethod
|
||||
def get_base_completion_call_args(self):
|
||||
|
|
0
tests/llm_translation/log.xt
Normal file
0
tests/llm_translation/log.xt
Normal file
|
@ -3368,7 +3368,6 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
|
|||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||
temperature=0.2,
|
||||
max_tokens=200,
|
||||
stop=["stop sequence"],
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
@ -3380,7 +3379,6 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
|
|||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||
temperature=0.2,
|
||||
max_tokens=100,
|
||||
stop=["stop sequence"],
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
|
|
@ -562,8 +562,9 @@ def test_groq_parallel_function_call():
|
|||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-haiku-20240307",
|
||||
# "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
# "claude-3-haiku-20240307",
|
||||
"databricks/databricks-claude-3-7-sonnet"
|
||||
],
|
||||
)
|
||||
def test_anthropic_function_call_with_no_schema(model):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue