forked from phoenix/litellm-mirror
fix(utils.py): support 'drop_params' for 'parallel_tool_calls'
Closes https://github.com/BerriAI/litellm/issues/4584 OpenAI-only param
This commit is contained in:
parent
40a045cb72
commit
bb905d7243
5 changed files with 86 additions and 24 deletions
|
@ -445,6 +445,7 @@ class OpenAIConfig:
|
|||
"functions",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
] # works across all models
|
||||
|
||||
model_specific_params = []
|
||||
|
|
|
@ -1,28 +1,31 @@
|
|||
import json, types, time # noqa: E401
|
||||
import asyncio
|
||||
import json # noqa: E401
|
||||
import time
|
||||
import types
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generator,
|
||||
AsyncGenerator,
|
||||
Iterator,
|
||||
AsyncIterator,
|
||||
Optional,
|
||||
Any,
|
||||
Union,
|
||||
List,
|
||||
ContextManager,
|
||||
AsyncContextManager,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, Usage, get_secret
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.utils import ModelResponse, Usage, get_secret
|
||||
|
||||
from .base import BaseLLM
|
||||
from .prompt_templates import factory as ptf
|
||||
|
@ -440,7 +443,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
timeout=None
|
||||
timeout=None,
|
||||
):
|
||||
"""
|
||||
Send a text generation request to the IBM Watsonx.ai API.
|
||||
|
@ -547,7 +550,9 @@ class IBMWatsonXAI(BaseLLM):
|
|||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
|
||||
def _process_embedding_response(self, json_resp: dict, model_response:Union[ModelResponse,None]=None) -> ModelResponse:
|
||||
def _process_embedding_response(
|
||||
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
|
||||
) -> ModelResponse:
|
||||
if model_response is None:
|
||||
model_response = ModelResponse(model=json_resp.get("model_id", None))
|
||||
results = json_resp.get("results", [])
|
||||
|
@ -563,10 +568,14 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model_response["object"] = "list"
|
||||
model_response["data"] = embedding_response
|
||||
input_tokens = json_resp.get("input_token_count", 0)
|
||||
model_response.usage = Usage(
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=input_tokens,
|
||||
),
|
||||
)
|
||||
return model_response
|
||||
|
||||
|
@ -580,7 +589,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
optional_params=None,
|
||||
encoding=None,
|
||||
print_verbose=None,
|
||||
aembedding=None
|
||||
aembedding=None,
|
||||
):
|
||||
"""
|
||||
Send a text embedding request to the IBM Watsonx.ai API.
|
||||
|
@ -593,7 +602,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
model_response['model'] = model
|
||||
model_response["model"] = model
|
||||
|
||||
# Load auth variables from environment variables
|
||||
if isinstance(input, str):
|
||||
|
@ -685,6 +694,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
return json_resp
|
||||
return [res["model_id"] for res in json_resp["resources"]]
|
||||
|
||||
|
||||
class RequestManager:
|
||||
"""
|
||||
A class to handle sync/async HTTP requests to the IBM Watsonx.ai API.
|
||||
|
|
|
@ -108,7 +108,6 @@ from .llms.databricks import DatabricksChatCompletion
|
|||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .llms.prompt_templates.factory import (
|
||||
custom_prompt,
|
||||
function_call_prompt,
|
||||
|
@ -119,6 +118,7 @@ from .llms.prompt_templates.factory import (
|
|||
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.vertex_httpx import VertexLLM
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .types.llms.openai import HttpxBinaryResponseContent
|
||||
from .types.utils import ChatCompletionMessageToolCall
|
||||
|
||||
|
@ -593,6 +593,7 @@ def completion(
|
|||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
deployment_id=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
# soon to be deprecated params by OpenAI
|
||||
|
@ -722,6 +723,7 @@ def completion(
|
|||
"tools",
|
||||
"tool_choice",
|
||||
"max_retries",
|
||||
"parallel_tool_calls",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"extra_headers",
|
||||
|
@ -932,6 +934,7 @@ def completion(
|
|||
top_logprobs=top_logprobs,
|
||||
extra_headers=extra_headers,
|
||||
api_version=api_version,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
**non_default_params,
|
||||
)
|
||||
|
||||
|
|
|
@ -269,7 +269,7 @@ def test_dynamic_drop_params(drop_params):
|
|||
"""
|
||||
Make a call to cohere w/ drop params = True vs. false.
|
||||
"""
|
||||
if drop_params == True:
|
||||
if drop_params is True:
|
||||
optional_params = litellm.utils.get_optional_params(
|
||||
model="command-r",
|
||||
custom_llm_provider="cohere",
|
||||
|
@ -306,6 +306,52 @@ def test_dynamic_drop_params_e2e():
|
|||
assert "response_format" not in mock_response.call_args.kwargs["data"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, provider, should_drop",
|
||||
[("command-r", "cohere", True), ("gpt-3.5-turbo", "openai", False)],
|
||||
)
|
||||
def test_drop_params_parallel_tool_calls(model, provider, should_drop):
|
||||
"""
|
||||
https://github.com/BerriAI/litellm/issues/4584
|
||||
"""
|
||||
response = litellm.utils.get_optional_params(
|
||||
model=model,
|
||||
custom_llm_provider=provider,
|
||||
response_format="json",
|
||||
parallel_tool_calls=True,
|
||||
drop_params=True,
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
if should_drop:
|
||||
assert "response_format" not in response
|
||||
assert "parallel_tool_calls" not in response
|
||||
else:
|
||||
assert "response_format" in response
|
||||
assert "parallel_tool_calls" in response
|
||||
|
||||
|
||||
def test_dynamic_drop_params_parallel_tool_calls():
|
||||
"""
|
||||
https://github.com/BerriAI/litellm/issues/4584
|
||||
"""
|
||||
with patch("requests.post", new=MagicMock()) as mock_response:
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="command-r",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
parallel_tool_calls=True,
|
||||
drop_params=True,
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
mock_response.assert_called_once()
|
||||
print(mock_response.call_args.kwargs["data"])
|
||||
assert "parallel_tool_calls" not in mock_response.call_args.kwargs["data"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("drop_params", [True, False, None])
|
||||
def test_dynamic_drop_additional_params(drop_params):
|
||||
"""
|
||||
|
|
|
@ -2393,6 +2393,7 @@ def get_optional_params(
|
|||
top_logprobs=None,
|
||||
extra_headers=None,
|
||||
api_version=None,
|
||||
parallel_tool_calls=None,
|
||||
drop_params=None,
|
||||
additional_drop_params=None,
|
||||
**kwargs,
|
||||
|
@ -2470,6 +2471,7 @@ def get_optional_params(
|
|||
"top_logprobs": None,
|
||||
"extra_headers": None,
|
||||
"api_version": None,
|
||||
"parallel_tool_calls": None,
|
||||
"drop_params": None,
|
||||
"additional_drop_params": None,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue