diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py index 9d3778ed68..d98931d23b 100644 --- a/litellm/llms/base_llm/chat/transformation.py +++ b/litellm/llms/base_llm/chat/transformation.py @@ -233,6 +233,7 @@ class BaseConfig(ABC): optional_params: dict, request_data: dict, api_base: str, + model: Optional[str] = None, stream: Optional[bool] = None, fake_stream: Optional[bool] = None, ) -> dict: diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py index a080e55bb3..f369057744 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py @@ -94,7 +94,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint( api_base=api_base, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, - aws_region_name=self._get_aws_region_name(optional_params=optional_params), + aws_region_name=self._get_aws_region_name( + optional_params=optional_params, model=model + ), ) if (stream is not None and stream is True) and provider != "ai21": @@ -114,6 +116,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): optional_params: dict, request_data: dict, api_base: str, + model: Optional[str] = None, stream: Optional[bool] = None, fake_stream: Optional[bool] = None, ) -> dict: @@ -135,7 +138,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): aws_profile_name = optional_params.get("aws_profile_name", None) aws_web_identity_token = optional_params.get("aws_web_identity_token", None) aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None) - aws_region_name = self._get_aws_region_name(optional_params) + aws_region_name = self._get_aws_region_name( + optional_params=optional_params, model=model + ) credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, @@ -586,27 +591,60 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): modelId = modelId.replace("invoke/", "", 1) if provider == "llama" and "llama/" in modelId: - modelId = self._get_model_id_for_llama_like_model(modelId) + modelId = self._get_model_id_from_model_with_spec(modelId, spec="llama") + elif provider == "deepseek_r1" and "deepseek_r1/" in modelId: + modelId = self._get_model_id_from_model_with_spec( + modelId, spec="deepseek_r1" + ) return modelId - def _get_aws_region_name(self, optional_params: dict) -> str: + def get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]: + try: + # First check if the string contains the expected prefix + if not isinstance(model, str) or "arn:aws:bedrock" not in model: + return None + + # Split the ARN and check if we have enough parts + parts = model.split(":") + if len(parts) < 4: + return None + + # Get the region from the correct position + region = parts[3] + if not region: # Check if region is empty + return None + + return region + except Exception: + # Catch any unexpected errors and return None + return None + + def _get_aws_region_name( + self, optional_params: dict, model: Optional[str] = None + ) -> str: """ Get the AWS region name from the environment variables """ aws_region_name = optional_params.get("aws_region_name", None) ### SET REGION NAME ### if aws_region_name is None: + # check model arn # + aws_region_name = self.get_aws_region_from_model_arn(model) # check env # litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) - if litellm_aws_region_name is not None and isinstance( - litellm_aws_region_name, str + if ( + aws_region_name is None + and litellm_aws_region_name is not None + and isinstance(litellm_aws_region_name, str) ): aws_region_name = litellm_aws_region_name standard_aws_region_name = get_secret("AWS_REGION", None) - if standard_aws_region_name is not None and isinstance( - standard_aws_region_name, str + if ( + aws_region_name is None + and standard_aws_region_name is not None + and isinstance(standard_aws_region_name, str) ): aws_region_name = standard_aws_region_name @@ -615,14 +653,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): return aws_region_name - def _get_model_id_for_llama_like_model( + def _get_model_id_from_model_with_spec( self, model: str, + spec: str, ) -> str: """ Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models """ - model_id = model.replace("llama/", "") + model_id = model.replace(spec + "/", "") return self.encode_model_id(model_id=model_id) def encode_model_id(self, model_id: str) -> str: diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index eafc345aa6..74cded6c45 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -247,6 +247,7 @@ class BaseLLMHTTPHandler: api_base=api_base, stream=stream, fake_stream=fake_stream, + model=model, ) ## LOGGING diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index dbd566078e..00c6aa14f1 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -7,4 +7,4 @@ model_list: litellm_settings: callbacks: ["prometheus"] - # custom_prometheus_metadata_labels: ["metadata.foo", "metadata.bar"] \ No newline at end of file + # custom_prometheus_metadata_labels: ["metadata.foo", "metadata.bar"] diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 4fa3e4b89b..57d0d0d957 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -814,7 +814,6 @@ async def team_member_add( @management_endpoint_wrapper async def team_member_delete( data: TeamMemberDeleteRequest, - http_request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ @@ -1128,15 +1127,21 @@ async def delete_team( raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) # check that all teams passed exist + team_rows: List[LiteLLM_TeamTable] = [] for team_id in data.team_ids: - team_row = await prisma_client.get_data( # type: ignore - team_id=team_id, table_name="team", query_type="find_unique" - ) - if team_row is None: + try: + team_row_base: BaseModel = ( + await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} + ) + ) + except Exception: raise HTTPException( - status_code=404, + status_code=400, detail={"error": f"Team not found, passed team_id={team_id}"}, ) + team_row_pydantic = LiteLLM_TeamTable(**team_row_base.model_dump()) + team_rows.append(team_row_pydantic) # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes @@ -1174,6 +1179,26 @@ async def delete_team( ## DELETE ASSOCIATED KEYS await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key") + + # ## DELETE TEAM MEMBERSHIPS + for team_row in team_rows: + ### get all team members + team_members = team_row.members_with_roles + ### call team_member_delete for each team member + tasks = [] + for team_member in team_members: + tasks.append( + team_member_delete( + data=TeamMemberDeleteRequest( + team_id=team_row.team_id, + user_id=team_member.user_id, + user_email=team_member.user_email, + ), + user_api_key_dict=user_api_key_dict, + ) + ) + await asyncio.gather(*tasks) + ## DELETE TEAMS deleted_teams = await prisma_client.delete_data( team_id_list=data.team_ids, table_name="team" diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index 8823ccdbe3..685c5e5409 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2614,7 +2614,7 @@ def test_bedrock_custom_deepseek(): # Verify the URL assert ( mock_post.call_args.kwargs["url"] - == "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke" + == "https://bedrock-runtime.us-east-1.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke" ) # Verify the request body format diff --git a/tests/llm_translation/test_openai.py b/tests/llm_translation/test_openai.py index 2f10d42820..6fd8662d16 100644 --- a/tests/llm_translation/test_openai.py +++ b/tests/llm_translation/test_openai.py @@ -294,7 +294,7 @@ class TestOpenAIChatCompletion(BaseLLMChatTest): ) assert response is not None except litellm.InternalServerError: - pytest.skip("Skipping test due to InternalServerError") + pytest.skip("OpenAI API is raising internal server errors") def test_completion_bad_org(): diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 819dea8f93..0539d04aba 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -3311,12 +3311,16 @@ def test_bedrock_deepseek_custom_prompt_dict(): ) -def test_bedrock_deepseek_known_tokenizer_config(): - model = "deepseek_r1/arn:aws:bedrock:us-east-1:1234:imported-model/45d34re" +def test_bedrock_deepseek_known_tokenizer_config(monkeypatch): + model = ( + "deepseek_r1/arn:aws:bedrock:us-west-2:888602223428:imported-model/bnnr6463ejgf" + ) from litellm.llms.custom_httpx.http_handler import HTTPHandler from unittest.mock import Mock import httpx + monkeypatch.setenv("AWS_REGION", "us-east-1") + mock_response = Mock(spec=httpx.Response) mock_response.status_code = 200 mock_response.headers = { @@ -3350,6 +3354,10 @@ def test_bedrock_deepseek_known_tokenizer_config(): mock_post.assert_called_once() print(mock_post.call_args.kwargs) + url = mock_post.call_args.kwargs["url"] + assert "deepseek_r1" not in url + assert "us-east-1" not in url + assert "us-west-2" in url json_data = json.loads(mock_post.call_args.kwargs["data"]) assert ( json_data["prompt"].rstrip() diff --git a/tests/test_team.py b/tests/test_team.py index 2381096daa..d5d71bdc5a 100644 --- a/tests/test_team.py +++ b/tests/test_team.py @@ -538,6 +538,13 @@ async def test_team_delete(): {"role": "user", "user_id": normal_user}, ] team_data = await new_team(session=session, i=0, member_list=member_list) + + ## ASSERT USER MEMBERSHIP IS CREATED + user_info = await get_user_info( + session=session, get_user=normal_user, call_user="sk-1234" + ) + assert len(user_info["teams"]) == 1 + ## Create key key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"]) key = key_gen["key"] @@ -546,6 +553,12 @@ async def test_team_delete(): ## Delete team await delete_team(session=session, i=0, team_id=team_data["team_id"]) + ## ASSERT USER MEMBERSHIP IS DELETED + user_info = await get_user_info( + session=session, get_user=normal_user, call_user="sk-1234" + ) + assert len(user_info["teams"]) == 0 + @pytest.mark.parametrize("dimension", ["user_id", "user_email"]) @pytest.mark.asyncio