forked from phoenix/litellm-mirror
fix(key_management_endpoints.py): fix user-membership check when creating team key (#6890)
* fix(key_management_endpoints.py): fix user-membership check when creating team key * docs: add deprecation notice on original `/v1/messages` endpoint + add better swagger tags on pass-through endpoints * fix(gemini/): fix image_url handling for gemini Fixes https://github.com/BerriAI/litellm/issues/6897 * fix(teams.tsx): fix member add when role is 'user' * fix(team_endpoints.py): /team/member_add fix adding several new members to team * test(test_vertex.py): remove redundant test * test(test_proxy_server.py): fix team member add tests
This commit is contained in:
parent
dcea31e50a
commit
8673f2541e
19 changed files with 399 additions and 169 deletions
|
@ -190,6 +190,35 @@ class BaseLLMChatTest(ABC):
|
|||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
||||
def test_image_url(self):
|
||||
litellm.set_verbose = True
|
||||
from litellm.utils import supports_vision
|
||||
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
if not supports_vision(base_completion_call_args["model"], None):
|
||||
pytest.skip("Model does not support image input")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
response = litellm.completion(**base_completion_call_args, messages=messages)
|
||||
assert response is not None
|
||||
|
||||
@pytest.fixture
|
||||
def pdf_messages(self):
|
||||
import base64
|
||||
|
|
15
tests/llm_translation/test_gemini.py
Normal file
15
tests/llm_translation/test_gemini.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
|
||||
class TestGoogleAIStudioGemini(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
return {"model": "gemini/gemini-1.5-flash"}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_to_gemini_tool_call_invoke,
|
||||
)
|
||||
|
||||
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
|
||||
print(result)
|
|
@ -687,3 +687,16 @@ def test_just_system_message():
|
|||
llm_provider="bedrock",
|
||||
)
|
||||
assert "bedrock requires at least one non-system message" in str(e.value)
|
||||
|
||||
|
||||
def test_convert_generic_image_chunk_to_openai_image_obj():
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_generic_image_chunk_to_openai_image_obj,
|
||||
convert_to_anthropic_image_obj,
|
||||
)
|
||||
|
||||
url = "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
|
||||
image_obj = convert_to_anthropic_image_obj(url)
|
||||
url_str = convert_generic_image_chunk_to_openai_image_obj(image_obj)
|
||||
image_obj = convert_to_anthropic_image_obj(url_str)
|
||||
print(image_obj)
|
||||
|
|
|
@ -1190,80 +1190,6 @@ def test_get_image_mime_type_from_url():
|
|||
assert _get_image_mime_type_from_url("invalid_url") is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"image_url", ["https://example.com/image.jpg", "https://example.com/image.png"]
|
||||
)
|
||||
def test_image_completion_request(image_url):
|
||||
"""https:// .jpg, .png images are passed directly to the model"""
|
||||
from unittest.mock import patch, Mock
|
||||
import litellm
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
|
||||
_get_image_mime_type_from_url,
|
||||
)
|
||||
|
||||
# Mock response data
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"candidates": [{"content": {"parts": [{"text": "This is a sunflower"}]}}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 11,
|
||||
"candidatesTokenCount": 50,
|
||||
"totalTokenCount": 61,
|
||||
},
|
||||
"modelVersion": "gemini-1.5-pro",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
# Expected request body
|
||||
expected_request_body = {
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"text": "Whats in this image?"},
|
||||
{
|
||||
"file_data": {
|
||||
"file_uri": image_url,
|
||||
"mime_type": _get_image_mime_type_from_url(image_url),
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"system_instruction": {"parts": [{"text": "Be a good bot"}]},
|
||||
"generationConfig": {},
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "Be a good bot"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Whats in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
client = HTTPHandler()
|
||||
with patch.object(client, "post", new=MagicMock()) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
try:
|
||||
litellm.completion(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
messages=messages,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
# Assert the request body matches expected
|
||||
mock_post.assert_called_once()
|
||||
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
|
||||
assert mock_post.call_args.kwargs["json"] == expected_request_body
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expected_url",
|
||||
[
|
||||
|
@ -1298,20 +1224,3 @@ def test_vertex_embedding_url(model, expected_url):
|
|||
|
||||
assert url == expected_url
|
||||
assert endpoint == "predict"
|
||||
|
||||
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
|
||||
class TestVertexGemini(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
return {"model": "gemini/gemini-1.5-flash"}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_to_gemini_tool_call_invoke,
|
||||
)
|
||||
|
||||
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
|
||||
print(result)
|
||||
|
|
|
@ -556,13 +556,22 @@ def test_team_key_generation_team_member_check():
|
|||
_team_key_generation_check,
|
||||
)
|
||||
from fastapi import HTTPException
|
||||
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
||||
|
||||
litellm.key_generation_settings = {
|
||||
"team_key_generation": {"allowed_team_member_roles": ["admin"]}
|
||||
}
|
||||
|
||||
team_table = LiteLLM_TeamTableCachedObj(
|
||||
team_id="test_team_id",
|
||||
team_alias="test_team_alias",
|
||||
members_with_roles=[Member(role="admin", user_id="test_user_id")],
|
||||
)
|
||||
|
||||
assert _team_key_generation_check(
|
||||
team_table=team_table,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_id="test_user_id",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
api_key="sk-1234",
|
||||
team_member=Member(role="admin", user_id="test_user_id"),
|
||||
|
@ -570,8 +579,15 @@ def test_team_key_generation_team_member_check():
|
|||
data=GenerateKeyRequest(),
|
||||
)
|
||||
|
||||
team_table = LiteLLM_TeamTableCachedObj(
|
||||
team_id="test_team_id",
|
||||
team_alias="test_team_alias",
|
||||
members_with_roles=[Member(role="user", user_id="test_user_id")],
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
_team_key_generation_check(
|
||||
team_table=team_table,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
api_key="sk-1234",
|
||||
|
@ -607,6 +623,7 @@ def test_key_generation_required_params_check(
|
|||
StandardKeyGenerationConfig,
|
||||
PersonalUIKeyGenerationConfig,
|
||||
)
|
||||
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
||||
from fastapi import HTTPException
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
|
@ -614,7 +631,13 @@ def test_key_generation_required_params_check(
|
|||
api_key="sk-1234",
|
||||
user_id="test_user_id",
|
||||
team_id="test_team_id",
|
||||
team_member=Member(role="admin", user_id="test_user_id"),
|
||||
team_member=None,
|
||||
)
|
||||
|
||||
team_table = LiteLLM_TeamTableCachedObj(
|
||||
team_id="test_team_id",
|
||||
team_alias="test_team_alias",
|
||||
members_with_roles=[Member(role="admin", user_id="test_user_id")],
|
||||
)
|
||||
|
||||
if key_type == "team_key":
|
||||
|
@ -632,13 +655,13 @@ def test_key_generation_required_params_check(
|
|||
|
||||
if expected_result:
|
||||
if key_type == "team_key":
|
||||
assert _team_key_generation_check(user_api_key_dict, input_data)
|
||||
assert _team_key_generation_check(team_table, user_api_key_dict, input_data)
|
||||
elif key_type == "personal_key":
|
||||
assert _personal_key_generation_check(user_api_key_dict, input_data)
|
||||
else:
|
||||
if key_type == "team_key":
|
||||
with pytest.raises(HTTPException):
|
||||
_team_key_generation_check(user_api_key_dict, input_data)
|
||||
_team_key_generation_check(team_table, user_api_key_dict, input_data)
|
||||
elif key_type == "personal_key":
|
||||
with pytest.raises(HTTPException):
|
||||
_personal_key_generation_check(user_api_key_dict, input_data)
|
||||
|
|
|
@ -1014,7 +1014,11 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
|||
with patch(
|
||||
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_litellm_usertable:
|
||||
) as mock_litellm_usertable, patch(
|
||||
"litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
|
||||
new=AsyncMock(return_value=team_obj),
|
||||
) as mock_team_obj:
|
||||
|
||||
mock_client = AsyncMock(
|
||||
return_value=LiteLLM_UserTable(
|
||||
user_id="1234", max_budget=100, user_email="1234"
|
||||
|
@ -1193,7 +1197,10 @@ async def test_create_team_member_add_team_admin(
|
|||
with patch(
|
||||
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_litellm_usertable:
|
||||
) as mock_litellm_usertable, patch(
|
||||
"litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
|
||||
new=AsyncMock(return_value=team_obj),
|
||||
) as mock_team_obj:
|
||||
mock_client = AsyncMock(
|
||||
return_value=LiteLLM_UserTable(
|
||||
user_id="1234", max_budget=100, user_email="1234"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue