mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
test_cohere_request_body_with_allowed_params
This commit is contained in:
parent
4080fe54d5
commit
4b99f833bb
1 changed files with 57 additions and 0 deletions
|
@ -17,6 +17,9 @@ import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
|
||||||
litellm.num_retries = 3
|
litellm.num_retries = 3
|
||||||
|
|
||||||
|
@ -224,3 +227,57 @@ async def test_chat_completion_cohere_stream(sync_mode):
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cohere_request_body_with_allowed_params():
|
||||||
|
"""
|
||||||
|
Test to validate that when allowed_openai_params is provided, the request body contains
|
||||||
|
the correct response_format and reasoning_effort values.
|
||||||
|
"""
|
||||||
|
# Define test parameters
|
||||||
|
test_response_format = {"type": "json"}
|
||||||
|
test_reasoning_effort = "low"
|
||||||
|
test_tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_time",
|
||||||
|
"description": "Get the current time in a given location.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string", "description": "The city name, e.g. San Francisco"}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
client = AsyncHTTPHandler()
|
||||||
|
|
||||||
|
# Mock the post method
|
||||||
|
with patch.object(client, "post", new=AsyncMock()) as mock_post:
|
||||||
|
try:
|
||||||
|
await litellm.acompletion(
|
||||||
|
model="cohere/command",
|
||||||
|
messages=[{"content": "what llm are you", "role": "user"}],
|
||||||
|
allowed_openai_params=["tools", "response_format", "reasoning_effort"],
|
||||||
|
response_format=test_response_format,
|
||||||
|
reasoning_effort=test_reasoning_effort,
|
||||||
|
tools=test_tools,
|
||||||
|
client=client
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # We only care about the request body validation
|
||||||
|
|
||||||
|
# Verify the API call was made
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
|
||||||
|
# Get and parse the request body
|
||||||
|
request_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
print(f"request_data: {request_data}")
|
||||||
|
|
||||||
|
# Validate request contains our specified parameters
|
||||||
|
assert "allowed_openai_params" not in request_data
|
||||||
|
assert request_data["response_format"] == test_response_format
|
||||||
|
assert request_data["reasoning_effort"] == test_reasoning_effort
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue