LiteLLM Minor Fixes & Improvements (11/06/2024) (#6624)

* refactor(proxy_server.py): add debug logging around license check event (refactor position in startup_event logic)

* fix(proxy/_types.py): allow admin_allowed_routes to be any str

* fix(router.py): raise 400-status code error for no 'model_name' error on router

Fixes issue with status code when unknown model name passed with pattern matching enabled

* fix(converse_handler.py): add claude 3-5 haiku to bedrock converse models

* test: update testing to replace claude-instant-1.2

* fix(router.py): fix router.moderation calls

* test: update test to remove claude-instant-1

* fix(router.py): support model_list values in router.moderation

* test: fix test

* test: fix test
This commit is contained in:
Krish Dholakia 2024-11-07 04:37:32 +05:30 committed by GitHub
parent 136693cac4
commit 0c204d33bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 180 additions and 130 deletions

View file

@ -19,6 +19,7 @@ from ..common_utils import BedrockError
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
BEDROCK_CONVERSE_MODELS = [ BEDROCK_CONVERSE_MODELS = [
"anthropic.claude-3-5-haiku-20241022-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0", "anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0", "anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-opus-20240229-v1:0", "anthropic.claude-3-opus-20240229-v1:0",

View file

@ -4319,9 +4319,9 @@ async def amoderation(
else: else:
_openai_client = openai_client _openai_client = openai_client
if model is not None: if model is not None:
response = await openai_client.moderations.create(input=input, model=model) response = await _openai_client.moderations.create(input=input, model=model)
else: else:
response = await openai_client.moderations.create(input=input) response = await _openai_client.moderations.create(input=input)
return response return response

View file

@ -23,6 +23,31 @@ model_list:
model: openai/my-fake-model model: openai/my-fake-model
api_key: my-fake-key api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
## bedrock chat completions
- model_name: "*anthropic.claude*"
litellm_params:
model: bedrock/*anthropic.claude*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
guardrailConfig:
"guardrailIdentifier": "h4dsqwhp6j66"
"guardrailVersion": "2"
"trace": "enabled"
## bedrock embeddings
- model_name: "*amazon.titan-embed-*"
litellm_params:
model: bedrock/amazon.titan-embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "*cohere.embed-*"
litellm_params:
model: bedrock/cohere.embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: gpt-4 - model_name: gpt-4
litellm_params: litellm_params:
@ -33,6 +58,7 @@ model_list:
rpm: 480 rpm: 480
timeout: 300 timeout: 300
stream_timeout: 60 stream_timeout: 60
# litellm_settings: # litellm_settings:
# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }] # fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
# callbacks: ["otel", "prometheus"] # callbacks: ["otel", "prometheus"]

View file

@ -436,15 +436,7 @@ class LiteLLM_JWTAuth(LiteLLMBase):
""" """
admin_jwt_scope: str = "litellm_proxy_admin" admin_jwt_scope: str = "litellm_proxy_admin"
admin_allowed_routes: List[ admin_allowed_routes: List[str] = [
Literal[
"openai_routes",
"info_routes",
"management_routes",
"spend_tracking_routes",
"global_spend_tracking_routes",
]
] = [
"management_routes", "management_routes",
"spend_tracking_routes", "spend_tracking_routes",
"global_spend_tracking_routes", "global_spend_tracking_routes",

View file

@ -5,6 +5,9 @@ import json
import os import os
import traceback import traceback
from datetime import datetime from datetime import datetime
from typing import Optional
import httpx
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
@ -44,23 +47,46 @@ class LicenseCheck:
verbose_proxy_logger.error(f"Error reading public key: {str(e)}") verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
def _verify(self, license_str: str) -> bool: def _verify(self, license_str: str) -> bool:
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format(
self.base_url, license_str
)
)
url = "{}/verify_license/{}".format(self.base_url, license_str) url = "{}/verify_license/{}".format(self.base_url, license_str)
response: Optional[httpx.Response] = None
try: # don't impact user, if call fails try: # don't impact user, if call fails
response = self.http_handler.get(url=url) num_retries = 3
for i in range(num_retries):
try:
response = self.http_handler.get(url=url)
if response is None:
raise Exception("No response from license server")
response.raise_for_status()
except httpx.HTTPStatusError:
if i == num_retries - 1:
raise
response.raise_for_status() if response is None:
raise Exception("No response from license server")
response_json = response.json() response_json = response.json()
premium = response_json["verify"] premium = response_json["verify"]
assert isinstance(premium, bool) assert isinstance(premium, bool)
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::_verify - License={} is premium={}".format(
license_str, premium
)
)
return premium return premium
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.exception(
"litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License via api. - {}".format( "litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License={} via api. - {}".format(
str(e) license_str, str(e)
) )
) )
return False return False
@ -72,7 +98,7 @@ class LicenseCheck:
""" """
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - {}".format( "litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - LiteLLM License={}".format(
self.license_str self.license_str
) )
) )

View file

@ -694,6 +694,9 @@ def run_server( # noqa: PLR0915
import litellm import litellm
if detailed_debug is True:
litellm._turn_on_debug()
# DO NOT DELETE - enables global variables to work across files # DO NOT DELETE - enables global variables to work across files
from litellm.proxy.proxy_server import app # noqa from litellm.proxy.proxy_server import app # noqa

View file

@ -3074,6 +3074,15 @@ async def startup_event():
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
) )
## CHECK PREMIUM USER
verbose_proxy_logger.debug(
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
premium_user
)
)
if premium_user is False:
premium_user = _license_check.is_premium()
### LOAD CONFIG ### ### LOAD CONFIG ###
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH") env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
@ -3121,21 +3130,6 @@ async def startup_event():
if isinstance(worker_config, dict): if isinstance(worker_config, dict):
await initialize(**worker_config) await initialize(**worker_config)
## CHECK PREMIUM USER
verbose_proxy_logger.debug(
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
premium_user
)
)
if premium_user is False:
premium_user = _license_check.is_premium()
verbose_proxy_logger.debug(
"litellm.proxy.proxy_server.py::startup() - PREMIUM USER value - {}".format(
premium_user
)
)
ProxyStartupEvent._initialize_startup_logging( ProxyStartupEvent._initialize_startup_logging(
llm_router=llm_router, llm_router=llm_router,
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,

View file

@ -65,6 +65,7 @@ async def route_request(
Common helper to route the request Common helper to route the request
""" """
router_model_names = llm_router.model_names if llm_router is not None else [] router_model_names = llm_router.model_names if llm_router is not None else []
if "api_key" in data or "api_base" in data: if "api_key" in data or "api_base" in data:
return getattr(litellm, f"{route_type}")(**data) return getattr(litellm, f"{route_type}")(**data)

View file

@ -556,6 +556,10 @@ class Router:
self.initialize_assistants_endpoint() self.initialize_assistants_endpoint()
self.amoderation = self.factory_function(
litellm.amoderation, call_type="moderation"
)
def initialize_assistants_endpoint(self): def initialize_assistants_endpoint(self):
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ## ## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
self.acreate_assistants = self.factory_function(litellm.acreate_assistants) self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
@ -1683,78 +1687,6 @@ class Router:
) )
raise e raise e
async def amoderation(self, model: str, input: str, **kwargs):
try:
kwargs["model"] = model
kwargs["input"] = input
kwargs["original_function"] = self._amoderation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
asyncio.create_task(
send_llm_exception_alert(
litellm_router_instance=self,
request_kwargs=kwargs,
error_traceback_str=traceback.format_exc(),
original_exception=e,
)
)
raise e
async def _amoderation(self, model: str, input: str, **kwargs):
model_name = None
try:
verbose_router_logger.debug(
f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
)
deployment = await self.async_get_available_deployment(
model=model,
input=input,
specific_deployment=kwargs.pop("specific_deployment", None),
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
data = deployment["litellm_params"].copy()
model_name = data["model"]
model_client = self._get_async_openai_model_client(
deployment=deployment,
kwargs=kwargs,
)
self.total_calls[model_name] += 1
timeout: Optional[Union[float, int]] = self._get_timeout(
kwargs=kwargs,
data=data,
)
response = await litellm.amoderation(
**{
**data,
"input": input,
"caching": self.cache_responses,
"client": model_client,
"timeout": timeout,
**kwargs,
}
)
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
async def arerank(self, model: str, **kwargs): async def arerank(self, model: str, **kwargs):
try: try:
kwargs["model"] = model kwargs["model"] = model
@ -2610,20 +2542,46 @@ class Router:
return final_results return final_results
#### ASSISTANTS API #### #### PASSTHROUGH API ####
def factory_function(self, original_function: Callable): async def _pass_through_moderation_endpoint_factory(
self,
original_function: Callable,
**kwargs,
):
if (
"model" in kwargs
and self.get_model_list(model_name=kwargs["model"]) is not None
):
deployment = await self.async_get_available_deployment(
model=kwargs["model"]
)
kwargs["model"] = deployment["litellm_params"]["model"]
return await original_function(**kwargs)
def factory_function(
self,
original_function: Callable,
call_type: Literal["assistants", "moderation"] = "assistants",
):
async def new_function( async def new_function(
custom_llm_provider: Optional[Literal["openai", "azure"]] = None, custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional["AsyncOpenAI"] = None, client: Optional["AsyncOpenAI"] = None,
**kwargs, **kwargs,
): ):
return await self._pass_through_assistants_endpoint_factory( if call_type == "assistants":
original_function=original_function, return await self._pass_through_assistants_endpoint_factory(
custom_llm_provider=custom_llm_provider, original_function=original_function,
client=client, custom_llm_provider=custom_llm_provider,
**kwargs, client=client,
) **kwargs,
)
elif call_type == "moderation":
return await self._pass_through_moderation_endpoint_factory( # type: ignore
original_function=original_function,
**kwargs,
)
return new_function return new_function
@ -5052,10 +5010,12 @@ class Router:
) )
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
raise ValueError( raise litellm.BadRequestError(
"{}. You passed in model={}. There is no 'model_name' with this string ".format( message="You passed in model={}. There is no 'model_name' with this string ".format(
RouterErrors.no_deployments_available.value, model model
) ),
model=model,
llm_provider="",
) )
if litellm.model_alias_map and model in litellm.model_alias_map: if litellm.model_alias_map and model in litellm.model_alias_map:

View file

@ -1043,6 +1043,7 @@ def client(original_function): # noqa: PLR0915
if ( if (
call_type != CallTypes.aimage_generation.value # model optional call_type != CallTypes.aimage_generation.value # model optional
and call_type != CallTypes.atext_completion.value # can also be engine and call_type != CallTypes.atext_completion.value # can also be engine
and call_type != CallTypes.amoderation.value
): ):
raise ValueError("model param not passed in.") raise ValueError("model param not passed in.")

View file

@ -689,9 +689,10 @@ async def aaaatest_user_token_output(
assert team_result.user_id == user_id assert team_result.user_id == user_id
@pytest.mark.parametrize("admin_allowed_routes", [None, ["ui_routes"]])
@pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_allowed_routes_admin(prisma_client, audience): async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_routes):
""" """
Add a check to make sure jwt proxy admin scope can access all allowed admin routes Add a check to make sure jwt proxy admin scope can access all allowed admin routes
@ -754,12 +755,17 @@ async def test_allowed_routes_admin(prisma_client, audience):
jwt_handler.user_api_key_cache = cache jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id") if admin_allowed_routes:
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
team_id_jwt_field="client_id", admin_allowed_routes=admin_allowed_routes
)
else:
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
# VALID TOKEN # VALID TOKEN
## GENERATE A TOKEN ## GENERATE A TOKEN
# Assuming the current time is in UTC # Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp())
# Generate the JWT token # Generate the JWT token
# But before, you should convert bytes to string # But before, you should convert bytes to string
@ -777,6 +783,7 @@ async def test_allowed_routes_admin(prisma_client, audience):
# verify token # verify token
print(f"admin_token: {admin_token}")
response = await jwt_handler.auth_jwt(token=admin_token) response = await jwt_handler.auth_jwt(token=admin_token)
## RUN IT THROUGH USER API KEY AUTH ## RUN IT THROUGH USER API KEY AUTH

View file

@ -1866,16 +1866,9 @@ async def test_router_amoderation():
router = Router(model_list=model_list) router = Router(model_list=model_list)
## Test 1: user facing function ## Test 1: user facing function
result = await router.amoderation( result = await router.amoderation(
model="openai-moderations", input="this is valid good text" model="text-moderation-stable", input="this is valid good text"
) )
## Test 2: underlying function
result = await router._amoderation(
model="openai-moderations", input="this is valid good text"
)
print("moderation result", result)
def test_router_add_deployment(): def test_router_add_deployment():
initial_model_list = [ initial_model_list = [

View file

@ -1226,9 +1226,7 @@ async def test_using_default_fallback(sync_mode):
pytest.fail(f"Expected call to fail we passed model=openai/foo") pytest.fail(f"Expected call to fail we passed model=openai/foo")
except Exception as e: except Exception as e:
print("got exception = ", e) print("got exception = ", e)
from litellm.types.router import RouterErrors assert "BadRequestError" in str(e)
assert RouterErrors.no_deployments_available.value in str(e)
@pytest.mark.parametrize("sync_mode", [False]) @pytest.mark.parametrize("sync_mode", [False])

View file

@ -158,6 +158,46 @@ def test_route_with_exception():
assert result is None assert result is None
@pytest.mark.asyncio
async def test_route_with_no_matching_pattern():
"""
Tests that the router returns None when there is no matching pattern
"""
from litellm.types.router import RouterErrors
router = Router(
model_list=[
{
"model_name": "*meta.llama3*",
"litellm_params": {"model": "bedrock/meta.llama3*"},
}
]
)
## WORKS
result = await router.acompletion(
model="bedrock/meta.llama3-70b",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Works",
)
assert result.choices[0].message.content == "Works"
## FAILS
with pytest.raises(litellm.BadRequestError) as e:
await router.acompletion(
model="my-fake-model",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Works",
)
assert RouterErrors.no_deployments_available.value not in str(e.value)
with pytest.raises(litellm.BadRequestError):
await router.aembedding(
model="my-fake-model",
input="Hello, world!",
)
def test_router_pattern_match_e2e(): def test_router_pattern_match_e2e():
""" """
Tests the end to end flow of the router Tests the end to end flow of the router
@ -188,3 +228,4 @@ def test_router_pattern_match_e2e():
"model": "gpt-4o", "model": "gpt-4o",
"messages": [{"role": "user", "content": "Hello, how are you?"}], "messages": [{"role": "user", "content": "Hello, how are you?"}],
} }

View file

@ -999,3 +999,10 @@ def test_pattern_match_deployment_set_model_name(
for model in updated_models: for model in updated_models:
assert model["litellm_params"]["model"] == expected_model assert model["litellm_params"]["model"] == expected_model
@pytest.mark.asyncio
async def test_pass_through_moderation_endpoint_factory(model_list):
router = Router(model_list=model_list)
response = await router._pass_through_moderation_endpoint_factory(
original_function=litellm.amoderation, input="this is valid good text"
)