mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Litellm dev 02 18 2025 p3 (#8640)
* fix(team_endpoints.py): cleanup user <-> team association on team delete Fixes issue where user table still listed team membership post delete * test(test_team.py): update e2e test - ensure user/team membership is deleted on team delete * fix(base_invoke_transformation.py): fix deepseek r1 transformation remove deepseek name from model url * test(test_completion.py): assert model route not in url * feat(base_invoke_transformation.py): infer region name from model arn prevent errors due to different region name in env var vs. model arn, respect if explicitly set in call though * test: fix test * test: skip on internal server error
This commit is contained in:
parent
bf6c013de0
commit
e08e8eda47
9 changed files with 108 additions and 21 deletions
|
@ -233,6 +233,7 @@ class BaseConfig(ABC):
|
|||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -94,7 +94,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=self._get_aws_region_name(optional_params=optional_params),
|
||||
aws_region_name=self._get_aws_region_name(
|
||||
optional_params=optional_params, model=model
|
||||
),
|
||||
)
|
||||
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
|
@ -114,6 +116,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> dict:
|
||||
|
@ -135,7 +138,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
aws_profile_name = optional_params.get("aws_profile_name", None)
|
||||
aws_web_identity_token = optional_params.get("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None)
|
||||
aws_region_name = self._get_aws_region_name(optional_params)
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params, model=model
|
||||
)
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
|
@ -586,27 +591,60 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
|
||||
modelId = modelId.replace("invoke/", "", 1)
|
||||
if provider == "llama" and "llama/" in modelId:
|
||||
modelId = self._get_model_id_for_llama_like_model(modelId)
|
||||
modelId = self._get_model_id_from_model_with_spec(modelId, spec="llama")
|
||||
elif provider == "deepseek_r1" and "deepseek_r1/" in modelId:
|
||||
modelId = self._get_model_id_from_model_with_spec(
|
||||
modelId, spec="deepseek_r1"
|
||||
)
|
||||
return modelId
|
||||
|
||||
def _get_aws_region_name(self, optional_params: dict) -> str:
|
||||
def get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
# First check if the string contains the expected prefix
|
||||
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
|
||||
return None
|
||||
|
||||
# Split the ARN and check if we have enough parts
|
||||
parts = model.split(":")
|
||||
if len(parts) < 4:
|
||||
return None
|
||||
|
||||
# Get the region from the correct position
|
||||
region = parts[3]
|
||||
if not region: # Check if region is empty
|
||||
return None
|
||||
|
||||
return region
|
||||
except Exception:
|
||||
# Catch any unexpected errors and return None
|
||||
return None
|
||||
|
||||
def _get_aws_region_name(
|
||||
self, optional_params: dict, model: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get the AWS region name from the environment variables
|
||||
"""
|
||||
aws_region_name = optional_params.get("aws_region_name", None)
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check model arn #
|
||||
aws_region_name = self.get_aws_region_from_model_arn(model)
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if litellm_aws_region_name is not None and isinstance(
|
||||
litellm_aws_region_name, str
|
||||
if (
|
||||
aws_region_name is None
|
||||
and litellm_aws_region_name is not None
|
||||
and isinstance(litellm_aws_region_name, str)
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if standard_aws_region_name is not None and isinstance(
|
||||
standard_aws_region_name, str
|
||||
if (
|
||||
aws_region_name is None
|
||||
and standard_aws_region_name is not None
|
||||
and isinstance(standard_aws_region_name, str)
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
|
@ -615,14 +653,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
|
||||
return aws_region_name
|
||||
|
||||
def _get_model_id_for_llama_like_model(
|
||||
def _get_model_id_from_model_with_spec(
|
||||
self,
|
||||
model: str,
|
||||
spec: str,
|
||||
) -> str:
|
||||
"""
|
||||
Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models
|
||||
"""
|
||||
model_id = model.replace("llama/", "")
|
||||
model_id = model.replace(spec + "/", "")
|
||||
return self.encode_model_id(model_id=model_id)
|
||||
|
||||
def encode_model_id(self, model_id: str) -> str:
|
||||
|
|
|
@ -247,6 +247,7 @@ class BaseLLMHTTPHandler:
|
|||
api_base=api_base,
|
||||
stream=stream,
|
||||
fake_stream=fake_stream,
|
||||
model=model,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
|
|
|
@ -814,7 +814,6 @@ async def team_member_add(
|
|||
@management_endpoint_wrapper
|
||||
async def team_member_delete(
|
||||
data: TeamMemberDeleteRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
|
@ -1128,15 +1127,21 @@ async def delete_team(
|
|||
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
|
||||
|
||||
# check that all teams passed exist
|
||||
team_rows: List[LiteLLM_TeamTable] = []
|
||||
for team_id in data.team_ids:
|
||||
team_row = await prisma_client.get_data( # type: ignore
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
if team_row is None:
|
||||
try:
|
||||
team_row_base: BaseModel = (
|
||||
await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
status_code=400,
|
||||
detail={"error": f"Team not found, passed team_id={team_id}"},
|
||||
)
|
||||
team_row_pydantic = LiteLLM_TeamTable(**team_row_base.model_dump())
|
||||
team_rows.append(team_row_pydantic)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
||||
|
@ -1174,6 +1179,26 @@ async def delete_team(
|
|||
|
||||
## DELETE ASSOCIATED KEYS
|
||||
await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key")
|
||||
|
||||
# ## DELETE TEAM MEMBERSHIPS
|
||||
for team_row in team_rows:
|
||||
### get all team members
|
||||
team_members = team_row.members_with_roles
|
||||
### call team_member_delete for each team member
|
||||
tasks = []
|
||||
for team_member in team_members:
|
||||
tasks.append(
|
||||
team_member_delete(
|
||||
data=TeamMemberDeleteRequest(
|
||||
team_id=team_row.team_id,
|
||||
user_id=team_member.user_id,
|
||||
user_email=team_member.user_email,
|
||||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
## DELETE TEAMS
|
||||
deleted_teams = await prisma_client.delete_data(
|
||||
team_id_list=data.team_ids, table_name="team"
|
||||
|
|
|
@ -2614,7 +2614,7 @@ def test_bedrock_custom_deepseek():
|
|||
# Verify the URL
|
||||
assert (
|
||||
mock_post.call_args.kwargs["url"]
|
||||
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke"
|
||||
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke"
|
||||
)
|
||||
|
||||
# Verify the request body format
|
||||
|
|
|
@ -294,7 +294,7 @@ class TestOpenAIChatCompletion(BaseLLMChatTest):
|
|||
)
|
||||
assert response is not None
|
||||
except litellm.InternalServerError:
|
||||
pytest.skip("Skipping test due to InternalServerError")
|
||||
pytest.skip("OpenAI API is raising internal server errors")
|
||||
|
||||
|
||||
def test_completion_bad_org():
|
||||
|
|
|
@ -3311,12 +3311,16 @@ def test_bedrock_deepseek_custom_prompt_dict():
|
|||
)
|
||||
|
||||
|
||||
def test_bedrock_deepseek_known_tokenizer_config():
|
||||
model = "deepseek_r1/arn:aws:bedrock:us-east-1:1234:imported-model/45d34re"
|
||||
def test_bedrock_deepseek_known_tokenizer_config(monkeypatch):
|
||||
model = (
|
||||
"deepseek_r1/arn:aws:bedrock:us-west-2:888602223428:imported-model/bnnr6463ejgf"
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from unittest.mock import Mock
|
||||
import httpx
|
||||
|
||||
monkeypatch.setenv("AWS_REGION", "us-east-1")
|
||||
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {
|
||||
|
@ -3350,6 +3354,10 @@ def test_bedrock_deepseek_known_tokenizer_config():
|
|||
|
||||
mock_post.assert_called_once()
|
||||
print(mock_post.call_args.kwargs)
|
||||
url = mock_post.call_args.kwargs["url"]
|
||||
assert "deepseek_r1" not in url
|
||||
assert "us-east-1" not in url
|
||||
assert "us-west-2" in url
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert (
|
||||
json_data["prompt"].rstrip()
|
||||
|
|
|
@ -538,6 +538,13 @@ async def test_team_delete():
|
|||
{"role": "user", "user_id": normal_user},
|
||||
]
|
||||
team_data = await new_team(session=session, i=0, member_list=member_list)
|
||||
|
||||
## ASSERT USER MEMBERSHIP IS CREATED
|
||||
user_info = await get_user_info(
|
||||
session=session, get_user=normal_user, call_user="sk-1234"
|
||||
)
|
||||
assert len(user_info["teams"]) == 1
|
||||
|
||||
## Create key
|
||||
key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"])
|
||||
key = key_gen["key"]
|
||||
|
@ -546,6 +553,12 @@ async def test_team_delete():
|
|||
## Delete team
|
||||
await delete_team(session=session, i=0, team_id=team_data["team_id"])
|
||||
|
||||
## ASSERT USER MEMBERSHIP IS DELETED
|
||||
user_info = await get_user_info(
|
||||
session=session, get_user=normal_user, call_user="sk-1234"
|
||||
)
|
||||
assert len(user_info["teams"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dimension", ["user_id", "user_email"])
|
||||
@pytest.mark.asyncio
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue