fix(utils.py): support 'drop_params' for 'parallel_tool_calls'

Closes https://github.com/BerriAI/litellm/issues/4584

 OpenAI-only param
This commit is contained in:
Krrish Dholakia 2024-07-08 07:36:41 -07:00
parent f889a7e4b0
commit a00a1267bc
5 changed files with 86 additions and 24 deletions

View file

@ -445,6 +445,7 @@ class OpenAIConfig:
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
] # works across all models
model_specific_params = []

View file

@ -1,28 +1,31 @@
import json, types, time # noqa: E401
import asyncio
import json # noqa: E401
import time
import types
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from enum import Enum
from contextlib import asynccontextmanager, contextmanager
from typing import (
Any,
AsyncContextManager,
AsyncGenerator,
AsyncIterator,
Callable,
ContextManager,
Dict,
Generator,
AsyncGenerator,
Iterator,
AsyncIterator,
Optional,
Any,
Union,
List,
ContextManager,
AsyncContextManager,
Optional,
Union,
)
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.utils import ModelResponse, Usage, get_secret
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.utils import ModelResponse, Usage, get_secret
from .base import BaseLLM
from .prompt_templates import factory as ptf
@ -440,7 +443,7 @@ class IBMWatsonXAI(BaseLLM):
acompletion=None,
litellm_params=None,
logger_fn=None,
timeout=None
timeout=None,
):
"""
Send a text generation request to the IBM Watsonx.ai API.
@ -547,7 +550,9 @@ class IBMWatsonXAI(BaseLLM):
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def _process_embedding_response(self, json_resp: dict, model_response:Union[ModelResponse,None]=None) -> ModelResponse:
def _process_embedding_response(
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
) -> ModelResponse:
if model_response is None:
model_response = ModelResponse(model=json_resp.get("model_id", None))
results = json_resp.get("results", [])
@ -563,10 +568,14 @@ class IBMWatsonXAI(BaseLLM):
model_response["object"] = "list"
model_response["data"] = embedding_response
input_tokens = json_resp.get("input_token_count", 0)
model_response.usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
),
)
return model_response
@ -580,7 +589,7 @@ class IBMWatsonXAI(BaseLLM):
optional_params=None,
encoding=None,
print_verbose=None,
aembedding=None
aembedding=None,
):
"""
Send a text embedding request to the IBM Watsonx.ai API.
@ -593,7 +602,7 @@ class IBMWatsonXAI(BaseLLM):
if k not in optional_params:
optional_params[k] = v
model_response['model'] = model
model_response["model"] = model
# Load auth variables from environment variables
if isinstance(input, str):
@ -685,6 +694,7 @@ class IBMWatsonXAI(BaseLLM):
return json_resp
return [res["model_id"] for res in json_resp["resources"]]
class RequestManager:
"""
A class to handle sync/async HTTP requests to the IBM Watsonx.ai API.

View file

@ -108,7 +108,6 @@ from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion
from .llms.watsonx import IBMWatsonXAI
from .llms.prompt_templates.factory import (
custom_prompt,
function_call_prompt,
@ -119,6 +118,7 @@ from .llms.prompt_templates.factory import (
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_httpx import VertexLLM
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ChatCompletionMessageToolCall
@ -593,6 +593,7 @@ def completion(
tool_choice: Optional[Union[str, dict]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
deployment_id=None,
extra_headers: Optional[dict] = None,
# soon to be deprecated params by OpenAI
@ -722,6 +723,7 @@ def completion(
"tools",
"tool_choice",
"max_retries",
"parallel_tool_calls",
"logprobs",
"top_logprobs",
"extra_headers",
@ -932,6 +934,7 @@ def completion(
top_logprobs=top_logprobs,
extra_headers=extra_headers,
api_version=api_version,
parallel_tool_calls=parallel_tool_calls,
**non_default_params,
)

View file

@ -269,7 +269,7 @@ def test_dynamic_drop_params(drop_params):
"""
Make a call to cohere w/ drop params = True vs. false.
"""
if drop_params == True:
if drop_params is True:
optional_params = litellm.utils.get_optional_params(
model="command-r",
custom_llm_provider="cohere",
@ -306,6 +306,52 @@ def test_dynamic_drop_params_e2e():
assert "response_format" not in mock_response.call_args.kwargs["data"]
@pytest.mark.parametrize(
"model, provider, should_drop",
[("command-r", "cohere", True), ("gpt-3.5-turbo", "openai", False)],
)
def test_drop_params_parallel_tool_calls(model, provider, should_drop):
"""
https://github.com/BerriAI/litellm/issues/4584
"""
response = litellm.utils.get_optional_params(
model=model,
custom_llm_provider=provider,
response_format="json",
parallel_tool_calls=True,
drop_params=True,
)
print(response)
if should_drop:
assert "response_format" not in response
assert "parallel_tool_calls" not in response
else:
assert "response_format" in response
assert "parallel_tool_calls" in response
def test_dynamic_drop_params_parallel_tool_calls():
"""
https://github.com/BerriAI/litellm/issues/4584
"""
with patch("requests.post", new=MagicMock()) as mock_response:
try:
response = litellm.completion(
model="command-r",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
parallel_tool_calls=True,
drop_params=True,
)
except Exception as e:
pass
mock_response.assert_called_once()
print(mock_response.call_args.kwargs["data"])
assert "parallel_tool_calls" not in mock_response.call_args.kwargs["data"]
@pytest.mark.parametrize("drop_params", [True, False, None])
def test_dynamic_drop_additional_params(drop_params):
"""

View file

@ -2393,6 +2393,7 @@ def get_optional_params(
top_logprobs=None,
extra_headers=None,
api_version=None,
parallel_tool_calls=None,
drop_params=None,
additional_drop_params=None,
**kwargs,
@ -2470,6 +2471,7 @@ def get_optional_params(
"top_logprobs": None,
"extra_headers": None,
"api_version": None,
"parallel_tool_calls": None,
"drop_params": None,
"additional_drop_params": None,
}