mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(acompletion): support fallbacks on acompletion (#7184)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 45s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 45s
* fix(acompletion): support fallbacks on acompletion allows health checks for wildcard routes to use fallback models * test: update cohere generate api testing * add max tokens to health check (#7000) * fix: fix health check test * test: update testing --------- Co-authored-by: Cameron <561860+wallies@users.noreply.github.com>
This commit is contained in:
parent
5ec649b512
commit
481645e49c
8 changed files with 240 additions and 69 deletions
|
@ -20,10 +20,13 @@ from litellm import completion
|
|||
from litellm.llms.cohere.completion.transformation import CohereTextConfig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_generate_api_completion():
|
||||
def test_cohere_generate_api_completion():
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
client = HTTPHandler()
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{
|
||||
|
@ -31,12 +34,28 @@ async def test_cohere_generate_api_completion():
|
|||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
response = completion(
|
||||
model="cohere/command-nightly",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
print(response)
|
||||
|
||||
with patch.object(client, "post") as mock_client:
|
||||
try:
|
||||
completion(
|
||||
model="cohere/command",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
mock_client.assert_called_once()
|
||||
print("mock_client.call_args.kwargs", mock_client.call_args.kwargs)
|
||||
|
||||
assert (
|
||||
mock_client.call_args.kwargs["url"]
|
||||
== "https://api.cohere.ai/v1/generate"
|
||||
)
|
||||
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
||||
assert json_data["model"] == "command"
|
||||
assert json_data["prompt"] == "You're a good bot Hey"
|
||||
assert json_data["max_tokens"] == 10
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -53,7 +72,7 @@ async def test_cohere_generate_api_stream():
|
|||
},
|
||||
]
|
||||
response = await litellm.acompletion(
|
||||
model="cohere/command-nightly",
|
||||
model="cohere/command",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
|
@ -76,7 +95,7 @@ def test_completion_cohere_stream_bad_key():
|
|||
},
|
||||
]
|
||||
completion(
|
||||
model="command-nightly",
|
||||
model="command",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
max_tokens=50,
|
||||
|
@ -100,7 +119,7 @@ def test_cohere_transform_request():
|
|||
headers = {}
|
||||
|
||||
transformed_request = config.transform_request(
|
||||
model="command-nightly",
|
||||
model="command",
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params={},
|
||||
|
@ -109,7 +128,7 @@ def test_cohere_transform_request():
|
|||
|
||||
print("transformed_request", json.dumps(transformed_request, indent=4))
|
||||
|
||||
assert transformed_request["model"] == "command-nightly"
|
||||
assert transformed_request["model"] == "command"
|
||||
assert transformed_request["prompt"] == "You're a helpful bot Hello"
|
||||
assert transformed_request["max_tokens"] == 10
|
||||
assert transformed_request["temperature"] == 0.7
|
||||
|
@ -137,7 +156,7 @@ def test_cohere_transform_request_with_tools():
|
|||
optional_params = {"tools": tools}
|
||||
|
||||
transformed_request = config.transform_request(
|
||||
model="command-nightly",
|
||||
model="command",
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params={},
|
||||
|
@ -168,7 +187,7 @@ def test_cohere_map_openai_params():
|
|||
mapped_params = config.map_openai_params(
|
||||
non_default_params=openai_params,
|
||||
optional_params={},
|
||||
model="command-nightly",
|
||||
model="command",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue