mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Add mock_tool_calls
This commit is contained in:
parent
57d3b591ba
commit
9df0091b97
1 changed files with 10 additions and 1 deletions
|
@ -41,6 +41,7 @@ from litellm.utils import (
|
||||||
get_optional_params_embeddings,
|
get_optional_params_embeddings,
|
||||||
get_optional_params_image_gen,
|
get_optional_params_image_gen,
|
||||||
supports_httpx_timeout,
|
supports_httpx_timeout,
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
)
|
)
|
||||||
from .llms import (
|
from .llms import (
|
||||||
anthropic_text,
|
anthropic_text,
|
||||||
|
@ -398,6 +399,7 @@ def mock_completion(
|
||||||
messages: List,
|
messages: List,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
mock_response: Union[str, Exception] = "This is a mock request",
|
mock_response: Union[str, Exception] = "This is a mock request",
|
||||||
|
mock_tool_calls: Optional[List] = None,
|
||||||
logging=None,
|
logging=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -465,6 +467,11 @@ def mock_completion(
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
|
|
||||||
|
if mock_tool_calls:
|
||||||
|
model_response["choices"][0]["message"]["tool_calls"] = [
|
||||||
|
ChatCompletionMessageToolCall(**tool_call) for tool_call in mock_tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
setattr(
|
setattr(
|
||||||
model_response,
|
model_response,
|
||||||
"usage",
|
"usage",
|
||||||
|
@ -578,6 +585,7 @@ def completion(
|
||||||
args = locals()
|
args = locals()
|
||||||
api_base = kwargs.get("api_base", None)
|
api_base = kwargs.get("api_base", None)
|
||||||
mock_response = kwargs.get("mock_response", None)
|
mock_response = kwargs.get("mock_response", None)
|
||||||
|
mock_tool_calls = kwargs.get("mock_tool_calls", None)
|
||||||
force_timeout = kwargs.get("force_timeout", 600) ## deprecated
|
force_timeout = kwargs.get("force_timeout", 600) ## deprecated
|
||||||
logger_fn = kwargs.get("logger_fn", None)
|
logger_fn = kwargs.get("logger_fn", None)
|
||||||
verbose = kwargs.get("verbose", False)
|
verbose = kwargs.get("verbose", False)
|
||||||
|
@ -896,12 +904,13 @@ def completion(
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
if mock_response:
|
if mock_response or mock_tool_calls:
|
||||||
return mock_completion(
|
return mock_completion(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
mock_response=mock_response,
|
mock_response=mock_response,
|
||||||
|
mock_tool_calls=mock_tool_calls,
|
||||||
logging=logging,
|
logging=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
mock_delay=kwargs.get("mock_delay", None),
|
mock_delay=kwargs.get("mock_delay", None),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue