Add mock_tool_calls

This commit is contained in:
= 2024-06-14 15:19:11 +02:00
parent 57d3b591ba
commit 9df0091b97

View file

@ -41,6 +41,7 @@ from litellm.utils import (
get_optional_params_embeddings, get_optional_params_embeddings,
get_optional_params_image_gen, get_optional_params_image_gen,
supports_httpx_timeout, supports_httpx_timeout,
ChatCompletionMessageToolCall,
) )
from .llms import ( from .llms import (
anthropic_text, anthropic_text,
@ -398,6 +399,7 @@ def mock_completion(
messages: List, messages: List,
stream: Optional[bool] = False, stream: Optional[bool] = False,
mock_response: Union[str, Exception] = "This is a mock request", mock_response: Union[str, Exception] = "This is a mock request",
mock_tool_calls: Optional[List] = None,
logging=None, logging=None,
**kwargs, **kwargs,
): ):
@ -465,6 +467,11 @@ def mock_completion(
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
if mock_tool_calls:
model_response["choices"][0]["message"]["tool_calls"] = [
ChatCompletionMessageToolCall(**tool_call) for tool_call in mock_tool_calls
]
setattr( setattr(
model_response, model_response,
"usage", "usage",
@ -578,6 +585,7 @@ def completion(
args = locals() args = locals()
api_base = kwargs.get("api_base", None) api_base = kwargs.get("api_base", None)
mock_response = kwargs.get("mock_response", None) mock_response = kwargs.get("mock_response", None)
mock_tool_calls = kwargs.get("mock_tool_calls", None)
force_timeout = kwargs.get("force_timeout", 600) ## deprecated force_timeout = kwargs.get("force_timeout", 600) ## deprecated
logger_fn = kwargs.get("logger_fn", None) logger_fn = kwargs.get("logger_fn", None)
verbose = kwargs.get("verbose", False) verbose = kwargs.get("verbose", False)
@ -896,12 +904,13 @@ def completion(
litellm_params=litellm_params, litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
if mock_response: if mock_response or mock_tool_calls:
return mock_completion( return mock_completion(
model, model,
messages, messages,
stream=stream, stream=stream,
mock_response=mock_response, mock_response=mock_response,
mock_tool_calls=mock_tool_calls,
logging=logging, logging=logging,
acompletion=acompletion, acompletion=acompletion,
mock_delay=kwargs.get("mock_delay", None), mock_delay=kwargs.get("mock_delay", None),