Litellm dev 02 18 2025 p3 (#8640)

* fix(team_endpoints.py): cleanup user <-> team association on team delete

Fixes issue where user table still listed team membership post delete

* test(test_team.py): update e2e test - ensure user/team membership is deleted on team delete

* fix(base_invoke_transformation.py): fix deepseek r1 transformation

remove deepseek name from model url

* test(test_completion.py): assert model route not in url

* feat(base_invoke_transformation.py): infer region name from model arn

prevent errors due to different region name in env var vs. model arn, respect if explicitly set in call though

* test: fix test

* test: skip on internal server error
This commit is contained in:
Krish Dholakia 2025-02-18 19:14:20 -08:00 committed by GitHub
parent bf6c013de0
commit e08e8eda47
9 changed files with 108 additions and 21 deletions

View file

@ -233,6 +233,7 @@ class BaseConfig(ABC):
optional_params: dict,
request_data: dict,
api_base: str,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> dict:

View file

@ -94,7 +94,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=self._get_aws_region_name(optional_params=optional_params),
aws_region_name=self._get_aws_region_name(
optional_params=optional_params, model=model
),
)
if (stream is not None and stream is True) and provider != "ai21":
@ -114,6 +116,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
optional_params: dict,
request_data: dict,
api_base: str,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> dict:
@ -135,7 +138,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
aws_profile_name = optional_params.get("aws_profile_name", None)
aws_web_identity_token = optional_params.get("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None)
aws_region_name = self._get_aws_region_name(optional_params)
aws_region_name = self._get_aws_region_name(
optional_params=optional_params, model=model
)
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
@ -586,27 +591,60 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
modelId = modelId.replace("invoke/", "", 1)
if provider == "llama" and "llama/" in modelId:
modelId = self._get_model_id_for_llama_like_model(modelId)
modelId = self._get_model_id_from_model_with_spec(modelId, spec="llama")
elif provider == "deepseek_r1" and "deepseek_r1/" in modelId:
modelId = self._get_model_id_from_model_with_spec(
modelId, spec="deepseek_r1"
)
return modelId
def _get_aws_region_name(self, optional_params: dict) -> str:
def get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
try:
# First check if the string contains the expected prefix
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
return None
# Split the ARN and check if we have enough parts
parts = model.split(":")
if len(parts) < 4:
return None
# Get the region from the correct position
region = parts[3]
if not region: # Check if region is empty
return None
return region
except Exception:
# Catch any unexpected errors and return None
return None
def _get_aws_region_name(
self, optional_params: dict, model: Optional[str] = None
) -> str:
"""
Get the AWS region name from the environment variables
"""
aws_region_name = optional_params.get("aws_region_name", None)
### SET REGION NAME ###
if aws_region_name is None:
# check model arn #
aws_region_name = self.get_aws_region_from_model_arn(model)
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
if (
aws_region_name is None
and litellm_aws_region_name is not None
and isinstance(litellm_aws_region_name, str)
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
if (
aws_region_name is None
and standard_aws_region_name is not None
and isinstance(standard_aws_region_name, str)
):
aws_region_name = standard_aws_region_name
@ -615,14 +653,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
return aws_region_name
def _get_model_id_for_llama_like_model(
def _get_model_id_from_model_with_spec(
self,
model: str,
spec: str,
) -> str:
"""
Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models
"""
model_id = model.replace("llama/", "")
model_id = model.replace(spec + "/", "")
return self.encode_model_id(model_id=model_id)
def encode_model_id(self, model_id: str) -> str:

View file

@ -247,6 +247,7 @@ class BaseLLMHTTPHandler:
api_base=api_base,
stream=stream,
fake_stream=fake_stream,
model=model,
)
## LOGGING

View file

@ -814,7 +814,6 @@ async def team_member_add(
@management_endpoint_wrapper
async def team_member_delete(
data: TeamMemberDeleteRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
@ -1128,15 +1127,21 @@ async def delete_team(
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
# check that all teams passed exist
team_rows: List[LiteLLM_TeamTable] = []
for team_id in data.team_ids:
team_row = await prisma_client.get_data( # type: ignore
team_id=team_id, table_name="team", query_type="find_unique"
try:
team_row_base: BaseModel = (
await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if team_row is None:
)
except Exception:
raise HTTPException(
status_code=404,
status_code=400,
detail={"error": f"Team not found, passed team_id={team_id}"},
)
team_row_pydantic = LiteLLM_TeamTable(**team_row_base.model_dump())
team_rows.append(team_row_pydantic)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
@ -1174,6 +1179,26 @@ async def delete_team(
## DELETE ASSOCIATED KEYS
await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key")
# ## DELETE TEAM MEMBERSHIPS
for team_row in team_rows:
### get all team members
team_members = team_row.members_with_roles
### call team_member_delete for each team member
tasks = []
for team_member in team_members:
tasks.append(
team_member_delete(
data=TeamMemberDeleteRequest(
team_id=team_row.team_id,
user_id=team_member.user_id,
user_email=team_member.user_email,
),
user_api_key_dict=user_api_key_dict,
)
)
await asyncio.gather(*tasks)
## DELETE TEAMS
deleted_teams = await prisma_client.delete_data(
team_id_list=data.team_ids, table_name="team"

View file

@ -2614,7 +2614,7 @@ def test_bedrock_custom_deepseek():
# Verify the URL
assert (
mock_post.call_args.kwargs["url"]
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke"
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke"
)
# Verify the request body format

View file

@ -294,7 +294,7 @@ class TestOpenAIChatCompletion(BaseLLMChatTest):
)
assert response is not None
except litellm.InternalServerError:
pytest.skip("Skipping test due to InternalServerError")
pytest.skip("OpenAI API is raising internal server errors")
def test_completion_bad_org():

View file

@ -3311,12 +3311,16 @@ def test_bedrock_deepseek_custom_prompt_dict():
)
def test_bedrock_deepseek_known_tokenizer_config():
model = "deepseek_r1/arn:aws:bedrock:us-east-1:1234:imported-model/45d34re"
def test_bedrock_deepseek_known_tokenizer_config(monkeypatch):
model = (
"deepseek_r1/arn:aws:bedrock:us-west-2:888602223428:imported-model/bnnr6463ejgf"
)
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import Mock
import httpx
monkeypatch.setenv("AWS_REGION", "us-east-1")
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.headers = {
@ -3350,6 +3354,10 @@ def test_bedrock_deepseek_known_tokenizer_config():
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)
url = mock_post.call_args.kwargs["url"]
assert "deepseek_r1" not in url
assert "us-east-1" not in url
assert "us-west-2" in url
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert (
json_data["prompt"].rstrip()

View file

@ -538,6 +538,13 @@ async def test_team_delete():
{"role": "user", "user_id": normal_user},
]
team_data = await new_team(session=session, i=0, member_list=member_list)
## ASSERT USER MEMBERSHIP IS CREATED
user_info = await get_user_info(
session=session, get_user=normal_user, call_user="sk-1234"
)
assert len(user_info["teams"]) == 1
## Create key
key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"])
key = key_gen["key"]
@ -546,6 +553,12 @@ async def test_team_delete():
## Delete team
await delete_team(session=session, i=0, team_id=team_data["team_id"])
## ASSERT USER MEMBERSHIP IS DELETED
user_info = await get_user_info(
session=session, get_user=normal_user, call_user="sk-1234"
)
assert len(user_info["teams"]) == 0
@pytest.mark.parametrize("dimension", ["user_id", "user_email"])
@pytest.mark.asyncio