forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (11/06/2024) (#6624)
* refactor(proxy_server.py): add debug logging around license check event (refactor position in startup_event logic) * fix(proxy/_types.py): allow admin_allowed_routes to be any str * fix(router.py): raise 400-status code error for no 'model_name' error on router Fixes issue with status code when unknown model name passed with pattern matching enabled * fix(converse_handler.py): add claude 3-5 haiku to bedrock converse models * test: update testing to replace claude-instant-1.2 * fix(router.py): fix router.moderation calls * test: update test to remove claude-instant-1 * fix(router.py): support model_list values in router.moderation * test: fix test * test: fix test
This commit is contained in:
parent
136693cac4
commit
0c204d33bc
15 changed files with 180 additions and 130 deletions
|
@ -19,6 +19,7 @@ from ..common_utils import BedrockError
|
||||||
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
|
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
|
||||||
|
|
||||||
BEDROCK_CONVERSE_MODELS = [
|
BEDROCK_CONVERSE_MODELS = [
|
||||||
|
"anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
"anthropic.claude-3-opus-20240229-v1:0",
|
"anthropic.claude-3-opus-20240229-v1:0",
|
||||||
|
|
|
@ -4319,9 +4319,9 @@ async def amoderation(
|
||||||
else:
|
else:
|
||||||
_openai_client = openai_client
|
_openai_client = openai_client
|
||||||
if model is not None:
|
if model is not None:
|
||||||
response = await openai_client.moderations.create(input=input, model=model)
|
response = await _openai_client.moderations.create(input=input, model=model)
|
||||||
else:
|
else:
|
||||||
response = await openai_client.moderations.create(input=input)
|
response = await _openai_client.moderations.create(input=input)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,31 @@ model_list:
|
||||||
model: openai/my-fake-model
|
model: openai/my-fake-model
|
||||||
api_key: my-fake-key
|
api_key: my-fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
## bedrock chat completions
|
||||||
|
- model_name: "*anthropic.claude*"
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/*anthropic.claude*
|
||||||
|
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||||
|
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||||
|
aws_region_name: os.environ/AWS_REGION_NAME
|
||||||
|
guardrailConfig:
|
||||||
|
"guardrailIdentifier": "h4dsqwhp6j66"
|
||||||
|
"guardrailVersion": "2"
|
||||||
|
"trace": "enabled"
|
||||||
|
|
||||||
|
## bedrock embeddings
|
||||||
|
- model_name: "*amazon.titan-embed-*"
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/amazon.titan-embed-*
|
||||||
|
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||||
|
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||||
|
aws_region_name: os.environ/AWS_REGION_NAME
|
||||||
|
- model_name: "*cohere.embed-*"
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/cohere.embed-*
|
||||||
|
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||||
|
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||||
|
aws_region_name: os.environ/AWS_REGION_NAME
|
||||||
|
|
||||||
- model_name: gpt-4
|
- model_name: gpt-4
|
||||||
litellm_params:
|
litellm_params:
|
||||||
|
@ -33,6 +58,7 @@ model_list:
|
||||||
rpm: 480
|
rpm: 480
|
||||||
timeout: 300
|
timeout: 300
|
||||||
stream_timeout: 60
|
stream_timeout: 60
|
||||||
|
|
||||||
# litellm_settings:
|
# litellm_settings:
|
||||||
# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
|
# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
|
||||||
# callbacks: ["otel", "prometheus"]
|
# callbacks: ["otel", "prometheus"]
|
||||||
|
|
|
@ -436,15 +436,7 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
admin_jwt_scope: str = "litellm_proxy_admin"
|
admin_jwt_scope: str = "litellm_proxy_admin"
|
||||||
admin_allowed_routes: List[
|
admin_allowed_routes: List[str] = [
|
||||||
Literal[
|
|
||||||
"openai_routes",
|
|
||||||
"info_routes",
|
|
||||||
"management_routes",
|
|
||||||
"spend_tracking_routes",
|
|
||||||
"global_spend_tracking_routes",
|
|
||||||
]
|
|
||||||
] = [
|
|
||||||
"management_routes",
|
"management_routes",
|
||||||
"spend_tracking_routes",
|
"spend_tracking_routes",
|
||||||
"global_spend_tracking_routes",
|
"global_spend_tracking_routes",
|
||||||
|
|
|
@ -5,6 +5,9 @@ import json
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
@ -44,23 +47,46 @@ class LicenseCheck:
|
||||||
verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
|
verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
|
||||||
|
|
||||||
def _verify(self, license_str: str) -> bool:
|
def _verify(self, license_str: str) -> bool:
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format(
|
||||||
|
self.base_url, license_str
|
||||||
|
)
|
||||||
|
)
|
||||||
url = "{}/verify_license/{}".format(self.base_url, license_str)
|
url = "{}/verify_license/{}".format(self.base_url, license_str)
|
||||||
|
|
||||||
|
response: Optional[httpx.Response] = None
|
||||||
try: # don't impact user, if call fails
|
try: # don't impact user, if call fails
|
||||||
response = self.http_handler.get(url=url)
|
num_retries = 3
|
||||||
|
for i in range(num_retries):
|
||||||
|
try:
|
||||||
|
response = self.http_handler.get(url=url)
|
||||||
|
if response is None:
|
||||||
|
raise Exception("No response from license server")
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
if i == num_retries - 1:
|
||||||
|
raise
|
||||||
|
|
||||||
response.raise_for_status()
|
if response is None:
|
||||||
|
raise Exception("No response from license server")
|
||||||
|
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
premium = response_json["verify"]
|
premium = response_json["verify"]
|
||||||
|
|
||||||
assert isinstance(premium, bool)
|
assert isinstance(premium, bool)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"litellm.proxy.auth.litellm_license.py::_verify - License={} is premium={}".format(
|
||||||
|
license_str, premium
|
||||||
|
)
|
||||||
|
)
|
||||||
return premium
|
return premium
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.exception(
|
||||||
"litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License via api. - {}".format(
|
"litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License={} via api. - {}".format(
|
||||||
str(e)
|
license_str, str(e)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
@ -72,7 +98,7 @@ class LicenseCheck:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - {}".format(
|
"litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - LiteLLM License={}".format(
|
||||||
self.license_str
|
self.license_str
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -694,6 +694,9 @@ def run_server( # noqa: PLR0915
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
if detailed_debug is True:
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
# DO NOT DELETE - enables global variables to work across files
|
# DO NOT DELETE - enables global variables to work across files
|
||||||
from litellm.proxy.proxy_server import app # noqa
|
from litellm.proxy.proxy_server import app # noqa
|
||||||
|
|
||||||
|
|
|
@ -3074,6 +3074,15 @@ async def startup_event():
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## CHECK PREMIUM USER
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
|
||||||
|
premium_user
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if premium_user is False:
|
||||||
|
premium_user = _license_check.is_premium()
|
||||||
|
|
||||||
### LOAD CONFIG ###
|
### LOAD CONFIG ###
|
||||||
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||||
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
|
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
|
||||||
|
@ -3121,21 +3130,6 @@ async def startup_event():
|
||||||
if isinstance(worker_config, dict):
|
if isinstance(worker_config, dict):
|
||||||
await initialize(**worker_config)
|
await initialize(**worker_config)
|
||||||
|
|
||||||
## CHECK PREMIUM USER
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
|
|
||||||
premium_user
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if premium_user is False:
|
|
||||||
premium_user = _license_check.is_premium()
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"litellm.proxy.proxy_server.py::startup() - PREMIUM USER value - {}".format(
|
|
||||||
premium_user
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
ProxyStartupEvent._initialize_startup_logging(
|
ProxyStartupEvent._initialize_startup_logging(
|
||||||
llm_router=llm_router,
|
llm_router=llm_router,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
|
|
@ -65,6 +65,7 @@ async def route_request(
|
||||||
Common helper to route the request
|
Common helper to route the request
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||||
if "api_key" in data or "api_base" in data:
|
if "api_key" in data or "api_base" in data:
|
||||||
return getattr(litellm, f"{route_type}")(**data)
|
return getattr(litellm, f"{route_type}")(**data)
|
||||||
|
|
|
@ -556,6 +556,10 @@ class Router:
|
||||||
|
|
||||||
self.initialize_assistants_endpoint()
|
self.initialize_assistants_endpoint()
|
||||||
|
|
||||||
|
self.amoderation = self.factory_function(
|
||||||
|
litellm.amoderation, call_type="moderation"
|
||||||
|
)
|
||||||
|
|
||||||
def initialize_assistants_endpoint(self):
|
def initialize_assistants_endpoint(self):
|
||||||
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
||||||
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
||||||
|
@ -1683,78 +1687,6 @@ class Router:
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def amoderation(self, model: str, input: str, **kwargs):
|
|
||||||
try:
|
|
||||||
kwargs["model"] = model
|
|
||||||
kwargs["input"] = input
|
|
||||||
kwargs["original_function"] = self._amoderation
|
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
|
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
|
||||||
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
asyncio.create_task(
|
|
||||||
send_llm_exception_alert(
|
|
||||||
litellm_router_instance=self,
|
|
||||||
request_kwargs=kwargs,
|
|
||||||
error_traceback_str=traceback.format_exc(),
|
|
||||||
original_exception=e,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
async def _amoderation(self, model: str, input: str, **kwargs):
|
|
||||||
model_name = None
|
|
||||||
try:
|
|
||||||
verbose_router_logger.debug(
|
|
||||||
f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
|
|
||||||
)
|
|
||||||
deployment = await self.async_get_available_deployment(
|
|
||||||
model=model,
|
|
||||||
input=input,
|
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
|
||||||
)
|
|
||||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
|
||||||
data = deployment["litellm_params"].copy()
|
|
||||||
model_name = data["model"]
|
|
||||||
model_client = self._get_async_openai_model_client(
|
|
||||||
deployment=deployment,
|
|
||||||
kwargs=kwargs,
|
|
||||||
)
|
|
||||||
self.total_calls[model_name] += 1
|
|
||||||
|
|
||||||
timeout: Optional[Union[float, int]] = self._get_timeout(
|
|
||||||
kwargs=kwargs,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await litellm.amoderation(
|
|
||||||
**{
|
|
||||||
**data,
|
|
||||||
"input": input,
|
|
||||||
"caching": self.cache_responses,
|
|
||||||
"client": model_client,
|
|
||||||
"timeout": timeout,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.success_calls[model_name] += 1
|
|
||||||
verbose_router_logger.info(
|
|
||||||
f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m"
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
verbose_router_logger.info(
|
|
||||||
f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m"
|
|
||||||
)
|
|
||||||
if model_name is not None:
|
|
||||||
self.fail_calls[model_name] += 1
|
|
||||||
raise e
|
|
||||||
|
|
||||||
async def arerank(self, model: str, **kwargs):
|
async def arerank(self, model: str, **kwargs):
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
|
@ -2610,20 +2542,46 @@ class Router:
|
||||||
|
|
||||||
return final_results
|
return final_results
|
||||||
|
|
||||||
#### ASSISTANTS API ####
|
#### PASSTHROUGH API ####
|
||||||
|
|
||||||
def factory_function(self, original_function: Callable):
|
async def _pass_through_moderation_endpoint_factory(
|
||||||
|
self,
|
||||||
|
original_function: Callable,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
"model" in kwargs
|
||||||
|
and self.get_model_list(model_name=kwargs["model"]) is not None
|
||||||
|
):
|
||||||
|
deployment = await self.async_get_available_deployment(
|
||||||
|
model=kwargs["model"]
|
||||||
|
)
|
||||||
|
kwargs["model"] = deployment["litellm_params"]["model"]
|
||||||
|
return await original_function(**kwargs)
|
||||||
|
|
||||||
|
def factory_function(
|
||||||
|
self,
|
||||||
|
original_function: Callable,
|
||||||
|
call_type: Literal["assistants", "moderation"] = "assistants",
|
||||||
|
):
|
||||||
async def new_function(
|
async def new_function(
|
||||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||||
client: Optional["AsyncOpenAI"] = None,
|
client: Optional["AsyncOpenAI"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return await self._pass_through_assistants_endpoint_factory(
|
if call_type == "assistants":
|
||||||
original_function=original_function,
|
return await self._pass_through_assistants_endpoint_factory(
|
||||||
custom_llm_provider=custom_llm_provider,
|
original_function=original_function,
|
||||||
client=client,
|
custom_llm_provider=custom_llm_provider,
|
||||||
**kwargs,
|
client=client,
|
||||||
)
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif call_type == "moderation":
|
||||||
|
|
||||||
|
return await self._pass_through_moderation_endpoint_factory( # type: ignore
|
||||||
|
original_function=original_function,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
return new_function
|
return new_function
|
||||||
|
|
||||||
|
@ -5052,10 +5010,12 @@ class Router:
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
raise ValueError(
|
raise litellm.BadRequestError(
|
||||||
"{}. You passed in model={}. There is no 'model_name' with this string ".format(
|
message="You passed in model={}. There is no 'model_name' with this string ".format(
|
||||||
RouterErrors.no_deployments_available.value, model
|
model
|
||||||
)
|
),
|
||||||
|
model=model,
|
||||||
|
llm_provider="",
|
||||||
)
|
)
|
||||||
|
|
||||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||||
|
|
|
@ -1043,6 +1043,7 @@ def client(original_function): # noqa: PLR0915
|
||||||
if (
|
if (
|
||||||
call_type != CallTypes.aimage_generation.value # model optional
|
call_type != CallTypes.aimage_generation.value # model optional
|
||||||
and call_type != CallTypes.atext_completion.value # can also be engine
|
and call_type != CallTypes.atext_completion.value # can also be engine
|
||||||
|
and call_type != CallTypes.amoderation.value
|
||||||
):
|
):
|
||||||
raise ValueError("model param not passed in.")
|
raise ValueError("model param not passed in.")
|
||||||
|
|
||||||
|
|
|
@ -689,9 +689,10 @@ async def aaaatest_user_token_output(
|
||||||
assert team_result.user_id == user_id
|
assert team_result.user_id == user_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("admin_allowed_routes", [None, ["ui_routes"]])
|
||||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_allowed_routes_admin(prisma_client, audience):
|
async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_routes):
|
||||||
"""
|
"""
|
||||||
Add a check to make sure jwt proxy admin scope can access all allowed admin routes
|
Add a check to make sure jwt proxy admin scope can access all allowed admin routes
|
||||||
|
|
||||||
|
@ -754,12 +755,17 @@ async def test_allowed_routes_admin(prisma_client, audience):
|
||||||
|
|
||||||
jwt_handler.user_api_key_cache = cache
|
jwt_handler.user_api_key_cache = cache
|
||||||
|
|
||||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
|
if admin_allowed_routes:
|
||||||
|
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||||
|
team_id_jwt_field="client_id", admin_allowed_routes=admin_allowed_routes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
|
||||||
|
|
||||||
# VALID TOKEN
|
# VALID TOKEN
|
||||||
## GENERATE A TOKEN
|
## GENERATE A TOKEN
|
||||||
# Assuming the current time is in UTC
|
# Assuming the current time is in UTC
|
||||||
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
|
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp())
|
||||||
|
|
||||||
# Generate the JWT token
|
# Generate the JWT token
|
||||||
# But before, you should convert bytes to string
|
# But before, you should convert bytes to string
|
||||||
|
@ -777,6 +783,7 @@ async def test_allowed_routes_admin(prisma_client, audience):
|
||||||
|
|
||||||
# verify token
|
# verify token
|
||||||
|
|
||||||
|
print(f"admin_token: {admin_token}")
|
||||||
response = await jwt_handler.auth_jwt(token=admin_token)
|
response = await jwt_handler.auth_jwt(token=admin_token)
|
||||||
|
|
||||||
## RUN IT THROUGH USER API KEY AUTH
|
## RUN IT THROUGH USER API KEY AUTH
|
||||||
|
|
|
@ -1866,16 +1866,9 @@ async def test_router_amoderation():
|
||||||
router = Router(model_list=model_list)
|
router = Router(model_list=model_list)
|
||||||
## Test 1: user facing function
|
## Test 1: user facing function
|
||||||
result = await router.amoderation(
|
result = await router.amoderation(
|
||||||
model="openai-moderations", input="this is valid good text"
|
model="text-moderation-stable", input="this is valid good text"
|
||||||
)
|
)
|
||||||
|
|
||||||
## Test 2: underlying function
|
|
||||||
result = await router._amoderation(
|
|
||||||
model="openai-moderations", input="this is valid good text"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("moderation result", result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_router_add_deployment():
|
def test_router_add_deployment():
|
||||||
initial_model_list = [
|
initial_model_list = [
|
||||||
|
|
|
@ -1226,9 +1226,7 @@ async def test_using_default_fallback(sync_mode):
|
||||||
pytest.fail(f"Expected call to fail we passed model=openai/foo")
|
pytest.fail(f"Expected call to fail we passed model=openai/foo")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("got exception = ", e)
|
print("got exception = ", e)
|
||||||
from litellm.types.router import RouterErrors
|
assert "BadRequestError" in str(e)
|
||||||
|
|
||||||
assert RouterErrors.no_deployments_available.value in str(e)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [False])
|
@pytest.mark.parametrize("sync_mode", [False])
|
||||||
|
|
|
@ -158,6 +158,46 @@ def test_route_with_exception():
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_route_with_no_matching_pattern():
|
||||||
|
"""
|
||||||
|
Tests that the router returns None when there is no matching pattern
|
||||||
|
"""
|
||||||
|
from litellm.types.router import RouterErrors
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "*meta.llama3*",
|
||||||
|
"litellm_params": {"model": "bedrock/meta.llama3*"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
## WORKS
|
||||||
|
result = await router.acompletion(
|
||||||
|
model="bedrock/meta.llama3-70b",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
mock_response="Works",
|
||||||
|
)
|
||||||
|
assert result.choices[0].message.content == "Works"
|
||||||
|
|
||||||
|
## FAILS
|
||||||
|
with pytest.raises(litellm.BadRequestError) as e:
|
||||||
|
await router.acompletion(
|
||||||
|
model="my-fake-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
mock_response="Works",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert RouterErrors.no_deployments_available.value not in str(e.value)
|
||||||
|
|
||||||
|
with pytest.raises(litellm.BadRequestError):
|
||||||
|
await router.aembedding(
|
||||||
|
model="my-fake-model",
|
||||||
|
input="Hello, world!",
|
||||||
|
)
|
||||||
|
|
||||||
def test_router_pattern_match_e2e():
|
def test_router_pattern_match_e2e():
|
||||||
"""
|
"""
|
||||||
Tests the end to end flow of the router
|
Tests the end to end flow of the router
|
||||||
|
@ -188,3 +228,4 @@ def test_router_pattern_match_e2e():
|
||||||
"model": "gpt-4o",
|
"model": "gpt-4o",
|
||||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -999,3 +999,10 @@ def test_pattern_match_deployment_set_model_name(
|
||||||
|
|
||||||
for model in updated_models:
|
for model in updated_models:
|
||||||
assert model["litellm_params"]["model"] == expected_model
|
assert model["litellm_params"]["model"] == expected_model
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pass_through_moderation_endpoint_factory(model_list):
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
response = await router._pass_through_moderation_endpoint_factory(
|
||||||
|
original_function=litellm.amoderation, input="this is valid good text"
|
||||||
|
)
|
Loading…
Add table
Add a link
Reference in a new issue