diff --git a/litellm/main.py b/litellm/main.py index 31cb8e364f..90feb9aa85 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -41,6 +41,7 @@ from litellm.utils import ( get_optional_params_embeddings, get_optional_params_image_gen, supports_httpx_timeout, + ChatCompletionMessageToolCall, ) from .llms import ( anthropic_text, @@ -398,6 +399,7 @@ def mock_completion( messages: List, stream: Optional[bool] = False, mock_response: Union[str, Exception] = "This is a mock request", + mock_tool_calls: Optional[List] = None, logging=None, **kwargs, ): @@ -465,6 +467,11 @@ def mock_completion( model_response["created"] = int(time.time()) model_response["model"] = model + if mock_tool_calls: + model_response["choices"][0]["message"]["tool_calls"] = [ + ChatCompletionMessageToolCall(**tool_call) for tool_call in mock_tool_calls + ] + setattr( model_response, "usage", @@ -578,6 +585,7 @@ def completion( args = locals() api_base = kwargs.get("api_base", None) mock_response = kwargs.get("mock_response", None) + mock_tool_calls = kwargs.get("mock_tool_calls", None) force_timeout = kwargs.get("force_timeout", 600) ## deprecated logger_fn = kwargs.get("logger_fn", None) verbose = kwargs.get("verbose", False) @@ -896,12 +904,13 @@ def completion( litellm_params=litellm_params, custom_llm_provider=custom_llm_provider, ) - if mock_response: + if mock_response or mock_tool_calls: return mock_completion( model, messages, stream=stream, mock_response=mock_response, + mock_tool_calls=mock_tool_calls, logging=logging, acompletion=acompletion, mock_delay=kwargs.get("mock_delay", None),