mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm dev 02 18 2025 p3 (#8640)
* fix(team_endpoints.py): cleanup user <-> team association on team delete Fixes issue where user table still listed team membership post delete * test(test_team.py): update e2e test - ensure user/team membership is deleted on team delete * fix(base_invoke_transformation.py): fix deepseek r1 transformation remove deepseek name from model url * test(test_completion.py): assert model route not in url * feat(base_invoke_transformation.py): infer region name from model arn prevent errors due to different region name in env var vs. model arn, respect if explicitly set in call though * test: fix test * test: skip on internal server error
This commit is contained in:
parent
bf6c013de0
commit
e08e8eda47
9 changed files with 108 additions and 21 deletions
|
@ -233,6 +233,7 @@ class BaseConfig(ABC):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
|
model: Optional[str] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
fake_stream: Optional[bool] = None,
|
fake_stream: Optional[bool] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -94,7 +94,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
aws_region_name=self._get_aws_region_name(optional_params=optional_params),
|
aws_region_name=self._get_aws_region_name(
|
||||||
|
optional_params=optional_params, model=model
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if (stream is not None and stream is True) and provider != "ai21":
|
if (stream is not None and stream is True) and provider != "ai21":
|
||||||
|
@ -114,6 +116,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
|
model: Optional[str] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
fake_stream: Optional[bool] = None,
|
fake_stream: Optional[bool] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
@ -135,7 +138,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
aws_profile_name = optional_params.get("aws_profile_name", None)
|
aws_profile_name = optional_params.get("aws_profile_name", None)
|
||||||
aws_web_identity_token = optional_params.get("aws_web_identity_token", None)
|
aws_web_identity_token = optional_params.get("aws_web_identity_token", None)
|
||||||
aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None)
|
aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None)
|
||||||
aws_region_name = self._get_aws_region_name(optional_params)
|
aws_region_name = self._get_aws_region_name(
|
||||||
|
optional_params=optional_params, model=model
|
||||||
|
)
|
||||||
|
|
||||||
credentials: Credentials = self.get_credentials(
|
credentials: Credentials = self.get_credentials(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
@ -586,27 +591,60 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
|
|
||||||
modelId = modelId.replace("invoke/", "", 1)
|
modelId = modelId.replace("invoke/", "", 1)
|
||||||
if provider == "llama" and "llama/" in modelId:
|
if provider == "llama" and "llama/" in modelId:
|
||||||
modelId = self._get_model_id_for_llama_like_model(modelId)
|
modelId = self._get_model_id_from_model_with_spec(modelId, spec="llama")
|
||||||
|
elif provider == "deepseek_r1" and "deepseek_r1/" in modelId:
|
||||||
|
modelId = self._get_model_id_from_model_with_spec(
|
||||||
|
modelId, spec="deepseek_r1"
|
||||||
|
)
|
||||||
return modelId
|
return modelId
|
||||||
|
|
||||||
def _get_aws_region_name(self, optional_params: dict) -> str:
|
def get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
# First check if the string contains the expected prefix
|
||||||
|
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Split the ARN and check if we have enough parts
|
||||||
|
parts = model.split(":")
|
||||||
|
if len(parts) < 4:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get the region from the correct position
|
||||||
|
region = parts[3]
|
||||||
|
if not region: # Check if region is empty
|
||||||
|
return None
|
||||||
|
|
||||||
|
return region
|
||||||
|
except Exception:
|
||||||
|
# Catch any unexpected errors and return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_aws_region_name(
|
||||||
|
self, optional_params: dict, model: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get the AWS region name from the environment variables
|
Get the AWS region name from the environment variables
|
||||||
"""
|
"""
|
||||||
aws_region_name = optional_params.get("aws_region_name", None)
|
aws_region_name = optional_params.get("aws_region_name", None)
|
||||||
### SET REGION NAME ###
|
### SET REGION NAME ###
|
||||||
if aws_region_name is None:
|
if aws_region_name is None:
|
||||||
|
# check model arn #
|
||||||
|
aws_region_name = self.get_aws_region_from_model_arn(model)
|
||||||
# check env #
|
# check env #
|
||||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
|
||||||
if litellm_aws_region_name is not None and isinstance(
|
if (
|
||||||
litellm_aws_region_name, str
|
aws_region_name is None
|
||||||
|
and litellm_aws_region_name is not None
|
||||||
|
and isinstance(litellm_aws_region_name, str)
|
||||||
):
|
):
|
||||||
aws_region_name = litellm_aws_region_name
|
aws_region_name = litellm_aws_region_name
|
||||||
|
|
||||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
if standard_aws_region_name is not None and isinstance(
|
if (
|
||||||
standard_aws_region_name, str
|
aws_region_name is None
|
||||||
|
and standard_aws_region_name is not None
|
||||||
|
and isinstance(standard_aws_region_name, str)
|
||||||
):
|
):
|
||||||
aws_region_name = standard_aws_region_name
|
aws_region_name = standard_aws_region_name
|
||||||
|
|
||||||
|
@ -615,14 +653,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
|
|
||||||
return aws_region_name
|
return aws_region_name
|
||||||
|
|
||||||
def _get_model_id_for_llama_like_model(
|
def _get_model_id_from_model_with_spec(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
spec: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models
|
Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models
|
||||||
"""
|
"""
|
||||||
model_id = model.replace("llama/", "")
|
model_id = model.replace(spec + "/", "")
|
||||||
return self.encode_model_id(model_id=model_id)
|
return self.encode_model_id(model_id=model_id)
|
||||||
|
|
||||||
def encode_model_id(self, model_id: str) -> str:
|
def encode_model_id(self, model_id: str) -> str:
|
||||||
|
|
|
@ -247,6 +247,7 @@ class BaseLLMHTTPHandler:
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
fake_stream=fake_stream,
|
fake_stream=fake_stream,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -814,7 +814,6 @@ async def team_member_add(
|
||||||
@management_endpoint_wrapper
|
@management_endpoint_wrapper
|
||||||
async def team_member_delete(
|
async def team_member_delete(
|
||||||
data: TeamMemberDeleteRequest,
|
data: TeamMemberDeleteRequest,
|
||||||
http_request: Request,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -1128,15 +1127,21 @@ async def delete_team(
|
||||||
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
|
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
|
||||||
|
|
||||||
# check that all teams passed exist
|
# check that all teams passed exist
|
||||||
|
team_rows: List[LiteLLM_TeamTable] = []
|
||||||
for team_id in data.team_ids:
|
for team_id in data.team_ids:
|
||||||
team_row = await prisma_client.get_data( # type: ignore
|
try:
|
||||||
team_id=team_id, table_name="team", query_type="find_unique"
|
team_row_base: BaseModel = (
|
||||||
)
|
await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
if team_row is None:
|
where={"team_id": team_id}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=400,
|
||||||
detail={"error": f"Team not found, passed team_id={team_id}"},
|
detail={"error": f"Team not found, passed team_id={team_id}"},
|
||||||
)
|
)
|
||||||
|
team_row_pydantic = LiteLLM_TeamTable(**team_row_base.model_dump())
|
||||||
|
team_rows.append(team_row_pydantic)
|
||||||
|
|
||||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||||
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
||||||
|
@ -1174,6 +1179,26 @@ async def delete_team(
|
||||||
|
|
||||||
## DELETE ASSOCIATED KEYS
|
## DELETE ASSOCIATED KEYS
|
||||||
await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key")
|
await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key")
|
||||||
|
|
||||||
|
# ## DELETE TEAM MEMBERSHIPS
|
||||||
|
for team_row in team_rows:
|
||||||
|
### get all team members
|
||||||
|
team_members = team_row.members_with_roles
|
||||||
|
### call team_member_delete for each team member
|
||||||
|
tasks = []
|
||||||
|
for team_member in team_members:
|
||||||
|
tasks.append(
|
||||||
|
team_member_delete(
|
||||||
|
data=TeamMemberDeleteRequest(
|
||||||
|
team_id=team_row.team_id,
|
||||||
|
user_id=team_member.user_id,
|
||||||
|
user_email=team_member.user_email,
|
||||||
|
),
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
## DELETE TEAMS
|
## DELETE TEAMS
|
||||||
deleted_teams = await prisma_client.delete_data(
|
deleted_teams = await prisma_client.delete_data(
|
||||||
team_id_list=data.team_ids, table_name="team"
|
team_id_list=data.team_ids, table_name="team"
|
||||||
|
|
|
@ -2614,7 +2614,7 @@ def test_bedrock_custom_deepseek():
|
||||||
# Verify the URL
|
# Verify the URL
|
||||||
assert (
|
assert (
|
||||||
mock_post.call_args.kwargs["url"]
|
mock_post.call_args.kwargs["url"]
|
||||||
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke"
|
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the request body format
|
# Verify the request body format
|
||||||
|
|
|
@ -294,7 +294,7 @@ class TestOpenAIChatCompletion(BaseLLMChatTest):
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
except litellm.InternalServerError:
|
except litellm.InternalServerError:
|
||||||
pytest.skip("Skipping test due to InternalServerError")
|
pytest.skip("OpenAI API is raising internal server errors")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_bad_org():
|
def test_completion_bad_org():
|
||||||
|
|
|
@ -3311,12 +3311,16 @@ def test_bedrock_deepseek_custom_prompt_dict():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_bedrock_deepseek_known_tokenizer_config():
|
def test_bedrock_deepseek_known_tokenizer_config(monkeypatch):
|
||||||
model = "deepseek_r1/arn:aws:bedrock:us-east-1:1234:imported-model/45d34re"
|
model = (
|
||||||
|
"deepseek_r1/arn:aws:bedrock:us-west-2:888602223428:imported-model/bnnr6463ejgf"
|
||||||
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
monkeypatch.setenv("AWS_REGION", "us-east-1")
|
||||||
|
|
||||||
mock_response = Mock(spec=httpx.Response)
|
mock_response = Mock(spec=httpx.Response)
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.headers = {
|
mock_response.headers = {
|
||||||
|
@ -3350,6 +3354,10 @@ def test_bedrock_deepseek_known_tokenizer_config():
|
||||||
|
|
||||||
mock_post.assert_called_once()
|
mock_post.assert_called_once()
|
||||||
print(mock_post.call_args.kwargs)
|
print(mock_post.call_args.kwargs)
|
||||||
|
url = mock_post.call_args.kwargs["url"]
|
||||||
|
assert "deepseek_r1" not in url
|
||||||
|
assert "us-east-1" not in url
|
||||||
|
assert "us-west-2" in url
|
||||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
assert (
|
assert (
|
||||||
json_data["prompt"].rstrip()
|
json_data["prompt"].rstrip()
|
||||||
|
|
|
@ -538,6 +538,13 @@ async def test_team_delete():
|
||||||
{"role": "user", "user_id": normal_user},
|
{"role": "user", "user_id": normal_user},
|
||||||
]
|
]
|
||||||
team_data = await new_team(session=session, i=0, member_list=member_list)
|
team_data = await new_team(session=session, i=0, member_list=member_list)
|
||||||
|
|
||||||
|
## ASSERT USER MEMBERSHIP IS CREATED
|
||||||
|
user_info = await get_user_info(
|
||||||
|
session=session, get_user=normal_user, call_user="sk-1234"
|
||||||
|
)
|
||||||
|
assert len(user_info["teams"]) == 1
|
||||||
|
|
||||||
## Create key
|
## Create key
|
||||||
key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"])
|
key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"])
|
||||||
key = key_gen["key"]
|
key = key_gen["key"]
|
||||||
|
@ -546,6 +553,12 @@ async def test_team_delete():
|
||||||
## Delete team
|
## Delete team
|
||||||
await delete_team(session=session, i=0, team_id=team_data["team_id"])
|
await delete_team(session=session, i=0, team_id=team_data["team_id"])
|
||||||
|
|
||||||
|
## ASSERT USER MEMBERSHIP IS DELETED
|
||||||
|
user_info = await get_user_info(
|
||||||
|
session=session, get_user=normal_user, call_user="sk-1234"
|
||||||
|
)
|
||||||
|
assert len(user_info["teams"]) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dimension", ["user_id", "user_email"])
|
@pytest.mark.parametrize("dimension", ["user_id", "user_email"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue