forked from phoenix/litellm-mirror
fix(utils.py): support 'drop_params' for 'parallel_tool_calls'
Closes https://github.com/BerriAI/litellm/issues/4584 OpenAI-only param
This commit is contained in:
parent
40a045cb72
commit
bb905d7243
5 changed files with 86 additions and 24 deletions
|
@ -445,6 +445,7 @@ class OpenAIConfig:
|
||||||
"functions",
|
"functions",
|
||||||
"max_retries",
|
"max_retries",
|
||||||
"extra_headers",
|
"extra_headers",
|
||||||
|
"parallel_tool_calls",
|
||||||
] # works across all models
|
] # works across all models
|
||||||
|
|
||||||
model_specific_params = []
|
model_specific_params = []
|
||||||
|
|
|
@ -1,28 +1,31 @@
|
||||||
import json, types, time # noqa: E401
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json # noqa: E401
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncContextManager,
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
|
ContextManager,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
AsyncGenerator,
|
|
||||||
Iterator,
|
Iterator,
|
||||||
AsyncIterator,
|
|
||||||
Optional,
|
|
||||||
Any,
|
|
||||||
Union,
|
|
||||||
List,
|
List,
|
||||||
ContextManager,
|
Optional,
|
||||||
AsyncContextManager,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse, Usage, get_secret
|
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from litellm.utils import ModelResponse, Usage, get_secret
|
||||||
|
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
from .prompt_templates import factory as ptf
|
from .prompt_templates import factory as ptf
|
||||||
|
@ -440,7 +443,7 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
acompletion=None,
|
acompletion=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
timeout=None
|
timeout=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Send a text generation request to the IBM Watsonx.ai API.
|
Send a text generation request to the IBM Watsonx.ai API.
|
||||||
|
@ -547,7 +550,9 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise WatsonXAIError(status_code=500, message=str(e))
|
raise WatsonXAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
def _process_embedding_response(self, json_resp: dict, model_response:Union[ModelResponse,None]=None) -> ModelResponse:
|
def _process_embedding_response(
|
||||||
|
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
|
||||||
|
) -> ModelResponse:
|
||||||
if model_response is None:
|
if model_response is None:
|
||||||
model_response = ModelResponse(model=json_resp.get("model_id", None))
|
model_response = ModelResponse(model=json_resp.get("model_id", None))
|
||||||
results = json_resp.get("results", [])
|
results = json_resp.get("results", [])
|
||||||
|
@ -563,10 +568,14 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
model_response["object"] = "list"
|
model_response["object"] = "list"
|
||||||
model_response["data"] = embedding_response
|
model_response["data"] = embedding_response
|
||||||
input_tokens = json_resp.get("input_token_count", 0)
|
input_tokens = json_resp.get("input_token_count", 0)
|
||||||
model_response.usage = Usage(
|
setattr(
|
||||||
|
model_response,
|
||||||
|
"usage",
|
||||||
|
Usage(
|
||||||
prompt_tokens=input_tokens,
|
prompt_tokens=input_tokens,
|
||||||
completion_tokens=0,
|
completion_tokens=0,
|
||||||
total_tokens=input_tokens,
|
total_tokens=input_tokens,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
@ -580,7 +589,7 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
print_verbose=None,
|
print_verbose=None,
|
||||||
aembedding=None
|
aembedding=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Send a text embedding request to the IBM Watsonx.ai API.
|
Send a text embedding request to the IBM Watsonx.ai API.
|
||||||
|
@ -593,7 +602,7 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
if k not in optional_params:
|
if k not in optional_params:
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
model_response['model'] = model
|
model_response["model"] = model
|
||||||
|
|
||||||
# Load auth variables from environment variables
|
# Load auth variables from environment variables
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
|
@ -685,6 +694,7 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
return json_resp
|
return json_resp
|
||||||
return [res["model_id"] for res in json_resp["resources"]]
|
return [res["model_id"] for res in json_resp["resources"]]
|
||||||
|
|
||||||
|
|
||||||
class RequestManager:
|
class RequestManager:
|
||||||
"""
|
"""
|
||||||
A class to handle sync/async HTTP requests to the IBM Watsonx.ai API.
|
A class to handle sync/async HTTP requests to the IBM Watsonx.ai API.
|
||||||
|
|
|
@ -108,7 +108,6 @@ from .llms.databricks import DatabricksChatCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
from .llms.watsonx import IBMWatsonXAI
|
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
function_call_prompt,
|
function_call_prompt,
|
||||||
|
@ -119,6 +118,7 @@ from .llms.prompt_templates.factory import (
|
||||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.vertex_httpx import VertexLLM
|
from .llms.vertex_httpx import VertexLLM
|
||||||
|
from .llms.watsonx import IBMWatsonXAI
|
||||||
from .types.llms.openai import HttpxBinaryResponseContent
|
from .types.llms.openai import HttpxBinaryResponseContent
|
||||||
from .types.utils import ChatCompletionMessageToolCall
|
from .types.utils import ChatCompletionMessageToolCall
|
||||||
|
|
||||||
|
@ -593,6 +593,7 @@ def completion(
|
||||||
tool_choice: Optional[Union[str, dict]] = None,
|
tool_choice: Optional[Union[str, dict]] = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: Optional[bool] = None,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: Optional[int] = None,
|
||||||
|
parallel_tool_calls: Optional[bool] = None,
|
||||||
deployment_id=None,
|
deployment_id=None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
# soon to be deprecated params by OpenAI
|
# soon to be deprecated params by OpenAI
|
||||||
|
@ -722,6 +723,7 @@ def completion(
|
||||||
"tools",
|
"tools",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
"max_retries",
|
"max_retries",
|
||||||
|
"parallel_tool_calls",
|
||||||
"logprobs",
|
"logprobs",
|
||||||
"top_logprobs",
|
"top_logprobs",
|
||||||
"extra_headers",
|
"extra_headers",
|
||||||
|
@ -932,6 +934,7 @@ def completion(
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=top_logprobs,
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
**non_default_params,
|
**non_default_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -269,7 +269,7 @@ def test_dynamic_drop_params(drop_params):
|
||||||
"""
|
"""
|
||||||
Make a call to cohere w/ drop params = True vs. false.
|
Make a call to cohere w/ drop params = True vs. false.
|
||||||
"""
|
"""
|
||||||
if drop_params == True:
|
if drop_params is True:
|
||||||
optional_params = litellm.utils.get_optional_params(
|
optional_params = litellm.utils.get_optional_params(
|
||||||
model="command-r",
|
model="command-r",
|
||||||
custom_llm_provider="cohere",
|
custom_llm_provider="cohere",
|
||||||
|
@ -306,6 +306,52 @@ def test_dynamic_drop_params_e2e():
|
||||||
assert "response_format" not in mock_response.call_args.kwargs["data"]
|
assert "response_format" not in mock_response.call_args.kwargs["data"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, provider, should_drop",
|
||||||
|
[("command-r", "cohere", True), ("gpt-3.5-turbo", "openai", False)],
|
||||||
|
)
|
||||||
|
def test_drop_params_parallel_tool_calls(model, provider, should_drop):
|
||||||
|
"""
|
||||||
|
https://github.com/BerriAI/litellm/issues/4584
|
||||||
|
"""
|
||||||
|
response = litellm.utils.get_optional_params(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=provider,
|
||||||
|
response_format="json",
|
||||||
|
parallel_tool_calls=True,
|
||||||
|
drop_params=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
if should_drop:
|
||||||
|
assert "response_format" not in response
|
||||||
|
assert "parallel_tool_calls" not in response
|
||||||
|
else:
|
||||||
|
assert "response_format" in response
|
||||||
|
assert "parallel_tool_calls" in response
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_drop_params_parallel_tool_calls():
|
||||||
|
"""
|
||||||
|
https://github.com/BerriAI/litellm/issues/4584
|
||||||
|
"""
|
||||||
|
with patch("requests.post", new=MagicMock()) as mock_response:
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="command-r",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
parallel_tool_calls=True,
|
||||||
|
drop_params=True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_response.assert_called_once()
|
||||||
|
print(mock_response.call_args.kwargs["data"])
|
||||||
|
assert "parallel_tool_calls" not in mock_response.call_args.kwargs["data"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("drop_params", [True, False, None])
|
@pytest.mark.parametrize("drop_params", [True, False, None])
|
||||||
def test_dynamic_drop_additional_params(drop_params):
|
def test_dynamic_drop_additional_params(drop_params):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -2393,6 +2393,7 @@ def get_optional_params(
|
||||||
top_logprobs=None,
|
top_logprobs=None,
|
||||||
extra_headers=None,
|
extra_headers=None,
|
||||||
api_version=None,
|
api_version=None,
|
||||||
|
parallel_tool_calls=None,
|
||||||
drop_params=None,
|
drop_params=None,
|
||||||
additional_drop_params=None,
|
additional_drop_params=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -2470,6 +2471,7 @@ def get_optional_params(
|
||||||
"top_logprobs": None,
|
"top_logprobs": None,
|
||||||
"extra_headers": None,
|
"extra_headers": None,
|
||||||
"api_version": None,
|
"api_version": None,
|
||||||
|
"parallel_tool_calls": None,
|
||||||
"drop_params": None,
|
"drop_params": None,
|
||||||
"additional_drop_params": None,
|
"additional_drop_params": None,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue