mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(acompletion): support fallbacks on acompletion (#7184)
* fix(acompletion): support fallbacks on acompletion allows health checks for wildcard routes to use fallback models * test: update cohere generate api testing * add max tokens to health check (#7000) * fix: fix health check test * test: update testing --------- Co-authored-by: Cameron <561860+wallies@users.noreply.github.com>
This commit is contained in:
parent
5fe77499d2
commit
a9aeb21d0b
8 changed files with 240 additions and 69 deletions
|
@ -534,7 +534,7 @@ def add_known_models():
|
||||||
gemini_models.append(key)
|
gemini_models.append(key)
|
||||||
elif value.get("litellm_provider") == "fireworks_ai":
|
elif value.get("litellm_provider") == "fireworks_ai":
|
||||||
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
||||||
if "-to-" not in key:
|
if "-to-" not in key and "fireworks-ai-default" not in key:
|
||||||
fireworks_ai_models.append(key)
|
fireworks_ai_models.append(key)
|
||||||
elif value.get("litellm_provider") == "fireworks_ai-embedding-models":
|
elif value.get("litellm_provider") == "fireworks_ai-embedding-models":
|
||||||
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
||||||
|
|
|
@ -189,7 +189,11 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
#################### Start of Provider Exception mapping ####################
|
#################### Start of Provider Exception mapping ####################
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
if "Request Timeout Error" in error_str or "Request timed out" in error_str:
|
if (
|
||||||
|
"Request Timeout Error" in error_str
|
||||||
|
or "Request timed out" in error_str
|
||||||
|
or "Timed out generating response" in error_str
|
||||||
|
):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise Timeout(
|
raise Timeout(
|
||||||
message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}",
|
message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}",
|
||||||
|
|
|
@ -23,6 +23,7 @@ from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
_get_httpx_client,
|
_get_httpx_client,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
|
@ -54,10 +55,15 @@ class BaseLLMHTTPHandler:
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
):
|
):
|
||||||
async_httpx_client = get_async_httpx_client(
|
if client is None:
|
||||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
async_httpx_client = get_async_httpx_client(
|
||||||
)
|
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
async_httpx_client = client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await async_httpx_client.post(
|
response = await async_httpx_client.post(
|
||||||
url=api_base,
|
url=api_base,
|
||||||
|
@ -97,6 +103,7 @@ class BaseLLMHTTPHandler:
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
headers={},
|
headers={},
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
|
@ -149,6 +156,11 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
data=data,
|
data=data,
|
||||||
fake_stream=fake_stream,
|
fake_stream=fake_stream,
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -167,6 +179,11 @@ class BaseLLMHTTPHandler:
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream is True:
|
if stream is True:
|
||||||
|
@ -182,6 +199,11 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
fake_stream=fake_stream,
|
fake_stream=fake_stream,
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None and isinstance(client, HTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return CustomStreamWrapper(
|
return CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -190,11 +212,14 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
sync_httpx_client = _get_httpx_client()
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
sync_httpx_client = _get_httpx_client()
|
||||||
|
else:
|
||||||
|
sync_httpx_client = client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = sync_httpx_client.post(
|
response = sync_httpx_client.post(
|
||||||
api_base,
|
url=api_base,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -229,8 +254,12 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
|
client: Optional[HTTPHandler] = None,
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
) -> Tuple[Any, httpx.Headers]:
|
||||||
sync_httpx_client = _get_httpx_client()
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
sync_httpx_client = _get_httpx_client()
|
||||||
|
else:
|
||||||
|
sync_httpx_client = client
|
||||||
try:
|
try:
|
||||||
stream = True
|
stream = True
|
||||||
if fake_stream is True:
|
if fake_stream is True:
|
||||||
|
@ -289,6 +318,7 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
data: dict,
|
data: dict,
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
):
|
):
|
||||||
completion_stream, _response_headers = await self.make_async_call(
|
completion_stream, _response_headers = await self.make_async_call(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
@ -300,6 +330,7 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
fake_stream=fake_stream,
|
fake_stream=fake_stream,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -320,10 +351,14 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
) -> Tuple[Any, httpx.Headers]:
|
||||||
async_httpx_client = get_async_httpx_client(
|
if client is None:
|
||||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
async_httpx_client = get_async_httpx_client(
|
||||||
)
|
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
async_httpx_client = client
|
||||||
stream = True
|
stream = True
|
||||||
if fake_stream is True:
|
if fake_stream is True:
|
||||||
stream = False
|
stream = False
|
||||||
|
|
|
@ -64,6 +64,7 @@ from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
Usage,
|
Usage,
|
||||||
|
async_completion_with_fallbacks,
|
||||||
async_mock_completion_streaming_obj,
|
async_mock_completion_streaming_obj,
|
||||||
completion_with_fallbacks,
|
completion_with_fallbacks,
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
|
@ -364,6 +365,8 @@ async def acompletion(
|
||||||
- The `completion` function is called using `run_in_executor` to execute synchronously in the event loop.
|
- The `completion` function is called using `run_in_executor` to execute synchronously in the event loop.
|
||||||
- If `stream` is True, the function returns an async generator that yields completion lines.
|
- If `stream` is True, the function returns an async generator that yields completion lines.
|
||||||
"""
|
"""
|
||||||
|
fallbacks = kwargs.get("fallbacks", None)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
||||||
# Adjusted to use explicit arguments instead of *args and **kwargs
|
# Adjusted to use explicit arguments instead of *args and **kwargs
|
||||||
|
@ -407,6 +410,18 @@ async def acompletion(
|
||||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
model=model, api_base=completion_kwargs.get("base_url", None)
|
model=model, api_base=completion_kwargs.get("base_url", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fallbacks = fallbacks or litellm.model_fallbacks
|
||||||
|
if fallbacks is not None:
|
||||||
|
response = await async_completion_with_fallbacks(
|
||||||
|
**completion_kwargs, kwargs={"fallbacks": fallbacks}
|
||||||
|
)
|
||||||
|
if response is None:
|
||||||
|
raise Exception(
|
||||||
|
"No response from fallbacks. Got none. Turn on `litellm.set_verbose=True` to see more details."
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use a partial function to pass your keyword arguments
|
# Use a partial function to pass your keyword arguments
|
||||||
func = partial(completion, **completion_kwargs, **kwargs)
|
func = partial(completion, **completion_kwargs, **kwargs)
|
||||||
|
@ -1884,6 +1899,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=cohere_key,
|
api_key=cohere_key,
|
||||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "cohere_chat":
|
elif custom_llm_provider == "cohere_chat":
|
||||||
cohere_key = (
|
cohere_key = (
|
||||||
|
@ -4997,6 +5013,38 @@ def speech(
|
||||||
##### Health Endpoints #######################
|
##### Health Endpoints #######################
|
||||||
|
|
||||||
|
|
||||||
|
async def ahealth_check_chat_models(
|
||||||
|
model: str, custom_llm_provider: str, model_params: dict
|
||||||
|
) -> dict:
|
||||||
|
if "*" in model:
|
||||||
|
from litellm.litellm_core_utils.llm_request_utils import (
|
||||||
|
pick_cheapest_chat_model_from_llm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# this is a wildcard model, we need to pick a random model from the provider
|
||||||
|
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
||||||
|
custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
fallback_models: Optional[List] = None
|
||||||
|
if custom_llm_provider in litellm.models_by_provider:
|
||||||
|
models = litellm.models_by_provider[custom_llm_provider]
|
||||||
|
random.shuffle(models) # Shuffle the models list in place
|
||||||
|
fallback_models = models[
|
||||||
|
:2
|
||||||
|
] # Pick the first 2 models from the shuffled list
|
||||||
|
model_params["model"] = cheapest_model
|
||||||
|
model_params["fallbacks"] = fallback_models
|
||||||
|
model_params["max_tokens"] = 1
|
||||||
|
await acompletion(**model_params)
|
||||||
|
response: dict = {} # args like remaining ratelimit etc.
|
||||||
|
else: # default to completion calls
|
||||||
|
model_params["max_tokens"] = 1
|
||||||
|
await acompletion(**model_params)
|
||||||
|
response = {} # args like remaining ratelimit etc.
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def ahealth_check( # noqa: PLR0915
|
async def ahealth_check( # noqa: PLR0915
|
||||||
model_params: dict,
|
model_params: dict,
|
||||||
mode: Optional[
|
mode: Optional[
|
||||||
|
@ -5128,21 +5176,12 @@ async def ahealth_check( # noqa: PLR0915
|
||||||
model_params["documents"] = ["my sample text"]
|
model_params["documents"] = ["my sample text"]
|
||||||
await litellm.arerank(**model_params)
|
await litellm.arerank(**model_params)
|
||||||
response = {}
|
response = {}
|
||||||
elif "*" in model:
|
else:
|
||||||
from litellm.litellm_core_utils.llm_request_utils import (
|
response = await ahealth_check_chat_models(
|
||||||
pick_cheapest_chat_model_from_llm_provider,
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
model_params=model_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
# this is a wildcard model, we need to pick a random model from the provider
|
|
||||||
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
|
||||||
custom_llm_provider=custom_llm_provider
|
|
||||||
)
|
|
||||||
model_params["model"] = cheapest_model
|
|
||||||
await acompletion(**model_params)
|
|
||||||
response = {} # args like remaining ratelimit etc.
|
|
||||||
else: # default to completion calls
|
|
||||||
await acompletion(**model_params)
|
|
||||||
response = {} # args like remaining ratelimit etc.
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
stack_trace = traceback.format_exc()
|
stack_trace = traceback.format_exc()
|
||||||
|
|
|
@ -5697,6 +5697,72 @@ def completion_with_fallbacks(**kwargs):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def async_completion_with_fallbacks(**kwargs):
|
||||||
|
nested_kwargs = kwargs.pop("kwargs", {})
|
||||||
|
response = None
|
||||||
|
rate_limited_models = set()
|
||||||
|
model_expiration_times = {}
|
||||||
|
start_time = time.time()
|
||||||
|
original_model = kwargs["model"]
|
||||||
|
fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", [])
|
||||||
|
if "fallbacks" in nested_kwargs:
|
||||||
|
del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive
|
||||||
|
if "acompletion" in kwargs:
|
||||||
|
del kwargs[
|
||||||
|
"acompletion"
|
||||||
|
] # remove acompletion so it doesn't lead to keyword errors
|
||||||
|
litellm_call_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# max time to process a request with fallbacks: default 45s
|
||||||
|
while response is None and time.time() - start_time < 45:
|
||||||
|
for model in fallbacks:
|
||||||
|
# loop thru all models
|
||||||
|
try:
|
||||||
|
# check if it's dict or new model string
|
||||||
|
if isinstance(
|
||||||
|
model, dict
|
||||||
|
): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}])
|
||||||
|
kwargs["api_key"] = model.get("api_key", None)
|
||||||
|
kwargs["api_base"] = model.get("api_base", None)
|
||||||
|
model = model.get("model", original_model)
|
||||||
|
elif (
|
||||||
|
model in rate_limited_models
|
||||||
|
): # check if model is currently cooling down
|
||||||
|
if (
|
||||||
|
model_expiration_times.get(model)
|
||||||
|
and time.time() >= model_expiration_times[model]
|
||||||
|
):
|
||||||
|
rate_limited_models.remove(
|
||||||
|
model
|
||||||
|
) # check if it's been 60s of cool down and remove model
|
||||||
|
else:
|
||||||
|
continue # skip model
|
||||||
|
|
||||||
|
# delete model from kwargs if it exists
|
||||||
|
if kwargs.get("model"):
|
||||||
|
del kwargs["model"]
|
||||||
|
|
||||||
|
print_verbose(f"trying to make completion call with model: {model}")
|
||||||
|
kwargs["litellm_call_id"] = litellm_call_id
|
||||||
|
kwargs = {
|
||||||
|
**kwargs,
|
||||||
|
**nested_kwargs,
|
||||||
|
} # combine the openai + litellm params at the same level
|
||||||
|
response = await litellm.acompletion(**kwargs, model=model)
|
||||||
|
print_verbose(f"response: {response}")
|
||||||
|
if response is not None:
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"error: {e}")
|
||||||
|
rate_limited_models.add(model)
|
||||||
|
model_expiration_times[model] = (
|
||||||
|
time.time() + 60
|
||||||
|
) # cool down this selected model
|
||||||
|
pass
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def process_system_message(system_message, max_tokens, model):
|
def process_system_message(system_message, max_tokens, model):
|
||||||
system_message_event = {"role": "system", "content": system_message}
|
system_message_event = {"role": "system", "content": system_message}
|
||||||
system_message_tokens = get_token_count([system_message_event], model)
|
system_message_tokens = get_token_count([system_message_event], model)
|
||||||
|
|
|
@ -20,10 +20,13 @@ from litellm import completion
|
||||||
from litellm.llms.cohere.completion.transformation import CohereTextConfig
|
from litellm.llms.cohere.completion.transformation import CohereTextConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_cohere_generate_api_completion():
|
||||||
async def test_cohere_generate_api_completion():
|
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = False
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
litellm.set_verbose = True
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You're a good bot"},
|
{"role": "system", "content": "You're a good bot"},
|
||||||
{
|
{
|
||||||
|
@ -31,12 +34,28 @@ async def test_cohere_generate_api_completion():
|
||||||
"content": "Hey",
|
"content": "Hey",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
response = completion(
|
|
||||||
model="cohere/command-nightly",
|
with patch.object(client, "post") as mock_client:
|
||||||
messages=messages,
|
try:
|
||||||
max_tokens=10,
|
completion(
|
||||||
)
|
model="cohere/command",
|
||||||
print(response)
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
print("mock_client.call_args.kwargs", mock_client.call_args.kwargs)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["url"]
|
||||||
|
== "https://api.cohere.ai/v1/generate"
|
||||||
|
)
|
||||||
|
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
||||||
|
assert json_data["model"] == "command"
|
||||||
|
assert json_data["prompt"] == "You're a good bot Hey"
|
||||||
|
assert json_data["max_tokens"] == 10
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -53,7 +72,7 @@ async def test_cohere_generate_api_stream():
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
model="cohere/command-nightly",
|
model="cohere/command",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -76,7 +95,7 @@ def test_completion_cohere_stream_bad_key():
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
completion(
|
completion(
|
||||||
model="command-nightly",
|
model="command",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
|
@ -100,7 +119,7 @@ def test_cohere_transform_request():
|
||||||
headers = {}
|
headers = {}
|
||||||
|
|
||||||
transformed_request = config.transform_request(
|
transformed_request = config.transform_request(
|
||||||
model="command-nightly",
|
model="command",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params={},
|
litellm_params={},
|
||||||
|
@ -109,7 +128,7 @@ def test_cohere_transform_request():
|
||||||
|
|
||||||
print("transformed_request", json.dumps(transformed_request, indent=4))
|
print("transformed_request", json.dumps(transformed_request, indent=4))
|
||||||
|
|
||||||
assert transformed_request["model"] == "command-nightly"
|
assert transformed_request["model"] == "command"
|
||||||
assert transformed_request["prompt"] == "You're a helpful bot Hello"
|
assert transformed_request["prompt"] == "You're a helpful bot Hello"
|
||||||
assert transformed_request["max_tokens"] == 10
|
assert transformed_request["max_tokens"] == 10
|
||||||
assert transformed_request["temperature"] == 0.7
|
assert transformed_request["temperature"] == 0.7
|
||||||
|
@ -137,7 +156,7 @@ def test_cohere_transform_request_with_tools():
|
||||||
optional_params = {"tools": tools}
|
optional_params = {"tools": tools}
|
||||||
|
|
||||||
transformed_request = config.transform_request(
|
transformed_request = config.transform_request(
|
||||||
model="command-nightly",
|
model="command",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params={},
|
litellm_params={},
|
||||||
|
@ -168,7 +187,7 @@ def test_cohere_map_openai_params():
|
||||||
mapped_params = config.map_openai_params(
|
mapped_params = config.map_openai_params(
|
||||||
non_default_params=openai_params,
|
non_default_params=openai_params,
|
||||||
optional_params={},
|
optional_params={},
|
||||||
model="command-nightly",
|
model="command",
|
||||||
drop_params=False,
|
drop_params=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -56,13 +56,17 @@ async def test_audio_output_from_model(stream):
|
||||||
if stream is False:
|
if stream is False:
|
||||||
audio_format = "wav"
|
audio_format = "wav"
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
completion = await litellm.acompletion(
|
try:
|
||||||
model="gpt-4o-audio-preview",
|
completion = await litellm.acompletion(
|
||||||
modalities=["text", "audio"],
|
model="gpt-4o-audio-preview",
|
||||||
audio={"voice": "alloy", "format": "pcm16"},
|
modalities=["text", "audio"],
|
||||||
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
|
audio={"voice": "alloy", "format": "pcm16"},
|
||||||
stream=stream,
|
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
|
||||||
)
|
stream=stream,
|
||||||
|
)
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
print(e)
|
||||||
|
pytest.skip("Skipping test due to timeout")
|
||||||
|
|
||||||
if stream is True:
|
if stream is True:
|
||||||
await check_streaming_response(completion)
|
await check_streaming_response(completion)
|
||||||
|
@ -88,25 +92,28 @@ async def test_audio_input_to_model(stream):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
wav_data = response.content
|
wav_data = response.content
|
||||||
encoded_string = base64.b64encode(wav_data).decode("utf-8")
|
encoded_string = base64.b64encode(wav_data).decode("utf-8")
|
||||||
|
try:
|
||||||
completion = await litellm.acompletion(
|
completion = await litellm.acompletion(
|
||||||
model="gpt-4o-audio-preview",
|
model="gpt-4o-audio-preview",
|
||||||
modalities=["text", "audio"],
|
modalities=["text", "audio"],
|
||||||
audio={"voice": "alloy", "format": audio_format},
|
audio={"voice": "alloy", "format": audio_format},
|
||||||
stream=stream,
|
stream=stream,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "text", "text": "What is in this recording?"},
|
{"type": "text", "text": "What is in this recording?"},
|
||||||
{
|
{
|
||||||
"type": "input_audio",
|
"type": "input_audio",
|
||||||
"input_audio": {"data": encoded_string, "format": "wav"},
|
"input_audio": {"data": encoded_string, "format": "wav"},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
print(e)
|
||||||
|
pytest.skip("Skipping test due to timeout")
|
||||||
|
|
||||||
if stream is True:
|
if stream is True:
|
||||||
await check_streaming_response(completion)
|
await check_streaming_response(completion)
|
||||||
|
|
|
@ -112,16 +112,17 @@ async def test_sagemaker_embedding_health_check():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fireworks_health_check():
|
async def test_groq_health_check():
|
||||||
"""
|
"""
|
||||||
This should not fail
|
This should not fail
|
||||||
|
|
||||||
ensure that provider wildcard model passes health check
|
ensure that provider wildcard model passes health check
|
||||||
"""
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
response = await litellm.ahealth_check(
|
response = await litellm.ahealth_check(
|
||||||
model_params={
|
model_params={
|
||||||
"api_key": os.environ.get("FIREWORKS_AI_API_KEY"),
|
"api_key": os.environ.get("GROQ_API_KEY"),
|
||||||
"model": "fireworks_ai/*",
|
"model": "groq/*",
|
||||||
"messages": [{"role": "user", "content": "What's 1 + 1?"}],
|
"messages": [{"role": "user", "content": "What's 1 + 1?"}],
|
||||||
},
|
},
|
||||||
mode=None,
|
mode=None,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue