diff --git a/tests/llm_translation/test_cohere.py b/tests/llm_translation/test_cohere.py index 124a5c8788..6b4d3a2045 100644 --- a/tests/llm_translation/test_cohere.py +++ b/tests/llm_translation/test_cohere.py @@ -17,6 +17,9 @@ import pytest import litellm from litellm import RateLimitError, Timeout, completion, completion_cost, embedding +from unittest.mock import AsyncMock, patch +from litellm import RateLimitError, Timeout, completion, completion_cost, embedding +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler litellm.num_retries = 3 @@ -224,3 +227,57 @@ async def test_chat_completion_cohere_stream(sync_mode): pass except Exception as e: pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio +async def test_cohere_request_body_with_allowed_params(): + """ + Test to validate that when allowed_openai_params is provided, the request body contains + the correct response_format and reasoning_effort values. + """ + # Define test parameters + test_response_format = {"type": "json"} + test_reasoning_effort = "low" + test_tools = [{ + "type": "function", + "function": { + "name": "get_current_time", + "description": "Get the current time in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city name, e.g. San Francisco"} + }, + "required": ["location"] + } + } + }] + + client = AsyncHTTPHandler() + + # Mock the post method + with patch.object(client, "post", new=AsyncMock()) as mock_post: + try: + await litellm.acompletion( + model="cohere/command", + messages=[{"content": "what llm are you", "role": "user"}], + allowed_openai_params=["tools", "response_format", "reasoning_effort"], + response_format=test_response_format, + reasoning_effort=test_reasoning_effort, + tools=test_tools, + client=client + ) + except Exception: + pass # We only care about the request body validation + + # Verify the API call was made + mock_post.assert_called_once() + + # Get and parse the request body + request_data = json.loads(mock_post.call_args.kwargs["data"]) + print(f"request_data: {request_data}") + + # Validate request contains our specified parameters + assert "allowed_openai_params" not in request_data + assert request_data["response_format"] == test_response_format + assert request_data["reasoning_effort"] == test_reasoning_effort