Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
Krish Dholakia
0f08577060 fix(key_management_endpoints.py): fix user-membership check when creating team key (#6890)
* fix(key_management_endpoints.py): fix user-membership check when creating team key

* docs: add deprecation notice on original `/v1/messages` endpoint + add better swagger tags on pass-through endpoints

* fix(gemini/): fix image_url handling for gemini

Fixes https://github.com/BerriAI/litellm/issues/6897

* fix(teams.tsx): fix member add when role is 'user'

* fix(team_endpoints.py): /team/member_add

fix adding several new members to team

* test(test_vertex.py): remove redundant test

* test(test_proxy_server.py): fix team member add tests
2024-11-27 17:32:52 -08:00
19 changed files with 406 additions and 169 deletions

View file

@ -33,6 +33,7 @@ from litellm.types.llms.openai import (
ChatCompletionAssistantToolCall,
ChatCompletionFunctionMessage,
ChatCompletionImageObject,
ChatCompletionImageUrlObject,
ChatCompletionTextObject,
ChatCompletionToolCallFunctionChunk,
ChatCompletionToolMessage,
@ -681,6 +682,27 @@ def construct_tool_use_system_prompt(
return tool_use_system_prompt
def convert_generic_image_chunk_to_openai_image_obj(
image_chunk: GenericImageParsingChunk,
) -> str:
"""
Convert a generic image chunk to an OpenAI image object.
Input:
GenericImageParsingChunk(
type="base64",
media_type="image/jpeg",
data="...",
)
Return:
"data:image/jpeg;base64,{base64_image}"
"""
return "data:{};{},{}".format(
image_chunk["media_type"], image_chunk["type"], image_chunk["data"]
)
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
"""
Input:
@ -706,6 +728,7 @@ def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsing
data=base64_data,
)
except Exception as e:
traceback.print_exc()
if "Error: Unable to fetch image from URL" in str(e):
raise e
raise Exception(

View file

@ -294,7 +294,12 @@ def _transform_request_body(
optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
try:
content = _gemini_convert_messages_with_history(messages=messages)
if custom_llm_provider == "gemini":
content = litellm.GoogleAIStudioGeminiConfig._transform_messages(
messages=messages
)
else:
content = litellm.VertexGeminiConfig._transform_messages(messages=messages)
tools: Optional[Tools] = optional_params.pop("tools", None)
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(

View file

@ -35,7 +35,12 @@ from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
)
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
@ -78,6 +83,8 @@ from ..common_utils import (
)
from ..vertex_llm_base import VertexBase
from .transformation import (
_gemini_convert_messages_with_history,
_process_gemini_image,
async_transform_request_body,
set_headers,
sync_transform_request_body,
@ -912,6 +919,10 @@ class VertexGeminiConfig:
return model_response
@staticmethod
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
return _gemini_convert_messages_with_history(messages=messages)
class GoogleAIStudioGeminiConfig(
VertexGeminiConfig
@ -1015,6 +1026,32 @@ class GoogleAIStudioGeminiConfig(
model, non_default_params, optional_params, drop_params
)
@staticmethod
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
"""
Google AI Studio Gemini does not support image urls in messages.
"""
for message in messages:
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element in _message_content:
if element.get("type") == "image_url":
img_element = element
_image_url: Optional[str] = None
if isinstance(img_element.get("image_url"), dict):
_image_url = img_element["image_url"].get("url") # type: ignore
else:
_image_url = img_element.get("image_url") # type: ignore
if _image_url and "https://" in _image_url:
image_obj = convert_to_anthropic_image_obj(_image_url)
img_element["image_url"] = ( # type: ignore
convert_generic_image_chunk_to_openai_image_obj(
image_obj
)
)
return _gemini_convert_messages_with_history(messages=messages)
async def make_call(
client: Optional[AsyncHTTPHandler],

View file

@ -12,3 +12,23 @@ model_list:
vertex_ai_project: "adroit-crow-413218"
vertex_ai_location: "us-east5"
router_settings:
routing_strategy: usage-based-routing-v2
#redis_url: "os.environ/REDIS_URL"
redis_host: "os.environ/REDIS_HOST"
redis_port: "os.environ/REDIS_PORT"
litellm_settings:
cache: true
cache_params:
type: redis
host: "os.environ/REDIS_HOST"
port: "os.environ/REDIS_PORT"
namespace: "litellm.caching"
ttl: 600
# key_generation_settings:
# team_key_generation:
# allowed_team_member_roles: ["admin"]
# required_params: ["tags"] # require team admins to set tags for cost-tracking when generating a team key
# personal_key_generation: # maps to 'Default Team' on UI
# allowed_user_roles: ["proxy_admin"]

View file

@ -1982,7 +1982,6 @@ class MemberAddRequest(LiteLLMBase):
# Replace member_data with the single Member object
data["member"] = member
# Call the superclass __init__ method to initialize the object
traceback.print_stack()
super().__init__(**data)

View file

@ -523,6 +523,10 @@ async def _cache_management_object(
proxy_logging_obj: Optional[ProxyLogging],
):
await user_api_key_cache.async_set_cache(key=key, value=value)
if proxy_logging_obj is not None:
await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache(
key=key, value=value
)
async def _cache_team_object(
@ -586,26 +590,63 @@ async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
)
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
check_cache_only: Optional[bool] = None,
) -> LiteLLM_TeamTableCachedObj:
"""
- Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
return await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
# check if in cache
key = "team_id:{}".format(team_id)
async def _get_team_object_from_user_api_key_cache(
team_id: str,
prisma_client: PrismaClient,
user_api_key_cache: DualCache,
last_db_access_time: LimitedSizeOrderedDict,
db_cache_expiry: int,
proxy_logging_obj: Optional[ProxyLogging],
key: str,
) -> LiteLLM_TeamTableCachedObj:
db_access_time_key = key
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await _get_team_db_check(
team_id=team_id, prisma_client=prisma_client
)
else:
response = None
if response is None:
raise Exception
_response = LiteLLM_TeamTableCachedObj(**response.dict())
# save the team object to cache
await _cache_team_object(
team_id=team_id,
team_table=_response,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# save to db access time
# save to db access time
_update_last_db_access_time(
key=db_access_time_key,
value=_response,
last_db_access_time=last_db_access_time,
)
return _response
async def _get_team_object_from_cache(
key: str,
proxy_logging_obj: Optional[ProxyLogging],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span],
) -> Optional[LiteLLM_TeamTableCachedObj]:
cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None
## CHECK REDIS CACHE ##
@ -613,6 +654,7 @@ async def get_team_object(
proxy_logging_obj is not None
and proxy_logging_obj.internal_usage_cache.dual_cache
):
cached_team_obj = (
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
key=key, parent_otel_span=parent_otel_span
@ -628,47 +670,58 @@ async def get_team_object(
elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj):
return cached_team_obj
if check_cache_only:
return None
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
check_cache_only: Optional[bool] = None,
check_db_only: Optional[bool] = None,
) -> LiteLLM_TeamTableCachedObj:
"""
- Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
key = "team_id:{}".format(team_id)
if not check_db_only:
cached_team_obj = await _get_team_object_from_cache(
key=key,
proxy_logging_obj=proxy_logging_obj,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
)
if cached_team_obj is not None:
return cached_team_obj
if check_cache_only:
raise Exception(
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
)
# else, check db
try:
db_access_time_key = "team_id:{}".format(team_id)
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await _get_team_db_check(
team_id=team_id, prisma_client=prisma_client
)
else:
response = None
if response is None:
raise Exception
_response = LiteLLM_TeamTableCachedObj(**response.dict())
# save the team object to cache
await _cache_team_object(
return await _get_team_object_from_user_api_key_cache(
team_id=team_id,
team_table=_response,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# save to db access time
# save to db access time
_update_last_db_access_time(
key=db_access_time_key,
value=_response,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
key=key,
)
return _response
except Exception:
raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."

View file

@ -29,6 +29,7 @@ from litellm.proxy.auth.auth_checks import (
_cache_key_object,
_delete_cache_key_object,
get_key_object,
get_team_object,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
@ -46,7 +47,19 @@ def _is_team_key(data: GenerateKeyRequest):
return data.team_id is not None
def _get_user_in_team(
team_table: LiteLLM_TeamTableCachedObj, user_id: Optional[str]
) -> Optional[Member]:
if user_id is None:
return None
for member in team_table.members_with_roles:
if member.user_id is not None and member.user_id == user_id:
return member
return None
def _team_key_generation_team_member_check(
team_table: LiteLLM_TeamTableCachedObj,
user_api_key_dict: UserAPIKeyAuth,
team_key_generation: Optional[TeamUIKeyGenerationConfig],
):
@ -56,17 +69,19 @@ def _team_key_generation_team_member_check(
):
return True
if user_api_key_dict.team_member is None:
user_in_team = _get_user_in_team(
team_table=team_table, user_id=user_api_key_dict.user_id
)
if user_in_team is None:
raise HTTPException(
status_code=400,
detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}",
detail=f"User={user_api_key_dict.user_id} not assigned to team={team_table.team_id}",
)
team_member_role = user_api_key_dict.team_member.role
if team_member_role not in team_key_generation["allowed_team_member_roles"]:
if user_in_team.role not in team_key_generation["allowed_team_member_roles"]:
raise HTTPException(
status_code=400,
detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore
detail=f"Team member role {user_in_team.role} not in allowed_team_member_roles={team_key_generation['allowed_team_member_roles']}",
)
return True
@ -88,7 +103,9 @@ def _key_generation_required_param_check(
def _team_key_generation_check(
user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest
team_table: LiteLLM_TeamTableCachedObj,
user_api_key_dict: UserAPIKeyAuth,
data: GenerateKeyRequest,
):
if (
litellm.key_generation_settings is None
@ -99,7 +116,8 @@ def _team_key_generation_check(
_team_key_generation = litellm.key_generation_settings["team_key_generation"] # type: ignore
_team_key_generation_team_member_check(
user_api_key_dict,
team_table=team_table,
user_api_key_dict=user_api_key_dict,
team_key_generation=_team_key_generation,
)
_key_generation_required_param_check(
@ -155,7 +173,9 @@ def _personal_key_generation_check(
def key_generation_check(
user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest
team_table: Optional[LiteLLM_TeamTableCachedObj],
user_api_key_dict: UserAPIKeyAuth,
data: GenerateKeyRequest,
) -> bool:
"""
Check if admin has restricted key creation to certain roles for teams or individuals
@ -170,8 +190,15 @@ def key_generation_check(
is_team_key = _is_team_key(data=data)
if is_team_key:
if team_table is None:
raise HTTPException(
status_code=400,
detail=f"Unable to find team object in database. Team ID: {data.team_id}",
)
return _team_key_generation_check(
user_api_key_dict=user_api_key_dict, data=data
team_table=team_table,
user_api_key_dict=user_api_key_dict,
data=data,
)
else:
return _personal_key_generation_check(
@ -254,6 +281,7 @@ async def generate_key_fn( # noqa: PLR0915
litellm_proxy_admin_name,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
user_custom_key_generate,
)
@ -271,7 +299,20 @@ async def generate_key_fn( # noqa: PLR0915
status_code=status.HTTP_403_FORBIDDEN, detail=message
)
elif litellm.key_generation_settings is not None:
key_generation_check(user_api_key_dict=user_api_key_dict, data=data)
if data.team_id is None:
team_table: Optional[LiteLLM_TeamTableCachedObj] = None
else:
team_table = await get_team_object(
team_id=data.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=user_api_key_dict.parent_otel_span,
)
key_generation_check(
team_table=team_table,
user_api_key_dict=user_api_key_dict,
data=data,
)
# check if user set default key/generate params on config.yaml
if litellm.default_key_generate_params is not None:
for elem in data:

View file

@ -547,6 +547,7 @@ async def team_member_add(
parent_otel_span=None,
proxy_logging_obj=proxy_logging_obj,
check_cache_only=False,
check_db_only=True,
)
if existing_team_row is None:
raise HTTPException(

View file

@ -54,12 +54,19 @@ def create_request_copy(request: Request):
}
@router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
@router.api_route(
"/gemini/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE"],
tags=["Google AI Studio Pass-through", "pass-through"],
)
async def gemini_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
):
"""
[Docs](https://docs.litellm.ai/docs/pass_through/google_ai_studio)
"""
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
api_key = request.query_params.get("key")
@ -111,13 +118,20 @@ async def gemini_proxy_route(
return received_value
@router.api_route("/cohere/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
@router.api_route(
"/cohere/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE"],
tags=["Cohere Pass-through", "pass-through"],
)
async def cohere_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://docs.litellm.ai/docs/pass_through/cohere)
"""
base_target_url = "https://api.cohere.com"
encoded_endpoint = httpx.URL(endpoint).path
@ -154,7 +168,9 @@ async def cohere_proxy_route(
@router.api_route(
"/anthropic/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
"/anthropic/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE"],
tags=["Anthropic Pass-through", "pass-through"],
)
async def anthropic_proxy_route(
endpoint: str,
@ -162,6 +178,9 @@ async def anthropic_proxy_route(
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
"""
base_target_url = "https://api.anthropic.com"
encoded_endpoint = httpx.URL(endpoint).path
@ -201,13 +220,20 @@ async def anthropic_proxy_route(
return received_value
@router.api_route("/bedrock/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
@router.api_route(
"/bedrock/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE"],
tags=["Bedrock Pass-through", "pass-through"],
)
async def bedrock_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://docs.litellm.ai/docs/pass_through/bedrock)
"""
create_request_copy(request)
try:
@ -275,13 +301,22 @@ async def bedrock_proxy_route(
return received_value
@router.api_route("/azure/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
@router.api_route(
"/azure/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE"],
tags=["Azure Pass-through", "pass-through"],
)
async def azure_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Call any azure endpoint using the proxy.
Just use `{PROXY_BASE_URL}/azure/{endpoint:path}`
"""
base_target_url = get_secret_str(secret_name="AZURE_API_BASE")
if base_target_url is None:
raise Exception(

View file

@ -5663,11 +5663,11 @@ async def anthropic_response( # noqa: PLR0915
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
This is a BETA endpoint that calls 100+ LLMs in the anthropic format.
🚨 DEPRECATED ENDPOINT🚨
To do a simple pass-through for anthropic, do `{PROXY_BASE_URL}/anthropic/v1/messages`
Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/anthropic_completion).
Docs - https://docs.litellm.ai/docs/anthropic_completion
This was a BETA endpoint that calls 100+ LLMs in the anthropic format.
"""
from litellm import adapter_completion
from litellm.adapters.anthropic_adapter import anthropic_adapter

View file

@ -58,12 +58,21 @@ def create_request_copy(request: Request):
}
@router.api_route("/langfuse/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
@router.api_route(
"/langfuse/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE"],
tags=["Langfuse Pass-through", "pass-through"],
)
async def langfuse_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
):
"""
Call Langfuse via LiteLLM proxy. Works with Langfuse SDK.
[Docs](https://docs.litellm.ai/docs/pass_through/langfuse)
"""
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
api_key = request.headers.get("Authorization") or ""

View file

@ -113,13 +113,26 @@ def construct_target_url(
@router.api_route(
"/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
"/vertex-ai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE"],
tags=["Vertex AI Pass-through", "pass-through"],
include_in_schema=False,
)
@router.api_route(
"/vertex_ai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE"],
tags=["Vertex AI Pass-through", "pass-through"],
)
async def vertex_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
):
"""
Call LiteLLM proxy via Vertex AI SDK.
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
"""
encoded_endpoint = httpx.URL(endpoint).path
import re

View file

@ -190,6 +190,35 @@ class BaseLLMChatTest(ABC):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass
def test_image_url(self):
litellm.set_verbose = True
from litellm.utils import supports_vision
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
base_completion_call_args = self.get_base_completion_call_args()
if not supports_vision(base_completion_call_args["model"], None):
pytest.skip("Model does not support image input")
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
},
},
],
}
]
response = litellm.completion(**base_completion_call_args, messages=messages)
assert response is not None
@pytest.fixture
def pdf_messages(self):
import base64

View file

@ -0,0 +1,15 @@
from base_llm_unit_tests import BaseLLMChatTest
class TestGoogleAIStudioGemini(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "gemini/gemini-1.5-flash"}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_invoke,
)
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
print(result)

View file

@ -687,3 +687,16 @@ def test_just_system_message():
llm_provider="bedrock",
)
assert "bedrock requires at least one non-system message" in str(e.value)
def test_convert_generic_image_chunk_to_openai_image_obj():
from litellm.llms.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
)
url = "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
image_obj = convert_to_anthropic_image_obj(url)
url_str = convert_generic_image_chunk_to_openai_image_obj(image_obj)
image_obj = convert_to_anthropic_image_obj(url_str)
print(image_obj)

View file

@ -1190,80 +1190,6 @@ def test_get_image_mime_type_from_url():
assert _get_image_mime_type_from_url("invalid_url") is None
@pytest.mark.parametrize(
"image_url", ["https://example.com/image.jpg", "https://example.com/image.png"]
)
def test_image_completion_request(image_url):
"""https:// .jpg, .png images are passed directly to the model"""
from unittest.mock import patch, Mock
import litellm
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
_get_image_mime_type_from_url,
)
# Mock response data
mock_response = Mock()
mock_response.json.return_value = {
"candidates": [{"content": {"parts": [{"text": "This is a sunflower"}]}}],
"usageMetadata": {
"promptTokenCount": 11,
"candidatesTokenCount": 50,
"totalTokenCount": 61,
},
"modelVersion": "gemini-1.5-pro",
}
mock_response.raise_for_status = MagicMock()
mock_response.status_code = 200
# Expected request body
expected_request_body = {
"contents": [
{
"role": "user",
"parts": [
{"text": "Whats in this image?"},
{
"file_data": {
"file_uri": image_url,
"mime_type": _get_image_mime_type_from_url(image_url),
}
},
],
}
],
"system_instruction": {"parts": [{"text": "Be a good bot"}]},
"generationConfig": {},
}
messages = [
{"role": "system", "content": "Be a good bot"},
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{"type": "image_url", "image_url": {"url": image_url}},
],
},
]
client = HTTPHandler()
with patch.object(client, "post", new=MagicMock()) as mock_post:
mock_post.return_value = mock_response
try:
litellm.completion(
model="gemini/gemini-1.5-pro",
messages=messages,
client=client,
)
except Exception as e:
print(e)
# Assert the request body matches expected
mock_post.assert_called_once()
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
assert mock_post.call_args.kwargs["json"] == expected_request_body
@pytest.mark.parametrize(
"model, expected_url",
[
@ -1298,20 +1224,3 @@ def test_vertex_embedding_url(model, expected_url):
assert url == expected_url
assert endpoint == "predict"
from base_llm_unit_tests import BaseLLMChatTest
class TestVertexGemini(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "gemini/gemini-1.5-flash"}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_invoke,
)
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
print(result)

View file

@ -556,13 +556,22 @@ def test_team_key_generation_team_member_check():
_team_key_generation_check,
)
from fastapi import HTTPException
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
litellm.key_generation_settings = {
"team_key_generation": {"allowed_team_member_roles": ["admin"]}
}
team_table = LiteLLM_TeamTableCachedObj(
team_id="test_team_id",
team_alias="test_team_alias",
members_with_roles=[Member(role="admin", user_id="test_user_id")],
)
assert _team_key_generation_check(
team_table=team_table,
user_api_key_dict=UserAPIKeyAuth(
user_id="test_user_id",
user_role=LitellmUserRoles.INTERNAL_USER,
api_key="sk-1234",
team_member=Member(role="admin", user_id="test_user_id"),
@ -570,8 +579,15 @@ def test_team_key_generation_team_member_check():
data=GenerateKeyRequest(),
)
team_table = LiteLLM_TeamTableCachedObj(
team_id="test_team_id",
team_alias="test_team_alias",
members_with_roles=[Member(role="user", user_id="test_user_id")],
)
with pytest.raises(HTTPException):
_team_key_generation_check(
team_table=team_table,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER,
api_key="sk-1234",
@ -607,6 +623,7 @@ def test_key_generation_required_params_check(
StandardKeyGenerationConfig,
PersonalUIKeyGenerationConfig,
)
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
from fastapi import HTTPException
user_api_key_dict = UserAPIKeyAuth(
@ -614,7 +631,13 @@ def test_key_generation_required_params_check(
api_key="sk-1234",
user_id="test_user_id",
team_id="test_team_id",
team_member=Member(role="admin", user_id="test_user_id"),
team_member=None,
)
team_table = LiteLLM_TeamTableCachedObj(
team_id="test_team_id",
team_alias="test_team_alias",
members_with_roles=[Member(role="admin", user_id="test_user_id")],
)
if key_type == "team_key":
@ -632,13 +655,13 @@ def test_key_generation_required_params_check(
if expected_result:
if key_type == "team_key":
assert _team_key_generation_check(user_api_key_dict, input_data)
assert _team_key_generation_check(team_table, user_api_key_dict, input_data)
elif key_type == "personal_key":
assert _personal_key_generation_check(user_api_key_dict, input_data)
else:
if key_type == "team_key":
with pytest.raises(HTTPException):
_team_key_generation_check(user_api_key_dict, input_data)
_team_key_generation_check(team_table, user_api_key_dict, input_data)
elif key_type == "personal_key":
with pytest.raises(HTTPException):
_personal_key_generation_check(user_api_key_dict, input_data)

View file

@ -1014,7 +1014,11 @@ async def test_create_team_member_add(prisma_client, new_member_method):
with patch(
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
new_callable=AsyncMock,
) as mock_litellm_usertable:
) as mock_litellm_usertable, patch(
"litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
new=AsyncMock(return_value=team_obj),
) as mock_team_obj:
mock_client = AsyncMock(
return_value=LiteLLM_UserTable(
user_id="1234", max_budget=100, user_email="1234"
@ -1193,7 +1197,10 @@ async def test_create_team_member_add_team_admin(
with patch(
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
new_callable=AsyncMock,
) as mock_litellm_usertable:
) as mock_litellm_usertable, patch(
"litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
new=AsyncMock(return_value=team_obj),
) as mock_team_obj:
mock_client = AsyncMock(
return_value=LiteLLM_UserTable(
user_id="1234", max_budget=100, user_email="1234"

View file

@ -413,6 +413,7 @@ const Team: React.FC<TeamProps> = ({
selectedTeam["team_id"],
user_role
);
message.success("Member added");
console.log(`response for team create call: ${response["data"]}`);
// Checking if the team exists in the list and updating or adding accordingly
const foundIndex = teams.findIndex((team) => {
@ -430,6 +431,7 @@ const Team: React.FC<TeamProps> = ({
setSelectedTeam(response.data);
}
setIsAddMemberModalVisible(false);
}
} catch (error) {
console.error("Error creating the team:", error);
@ -825,6 +827,9 @@ const Team: React.FC<TeamProps> = ({
labelCol={{ span: 8 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
initialValues={{
role: "user",
}}
>
<>
<Form.Item label="Email" name="user_email" className="mb-4">
@ -842,8 +847,8 @@ const Team: React.FC<TeamProps> = ({
</Form.Item>
<Form.Item label="Member Role" name="role" className="mb-4">
<Select2 defaultValue="user">
<Select2.Option value="user">user</Select2.Option>
<Select2.Option value="admin">admin</Select2.Option>
<Select2.Option value="user">user</Select2.Option>
</Select2>
</Form.Item>
</>