From a00a1267bcf43c9af04c0f860caafa0107a53dd5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 8 Jul 2024 07:36:41 -0700 Subject: [PATCH] fix(utils.py): support 'drop_params' for 'parallel_tool_calls' Closes https://github.com/BerriAI/litellm/issues/4584 OpenAI-only param --- litellm/llms/openai.py | 1 + litellm/llms/watsonx.py | 54 ++++++++++++++++----------- litellm/main.py | 5 ++- litellm/tests/test_optional_params.py | 48 +++++++++++++++++++++++- litellm/utils.py | 2 + 5 files changed, 86 insertions(+), 24 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 990ef2faeb..25641c4b8d 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -445,6 +445,7 @@ class OpenAIConfig: "functions", "max_retries", "extra_headers", + "parallel_tool_calls", ] # works across all models model_specific_params = [] diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 3151d3f9c9..5649b714a0 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -1,28 +1,31 @@ -import json, types, time # noqa: E401 import asyncio +import json # noqa: E401 +import time +import types +from contextlib import asynccontextmanager, contextmanager from datetime import datetime from enum import Enum -from contextlib import asynccontextmanager, contextmanager from typing import ( + Any, + AsyncContextManager, + AsyncGenerator, + AsyncIterator, Callable, + ContextManager, Dict, Generator, - AsyncGenerator, Iterator, - AsyncIterator, - Optional, - Any, - Union, List, - ContextManager, - AsyncContextManager, + Optional, + Union, ) import httpx # type: ignore import requests # type: ignore + import litellm -from litellm.utils import ModelResponse, Usage, get_secret from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.utils import ModelResponse, Usage, get_secret from .base import BaseLLM from .prompt_templates import factory as ptf @@ -287,8 +290,8 @@ class IBMWatsonXAI(BaseLLM): ) def _get_api_params( - self, - params: dict, + self, + params: dict, print_verbose: Optional[Callable] = None, generate_token: Optional[bool] = True, ) -> dict: @@ -440,7 +443,7 @@ class IBMWatsonXAI(BaseLLM): acompletion=None, litellm_params=None, logger_fn=None, - timeout=None + timeout=None, ): """ Send a text generation request to the IBM Watsonx.ai API. @@ -547,7 +550,9 @@ class IBMWatsonXAI(BaseLLM): except Exception as e: raise WatsonXAIError(status_code=500, message=str(e)) - def _process_embedding_response(self, json_resp: dict, model_response:Union[ModelResponse,None]=None) -> ModelResponse: + def _process_embedding_response( + self, json_resp: dict, model_response: Union[ModelResponse, None] = None + ) -> ModelResponse: if model_response is None: model_response = ModelResponse(model=json_resp.get("model_id", None)) results = json_resp.get("results", []) @@ -563,10 +568,14 @@ class IBMWatsonXAI(BaseLLM): model_response["object"] = "list" model_response["data"] = embedding_response input_tokens = json_resp.get("input_token_count", 0) - model_response.usage = Usage( - prompt_tokens=input_tokens, - completion_tokens=0, - total_tokens=input_tokens, + setattr( + model_response, + "usage", + Usage( + prompt_tokens=input_tokens, + completion_tokens=0, + total_tokens=input_tokens, + ), ) return model_response @@ -580,7 +589,7 @@ class IBMWatsonXAI(BaseLLM): optional_params=None, encoding=None, print_verbose=None, - aembedding=None + aembedding=None, ): """ Send a text embedding request to the IBM Watsonx.ai API. @@ -593,8 +602,8 @@ class IBMWatsonXAI(BaseLLM): if k not in optional_params: optional_params[k] = v - model_response['model'] = model - + model_response["model"] = model + # Load auth variables from environment variables if isinstance(input, str): input = [input] @@ -685,6 +694,7 @@ class IBMWatsonXAI(BaseLLM): return json_resp return [res["model_id"] for res in json_resp["resources"]] + class RequestManager: """ A class to handle sync/async HTTP requests to the IBM Watsonx.ai API. @@ -744,7 +754,7 @@ class RequestManager: @contextmanager def request( - self, + self, request_params: dict, stream: bool = False, input: Optional[Any] = None, diff --git a/litellm/main.py b/litellm/main.py index 0a3a0a64ae..097c867073 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -108,7 +108,6 @@ from .llms.databricks import DatabricksChatCompletion from .llms.huggingface_restapi import Huggingface from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.predibase import PredibaseChatCompletion -from .llms.watsonx import IBMWatsonXAI from .llms.prompt_templates.factory import ( custom_prompt, function_call_prompt, @@ -119,6 +118,7 @@ from .llms.prompt_templates.factory import ( from .llms.text_completion_codestral import CodestralTextCompletion from .llms.triton import TritonChatCompletion from .llms.vertex_httpx import VertexLLM +from .llms.watsonx import IBMWatsonXAI from .types.llms.openai import HttpxBinaryResponseContent from .types.utils import ChatCompletionMessageToolCall @@ -593,6 +593,7 @@ def completion( tool_choice: Optional[Union[str, dict]] = None, logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, deployment_id=None, extra_headers: Optional[dict] = None, # soon to be deprecated params by OpenAI @@ -722,6 +723,7 @@ def completion( "tools", "tool_choice", "max_retries", + "parallel_tool_calls", "logprobs", "top_logprobs", "extra_headers", @@ -932,6 +934,7 @@ def completion( top_logprobs=top_logprobs, extra_headers=extra_headers, api_version=api_version, + parallel_tool_calls=parallel_tool_calls, **non_default_params, ) diff --git a/litellm/tests/test_optional_params.py b/litellm/tests/test_optional_params.py index a6fa6334b9..bbfc88710f 100644 --- a/litellm/tests/test_optional_params.py +++ b/litellm/tests/test_optional_params.py @@ -269,7 +269,7 @@ def test_dynamic_drop_params(drop_params): """ Make a call to cohere w/ drop params = True vs. false. """ - if drop_params == True: + if drop_params is True: optional_params = litellm.utils.get_optional_params( model="command-r", custom_llm_provider="cohere", @@ -306,6 +306,52 @@ def test_dynamic_drop_params_e2e(): assert "response_format" not in mock_response.call_args.kwargs["data"] +@pytest.mark.parametrize( + "model, provider, should_drop", + [("command-r", "cohere", True), ("gpt-3.5-turbo", "openai", False)], +) +def test_drop_params_parallel_tool_calls(model, provider, should_drop): + """ + https://github.com/BerriAI/litellm/issues/4584 + """ + response = litellm.utils.get_optional_params( + model=model, + custom_llm_provider=provider, + response_format="json", + parallel_tool_calls=True, + drop_params=True, + ) + + print(response) + + if should_drop: + assert "response_format" not in response + assert "parallel_tool_calls" not in response + else: + assert "response_format" in response + assert "parallel_tool_calls" in response + + +def test_dynamic_drop_params_parallel_tool_calls(): + """ + https://github.com/BerriAI/litellm/issues/4584 + """ + with patch("requests.post", new=MagicMock()) as mock_response: + try: + response = litellm.completion( + model="command-r", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + parallel_tool_calls=True, + drop_params=True, + ) + except Exception as e: + pass + + mock_response.assert_called_once() + print(mock_response.call_args.kwargs["data"]) + assert "parallel_tool_calls" not in mock_response.call_args.kwargs["data"] + + @pytest.mark.parametrize("drop_params", [True, False, None]) def test_dynamic_drop_additional_params(drop_params): """ diff --git a/litellm/utils.py b/litellm/utils.py index b5c1f4a316..3849f2b299 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2393,6 +2393,7 @@ def get_optional_params( top_logprobs=None, extra_headers=None, api_version=None, + parallel_tool_calls=None, drop_params=None, additional_drop_params=None, **kwargs, @@ -2470,6 +2471,7 @@ def get_optional_params( "top_logprobs": None, "extra_headers": None, "api_version": None, + "parallel_tool_calls": None, "drop_params": None, "additional_drop_params": None, }