fix(acompletion): support fallbacks on acompletion (#7184)

* fix(acompletion): support fallbacks on acompletion

allows health checks for wildcard routes to use fallback models

* test: update cohere generate api testing

* add max tokens to health check (#7000)

* fix: fix health check test

* test: update testing

---------

Co-authored-by: Cameron <561860+wallies@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-12-11 19:20:54 -08:00 committed by GitHub
parent 5fe77499d2
commit a9aeb21d0b
8 changed files with 240 additions and 69 deletions

View file

@ -23,6 +23,7 @@ from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
@ -54,10 +55,15 @@ class BaseLLMHTTPHandler:
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
):
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
if client is None:
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
else:
async_httpx_client = client
try:
response = await async_httpx_client.post(
url=api_base,
@ -97,6 +103,7 @@ class BaseLLMHTTPHandler:
fake_stream: bool = False,
api_key: Optional[str] = None,
headers={},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
@ -149,6 +156,11 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj,
data=data,
fake_stream=fake_stream,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
else:
@ -167,6 +179,11 @@ class BaseLLMHTTPHandler:
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
if stream is True:
@ -182,6 +199,11 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj,
timeout=timeout,
fake_stream=fake_stream,
client=(
client
if client is not None and isinstance(client, HTTPHandler)
else None
),
)
return CustomStreamWrapper(
completion_stream=completion_stream,
@ -190,11 +212,14 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj,
)
sync_httpx_client = _get_httpx_client()
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
try:
response = sync_httpx_client.post(
api_base,
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
@ -229,8 +254,12 @@ class BaseLLMHTTPHandler:
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
fake_stream: bool = False,
client: Optional[HTTPHandler] = None,
) -> Tuple[Any, httpx.Headers]:
sync_httpx_client = _get_httpx_client()
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
try:
stream = True
if fake_stream is True:
@ -289,6 +318,7 @@ class BaseLLMHTTPHandler:
logging_obj: LiteLLMLoggingObj,
data: dict,
fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None,
):
completion_stream, _response_headers = await self.make_async_call(
custom_llm_provider=custom_llm_provider,
@ -300,6 +330,7 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj,
timeout=timeout,
fake_stream=fake_stream,
client=client,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
@ -320,10 +351,14 @@ class BaseLLMHTTPHandler:
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None,
) -> Tuple[Any, httpx.Headers]:
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
if client is None:
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
else:
async_httpx_client = client
stream = True
if fake_stream is True:
stream = False