fix(acompletion): support fallbacks on acompletion (#7184)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 45s

* fix(acompletion): support fallbacks on acompletion

allows health checks for wildcard routes to use fallback models

* test: update cohere generate api testing

* add max tokens to health check (#7000)

* fix: fix health check test

* test: update testing

---------

Co-authored-by: Cameron <561860+wallies@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-12-11 19:20:54 -08:00 committed by GitHub
parent 5ec649b512
commit 481645e49c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 240 additions and 69 deletions

View file

@ -20,10 +20,13 @@ from litellm import completion
from litellm.llms.cohere.completion.transformation import CohereTextConfig
@pytest.mark.asyncio
async def test_cohere_generate_api_completion():
def test_cohere_generate_api_completion():
try:
litellm.set_verbose = False
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, MagicMock
client = HTTPHandler()
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You're a good bot"},
{
@ -31,12 +34,28 @@ async def test_cohere_generate_api_completion():
"content": "Hey",
},
]
response = completion(
model="cohere/command-nightly",
messages=messages,
max_tokens=10,
)
print(response)
with patch.object(client, "post") as mock_client:
try:
completion(
model="cohere/command",
messages=messages,
max_tokens=10,
client=client,
)
except Exception as e:
print(e)
mock_client.assert_called_once()
print("mock_client.call_args.kwargs", mock_client.call_args.kwargs)
assert (
mock_client.call_args.kwargs["url"]
== "https://api.cohere.ai/v1/generate"
)
json_data = json.loads(mock_client.call_args.kwargs["data"])
assert json_data["model"] == "command"
assert json_data["prompt"] == "You're a good bot Hey"
assert json_data["max_tokens"] == 10
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -53,7 +72,7 @@ async def test_cohere_generate_api_stream():
},
]
response = await litellm.acompletion(
model="cohere/command-nightly",
model="cohere/command",
messages=messages,
max_tokens=10,
stream=True,
@ -76,7 +95,7 @@ def test_completion_cohere_stream_bad_key():
},
]
completion(
model="command-nightly",
model="command",
messages=messages,
stream=True,
max_tokens=50,
@ -100,7 +119,7 @@ def test_cohere_transform_request():
headers = {}
transformed_request = config.transform_request(
model="command-nightly",
model="command",
messages=messages,
optional_params=optional_params,
litellm_params={},
@ -109,7 +128,7 @@ def test_cohere_transform_request():
print("transformed_request", json.dumps(transformed_request, indent=4))
assert transformed_request["model"] == "command-nightly"
assert transformed_request["model"] == "command"
assert transformed_request["prompt"] == "You're a helpful bot Hello"
assert transformed_request["max_tokens"] == 10
assert transformed_request["temperature"] == 0.7
@ -137,7 +156,7 @@ def test_cohere_transform_request_with_tools():
optional_params = {"tools": tools}
transformed_request = config.transform_request(
model="command-nightly",
model="command",
messages=messages,
optional_params=optional_params,
litellm_params={},
@ -168,7 +187,7 @@ def test_cohere_map_openai_params():
mapped_params = config.map_openai_params(
non_default_params=openai_params,
optional_params={},
model="command-nightly",
model="command",
drop_params=False,
)