From 177acd1c93737bc7ae9d2284e5afeb814b72be0c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sun, 24 Nov 2024 17:21:49 +0530 Subject: [PATCH 1/7] fix(key_management_endpoints.py): fix user-membership check when creating team key --- litellm/proxy/_new_secret_config.yaml | 7 +++ .../key_management_endpoints.py | 61 ++++++++++++++++--- .../test_key_management.py | 29 ++++++++- 3 files changed, 84 insertions(+), 13 deletions(-) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 7ff209094..fbb714211 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -12,3 +12,10 @@ model_list: vertex_ai_project: "adroit-crow-413218" vertex_ai_location: "us-east5" +litellm_settings: + key_generation_settings: + team_key_generation: + allowed_team_member_roles: ["admin"] + required_params: ["tags"] # require team admins to set tags for cost-tracking when generating a team key + personal_key_generation: # maps to 'Default Team' on UI + allowed_user_roles: ["proxy_admin"] \ No newline at end of file diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 511e5a940..f7a383183 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -29,6 +29,7 @@ from litellm.proxy.auth.auth_checks import ( _cache_key_object, _delete_cache_key_object, get_key_object, + get_team_object, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks @@ -46,7 +47,19 @@ def _is_team_key(data: GenerateKeyRequest): return data.team_id is not None +def _get_user_in_team( + team_table: LiteLLM_TeamTableCachedObj, user_id: Optional[str] +) -> Optional[Member]: + if user_id is None: + return None + for member in team_table.members_with_roles: + if member.user_id is not None and member.user_id == user_id: + return member + return None + + def _team_key_generation_team_member_check( + team_table: LiteLLM_TeamTableCachedObj, user_api_key_dict: UserAPIKeyAuth, team_key_generation: Optional[TeamUIKeyGenerationConfig], ): @@ -56,17 +69,19 @@ def _team_key_generation_team_member_check( ): return True - if user_api_key_dict.team_member is None: + user_in_team = _get_user_in_team( + team_table=team_table, user_id=user_api_key_dict.user_id + ) + if user_in_team is None: raise HTTPException( status_code=400, - detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}", + detail=f"User={user_api_key_dict.user_id} not assigned to team={team_table.team_id}", ) - team_member_role = user_api_key_dict.team_member.role - if team_member_role not in team_key_generation["allowed_team_member_roles"]: + if user_in_team.role not in team_key_generation["allowed_team_member_roles"]: raise HTTPException( status_code=400, - detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore + detail=f"Team member role {user_in_team.role} not in allowed_team_member_roles={team_key_generation['allowed_team_member_roles']}", ) return True @@ -88,7 +103,9 @@ def _key_generation_required_param_check( def _team_key_generation_check( - user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest + team_table: LiteLLM_TeamTableCachedObj, + user_api_key_dict: UserAPIKeyAuth, + data: GenerateKeyRequest, ): if ( litellm.key_generation_settings is None @@ -99,7 +116,8 @@ def _team_key_generation_check( _team_key_generation = litellm.key_generation_settings["team_key_generation"] # type: ignore _team_key_generation_team_member_check( - user_api_key_dict, + team_table=team_table, + user_api_key_dict=user_api_key_dict, team_key_generation=_team_key_generation, ) _key_generation_required_param_check( @@ -155,7 +173,9 @@ def _personal_key_generation_check( def key_generation_check( - user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest + team_table: Optional[LiteLLM_TeamTableCachedObj], + user_api_key_dict: UserAPIKeyAuth, + data: GenerateKeyRequest, ) -> bool: """ Check if admin has restricted key creation to certain roles for teams or individuals @@ -170,8 +190,15 @@ def key_generation_check( is_team_key = _is_team_key(data=data) if is_team_key: + if team_table is None: + raise HTTPException( + status_code=400, + detail=f"Unable to find team object in database. Team ID: {data.team_id}", + ) return _team_key_generation_check( - user_api_key_dict=user_api_key_dict, data=data + team_table=team_table, + user_api_key_dict=user_api_key_dict, + data=data, ) else: return _personal_key_generation_check( @@ -254,6 +281,7 @@ async def generate_key_fn( # noqa: PLR0915 litellm_proxy_admin_name, prisma_client, proxy_logging_obj, + user_api_key_cache, user_custom_key_generate, ) @@ -271,7 +299,20 @@ async def generate_key_fn( # noqa: PLR0915 status_code=status.HTTP_403_FORBIDDEN, detail=message ) elif litellm.key_generation_settings is not None: - key_generation_check(user_api_key_dict=user_api_key_dict, data=data) + if data.team_id is None: + team_table: Optional[LiteLLM_TeamTableCachedObj] = None + else: + team_table = await get_team_object( + team_id=data.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=user_api_key_dict.parent_otel_span, + ) + key_generation_check( + team_table=team_table, + user_api_key_dict=user_api_key_dict, + data=data, + ) # check if user set default key/generate params on config.yaml if litellm.default_key_generate_params is not None: for elem in data: diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index 0b392a268..d0b1ab294 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -556,13 +556,22 @@ def test_team_key_generation_team_member_check(): _team_key_generation_check, ) from fastapi import HTTPException + from litellm.proxy._types import LiteLLM_TeamTableCachedObj litellm.key_generation_settings = { "team_key_generation": {"allowed_team_member_roles": ["admin"]} } + team_table = LiteLLM_TeamTableCachedObj( + team_id="test_team_id", + team_alias="test_team_alias", + members_with_roles=[Member(role="admin", user_id="test_user_id")], + ) + assert _team_key_generation_check( + team_table=team_table, user_api_key_dict=UserAPIKeyAuth( + user_id="test_user_id", user_role=LitellmUserRoles.INTERNAL_USER, api_key="sk-1234", team_member=Member(role="admin", user_id="test_user_id"), @@ -570,8 +579,15 @@ def test_team_key_generation_team_member_check(): data=GenerateKeyRequest(), ) + team_table = LiteLLM_TeamTableCachedObj( + team_id="test_team_id", + team_alias="test_team_alias", + members_with_roles=[Member(role="user", user_id="test_user_id")], + ) + with pytest.raises(HTTPException): _team_key_generation_check( + team_table=team_table, user_api_key_dict=UserAPIKeyAuth( user_role=LitellmUserRoles.INTERNAL_USER, api_key="sk-1234", @@ -607,6 +623,7 @@ def test_key_generation_required_params_check( StandardKeyGenerationConfig, PersonalUIKeyGenerationConfig, ) + from litellm.proxy._types import LiteLLM_TeamTableCachedObj from fastapi import HTTPException user_api_key_dict = UserAPIKeyAuth( @@ -614,7 +631,13 @@ def test_key_generation_required_params_check( api_key="sk-1234", user_id="test_user_id", team_id="test_team_id", - team_member=Member(role="admin", user_id="test_user_id"), + team_member=None, + ) + + team_table = LiteLLM_TeamTableCachedObj( + team_id="test_team_id", + team_alias="test_team_alias", + members_with_roles=[Member(role="admin", user_id="test_user_id")], ) if key_type == "team_key": @@ -632,13 +655,13 @@ def test_key_generation_required_params_check( if expected_result: if key_type == "team_key": - assert _team_key_generation_check(user_api_key_dict, input_data) + assert _team_key_generation_check(team_table, user_api_key_dict, input_data) elif key_type == "personal_key": assert _personal_key_generation_check(user_api_key_dict, input_data) else: if key_type == "team_key": with pytest.raises(HTTPException): - _team_key_generation_check(user_api_key_dict, input_data) + _team_key_generation_check(team_table, user_api_key_dict, input_data) elif key_type == "personal_key": with pytest.raises(HTTPException): _personal_key_generation_check(user_api_key_dict, input_data) From b55c829561a2ba91b2270263e5e7718d6f430bda Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 25 Nov 2024 19:40:54 +0530 Subject: [PATCH 2/7] docs: add deprecation notice on original `/v1/messages` endpoint + add better swagger tags on pass-through endpoints --- litellm/proxy/_new_secret_config.yaml | 14 +++--- .../llm_passthrough_endpoints.py | 45 ++++++++++++++++--- litellm/proxy/proxy_server.py | 6 +-- .../vertex_ai_endpoints/langfuse_endpoints.py | 11 ++++- .../vertex_ai_endpoints/vertex_endpoints.py | 9 +++- 5 files changed, 68 insertions(+), 17 deletions(-) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index fbb714211..a968e80d4 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -12,10 +12,10 @@ model_list: vertex_ai_project: "adroit-crow-413218" vertex_ai_location: "us-east5" -litellm_settings: - key_generation_settings: - team_key_generation: - allowed_team_member_roles: ["admin"] - required_params: ["tags"] # require team admins to set tags for cost-tracking when generating a team key - personal_key_generation: # maps to 'Default Team' on UI - allowed_user_roles: ["proxy_admin"] \ No newline at end of file +# litellm_settings: +# key_generation_settings: +# team_key_generation: +# allowed_team_member_roles: ["admin"] +# required_params: ["tags"] # require team admins to set tags for cost-tracking when generating a team key +# personal_key_generation: # maps to 'Default Team' on UI +# allowed_user_roles: ["proxy_admin"] \ No newline at end of file diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index 3f4643afc..6a96638a7 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -54,12 +54,19 @@ def create_request_copy(request: Request): } -@router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) +@router.api_route( + "/gemini/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE"], + tags=["Google AI Studio Pass-through", "pass-through"], +) async def gemini_proxy_route( endpoint: str, request: Request, fastapi_response: Response, ): + """ + [Docs](https://docs.litellm.ai/docs/pass_through/google_ai_studio) + """ ## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY api_key = request.query_params.get("key") @@ -111,13 +118,20 @@ async def gemini_proxy_route( return received_value -@router.api_route("/cohere/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) +@router.api_route( + "/cohere/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE"], + tags=["Cohere Pass-through", "pass-through"], +) async def cohere_proxy_route( endpoint: str, request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): + """ + [Docs](https://docs.litellm.ai/docs/pass_through/cohere) + """ base_target_url = "https://api.cohere.com" encoded_endpoint = httpx.URL(endpoint).path @@ -154,7 +168,9 @@ async def cohere_proxy_route( @router.api_route( - "/anthropic/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"] + "/anthropic/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE"], + tags=["Anthropic Pass-through", "pass-through"], ) async def anthropic_proxy_route( endpoint: str, @@ -162,6 +178,9 @@ async def anthropic_proxy_route( fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): + """ + [Docs](https://docs.litellm.ai/docs/anthropic_completion) + """ base_target_url = "https://api.anthropic.com" encoded_endpoint = httpx.URL(endpoint).path @@ -201,13 +220,20 @@ async def anthropic_proxy_route( return received_value -@router.api_route("/bedrock/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) +@router.api_route( + "/bedrock/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE"], + tags=["Bedrock Pass-through", "pass-through"], +) async def bedrock_proxy_route( endpoint: str, request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): + """ + [Docs](https://docs.litellm.ai/docs/pass_through/bedrock) + """ create_request_copy(request) try: @@ -275,13 +301,22 @@ async def bedrock_proxy_route( return received_value -@router.api_route("/azure/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) +@router.api_route( + "/azure/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE"], + tags=["Azure Pass-through", "pass-through"], +) async def azure_proxy_route( endpoint: str, request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): + """ + Call any azure endpoint using the proxy. + + Just use `{PROXY_BASE_URL}/azure/{endpoint:path}` + """ base_target_url = get_secret_str(secret_name="AZURE_API_BASE") if base_target_url is None: raise Exception( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 70bf5b523..a982ba39f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -5663,11 +5663,11 @@ async def anthropic_response( # noqa: PLR0915 user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ - This is a BETA endpoint that calls 100+ LLMs in the anthropic format. + 🚨 DEPRECATED ENDPOINT🚨 - To do a simple pass-through for anthropic, do `{PROXY_BASE_URL}/anthropic/v1/messages` + Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/anthropic_completion). - Docs - https://docs.litellm.ai/docs/anthropic_completion + This was a BETA endpoint that calls 100+ LLMs in the anthropic format. """ from litellm import adapter_completion from litellm.adapters.anthropic_adapter import anthropic_adapter diff --git a/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py b/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py index ba8653d82..6ce9d5dd8 100644 --- a/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py @@ -58,12 +58,21 @@ def create_request_copy(request: Request): } -@router.api_route("/langfuse/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) +@router.api_route( + "/langfuse/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE"], + tags=["Langfuse Pass-through", "pass-through"], +) async def langfuse_proxy_route( endpoint: str, request: Request, fastapi_response: Response, ): + """ + Call Langfuse via LiteLLM proxy. Works with Langfuse SDK. + + [Docs](https://docs.litellm.ai/docs/pass_through/langfuse) + """ ## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY api_key = request.headers.get("Authorization") or "" diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index fbf37ce8d..470744e19 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -113,13 +113,20 @@ def construct_target_url( @router.api_route( - "/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"] + "/vertex-ai/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE"], + tags=["Vertex AI Pass-through", "pass-through"], ) async def vertex_proxy_route( endpoint: str, request: Request, fastapi_response: Response, ): + """ + Call LiteLLM proxy via Vertex AI SDK. + + [Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai) + """ encoded_endpoint = httpx.URL(endpoint).path import re From 05b5a21014048b8f71a428d2722b827fcdf17ba7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 25 Nov 2024 21:15:19 +0530 Subject: [PATCH 3/7] fix(gemini/): fix image_url handling for gemini Fixes https://github.com/BerriAI/litellm/issues/6897 --- litellm/llms/prompt_templates/factory.py | 23 ++++++++++++ .../gemini/transformation.py | 7 +++- .../vertex_and_google_ai_studio_gemini.py | 37 +++++++++++++++++++ tests/llm_translation/base_llm_unit_tests.py | 29 +++++++++++++++ tests/llm_translation/test_gemini.py | 15 ++++++++ tests/llm_translation/test_prompt_factory.py | 13 +++++++ tests/llm_translation/test_vertex.py | 17 --------- 7 files changed, 123 insertions(+), 18 deletions(-) create mode 100644 tests/llm_translation/test_gemini.py diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 45b7a6c5b..cb79a81b7 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -33,6 +33,7 @@ from litellm.types.llms.openai import ( ChatCompletionAssistantToolCall, ChatCompletionFunctionMessage, ChatCompletionImageObject, + ChatCompletionImageUrlObject, ChatCompletionTextObject, ChatCompletionToolCallFunctionChunk, ChatCompletionToolMessage, @@ -681,6 +682,27 @@ def construct_tool_use_system_prompt( return tool_use_system_prompt +def convert_generic_image_chunk_to_openai_image_obj( + image_chunk: GenericImageParsingChunk, +) -> str: + """ + Convert a generic image chunk to an OpenAI image object. + + Input: + GenericImageParsingChunk( + type="base64", + media_type="image/jpeg", + data="...", + ) + + Return: + "data:image/jpeg;base64,{base64_image}" + """ + return "data:{};{},{}".format( + image_chunk["media_type"], image_chunk["type"], image_chunk["data"] + ) + + def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk: """ Input: @@ -706,6 +728,7 @@ def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsing data=base64_data, ) except Exception as e: + traceback.print_exc() if "Error: Unable to fetch image from URL" in str(e): raise e raise Exception( diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py index f828d93c8..4b5b7281b 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py @@ -294,7 +294,12 @@ def _transform_request_body( optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys} try: - content = _gemini_convert_messages_with_history(messages=messages) + if custom_llm_provider == "gemini": + content = litellm.GoogleAIStudioGeminiConfig._transform_messages( + messages=messages + ) + else: + content = litellm.VertexGeminiConfig._transform_messages(messages=messages) tools: Optional[Tools] = optional_params.pop("tools", None) tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index f2fc599ed..4287ed1bc 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -35,7 +35,12 @@ from litellm.llms.custom_httpx.http_handler import ( HTTPHandler, get_async_httpx_client, ) +from litellm.llms.prompt_templates.factory import ( + convert_generic_image_chunk_to_openai_image_obj, + convert_to_anthropic_image_obj, +) from litellm.types.llms.openai import ( + AllMessageValues, ChatCompletionResponseMessage, ChatCompletionToolCallChunk, ChatCompletionToolCallFunctionChunk, @@ -78,6 +83,8 @@ from ..common_utils import ( ) from ..vertex_llm_base import VertexBase from .transformation import ( + _gemini_convert_messages_with_history, + _process_gemini_image, async_transform_request_body, set_headers, sync_transform_request_body, @@ -912,6 +919,10 @@ class VertexGeminiConfig: return model_response + @staticmethod + def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]: + return _gemini_convert_messages_with_history(messages=messages) + class GoogleAIStudioGeminiConfig( VertexGeminiConfig @@ -1015,6 +1026,32 @@ class GoogleAIStudioGeminiConfig( model, non_default_params, optional_params, drop_params ) + @staticmethod + def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]: + """ + Google AI Studio Gemini does not support image urls in messages. + """ + for message in messages: + _message_content = message.get("content") + if _message_content is not None and isinstance(_message_content, list): + _parts: List[PartType] = [] + for element in _message_content: + if element.get("type") == "image_url": + img_element = element + _image_url: Optional[str] = None + if isinstance(img_element.get("image_url"), dict): + _image_url = img_element["image_url"].get("url") # type: ignore + else: + _image_url = img_element.get("image_url") # type: ignore + if _image_url and "https://" in _image_url: + image_obj = convert_to_anthropic_image_obj(_image_url) + img_element["image_url"] = ( # type: ignore + convert_generic_image_chunk_to_openai_image_obj( + image_obj + ) + ) + return _gemini_convert_messages_with_history(messages=messages) + async def make_call( client: Optional[AsyncHTTPHandler], diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index 88fce6dac..24a972e20 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -190,6 +190,35 @@ class BaseLLMChatTest(ABC): """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" pass + def test_image_url(self): + litellm.set_verbose = True + from litellm.utils import supports_vision + + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + + base_completion_call_args = self.get_base_completion_call_args() + if not supports_vision(base_completion_call_args["model"], None): + pytest.skip("Model does not support image input") + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg" + }, + }, + ], + } + ] + + response = litellm.completion(**base_completion_call_args, messages=messages) + assert response is not None + @pytest.fixture def pdf_messages(self): import base64 diff --git a/tests/llm_translation/test_gemini.py b/tests/llm_translation/test_gemini.py new file mode 100644 index 000000000..4e6c5118d --- /dev/null +++ b/tests/llm_translation/test_gemini.py @@ -0,0 +1,15 @@ +from base_llm_unit_tests import BaseLLMChatTest + + +class TestGoogleAIStudioGemini(BaseLLMChatTest): + def get_base_completion_call_args(self) -> dict: + return {"model": "gemini/gemini-1.5-flash"} + + def test_tool_call_no_arguments(self, tool_call_no_arguments): + """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" + from litellm.llms.prompt_templates.factory import ( + convert_to_gemini_tool_call_invoke, + ) + + result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments) + print(result) diff --git a/tests/llm_translation/test_prompt_factory.py b/tests/llm_translation/test_prompt_factory.py index 104997563..d8cf191f6 100644 --- a/tests/llm_translation/test_prompt_factory.py +++ b/tests/llm_translation/test_prompt_factory.py @@ -687,3 +687,16 @@ def test_just_system_message(): llm_provider="bedrock", ) assert "bedrock requires at least one non-system message" in str(e.value) + + +def test_convert_generic_image_chunk_to_openai_image_obj(): + from litellm.llms.prompt_templates.factory import ( + convert_generic_image_chunk_to_openai_image_obj, + convert_to_anthropic_image_obj, + ) + + url = "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg" + image_obj = convert_to_anthropic_image_obj(url) + url_str = convert_generic_image_chunk_to_openai_image_obj(image_obj) + image_obj = convert_to_anthropic_image_obj(url_str) + print(image_obj) diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py index 3e1087536..c2c1fdecf 100644 --- a/tests/llm_translation/test_vertex.py +++ b/tests/llm_translation/test_vertex.py @@ -1298,20 +1298,3 @@ def test_vertex_embedding_url(model, expected_url): assert url == expected_url assert endpoint == "predict" - - -from base_llm_unit_tests import BaseLLMChatTest - - -class TestVertexGemini(BaseLLMChatTest): - def get_base_completion_call_args(self) -> dict: - return {"model": "gemini/gemini-1.5-flash"} - - def test_tool_call_no_arguments(self, tool_call_no_arguments): - """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" - from litellm.llms.prompt_templates.factory import ( - convert_to_gemini_tool_call_invoke, - ) - - result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments) - print(result) From 1413fdfc06f94e93350bed21b7e60d9108fceb2c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Nov 2024 00:05:09 +0530 Subject: [PATCH 4/7] fix(teams.tsx): fix member add when role is 'user' --- litellm/proxy/_new_secret_config.yaml | 15 ++++++++++++++- litellm/proxy/_types.py | 1 - ui/litellm-dashboard/src/components/teams.tsx | 7 ++++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index a968e80d4..86ece3788 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -12,7 +12,20 @@ model_list: vertex_ai_project: "adroit-crow-413218" vertex_ai_location: "us-east5" -# litellm_settings: +router_settings: + routing_strategy: usage-based-routing-v2 + #redis_url: "os.environ/REDIS_URL" + redis_host: "os.environ/REDIS_HOST" + redis_port: "os.environ/REDIS_PORT" + +litellm_settings: + cache: true + cache_params: + type: redis + host: "os.environ/REDIS_HOST" + port: "os.environ/REDIS_PORT" + namespace: "litellm.caching" + ttl: 600 # key_generation_settings: # team_key_generation: # allowed_team_member_roles: ["admin"] diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 74e82b0ea..09cbe7cc9 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1982,7 +1982,6 @@ class MemberAddRequest(LiteLLMBase): # Replace member_data with the single Member object data["member"] = member # Call the superclass __init__ method to initialize the object - traceback.print_stack() super().__init__(**data) diff --git a/ui/litellm-dashboard/src/components/teams.tsx b/ui/litellm-dashboard/src/components/teams.tsx index 0364245be..db83fd532 100644 --- a/ui/litellm-dashboard/src/components/teams.tsx +++ b/ui/litellm-dashboard/src/components/teams.tsx @@ -413,6 +413,7 @@ const Team: React.FC = ({ selectedTeam["team_id"], user_role ); + message.success("Member added"); console.log(`response for team create call: ${response["data"]}`); // Checking if the team exists in the list and updating or adding accordingly const foundIndex = teams.findIndex((team) => { @@ -430,6 +431,7 @@ const Team: React.FC = ({ setSelectedTeam(response.data); } setIsAddMemberModalVisible(false); + } } catch (error) { console.error("Error creating the team:", error); @@ -825,6 +827,9 @@ const Team: React.FC = ({ labelCol={{ span: 8 }} wrapperCol={{ span: 16 }} labelAlign="left" + initialValues={{ + role: "user", + }} > <> @@ -842,8 +847,8 @@ const Team: React.FC = ({ - user admin + user From c77af015f8d1164685b4d83d8a3c5d5c021ad9ff Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Nov 2024 10:28:40 +0530 Subject: [PATCH 5/7] fix(team_endpoints.py): /team/member_add fix adding several new members to team --- litellm/proxy/auth/auth_checks.py | 153 ++++++++++++------ .../management_endpoints/team_endpoints.py | 1 + 2 files changed, 104 insertions(+), 50 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 7d29032c6..5d789436a 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -523,6 +523,10 @@ async def _cache_management_object( proxy_logging_obj: Optional[ProxyLogging], ): await user_api_key_cache.async_set_cache(key=key, value=value) + if proxy_logging_obj is not None: + await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache( + key=key, value=value + ) async def _cache_team_object( @@ -586,26 +590,63 @@ async def _get_team_db_check(team_id: str, prisma_client: PrismaClient): ) -async def get_team_object( - team_id: str, - prisma_client: Optional[PrismaClient], - user_api_key_cache: DualCache, - parent_otel_span: Optional[Span] = None, - proxy_logging_obj: Optional[ProxyLogging] = None, - check_cache_only: Optional[bool] = None, -) -> LiteLLM_TeamTableCachedObj: - """ - - Check if team id in proxy Team Table - - if valid, return LiteLLM_TeamTable object with defined limits - - if not, then raise an error - """ - if prisma_client is None: - raise Exception( - "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" - ) +async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient): + return await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} + ) - # check if in cache - key = "team_id:{}".format(team_id) + +async def _get_team_object_from_user_api_key_cache( + team_id: str, + prisma_client: PrismaClient, + user_api_key_cache: DualCache, + last_db_access_time: LimitedSizeOrderedDict, + db_cache_expiry: int, + proxy_logging_obj: Optional[ProxyLogging], + key: str, +) -> LiteLLM_TeamTableCachedObj: + db_access_time_key = key + should_check_db = _should_check_db( + key=db_access_time_key, + last_db_access_time=last_db_access_time, + db_cache_expiry=db_cache_expiry, + ) + if should_check_db: + response = await _get_team_db_check( + team_id=team_id, prisma_client=prisma_client + ) + else: + response = None + + if response is None: + raise Exception + + _response = LiteLLM_TeamTableCachedObj(**response.dict()) + # save the team object to cache + await _cache_team_object( + team_id=team_id, + team_table=_response, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + # save to db access time + # save to db access time + _update_last_db_access_time( + key=db_access_time_key, + value=_response, + last_db_access_time=last_db_access_time, + ) + + return _response + + +async def _get_team_object_from_cache( + key: str, + proxy_logging_obj: Optional[ProxyLogging], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span], +) -> Optional[LiteLLM_TeamTableCachedObj]: cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None ## CHECK REDIS CACHE ## @@ -613,6 +654,7 @@ async def get_team_object( proxy_logging_obj is not None and proxy_logging_obj.internal_usage_cache.dual_cache ): + cached_team_obj = ( await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache( key=key, parent_otel_span=parent_otel_span @@ -628,47 +670,58 @@ async def get_team_object( elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj): return cached_team_obj - if check_cache_only: + return None + + +async def get_team_object( + team_id: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, + check_cache_only: Optional[bool] = None, + check_db_only: Optional[bool] = None, +) -> LiteLLM_TeamTableCachedObj: + """ + - Check if team id in proxy Team Table + - if valid, return LiteLLM_TeamTable object with defined limits + - if not, then raise an error + """ + if prisma_client is None: raise Exception( - f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}." + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" ) + # check if in cache + key = "team_id:{}".format(team_id) + + if not check_db_only: + cached_team_obj = await _get_team_object_from_cache( + key=key, + proxy_logging_obj=proxy_logging_obj, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + ) + + if cached_team_obj is not None: + return cached_team_obj + + if check_cache_only: + raise Exception( + f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}." + ) + # else, check db try: - db_access_time_key = "team_id:{}".format(team_id) - should_check_db = _should_check_db( - key=db_access_time_key, - last_db_access_time=last_db_access_time, - db_cache_expiry=db_cache_expiry, - ) - if should_check_db: - response = await _get_team_db_check( - team_id=team_id, prisma_client=prisma_client - ) - else: - response = None - - if response is None: - raise Exception - - _response = LiteLLM_TeamTableCachedObj(**response.dict()) - # save the team object to cache - await _cache_team_object( + return await _get_team_object_from_user_api_key_cache( team_id=team_id, - team_table=_response, + prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, - ) - - # save to db access time - # save to db access time - _update_last_db_access_time( - key=db_access_time_key, - value=_response, last_db_access_time=last_db_access_time, + db_cache_expiry=db_cache_expiry, + key=key, ) - - return _response except Exception: raise Exception( f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call." diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index dc1ec444d..9f749cee1 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -547,6 +547,7 @@ async def team_member_add( parent_otel_span=None, proxy_logging_obj=proxy_logging_obj, check_cache_only=False, + check_db_only=True, ) if existing_team_row is None: raise HTTPException( From 1aab1c9b04906a6ea6fbf76e869a3582cd2bb227 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Nov 2024 10:33:52 +0530 Subject: [PATCH 6/7] test(test_vertex.py): remove redundant test --- tests/llm_translation/test_vertex.py | 74 ---------------------------- 1 file changed, 74 deletions(-) diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py index c2c1fdecf..1ea2514c9 100644 --- a/tests/llm_translation/test_vertex.py +++ b/tests/llm_translation/test_vertex.py @@ -1190,80 +1190,6 @@ def test_get_image_mime_type_from_url(): assert _get_image_mime_type_from_url("invalid_url") is None -@pytest.mark.parametrize( - "image_url", ["https://example.com/image.jpg", "https://example.com/image.png"] -) -def test_image_completion_request(image_url): - """https:// .jpg, .png images are passed directly to the model""" - from unittest.mock import patch, Mock - import litellm - from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import ( - _get_image_mime_type_from_url, - ) - - # Mock response data - mock_response = Mock() - mock_response.json.return_value = { - "candidates": [{"content": {"parts": [{"text": "This is a sunflower"}]}}], - "usageMetadata": { - "promptTokenCount": 11, - "candidatesTokenCount": 50, - "totalTokenCount": 61, - }, - "modelVersion": "gemini-1.5-pro", - } - mock_response.raise_for_status = MagicMock() - mock_response.status_code = 200 - - # Expected request body - expected_request_body = { - "contents": [ - { - "role": "user", - "parts": [ - {"text": "Whats in this image?"}, - { - "file_data": { - "file_uri": image_url, - "mime_type": _get_image_mime_type_from_url(image_url), - } - }, - ], - } - ], - "system_instruction": {"parts": [{"text": "Be a good bot"}]}, - "generationConfig": {}, - } - - messages = [ - {"role": "system", "content": "Be a good bot"}, - { - "role": "user", - "content": [ - {"type": "text", "text": "Whats in this image?"}, - {"type": "image_url", "image_url": {"url": image_url}}, - ], - }, - ] - - client = HTTPHandler() - with patch.object(client, "post", new=MagicMock()) as mock_post: - mock_post.return_value = mock_response - try: - litellm.completion( - model="gemini/gemini-1.5-pro", - messages=messages, - client=client, - ) - except Exception as e: - print(e) - - # Assert the request body matches expected - mock_post.assert_called_once() - print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"]) - assert mock_post.call_args.kwargs["json"] == expected_request_body - - @pytest.mark.parametrize( "model, expected_url", [ From 6d8c17967f1121994329c100742a6648891d1eb3 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Nov 2024 12:21:55 +0530 Subject: [PATCH 7/7] test(test_proxy_server.py): fix team member add tests --- tests/proxy_unit_tests/test_proxy_server.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/proxy_unit_tests/test_proxy_server.py b/tests/proxy_unit_tests/test_proxy_server.py index d70962858..64bb67b58 100644 --- a/tests/proxy_unit_tests/test_proxy_server.py +++ b/tests/proxy_unit_tests/test_proxy_server.py @@ -1014,7 +1014,11 @@ async def test_create_team_member_add(prisma_client, new_member_method): with patch( "litellm.proxy.proxy_server.prisma_client.db.litellm_usertable", new_callable=AsyncMock, - ) as mock_litellm_usertable: + ) as mock_litellm_usertable, patch( + "litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache", + new=AsyncMock(return_value=team_obj), + ) as mock_team_obj: + mock_client = AsyncMock( return_value=LiteLLM_UserTable( user_id="1234", max_budget=100, user_email="1234" @@ -1193,7 +1197,10 @@ async def test_create_team_member_add_team_admin( with patch( "litellm.proxy.proxy_server.prisma_client.db.litellm_usertable", new_callable=AsyncMock, - ) as mock_litellm_usertable: + ) as mock_litellm_usertable, patch( + "litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache", + new=AsyncMock(return_value=team_obj), + ) as mock_team_obj: mock_client = AsyncMock( return_value=LiteLLM_UserTable( user_id="1234", max_budget=100, user_email="1234"