test_cohere_request_body_with_allowed_params

This commit is contained in:
Ishaan Jaff 2025-04-01 21:30:24 -07:00
parent 4080fe54d5
commit 4b99f833bb

View file

@ -17,6 +17,9 @@ import pytest
import litellm
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
from unittest.mock import AsyncMock, patch
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
litellm.num_retries = 3
@ -224,3 +227,57 @@ async def test_chat_completion_cohere_stream(sync_mode):
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_cohere_request_body_with_allowed_params():
"""
Test to validate that when allowed_openai_params is provided, the request body contains
the correct response_format and reasoning_effort values.
"""
# Define test parameters
test_response_format = {"type": "json"}
test_reasoning_effort = "low"
test_tools = [{
"type": "function",
"function": {
"name": "get_current_time",
"description": "Get the current time in a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city name, e.g. San Francisco"}
},
"required": ["location"]
}
}
}]
client = AsyncHTTPHandler()
# Mock the post method
with patch.object(client, "post", new=AsyncMock()) as mock_post:
try:
await litellm.acompletion(
model="cohere/command",
messages=[{"content": "what llm are you", "role": "user"}],
allowed_openai_params=["tools", "response_format", "reasoning_effort"],
response_format=test_response_format,
reasoning_effort=test_reasoning_effort,
tools=test_tools,
client=client
)
except Exception:
pass # We only care about the request body validation
# Verify the API call was made
mock_post.assert_called_once()
# Get and parse the request body
request_data = json.loads(mock_post.call_args.kwargs["data"])
print(f"request_data: {request_data}")
# Validate request contains our specified parameters
assert "allowed_openai_params" not in request_data
assert request_data["response_format"] == test_response_format
assert request_data["reasoning_effort"] == test_reasoning_effort