forked from phoenix/litellm-mirror
Compare commits
1 commit
main
...
litellm_st
Author | SHA1 | Date | |
---|---|---|---|
|
0f08577060 |
19 changed files with 406 additions and 169 deletions
|
@ -33,6 +33,7 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionAssistantToolCall,
|
ChatCompletionAssistantToolCall,
|
||||||
ChatCompletionFunctionMessage,
|
ChatCompletionFunctionMessage,
|
||||||
ChatCompletionImageObject,
|
ChatCompletionImageObject,
|
||||||
|
ChatCompletionImageUrlObject,
|
||||||
ChatCompletionTextObject,
|
ChatCompletionTextObject,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionToolMessage,
|
ChatCompletionToolMessage,
|
||||||
|
@ -681,6 +682,27 @@ def construct_tool_use_system_prompt(
|
||||||
return tool_use_system_prompt
|
return tool_use_system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def convert_generic_image_chunk_to_openai_image_obj(
|
||||||
|
image_chunk: GenericImageParsingChunk,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Convert a generic image chunk to an OpenAI image object.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
GenericImageParsingChunk(
|
||||||
|
type="base64",
|
||||||
|
media_type="image/jpeg",
|
||||||
|
data="...",
|
||||||
|
)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
"data:image/jpeg;base64,{base64_image}"
|
||||||
|
"""
|
||||||
|
return "data:{};{},{}".format(
|
||||||
|
image_chunk["media_type"], image_chunk["type"], image_chunk["data"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
|
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
|
||||||
"""
|
"""
|
||||||
Input:
|
Input:
|
||||||
|
@ -706,6 +728,7 @@ def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsing
|
||||||
data=base64_data,
|
data=base64_data,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
if "Error: Unable to fetch image from URL" in str(e):
|
if "Error: Unable to fetch image from URL" in str(e):
|
||||||
raise e
|
raise e
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
@ -294,7 +294,12 @@ def _transform_request_body(
|
||||||
optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
|
optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = _gemini_convert_messages_with_history(messages=messages)
|
if custom_llm_provider == "gemini":
|
||||||
|
content = litellm.GoogleAIStudioGeminiConfig._transform_messages(
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
content = litellm.VertexGeminiConfig._transform_messages(messages=messages)
|
||||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||||
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
||||||
|
|
|
@ -35,7 +35,12 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.prompt_templates.factory import (
|
||||||
|
convert_generic_image_chunk_to_openai_image_obj,
|
||||||
|
convert_to_anthropic_image_obj,
|
||||||
|
)
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionResponseMessage,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
@ -78,6 +83,8 @@ from ..common_utils import (
|
||||||
)
|
)
|
||||||
from ..vertex_llm_base import VertexBase
|
from ..vertex_llm_base import VertexBase
|
||||||
from .transformation import (
|
from .transformation import (
|
||||||
|
_gemini_convert_messages_with_history,
|
||||||
|
_process_gemini_image,
|
||||||
async_transform_request_body,
|
async_transform_request_body,
|
||||||
set_headers,
|
set_headers,
|
||||||
sync_transform_request_body,
|
sync_transform_request_body,
|
||||||
|
@ -912,6 +919,10 @@ class VertexGeminiConfig:
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
|
||||||
|
return _gemini_convert_messages_with_history(messages=messages)
|
||||||
|
|
||||||
|
|
||||||
class GoogleAIStudioGeminiConfig(
|
class GoogleAIStudioGeminiConfig(
|
||||||
VertexGeminiConfig
|
VertexGeminiConfig
|
||||||
|
@ -1015,6 +1026,32 @@ class GoogleAIStudioGeminiConfig(
|
||||||
model, non_default_params, optional_params, drop_params
|
model, non_default_params, optional_params, drop_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
|
||||||
|
"""
|
||||||
|
Google AI Studio Gemini does not support image urls in messages.
|
||||||
|
"""
|
||||||
|
for message in messages:
|
||||||
|
_message_content = message.get("content")
|
||||||
|
if _message_content is not None and isinstance(_message_content, list):
|
||||||
|
_parts: List[PartType] = []
|
||||||
|
for element in _message_content:
|
||||||
|
if element.get("type") == "image_url":
|
||||||
|
img_element = element
|
||||||
|
_image_url: Optional[str] = None
|
||||||
|
if isinstance(img_element.get("image_url"), dict):
|
||||||
|
_image_url = img_element["image_url"].get("url") # type: ignore
|
||||||
|
else:
|
||||||
|
_image_url = img_element.get("image_url") # type: ignore
|
||||||
|
if _image_url and "https://" in _image_url:
|
||||||
|
image_obj = convert_to_anthropic_image_obj(_image_url)
|
||||||
|
img_element["image_url"] = ( # type: ignore
|
||||||
|
convert_generic_image_chunk_to_openai_image_obj(
|
||||||
|
image_obj
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return _gemini_convert_messages_with_history(messages=messages)
|
||||||
|
|
||||||
|
|
||||||
async def make_call(
|
async def make_call(
|
||||||
client: Optional[AsyncHTTPHandler],
|
client: Optional[AsyncHTTPHandler],
|
||||||
|
|
|
@ -12,3 +12,23 @@ model_list:
|
||||||
vertex_ai_project: "adroit-crow-413218"
|
vertex_ai_project: "adroit-crow-413218"
|
||||||
vertex_ai_location: "us-east5"
|
vertex_ai_location: "us-east5"
|
||||||
|
|
||||||
|
router_settings:
|
||||||
|
routing_strategy: usage-based-routing-v2
|
||||||
|
#redis_url: "os.environ/REDIS_URL"
|
||||||
|
redis_host: "os.environ/REDIS_HOST"
|
||||||
|
redis_port: "os.environ/REDIS_PORT"
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
cache: true
|
||||||
|
cache_params:
|
||||||
|
type: redis
|
||||||
|
host: "os.environ/REDIS_HOST"
|
||||||
|
port: "os.environ/REDIS_PORT"
|
||||||
|
namespace: "litellm.caching"
|
||||||
|
ttl: 600
|
||||||
|
# key_generation_settings:
|
||||||
|
# team_key_generation:
|
||||||
|
# allowed_team_member_roles: ["admin"]
|
||||||
|
# required_params: ["tags"] # require team admins to set tags for cost-tracking when generating a team key
|
||||||
|
# personal_key_generation: # maps to 'Default Team' on UI
|
||||||
|
# allowed_user_roles: ["proxy_admin"]
|
|
@ -1982,7 +1982,6 @@ class MemberAddRequest(LiteLLMBase):
|
||||||
# Replace member_data with the single Member object
|
# Replace member_data with the single Member object
|
||||||
data["member"] = member
|
data["member"] = member
|
||||||
# Call the superclass __init__ method to initialize the object
|
# Call the superclass __init__ method to initialize the object
|
||||||
traceback.print_stack()
|
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -523,6 +523,10 @@ async def _cache_management_object(
|
||||||
proxy_logging_obj: Optional[ProxyLogging],
|
proxy_logging_obj: Optional[ProxyLogging],
|
||||||
):
|
):
|
||||||
await user_api_key_cache.async_set_cache(key=key, value=value)
|
await user_api_key_cache.async_set_cache(key=key, value=value)
|
||||||
|
if proxy_logging_obj is not None:
|
||||||
|
await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache(
|
||||||
|
key=key, value=value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _cache_team_object(
|
async def _cache_team_object(
|
||||||
|
@ -586,26 +590,63 @@ async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_team_object(
|
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
|
||||||
team_id: str,
|
return await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
prisma_client: Optional[PrismaClient],
|
where={"team_id": team_id}
|
||||||
user_api_key_cache: DualCache,
|
)
|
||||||
parent_otel_span: Optional[Span] = None,
|
|
||||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
|
||||||
check_cache_only: Optional[bool] = None,
|
|
||||||
) -> LiteLLM_TeamTableCachedObj:
|
|
||||||
"""
|
|
||||||
- Check if team id in proxy Team Table
|
|
||||||
- if valid, return LiteLLM_TeamTable object with defined limits
|
|
||||||
- if not, then raise an error
|
|
||||||
"""
|
|
||||||
if prisma_client is None:
|
|
||||||
raise Exception(
|
|
||||||
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
|
||||||
)
|
|
||||||
|
|
||||||
# check if in cache
|
|
||||||
key = "team_id:{}".format(team_id)
|
async def _get_team_object_from_user_api_key_cache(
|
||||||
|
team_id: str,
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
last_db_access_time: LimitedSizeOrderedDict,
|
||||||
|
db_cache_expiry: int,
|
||||||
|
proxy_logging_obj: Optional[ProxyLogging],
|
||||||
|
key: str,
|
||||||
|
) -> LiteLLM_TeamTableCachedObj:
|
||||||
|
db_access_time_key = key
|
||||||
|
should_check_db = _should_check_db(
|
||||||
|
key=db_access_time_key,
|
||||||
|
last_db_access_time=last_db_access_time,
|
||||||
|
db_cache_expiry=db_cache_expiry,
|
||||||
|
)
|
||||||
|
if should_check_db:
|
||||||
|
response = await _get_team_db_check(
|
||||||
|
team_id=team_id, prisma_client=prisma_client
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = None
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
_response = LiteLLM_TeamTableCachedObj(**response.dict())
|
||||||
|
# save the team object to cache
|
||||||
|
await _cache_team_object(
|
||||||
|
team_id=team_id,
|
||||||
|
team_table=_response,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
# save to db access time
|
||||||
|
# save to db access time
|
||||||
|
_update_last_db_access_time(
|
||||||
|
key=db_access_time_key,
|
||||||
|
value=_response,
|
||||||
|
last_db_access_time=last_db_access_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _response
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_team_object_from_cache(
|
||||||
|
key: str,
|
||||||
|
proxy_logging_obj: Optional[ProxyLogging],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
parent_otel_span: Optional[Span],
|
||||||
|
) -> Optional[LiteLLM_TeamTableCachedObj]:
|
||||||
cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None
|
cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None
|
||||||
|
|
||||||
## CHECK REDIS CACHE ##
|
## CHECK REDIS CACHE ##
|
||||||
|
@ -613,6 +654,7 @@ async def get_team_object(
|
||||||
proxy_logging_obj is not None
|
proxy_logging_obj is not None
|
||||||
and proxy_logging_obj.internal_usage_cache.dual_cache
|
and proxy_logging_obj.internal_usage_cache.dual_cache
|
||||||
):
|
):
|
||||||
|
|
||||||
cached_team_obj = (
|
cached_team_obj = (
|
||||||
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
|
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
|
||||||
key=key, parent_otel_span=parent_otel_span
|
key=key, parent_otel_span=parent_otel_span
|
||||||
|
@ -628,47 +670,58 @@ async def get_team_object(
|
||||||
elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj):
|
elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj):
|
||||||
return cached_team_obj
|
return cached_team_obj
|
||||||
|
|
||||||
if check_cache_only:
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_team_object(
|
||||||
|
team_id: str,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||||
|
check_cache_only: Optional[bool] = None,
|
||||||
|
check_db_only: Optional[bool] = None,
|
||||||
|
) -> LiteLLM_TeamTableCachedObj:
|
||||||
|
"""
|
||||||
|
- Check if team id in proxy Team Table
|
||||||
|
- if valid, return LiteLLM_TeamTable object with defined limits
|
||||||
|
- if not, then raise an error
|
||||||
|
"""
|
||||||
|
if prisma_client is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
|
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# check if in cache
|
||||||
|
key = "team_id:{}".format(team_id)
|
||||||
|
|
||||||
|
if not check_db_only:
|
||||||
|
cached_team_obj = await _get_team_object_from_cache(
|
||||||
|
key=key,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cached_team_obj is not None:
|
||||||
|
return cached_team_obj
|
||||||
|
|
||||||
|
if check_cache_only:
|
||||||
|
raise Exception(
|
||||||
|
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
|
||||||
|
)
|
||||||
|
|
||||||
# else, check db
|
# else, check db
|
||||||
try:
|
try:
|
||||||
db_access_time_key = "team_id:{}".format(team_id)
|
return await _get_team_object_from_user_api_key_cache(
|
||||||
should_check_db = _should_check_db(
|
|
||||||
key=db_access_time_key,
|
|
||||||
last_db_access_time=last_db_access_time,
|
|
||||||
db_cache_expiry=db_cache_expiry,
|
|
||||||
)
|
|
||||||
if should_check_db:
|
|
||||||
response = await _get_team_db_check(
|
|
||||||
team_id=team_id, prisma_client=prisma_client
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = None
|
|
||||||
|
|
||||||
if response is None:
|
|
||||||
raise Exception
|
|
||||||
|
|
||||||
_response = LiteLLM_TeamTableCachedObj(**response.dict())
|
|
||||||
# save the team object to cache
|
|
||||||
await _cache_team_object(
|
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
team_table=_response,
|
prisma_client=prisma_client,
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
)
|
|
||||||
|
|
||||||
# save to db access time
|
|
||||||
# save to db access time
|
|
||||||
_update_last_db_access_time(
|
|
||||||
key=db_access_time_key,
|
|
||||||
value=_response,
|
|
||||||
last_db_access_time=last_db_access_time,
|
last_db_access_time=last_db_access_time,
|
||||||
|
db_cache_expiry=db_cache_expiry,
|
||||||
|
key=key,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _response
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||||
|
|
|
@ -29,6 +29,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
_cache_key_object,
|
_cache_key_object,
|
||||||
_delete_cache_key_object,
|
_delete_cache_key_object,
|
||||||
get_key_object,
|
get_key_object,
|
||||||
|
get_team_object,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||||
|
@ -46,7 +47,19 @@ def _is_team_key(data: GenerateKeyRequest):
|
||||||
return data.team_id is not None
|
return data.team_id is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_in_team(
|
||||||
|
team_table: LiteLLM_TeamTableCachedObj, user_id: Optional[str]
|
||||||
|
) -> Optional[Member]:
|
||||||
|
if user_id is None:
|
||||||
|
return None
|
||||||
|
for member in team_table.members_with_roles:
|
||||||
|
if member.user_id is not None and member.user_id == user_id:
|
||||||
|
return member
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _team_key_generation_team_member_check(
|
def _team_key_generation_team_member_check(
|
||||||
|
team_table: LiteLLM_TeamTableCachedObj,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
team_key_generation: Optional[TeamUIKeyGenerationConfig],
|
team_key_generation: Optional[TeamUIKeyGenerationConfig],
|
||||||
):
|
):
|
||||||
|
@ -56,17 +69,19 @@ def _team_key_generation_team_member_check(
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if user_api_key_dict.team_member is None:
|
user_in_team = _get_user_in_team(
|
||||||
|
team_table=team_table, user_id=user_api_key_dict.user_id
|
||||||
|
)
|
||||||
|
if user_in_team is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}",
|
detail=f"User={user_api_key_dict.user_id} not assigned to team={team_table.team_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
team_member_role = user_api_key_dict.team_member.role
|
if user_in_team.role not in team_key_generation["allowed_team_member_roles"]:
|
||||||
if team_member_role not in team_key_generation["allowed_team_member_roles"]:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore
|
detail=f"Team member role {user_in_team.role} not in allowed_team_member_roles={team_key_generation['allowed_team_member_roles']}",
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -88,7 +103,9 @@ def _key_generation_required_param_check(
|
||||||
|
|
||||||
|
|
||||||
def _team_key_generation_check(
|
def _team_key_generation_check(
|
||||||
user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest
|
team_table: LiteLLM_TeamTableCachedObj,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
data: GenerateKeyRequest,
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
litellm.key_generation_settings is None
|
litellm.key_generation_settings is None
|
||||||
|
@ -99,7 +116,8 @@ def _team_key_generation_check(
|
||||||
_team_key_generation = litellm.key_generation_settings["team_key_generation"] # type: ignore
|
_team_key_generation = litellm.key_generation_settings["team_key_generation"] # type: ignore
|
||||||
|
|
||||||
_team_key_generation_team_member_check(
|
_team_key_generation_team_member_check(
|
||||||
user_api_key_dict,
|
team_table=team_table,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
team_key_generation=_team_key_generation,
|
team_key_generation=_team_key_generation,
|
||||||
)
|
)
|
||||||
_key_generation_required_param_check(
|
_key_generation_required_param_check(
|
||||||
|
@ -155,7 +173,9 @@ def _personal_key_generation_check(
|
||||||
|
|
||||||
|
|
||||||
def key_generation_check(
|
def key_generation_check(
|
||||||
user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest
|
team_table: Optional[LiteLLM_TeamTableCachedObj],
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
data: GenerateKeyRequest,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if admin has restricted key creation to certain roles for teams or individuals
|
Check if admin has restricted key creation to certain roles for teams or individuals
|
||||||
|
@ -170,8 +190,15 @@ def key_generation_check(
|
||||||
is_team_key = _is_team_key(data=data)
|
is_team_key = _is_team_key(data=data)
|
||||||
|
|
||||||
if is_team_key:
|
if is_team_key:
|
||||||
|
if team_table is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unable to find team object in database. Team ID: {data.team_id}",
|
||||||
|
)
|
||||||
return _team_key_generation_check(
|
return _team_key_generation_check(
|
||||||
user_api_key_dict=user_api_key_dict, data=data
|
team_table=team_table,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
data=data,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return _personal_key_generation_check(
|
return _personal_key_generation_check(
|
||||||
|
@ -254,6 +281,7 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
litellm_proxy_admin_name,
|
litellm_proxy_admin_name,
|
||||||
prisma_client,
|
prisma_client,
|
||||||
proxy_logging_obj,
|
proxy_logging_obj,
|
||||||
|
user_api_key_cache,
|
||||||
user_custom_key_generate,
|
user_custom_key_generate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -271,7 +299,20 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=message
|
status_code=status.HTTP_403_FORBIDDEN, detail=message
|
||||||
)
|
)
|
||||||
elif litellm.key_generation_settings is not None:
|
elif litellm.key_generation_settings is not None:
|
||||||
key_generation_check(user_api_key_dict=user_api_key_dict, data=data)
|
if data.team_id is None:
|
||||||
|
team_table: Optional[LiteLLM_TeamTableCachedObj] = None
|
||||||
|
else:
|
||||||
|
team_table = await get_team_object(
|
||||||
|
team_id=data.team_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||||
|
)
|
||||||
|
key_generation_check(
|
||||||
|
team_table=team_table,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
# check if user set default key/generate params on config.yaml
|
# check if user set default key/generate params on config.yaml
|
||||||
if litellm.default_key_generate_params is not None:
|
if litellm.default_key_generate_params is not None:
|
||||||
for elem in data:
|
for elem in data:
|
||||||
|
|
|
@ -547,6 +547,7 @@ async def team_member_add(
|
||||||
parent_otel_span=None,
|
parent_otel_span=None,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
check_cache_only=False,
|
check_cache_only=False,
|
||||||
|
check_db_only=True,
|
||||||
)
|
)
|
||||||
if existing_team_row is None:
|
if existing_team_row is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
@ -54,12 +54,19 @@ def create_request_copy(request: Request):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
@router.api_route(
|
||||||
|
"/gemini/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE"],
|
||||||
|
tags=["Google AI Studio Pass-through", "pass-through"],
|
||||||
|
)
|
||||||
async def gemini_proxy_route(
|
async def gemini_proxy_route(
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
[Docs](https://docs.litellm.ai/docs/pass_through/google_ai_studio)
|
||||||
|
"""
|
||||||
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
|
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
|
||||||
api_key = request.query_params.get("key")
|
api_key = request.query_params.get("key")
|
||||||
|
|
||||||
|
@ -111,13 +118,20 @@ async def gemini_proxy_route(
|
||||||
return received_value
|
return received_value
|
||||||
|
|
||||||
|
|
||||||
@router.api_route("/cohere/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
@router.api_route(
|
||||||
|
"/cohere/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE"],
|
||||||
|
tags=["Cohere Pass-through", "pass-through"],
|
||||||
|
)
|
||||||
async def cohere_proxy_route(
|
async def cohere_proxy_route(
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
[Docs](https://docs.litellm.ai/docs/pass_through/cohere)
|
||||||
|
"""
|
||||||
base_target_url = "https://api.cohere.com"
|
base_target_url = "https://api.cohere.com"
|
||||||
encoded_endpoint = httpx.URL(endpoint).path
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
|
@ -154,7 +168,9 @@ async def cohere_proxy_route(
|
||||||
|
|
||||||
|
|
||||||
@router.api_route(
|
@router.api_route(
|
||||||
"/anthropic/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
|
"/anthropic/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE"],
|
||||||
|
tags=["Anthropic Pass-through", "pass-through"],
|
||||||
)
|
)
|
||||||
async def anthropic_proxy_route(
|
async def anthropic_proxy_route(
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
|
@ -162,6 +178,9 @@ async def anthropic_proxy_route(
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
|
||||||
|
"""
|
||||||
base_target_url = "https://api.anthropic.com"
|
base_target_url = "https://api.anthropic.com"
|
||||||
encoded_endpoint = httpx.URL(endpoint).path
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
|
@ -201,13 +220,20 @@ async def anthropic_proxy_route(
|
||||||
return received_value
|
return received_value
|
||||||
|
|
||||||
|
|
||||||
@router.api_route("/bedrock/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
@router.api_route(
|
||||||
|
"/bedrock/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE"],
|
||||||
|
tags=["Bedrock Pass-through", "pass-through"],
|
||||||
|
)
|
||||||
async def bedrock_proxy_route(
|
async def bedrock_proxy_route(
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
[Docs](https://docs.litellm.ai/docs/pass_through/bedrock)
|
||||||
|
"""
|
||||||
create_request_copy(request)
|
create_request_copy(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -275,13 +301,22 @@ async def bedrock_proxy_route(
|
||||||
return received_value
|
return received_value
|
||||||
|
|
||||||
|
|
||||||
@router.api_route("/azure/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
@router.api_route(
|
||||||
|
"/azure/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE"],
|
||||||
|
tags=["Azure Pass-through", "pass-through"],
|
||||||
|
)
|
||||||
async def azure_proxy_route(
|
async def azure_proxy_route(
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Call any azure endpoint using the proxy.
|
||||||
|
|
||||||
|
Just use `{PROXY_BASE_URL}/azure/{endpoint:path}`
|
||||||
|
"""
|
||||||
base_target_url = get_secret_str(secret_name="AZURE_API_BASE")
|
base_target_url = get_secret_str(secret_name="AZURE_API_BASE")
|
||||||
if base_target_url is None:
|
if base_target_url is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
@ -5663,11 +5663,11 @@ async def anthropic_response( # noqa: PLR0915
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This is a BETA endpoint that calls 100+ LLMs in the anthropic format.
|
🚨 DEPRECATED ENDPOINT🚨
|
||||||
|
|
||||||
To do a simple pass-through for anthropic, do `{PROXY_BASE_URL}/anthropic/v1/messages`
|
Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/anthropic_completion).
|
||||||
|
|
||||||
Docs - https://docs.litellm.ai/docs/anthropic_completion
|
This was a BETA endpoint that calls 100+ LLMs in the anthropic format.
|
||||||
"""
|
"""
|
||||||
from litellm import adapter_completion
|
from litellm import adapter_completion
|
||||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
||||||
|
|
|
@ -58,12 +58,21 @@ def create_request_copy(request: Request):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.api_route("/langfuse/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
@router.api_route(
|
||||||
|
"/langfuse/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE"],
|
||||||
|
tags=["Langfuse Pass-through", "pass-through"],
|
||||||
|
)
|
||||||
async def langfuse_proxy_route(
|
async def langfuse_proxy_route(
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Call Langfuse via LiteLLM proxy. Works with Langfuse SDK.
|
||||||
|
|
||||||
|
[Docs](https://docs.litellm.ai/docs/pass_through/langfuse)
|
||||||
|
"""
|
||||||
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
|
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
|
||||||
api_key = request.headers.get("Authorization") or ""
|
api_key = request.headers.get("Authorization") or ""
|
||||||
|
|
||||||
|
|
|
@ -113,13 +113,26 @@ def construct_target_url(
|
||||||
|
|
||||||
|
|
||||||
@router.api_route(
|
@router.api_route(
|
||||||
"/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
|
"/vertex-ai/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE"],
|
||||||
|
tags=["Vertex AI Pass-through", "pass-through"],
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
@router.api_route(
|
||||||
|
"/vertex_ai/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE"],
|
||||||
|
tags=["Vertex AI Pass-through", "pass-through"],
|
||||||
)
|
)
|
||||||
async def vertex_proxy_route(
|
async def vertex_proxy_route(
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Call LiteLLM proxy via Vertex AI SDK.
|
||||||
|
|
||||||
|
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
|
||||||
|
"""
|
||||||
encoded_endpoint = httpx.URL(endpoint).path
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
|
@ -190,6 +190,35 @@ class BaseLLMChatTest(ABC):
|
||||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_image_url(self):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
from litellm.utils import supports_vision
|
||||||
|
|
||||||
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
if not supports_vision(base_completion_call_args["model"], None):
|
||||||
|
pytest.skip("Model does not support image input")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = litellm.completion(**base_completion_call_args, messages=messages)
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pdf_messages(self):
|
def pdf_messages(self):
|
||||||
import base64
|
import base64
|
||||||
|
|
15
tests/llm_translation/test_gemini.py
Normal file
15
tests/llm_translation/test_gemini.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from base_llm_unit_tests import BaseLLMChatTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestGoogleAIStudioGemini(BaseLLMChatTest):
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
return {"model": "gemini/gemini-1.5-flash"}
|
||||||
|
|
||||||
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||||
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
|
from litellm.llms.prompt_templates.factory import (
|
||||||
|
convert_to_gemini_tool_call_invoke,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
|
||||||
|
print(result)
|
|
@ -687,3 +687,16 @@ def test_just_system_message():
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
)
|
)
|
||||||
assert "bedrock requires at least one non-system message" in str(e.value)
|
assert "bedrock requires at least one non-system message" in str(e.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_generic_image_chunk_to_openai_image_obj():
|
||||||
|
from litellm.llms.prompt_templates.factory import (
|
||||||
|
convert_generic_image_chunk_to_openai_image_obj,
|
||||||
|
convert_to_anthropic_image_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
url = "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
|
||||||
|
image_obj = convert_to_anthropic_image_obj(url)
|
||||||
|
url_str = convert_generic_image_chunk_to_openai_image_obj(image_obj)
|
||||||
|
image_obj = convert_to_anthropic_image_obj(url_str)
|
||||||
|
print(image_obj)
|
||||||
|
|
|
@ -1190,80 +1190,6 @@ def test_get_image_mime_type_from_url():
|
||||||
assert _get_image_mime_type_from_url("invalid_url") is None
|
assert _get_image_mime_type_from_url("invalid_url") is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"image_url", ["https://example.com/image.jpg", "https://example.com/image.png"]
|
|
||||||
)
|
|
||||||
def test_image_completion_request(image_url):
|
|
||||||
"""https:// .jpg, .png images are passed directly to the model"""
|
|
||||||
from unittest.mock import patch, Mock
|
|
||||||
import litellm
|
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
|
|
||||||
_get_image_mime_type_from_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock response data
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"candidates": [{"content": {"parts": [{"text": "This is a sunflower"}]}}],
|
|
||||||
"usageMetadata": {
|
|
||||||
"promptTokenCount": 11,
|
|
||||||
"candidatesTokenCount": 50,
|
|
||||||
"totalTokenCount": 61,
|
|
||||||
},
|
|
||||||
"modelVersion": "gemini-1.5-pro",
|
|
||||||
}
|
|
||||||
mock_response.raise_for_status = MagicMock()
|
|
||||||
mock_response.status_code = 200
|
|
||||||
|
|
||||||
# Expected request body
|
|
||||||
expected_request_body = {
|
|
||||||
"contents": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"parts": [
|
|
||||||
{"text": "Whats in this image?"},
|
|
||||||
{
|
|
||||||
"file_data": {
|
|
||||||
"file_uri": image_url,
|
|
||||||
"mime_type": _get_image_mime_type_from_url(image_url),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"system_instruction": {"parts": [{"text": "Be a good bot"}]},
|
|
||||||
"generationConfig": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": "Be a good bot"},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "Whats in this image?"},
|
|
||||||
{"type": "image_url", "image_url": {"url": image_url}},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
client = HTTPHandler()
|
|
||||||
with patch.object(client, "post", new=MagicMock()) as mock_post:
|
|
||||||
mock_post.return_value = mock_response
|
|
||||||
try:
|
|
||||||
litellm.completion(
|
|
||||||
model="gemini/gemini-1.5-pro",
|
|
||||||
messages=messages,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
# Assert the request body matches expected
|
|
||||||
mock_post.assert_called_once()
|
|
||||||
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
|
|
||||||
assert mock_post.call_args.kwargs["json"] == expected_request_body
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, expected_url",
|
"model, expected_url",
|
||||||
[
|
[
|
||||||
|
@ -1298,20 +1224,3 @@ def test_vertex_embedding_url(model, expected_url):
|
||||||
|
|
||||||
assert url == expected_url
|
assert url == expected_url
|
||||||
assert endpoint == "predict"
|
assert endpoint == "predict"
|
||||||
|
|
||||||
|
|
||||||
from base_llm_unit_tests import BaseLLMChatTest
|
|
||||||
|
|
||||||
|
|
||||||
class TestVertexGemini(BaseLLMChatTest):
|
|
||||||
def get_base_completion_call_args(self) -> dict:
|
|
||||||
return {"model": "gemini/gemini-1.5-flash"}
|
|
||||||
|
|
||||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
|
||||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
|
||||||
from litellm.llms.prompt_templates.factory import (
|
|
||||||
convert_to_gemini_tool_call_invoke,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
|
|
||||||
print(result)
|
|
||||||
|
|
|
@ -556,13 +556,22 @@ def test_team_key_generation_team_member_check():
|
||||||
_team_key_generation_check,
|
_team_key_generation_check,
|
||||||
)
|
)
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
||||||
|
|
||||||
litellm.key_generation_settings = {
|
litellm.key_generation_settings = {
|
||||||
"team_key_generation": {"allowed_team_member_roles": ["admin"]}
|
"team_key_generation": {"allowed_team_member_roles": ["admin"]}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
team_table = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id="test_team_id",
|
||||||
|
team_alias="test_team_alias",
|
||||||
|
members_with_roles=[Member(role="admin", user_id="test_user_id")],
|
||||||
|
)
|
||||||
|
|
||||||
assert _team_key_generation_check(
|
assert _team_key_generation_check(
|
||||||
|
team_table=team_table,
|
||||||
user_api_key_dict=UserAPIKeyAuth(
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_id="test_user_id",
|
||||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
api_key="sk-1234",
|
api_key="sk-1234",
|
||||||
team_member=Member(role="admin", user_id="test_user_id"),
|
team_member=Member(role="admin", user_id="test_user_id"),
|
||||||
|
@ -570,8 +579,15 @@ def test_team_key_generation_team_member_check():
|
||||||
data=GenerateKeyRequest(),
|
data=GenerateKeyRequest(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
team_table = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id="test_team_id",
|
||||||
|
team_alias="test_team_alias",
|
||||||
|
members_with_roles=[Member(role="user", user_id="test_user_id")],
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
_team_key_generation_check(
|
_team_key_generation_check(
|
||||||
|
team_table=team_table,
|
||||||
user_api_key_dict=UserAPIKeyAuth(
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
api_key="sk-1234",
|
api_key="sk-1234",
|
||||||
|
@ -607,6 +623,7 @@ def test_key_generation_required_params_check(
|
||||||
StandardKeyGenerationConfig,
|
StandardKeyGenerationConfig,
|
||||||
PersonalUIKeyGenerationConfig,
|
PersonalUIKeyGenerationConfig,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
user_api_key_dict = UserAPIKeyAuth(
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
@ -614,7 +631,13 @@ def test_key_generation_required_params_check(
|
||||||
api_key="sk-1234",
|
api_key="sk-1234",
|
||||||
user_id="test_user_id",
|
user_id="test_user_id",
|
||||||
team_id="test_team_id",
|
team_id="test_team_id",
|
||||||
team_member=Member(role="admin", user_id="test_user_id"),
|
team_member=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
team_table = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id="test_team_id",
|
||||||
|
team_alias="test_team_alias",
|
||||||
|
members_with_roles=[Member(role="admin", user_id="test_user_id")],
|
||||||
)
|
)
|
||||||
|
|
||||||
if key_type == "team_key":
|
if key_type == "team_key":
|
||||||
|
@ -632,13 +655,13 @@ def test_key_generation_required_params_check(
|
||||||
|
|
||||||
if expected_result:
|
if expected_result:
|
||||||
if key_type == "team_key":
|
if key_type == "team_key":
|
||||||
assert _team_key_generation_check(user_api_key_dict, input_data)
|
assert _team_key_generation_check(team_table, user_api_key_dict, input_data)
|
||||||
elif key_type == "personal_key":
|
elif key_type == "personal_key":
|
||||||
assert _personal_key_generation_check(user_api_key_dict, input_data)
|
assert _personal_key_generation_check(user_api_key_dict, input_data)
|
||||||
else:
|
else:
|
||||||
if key_type == "team_key":
|
if key_type == "team_key":
|
||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
_team_key_generation_check(user_api_key_dict, input_data)
|
_team_key_generation_check(team_table, user_api_key_dict, input_data)
|
||||||
elif key_type == "personal_key":
|
elif key_type == "personal_key":
|
||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
_personal_key_generation_check(user_api_key_dict, input_data)
|
_personal_key_generation_check(user_api_key_dict, input_data)
|
||||||
|
|
|
@ -1014,7 +1014,11 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
||||||
with patch(
|
with patch(
|
||||||
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
) as mock_litellm_usertable:
|
) as mock_litellm_usertable, patch(
|
||||||
|
"litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
|
||||||
|
new=AsyncMock(return_value=team_obj),
|
||||||
|
) as mock_team_obj:
|
||||||
|
|
||||||
mock_client = AsyncMock(
|
mock_client = AsyncMock(
|
||||||
return_value=LiteLLM_UserTable(
|
return_value=LiteLLM_UserTable(
|
||||||
user_id="1234", max_budget=100, user_email="1234"
|
user_id="1234", max_budget=100, user_email="1234"
|
||||||
|
@ -1193,7 +1197,10 @@ async def test_create_team_member_add_team_admin(
|
||||||
with patch(
|
with patch(
|
||||||
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
) as mock_litellm_usertable:
|
) as mock_litellm_usertable, patch(
|
||||||
|
"litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
|
||||||
|
new=AsyncMock(return_value=team_obj),
|
||||||
|
) as mock_team_obj:
|
||||||
mock_client = AsyncMock(
|
mock_client = AsyncMock(
|
||||||
return_value=LiteLLM_UserTable(
|
return_value=LiteLLM_UserTable(
|
||||||
user_id="1234", max_budget=100, user_email="1234"
|
user_id="1234", max_budget=100, user_email="1234"
|
||||||
|
|
|
@ -413,6 +413,7 @@ const Team: React.FC<TeamProps> = ({
|
||||||
selectedTeam["team_id"],
|
selectedTeam["team_id"],
|
||||||
user_role
|
user_role
|
||||||
);
|
);
|
||||||
|
message.success("Member added");
|
||||||
console.log(`response for team create call: ${response["data"]}`);
|
console.log(`response for team create call: ${response["data"]}`);
|
||||||
// Checking if the team exists in the list and updating or adding accordingly
|
// Checking if the team exists in the list and updating or adding accordingly
|
||||||
const foundIndex = teams.findIndex((team) => {
|
const foundIndex = teams.findIndex((team) => {
|
||||||
|
@ -430,6 +431,7 @@ const Team: React.FC<TeamProps> = ({
|
||||||
setSelectedTeam(response.data);
|
setSelectedTeam(response.data);
|
||||||
}
|
}
|
||||||
setIsAddMemberModalVisible(false);
|
setIsAddMemberModalVisible(false);
|
||||||
|
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error creating the team:", error);
|
console.error("Error creating the team:", error);
|
||||||
|
@ -825,6 +827,9 @@ const Team: React.FC<TeamProps> = ({
|
||||||
labelCol={{ span: 8 }}
|
labelCol={{ span: 8 }}
|
||||||
wrapperCol={{ span: 16 }}
|
wrapperCol={{ span: 16 }}
|
||||||
labelAlign="left"
|
labelAlign="left"
|
||||||
|
initialValues={{
|
||||||
|
role: "user",
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
<>
|
<>
|
||||||
<Form.Item label="Email" name="user_email" className="mb-4">
|
<Form.Item label="Email" name="user_email" className="mb-4">
|
||||||
|
@ -842,8 +847,8 @@ const Team: React.FC<TeamProps> = ({
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item label="Member Role" name="role" className="mb-4">
|
<Form.Item label="Member Role" name="role" className="mb-4">
|
||||||
<Select2 defaultValue="user">
|
<Select2 defaultValue="user">
|
||||||
<Select2.Option value="user">user</Select2.Option>
|
|
||||||
<Select2.Option value="admin">admin</Select2.Option>
|
<Select2.Option value="admin">admin</Select2.Option>
|
||||||
|
<Select2.Option value="user">user</Select2.Option>
|
||||||
</Select2>
|
</Select2>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</>
|
</>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue