fix(utils.py): support 'drop_params' for 'parallel_tool_calls'

Closes https://github.com/BerriAI/litellm/issues/4584

 OpenAI-only param
This commit is contained in:
Krrish Dholakia 2024-07-08 07:36:41 -07:00
parent f889a7e4b0
commit a00a1267bc
5 changed files with 86 additions and 24 deletions

View file

@ -445,6 +445,7 @@ class OpenAIConfig:
"functions", "functions",
"max_retries", "max_retries",
"extra_headers", "extra_headers",
"parallel_tool_calls",
] # works across all models ] # works across all models
model_specific_params = [] model_specific_params = []

View file

@ -1,28 +1,31 @@
import json, types, time # noqa: E401
import asyncio import asyncio
import json # noqa: E401
import time
import types
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from contextlib import asynccontextmanager, contextmanager
from typing import ( from typing import (
Any,
AsyncContextManager,
AsyncGenerator,
AsyncIterator,
Callable, Callable,
ContextManager,
Dict, Dict,
Generator, Generator,
AsyncGenerator,
Iterator, Iterator,
AsyncIterator,
Optional,
Any,
Union,
List, List,
ContextManager, Optional,
AsyncContextManager, Union,
) )
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import litellm import litellm
from litellm.utils import ModelResponse, Usage, get_secret
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.utils import ModelResponse, Usage, get_secret
from .base import BaseLLM from .base import BaseLLM
from .prompt_templates import factory as ptf from .prompt_templates import factory as ptf
@ -287,8 +290,8 @@ class IBMWatsonXAI(BaseLLM):
) )
def _get_api_params( def _get_api_params(
self, self,
params: dict, params: dict,
print_verbose: Optional[Callable] = None, print_verbose: Optional[Callable] = None,
generate_token: Optional[bool] = True, generate_token: Optional[bool] = True,
) -> dict: ) -> dict:
@ -440,7 +443,7 @@ class IBMWatsonXAI(BaseLLM):
acompletion=None, acompletion=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
timeout=None timeout=None,
): ):
""" """
Send a text generation request to the IBM Watsonx.ai API. Send a text generation request to the IBM Watsonx.ai API.
@ -547,7 +550,9 @@ class IBMWatsonXAI(BaseLLM):
except Exception as e: except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e)) raise WatsonXAIError(status_code=500, message=str(e))
def _process_embedding_response(self, json_resp: dict, model_response:Union[ModelResponse,None]=None) -> ModelResponse: def _process_embedding_response(
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
) -> ModelResponse:
if model_response is None: if model_response is None:
model_response = ModelResponse(model=json_resp.get("model_id", None)) model_response = ModelResponse(model=json_resp.get("model_id", None))
results = json_resp.get("results", []) results = json_resp.get("results", [])
@ -563,10 +568,14 @@ class IBMWatsonXAI(BaseLLM):
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = embedding_response model_response["data"] = embedding_response
input_tokens = json_resp.get("input_token_count", 0) input_tokens = json_resp.get("input_token_count", 0)
model_response.usage = Usage( setattr(
prompt_tokens=input_tokens, model_response,
completion_tokens=0, "usage",
total_tokens=input_tokens, Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
),
) )
return model_response return model_response
@ -580,7 +589,7 @@ class IBMWatsonXAI(BaseLLM):
optional_params=None, optional_params=None,
encoding=None, encoding=None,
print_verbose=None, print_verbose=None,
aembedding=None aembedding=None,
): ):
""" """
Send a text embedding request to the IBM Watsonx.ai API. Send a text embedding request to the IBM Watsonx.ai API.
@ -593,8 +602,8 @@ class IBMWatsonXAI(BaseLLM):
if k not in optional_params: if k not in optional_params:
optional_params[k] = v optional_params[k] = v
model_response['model'] = model model_response["model"] = model
# Load auth variables from environment variables # Load auth variables from environment variables
if isinstance(input, str): if isinstance(input, str):
input = [input] input = [input]
@ -685,6 +694,7 @@ class IBMWatsonXAI(BaseLLM):
return json_resp return json_resp
return [res["model_id"] for res in json_resp["resources"]] return [res["model_id"] for res in json_resp["resources"]]
class RequestManager: class RequestManager:
""" """
A class to handle sync/async HTTP requests to the IBM Watsonx.ai API. A class to handle sync/async HTTP requests to the IBM Watsonx.ai API.
@ -744,7 +754,7 @@ class RequestManager:
@contextmanager @contextmanager
def request( def request(
self, self,
request_params: dict, request_params: dict,
stream: bool = False, stream: bool = False,
input: Optional[Any] = None, input: Optional[Any] = None,

View file

@ -108,7 +108,6 @@ from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion from .llms.predibase import PredibaseChatCompletion
from .llms.watsonx import IBMWatsonXAI
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
custom_prompt, custom_prompt,
function_call_prompt, function_call_prompt,
@ -119,6 +118,7 @@ from .llms.prompt_templates.factory import (
from .llms.text_completion_codestral import CodestralTextCompletion from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
from .llms.vertex_httpx import VertexLLM from .llms.vertex_httpx import VertexLLM
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ChatCompletionMessageToolCall from .types.utils import ChatCompletionMessageToolCall
@ -593,6 +593,7 @@ def completion(
tool_choice: Optional[Union[str, dict]] = None, tool_choice: Optional[Union[str, dict]] = None,
logprobs: Optional[bool] = None, logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
deployment_id=None, deployment_id=None,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
# soon to be deprecated params by OpenAI # soon to be deprecated params by OpenAI
@ -722,6 +723,7 @@ def completion(
"tools", "tools",
"tool_choice", "tool_choice",
"max_retries", "max_retries",
"parallel_tool_calls",
"logprobs", "logprobs",
"top_logprobs", "top_logprobs",
"extra_headers", "extra_headers",
@ -932,6 +934,7 @@ def completion(
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
extra_headers=extra_headers, extra_headers=extra_headers,
api_version=api_version, api_version=api_version,
parallel_tool_calls=parallel_tool_calls,
**non_default_params, **non_default_params,
) )

View file

@ -269,7 +269,7 @@ def test_dynamic_drop_params(drop_params):
""" """
Make a call to cohere w/ drop params = True vs. false. Make a call to cohere w/ drop params = True vs. false.
""" """
if drop_params == True: if drop_params is True:
optional_params = litellm.utils.get_optional_params( optional_params = litellm.utils.get_optional_params(
model="command-r", model="command-r",
custom_llm_provider="cohere", custom_llm_provider="cohere",
@ -306,6 +306,52 @@ def test_dynamic_drop_params_e2e():
assert "response_format" not in mock_response.call_args.kwargs["data"] assert "response_format" not in mock_response.call_args.kwargs["data"]
@pytest.mark.parametrize(
"model, provider, should_drop",
[("command-r", "cohere", True), ("gpt-3.5-turbo", "openai", False)],
)
def test_drop_params_parallel_tool_calls(model, provider, should_drop):
"""
https://github.com/BerriAI/litellm/issues/4584
"""
response = litellm.utils.get_optional_params(
model=model,
custom_llm_provider=provider,
response_format="json",
parallel_tool_calls=True,
drop_params=True,
)
print(response)
if should_drop:
assert "response_format" not in response
assert "parallel_tool_calls" not in response
else:
assert "response_format" in response
assert "parallel_tool_calls" in response
def test_dynamic_drop_params_parallel_tool_calls():
"""
https://github.com/BerriAI/litellm/issues/4584
"""
with patch("requests.post", new=MagicMock()) as mock_response:
try:
response = litellm.completion(
model="command-r",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
parallel_tool_calls=True,
drop_params=True,
)
except Exception as e:
pass
mock_response.assert_called_once()
print(mock_response.call_args.kwargs["data"])
assert "parallel_tool_calls" not in mock_response.call_args.kwargs["data"]
@pytest.mark.parametrize("drop_params", [True, False, None]) @pytest.mark.parametrize("drop_params", [True, False, None])
def test_dynamic_drop_additional_params(drop_params): def test_dynamic_drop_additional_params(drop_params):
""" """

View file

@ -2393,6 +2393,7 @@ def get_optional_params(
top_logprobs=None, top_logprobs=None,
extra_headers=None, extra_headers=None,
api_version=None, api_version=None,
parallel_tool_calls=None,
drop_params=None, drop_params=None,
additional_drop_params=None, additional_drop_params=None,
**kwargs, **kwargs,
@ -2470,6 +2471,7 @@ def get_optional_params(
"top_logprobs": None, "top_logprobs": None,
"extra_headers": None, "extra_headers": None,
"api_version": None, "api_version": None,
"parallel_tool_calls": None,
"drop_params": None, "drop_params": None,
"additional_drop_params": None, "additional_drop_params": None,
} }