forked from phoenix/litellm-mirror
Compare commits
1 commit
main
...
litellm_st
Author | SHA1 | Date | |
---|---|---|---|
|
0f08577060 |
19 changed files with 406 additions and 169 deletions
|
@ -33,6 +33,7 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionAssistantToolCall,
|
||||
ChatCompletionFunctionMessage,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionImageUrlObject,
|
||||
ChatCompletionTextObject,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionToolMessage,
|
||||
|
@ -681,6 +682,27 @@ def construct_tool_use_system_prompt(
|
|||
return tool_use_system_prompt
|
||||
|
||||
|
||||
def convert_generic_image_chunk_to_openai_image_obj(
|
||||
image_chunk: GenericImageParsingChunk,
|
||||
) -> str:
|
||||
"""
|
||||
Convert a generic image chunk to an OpenAI image object.
|
||||
|
||||
Input:
|
||||
GenericImageParsingChunk(
|
||||
type="base64",
|
||||
media_type="image/jpeg",
|
||||
data="...",
|
||||
)
|
||||
|
||||
Return:
|
||||
"data:image/jpeg;base64,{base64_image}"
|
||||
"""
|
||||
return "data:{};{},{}".format(
|
||||
image_chunk["media_type"], image_chunk["type"], image_chunk["data"]
|
||||
)
|
||||
|
||||
|
||||
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
|
||||
"""
|
||||
Input:
|
||||
|
@ -706,6 +728,7 @@ def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsing
|
|||
data=base64_data,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if "Error: Unable to fetch image from URL" in str(e):
|
||||
raise e
|
||||
raise Exception(
|
||||
|
|
|
@ -294,7 +294,12 @@ def _transform_request_body(
|
|||
optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
|
||||
|
||||
try:
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
if custom_llm_provider == "gemini":
|
||||
content = litellm.GoogleAIStudioGeminiConfig._transform_messages(
|
||||
messages=messages
|
||||
)
|
||||
else:
|
||||
content = litellm.VertexGeminiConfig._transform_messages(messages=messages)
|
||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
||||
|
|
|
@ -35,7 +35,12 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_generic_image_chunk_to_openai_image_obj,
|
||||
convert_to_anthropic_image_obj,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
|
@ -78,6 +83,8 @@ from ..common_utils import (
|
|||
)
|
||||
from ..vertex_llm_base import VertexBase
|
||||
from .transformation import (
|
||||
_gemini_convert_messages_with_history,
|
||||
_process_gemini_image,
|
||||
async_transform_request_body,
|
||||
set_headers,
|
||||
sync_transform_request_body,
|
||||
|
@ -912,6 +919,10 @@ class VertexGeminiConfig:
|
|||
|
||||
return model_response
|
||||
|
||||
@staticmethod
|
||||
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
|
||||
return _gemini_convert_messages_with_history(messages=messages)
|
||||
|
||||
|
||||
class GoogleAIStudioGeminiConfig(
|
||||
VertexGeminiConfig
|
||||
|
@ -1015,6 +1026,32 @@ class GoogleAIStudioGeminiConfig(
|
|||
model, non_default_params, optional_params, drop_params
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
|
||||
"""
|
||||
Google AI Studio Gemini does not support image urls in messages.
|
||||
"""
|
||||
for message in messages:
|
||||
_message_content = message.get("content")
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
_parts: List[PartType] = []
|
||||
for element in _message_content:
|
||||
if element.get("type") == "image_url":
|
||||
img_element = element
|
||||
_image_url: Optional[str] = None
|
||||
if isinstance(img_element.get("image_url"), dict):
|
||||
_image_url = img_element["image_url"].get("url") # type: ignore
|
||||
else:
|
||||
_image_url = img_element.get("image_url") # type: ignore
|
||||
if _image_url and "https://" in _image_url:
|
||||
image_obj = convert_to_anthropic_image_obj(_image_url)
|
||||
img_element["image_url"] = ( # type: ignore
|
||||
convert_generic_image_chunk_to_openai_image_obj(
|
||||
image_obj
|
||||
)
|
||||
)
|
||||
return _gemini_convert_messages_with_history(messages=messages)
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
|
|
|
@ -12,3 +12,23 @@ model_list:
|
|||
vertex_ai_project: "adroit-crow-413218"
|
||||
vertex_ai_location: "us-east5"
|
||||
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2
|
||||
#redis_url: "os.environ/REDIS_URL"
|
||||
redis_host: "os.environ/REDIS_HOST"
|
||||
redis_port: "os.environ/REDIS_PORT"
|
||||
|
||||
litellm_settings:
|
||||
cache: true
|
||||
cache_params:
|
||||
type: redis
|
||||
host: "os.environ/REDIS_HOST"
|
||||
port: "os.environ/REDIS_PORT"
|
||||
namespace: "litellm.caching"
|
||||
ttl: 600
|
||||
# key_generation_settings:
|
||||
# team_key_generation:
|
||||
# allowed_team_member_roles: ["admin"]
|
||||
# required_params: ["tags"] # require team admins to set tags for cost-tracking when generating a team key
|
||||
# personal_key_generation: # maps to 'Default Team' on UI
|
||||
# allowed_user_roles: ["proxy_admin"]
|
|
@ -1982,7 +1982,6 @@ class MemberAddRequest(LiteLLMBase):
|
|||
# Replace member_data with the single Member object
|
||||
data["member"] = member
|
||||
# Call the superclass __init__ method to initialize the object
|
||||
traceback.print_stack()
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
|
|
|
@ -523,6 +523,10 @@ async def _cache_management_object(
|
|||
proxy_logging_obj: Optional[ProxyLogging],
|
||||
):
|
||||
await user_api_key_cache.async_set_cache(key=key, value=value)
|
||||
if proxy_logging_obj is not None:
|
||||
await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache(
|
||||
key=key, value=value
|
||||
)
|
||||
|
||||
|
||||
async def _cache_team_object(
|
||||
|
@ -586,26 +590,63 @@ async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
|
|||
)
|
||||
|
||||
|
||||
async def get_team_object(
|
||||
team_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
check_cache_only: Optional[bool] = None,
|
||||
) -> LiteLLM_TeamTableCachedObj:
|
||||
"""
|
||||
- Check if team id in proxy Team Table
|
||||
- if valid, return LiteLLM_TeamTable object with defined limits
|
||||
- if not, then raise an error
|
||||
"""
|
||||
if prisma_client is None:
|
||||
raise Exception(
|
||||
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
)
|
||||
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
|
||||
return await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
|
||||
# check if in cache
|
||||
key = "team_id:{}".format(team_id)
|
||||
|
||||
async def _get_team_object_from_user_api_key_cache(
|
||||
team_id: str,
|
||||
prisma_client: PrismaClient,
|
||||
user_api_key_cache: DualCache,
|
||||
last_db_access_time: LimitedSizeOrderedDict,
|
||||
db_cache_expiry: int,
|
||||
proxy_logging_obj: Optional[ProxyLogging],
|
||||
key: str,
|
||||
) -> LiteLLM_TeamTableCachedObj:
|
||||
db_access_time_key = key
|
||||
should_check_db = _should_check_db(
|
||||
key=db_access_time_key,
|
||||
last_db_access_time=last_db_access_time,
|
||||
db_cache_expiry=db_cache_expiry,
|
||||
)
|
||||
if should_check_db:
|
||||
response = await _get_team_db_check(
|
||||
team_id=team_id, prisma_client=prisma_client
|
||||
)
|
||||
else:
|
||||
response = None
|
||||
|
||||
if response is None:
|
||||
raise Exception
|
||||
|
||||
_response = LiteLLM_TeamTableCachedObj(**response.dict())
|
||||
# save the team object to cache
|
||||
await _cache_team_object(
|
||||
team_id=team_id,
|
||||
team_table=_response,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# save to db access time
|
||||
# save to db access time
|
||||
_update_last_db_access_time(
|
||||
key=db_access_time_key,
|
||||
value=_response,
|
||||
last_db_access_time=last_db_access_time,
|
||||
)
|
||||
|
||||
return _response
|
||||
|
||||
|
||||
async def _get_team_object_from_cache(
|
||||
key: str,
|
||||
proxy_logging_obj: Optional[ProxyLogging],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
) -> Optional[LiteLLM_TeamTableCachedObj]:
|
||||
cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None
|
||||
|
||||
## CHECK REDIS CACHE ##
|
||||
|
@ -613,6 +654,7 @@ async def get_team_object(
|
|||
proxy_logging_obj is not None
|
||||
and proxy_logging_obj.internal_usage_cache.dual_cache
|
||||
):
|
||||
|
||||
cached_team_obj = (
|
||||
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
|
||||
key=key, parent_otel_span=parent_otel_span
|
||||
|
@ -628,47 +670,58 @@ async def get_team_object(
|
|||
elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj):
|
||||
return cached_team_obj
|
||||
|
||||
if check_cache_only:
|
||||
return None
|
||||
|
||||
|
||||
async def get_team_object(
|
||||
team_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
check_cache_only: Optional[bool] = None,
|
||||
check_db_only: Optional[bool] = None,
|
||||
) -> LiteLLM_TeamTableCachedObj:
|
||||
"""
|
||||
- Check if team id in proxy Team Table
|
||||
- if valid, return LiteLLM_TeamTable object with defined limits
|
||||
- if not, then raise an error
|
||||
"""
|
||||
if prisma_client is None:
|
||||
raise Exception(
|
||||
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
|
||||
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
)
|
||||
|
||||
# check if in cache
|
||||
key = "team_id:{}".format(team_id)
|
||||
|
||||
if not check_db_only:
|
||||
cached_team_obj = await _get_team_object_from_cache(
|
||||
key=key,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
if cached_team_obj is not None:
|
||||
return cached_team_obj
|
||||
|
||||
if check_cache_only:
|
||||
raise Exception(
|
||||
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
|
||||
)
|
||||
|
||||
# else, check db
|
||||
try:
|
||||
db_access_time_key = "team_id:{}".format(team_id)
|
||||
should_check_db = _should_check_db(
|
||||
key=db_access_time_key,
|
||||
last_db_access_time=last_db_access_time,
|
||||
db_cache_expiry=db_cache_expiry,
|
||||
)
|
||||
if should_check_db:
|
||||
response = await _get_team_db_check(
|
||||
team_id=team_id, prisma_client=prisma_client
|
||||
)
|
||||
else:
|
||||
response = None
|
||||
|
||||
if response is None:
|
||||
raise Exception
|
||||
|
||||
_response = LiteLLM_TeamTableCachedObj(**response.dict())
|
||||
# save the team object to cache
|
||||
await _cache_team_object(
|
||||
return await _get_team_object_from_user_api_key_cache(
|
||||
team_id=team_id,
|
||||
team_table=_response,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# save to db access time
|
||||
# save to db access time
|
||||
_update_last_db_access_time(
|
||||
key=db_access_time_key,
|
||||
value=_response,
|
||||
last_db_access_time=last_db_access_time,
|
||||
db_cache_expiry=db_cache_expiry,
|
||||
key=key,
|
||||
)
|
||||
|
||||
return _response
|
||||
except Exception:
|
||||
raise Exception(
|
||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||
|
|
|
@ -29,6 +29,7 @@ from litellm.proxy.auth.auth_checks import (
|
|||
_cache_key_object,
|
||||
_delete_cache_key_object,
|
||||
get_key_object,
|
||||
get_team_object,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||
|
@ -46,7 +47,19 @@ def _is_team_key(data: GenerateKeyRequest):
|
|||
return data.team_id is not None
|
||||
|
||||
|
||||
def _get_user_in_team(
|
||||
team_table: LiteLLM_TeamTableCachedObj, user_id: Optional[str]
|
||||
) -> Optional[Member]:
|
||||
if user_id is None:
|
||||
return None
|
||||
for member in team_table.members_with_roles:
|
||||
if member.user_id is not None and member.user_id == user_id:
|
||||
return member
|
||||
return None
|
||||
|
||||
|
||||
def _team_key_generation_team_member_check(
|
||||
team_table: LiteLLM_TeamTableCachedObj,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
team_key_generation: Optional[TeamUIKeyGenerationConfig],
|
||||
):
|
||||
|
@ -56,17 +69,19 @@ def _team_key_generation_team_member_check(
|
|||
):
|
||||
return True
|
||||
|
||||
if user_api_key_dict.team_member is None:
|
||||
user_in_team = _get_user_in_team(
|
||||
team_table=team_table, user_id=user_api_key_dict.user_id
|
||||
)
|
||||
if user_in_team is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}",
|
||||
detail=f"User={user_api_key_dict.user_id} not assigned to team={team_table.team_id}",
|
||||
)
|
||||
|
||||
team_member_role = user_api_key_dict.team_member.role
|
||||
if team_member_role not in team_key_generation["allowed_team_member_roles"]:
|
||||
if user_in_team.role not in team_key_generation["allowed_team_member_roles"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore
|
||||
detail=f"Team member role {user_in_team.role} not in allowed_team_member_roles={team_key_generation['allowed_team_member_roles']}",
|
||||
)
|
||||
return True
|
||||
|
||||
|
@ -88,7 +103,9 @@ def _key_generation_required_param_check(
|
|||
|
||||
|
||||
def _team_key_generation_check(
|
||||
user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest
|
||||
team_table: LiteLLM_TeamTableCachedObj,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: GenerateKeyRequest,
|
||||
):
|
||||
if (
|
||||
litellm.key_generation_settings is None
|
||||
|
@ -99,7 +116,8 @@ def _team_key_generation_check(
|
|||
_team_key_generation = litellm.key_generation_settings["team_key_generation"] # type: ignore
|
||||
|
||||
_team_key_generation_team_member_check(
|
||||
user_api_key_dict,
|
||||
team_table=team_table,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
team_key_generation=_team_key_generation,
|
||||
)
|
||||
_key_generation_required_param_check(
|
||||
|
@ -155,7 +173,9 @@ def _personal_key_generation_check(
|
|||
|
||||
|
||||
def key_generation_check(
|
||||
user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest
|
||||
team_table: Optional[LiteLLM_TeamTableCachedObj],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: GenerateKeyRequest,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if admin has restricted key creation to certain roles for teams or individuals
|
||||
|
@ -170,8 +190,15 @@ def key_generation_check(
|
|||
is_team_key = _is_team_key(data=data)
|
||||
|
||||
if is_team_key:
|
||||
if team_table is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unable to find team object in database. Team ID: {data.team_id}",
|
||||
)
|
||||
return _team_key_generation_check(
|
||||
user_api_key_dict=user_api_key_dict, data=data
|
||||
team_table=team_table,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
)
|
||||
else:
|
||||
return _personal_key_generation_check(
|
||||
|
@ -254,6 +281,7 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
litellm_proxy_admin_name,
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
user_custom_key_generate,
|
||||
)
|
||||
|
||||
|
@ -271,7 +299,20 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
status_code=status.HTTP_403_FORBIDDEN, detail=message
|
||||
)
|
||||
elif litellm.key_generation_settings is not None:
|
||||
key_generation_check(user_api_key_dict=user_api_key_dict, data=data)
|
||||
if data.team_id is None:
|
||||
team_table: Optional[LiteLLM_TeamTableCachedObj] = None
|
||||
else:
|
||||
team_table = await get_team_object(
|
||||
team_id=data.team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
key_generation_check(
|
||||
team_table=team_table,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
)
|
||||
# check if user set default key/generate params on config.yaml
|
||||
if litellm.default_key_generate_params is not None:
|
||||
for elem in data:
|
||||
|
|
|
@ -547,6 +547,7 @@ async def team_member_add(
|
|||
parent_otel_span=None,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
check_cache_only=False,
|
||||
check_db_only=True,
|
||||
)
|
||||
if existing_team_row is None:
|
||||
raise HTTPException(
|
||||
|
|
|
@ -54,12 +54,19 @@ def create_request_copy(request: Request):
|
|||
}
|
||||
|
||||
|
||||
@router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
@router.api_route(
|
||||
"/gemini/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE"],
|
||||
tags=["Google AI Studio Pass-through", "pass-through"],
|
||||
)
|
||||
async def gemini_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/google_ai_studio)
|
||||
"""
|
||||
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
|
||||
api_key = request.query_params.get("key")
|
||||
|
||||
|
@ -111,13 +118,20 @@ async def gemini_proxy_route(
|
|||
return received_value
|
||||
|
||||
|
||||
@router.api_route("/cohere/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
@router.api_route(
|
||||
"/cohere/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE"],
|
||||
tags=["Cohere Pass-through", "pass-through"],
|
||||
)
|
||||
async def cohere_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/cohere)
|
||||
"""
|
||||
base_target_url = "https://api.cohere.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
|
@ -154,7 +168,9 @@ async def cohere_proxy_route(
|
|||
|
||||
|
||||
@router.api_route(
|
||||
"/anthropic/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
|
||||
"/anthropic/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE"],
|
||||
tags=["Anthropic Pass-through", "pass-through"],
|
||||
)
|
||||
async def anthropic_proxy_route(
|
||||
endpoint: str,
|
||||
|
@ -162,6 +178,9 @@ async def anthropic_proxy_route(
|
|||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
|
||||
"""
|
||||
base_target_url = "https://api.anthropic.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
|
@ -201,13 +220,20 @@ async def anthropic_proxy_route(
|
|||
return received_value
|
||||
|
||||
|
||||
@router.api_route("/bedrock/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
@router.api_route(
|
||||
"/bedrock/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE"],
|
||||
tags=["Bedrock Pass-through", "pass-through"],
|
||||
)
|
||||
async def bedrock_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/bedrock)
|
||||
"""
|
||||
create_request_copy(request)
|
||||
|
||||
try:
|
||||
|
@ -275,13 +301,22 @@ async def bedrock_proxy_route(
|
|||
return received_value
|
||||
|
||||
|
||||
@router.api_route("/azure/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
@router.api_route(
|
||||
"/azure/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE"],
|
||||
tags=["Azure Pass-through", "pass-through"],
|
||||
)
|
||||
async def azure_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Call any azure endpoint using the proxy.
|
||||
|
||||
Just use `{PROXY_BASE_URL}/azure/{endpoint:path}`
|
||||
"""
|
||||
base_target_url = get_secret_str(secret_name="AZURE_API_BASE")
|
||||
if base_target_url is None:
|
||||
raise Exception(
|
||||
|
|
|
@ -5663,11 +5663,11 @@ async def anthropic_response( # noqa: PLR0915
|
|||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
This is a BETA endpoint that calls 100+ LLMs in the anthropic format.
|
||||
🚨 DEPRECATED ENDPOINT🚨
|
||||
|
||||
To do a simple pass-through for anthropic, do `{PROXY_BASE_URL}/anthropic/v1/messages`
|
||||
Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/anthropic_completion).
|
||||
|
||||
Docs - https://docs.litellm.ai/docs/anthropic_completion
|
||||
This was a BETA endpoint that calls 100+ LLMs in the anthropic format.
|
||||
"""
|
||||
from litellm import adapter_completion
|
||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
||||
|
|
|
@ -58,12 +58,21 @@ def create_request_copy(request: Request):
|
|||
}
|
||||
|
||||
|
||||
@router.api_route("/langfuse/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
@router.api_route(
|
||||
"/langfuse/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE"],
|
||||
tags=["Langfuse Pass-through", "pass-through"],
|
||||
)
|
||||
async def langfuse_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
):
|
||||
"""
|
||||
Call Langfuse via LiteLLM proxy. Works with Langfuse SDK.
|
||||
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/langfuse)
|
||||
"""
|
||||
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
|
||||
api_key = request.headers.get("Authorization") or ""
|
||||
|
||||
|
|
|
@ -113,13 +113,26 @@ def construct_target_url(
|
|||
|
||||
|
||||
@router.api_route(
|
||||
"/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
|
||||
"/vertex-ai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE"],
|
||||
tags=["Vertex AI Pass-through", "pass-through"],
|
||||
include_in_schema=False,
|
||||
)
|
||||
@router.api_route(
|
||||
"/vertex_ai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE"],
|
||||
tags=["Vertex AI Pass-through", "pass-through"],
|
||||
)
|
||||
async def vertex_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
):
|
||||
"""
|
||||
Call LiteLLM proxy via Vertex AI SDK.
|
||||
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
|
||||
"""
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
import re
|
||||
|
|
|
@ -190,6 +190,35 @@ class BaseLLMChatTest(ABC):
|
|||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
||||
def test_image_url(self):
|
||||
litellm.set_verbose = True
|
||||
from litellm.utils import supports_vision
|
||||
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
if not supports_vision(base_completion_call_args["model"], None):
|
||||
pytest.skip("Model does not support image input")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
response = litellm.completion(**base_completion_call_args, messages=messages)
|
||||
assert response is not None
|
||||
|
||||
@pytest.fixture
|
||||
def pdf_messages(self):
|
||||
import base64
|
||||
|
|
15
tests/llm_translation/test_gemini.py
Normal file
15
tests/llm_translation/test_gemini.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
|
||||
class TestGoogleAIStudioGemini(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
return {"model": "gemini/gemini-1.5-flash"}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_to_gemini_tool_call_invoke,
|
||||
)
|
||||
|
||||
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
|
||||
print(result)
|
|
@ -687,3 +687,16 @@ def test_just_system_message():
|
|||
llm_provider="bedrock",
|
||||
)
|
||||
assert "bedrock requires at least one non-system message" in str(e.value)
|
||||
|
||||
|
||||
def test_convert_generic_image_chunk_to_openai_image_obj():
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_generic_image_chunk_to_openai_image_obj,
|
||||
convert_to_anthropic_image_obj,
|
||||
)
|
||||
|
||||
url = "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
|
||||
image_obj = convert_to_anthropic_image_obj(url)
|
||||
url_str = convert_generic_image_chunk_to_openai_image_obj(image_obj)
|
||||
image_obj = convert_to_anthropic_image_obj(url_str)
|
||||
print(image_obj)
|
||||
|
|
|
@ -1190,80 +1190,6 @@ def test_get_image_mime_type_from_url():
|
|||
assert _get_image_mime_type_from_url("invalid_url") is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"image_url", ["https://example.com/image.jpg", "https://example.com/image.png"]
|
||||
)
|
||||
def test_image_completion_request(image_url):
|
||||
"""https:// .jpg, .png images are passed directly to the model"""
|
||||
from unittest.mock import patch, Mock
|
||||
import litellm
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
|
||||
_get_image_mime_type_from_url,
|
||||
)
|
||||
|
||||
# Mock response data
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"candidates": [{"content": {"parts": [{"text": "This is a sunflower"}]}}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 11,
|
||||
"candidatesTokenCount": 50,
|
||||
"totalTokenCount": 61,
|
||||
},
|
||||
"modelVersion": "gemini-1.5-pro",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
# Expected request body
|
||||
expected_request_body = {
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"text": "Whats in this image?"},
|
||||
{
|
||||
"file_data": {
|
||||
"file_uri": image_url,
|
||||
"mime_type": _get_image_mime_type_from_url(image_url),
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"system_instruction": {"parts": [{"text": "Be a good bot"}]},
|
||||
"generationConfig": {},
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "Be a good bot"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Whats in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
client = HTTPHandler()
|
||||
with patch.object(client, "post", new=MagicMock()) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
try:
|
||||
litellm.completion(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
messages=messages,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
# Assert the request body matches expected
|
||||
mock_post.assert_called_once()
|
||||
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
|
||||
assert mock_post.call_args.kwargs["json"] == expected_request_body
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expected_url",
|
||||
[
|
||||
|
@ -1298,20 +1224,3 @@ def test_vertex_embedding_url(model, expected_url):
|
|||
|
||||
assert url == expected_url
|
||||
assert endpoint == "predict"
|
||||
|
||||
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
|
||||
class TestVertexGemini(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
return {"model": "gemini/gemini-1.5-flash"}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_to_gemini_tool_call_invoke,
|
||||
)
|
||||
|
||||
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
|
||||
print(result)
|
||||
|
|
|
@ -556,13 +556,22 @@ def test_team_key_generation_team_member_check():
|
|||
_team_key_generation_check,
|
||||
)
|
||||
from fastapi import HTTPException
|
||||
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
||||
|
||||
litellm.key_generation_settings = {
|
||||
"team_key_generation": {"allowed_team_member_roles": ["admin"]}
|
||||
}
|
||||
|
||||
team_table = LiteLLM_TeamTableCachedObj(
|
||||
team_id="test_team_id",
|
||||
team_alias="test_team_alias",
|
||||
members_with_roles=[Member(role="admin", user_id="test_user_id")],
|
||||
)
|
||||
|
||||
assert _team_key_generation_check(
|
||||
team_table=team_table,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_id="test_user_id",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
api_key="sk-1234",
|
||||
team_member=Member(role="admin", user_id="test_user_id"),
|
||||
|
@ -570,8 +579,15 @@ def test_team_key_generation_team_member_check():
|
|||
data=GenerateKeyRequest(),
|
||||
)
|
||||
|
||||
team_table = LiteLLM_TeamTableCachedObj(
|
||||
team_id="test_team_id",
|
||||
team_alias="test_team_alias",
|
||||
members_with_roles=[Member(role="user", user_id="test_user_id")],
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
_team_key_generation_check(
|
||||
team_table=team_table,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
api_key="sk-1234",
|
||||
|
@ -607,6 +623,7 @@ def test_key_generation_required_params_check(
|
|||
StandardKeyGenerationConfig,
|
||||
PersonalUIKeyGenerationConfig,
|
||||
)
|
||||
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
||||
from fastapi import HTTPException
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
|
@ -614,7 +631,13 @@ def test_key_generation_required_params_check(
|
|||
api_key="sk-1234",
|
||||
user_id="test_user_id",
|
||||
team_id="test_team_id",
|
||||
team_member=Member(role="admin", user_id="test_user_id"),
|
||||
team_member=None,
|
||||
)
|
||||
|
||||
team_table = LiteLLM_TeamTableCachedObj(
|
||||
team_id="test_team_id",
|
||||
team_alias="test_team_alias",
|
||||
members_with_roles=[Member(role="admin", user_id="test_user_id")],
|
||||
)
|
||||
|
||||
if key_type == "team_key":
|
||||
|
@ -632,13 +655,13 @@ def test_key_generation_required_params_check(
|
|||
|
||||
if expected_result:
|
||||
if key_type == "team_key":
|
||||
assert _team_key_generation_check(user_api_key_dict, input_data)
|
||||
assert _team_key_generation_check(team_table, user_api_key_dict, input_data)
|
||||
elif key_type == "personal_key":
|
||||
assert _personal_key_generation_check(user_api_key_dict, input_data)
|
||||
else:
|
||||
if key_type == "team_key":
|
||||
with pytest.raises(HTTPException):
|
||||
_team_key_generation_check(user_api_key_dict, input_data)
|
||||
_team_key_generation_check(team_table, user_api_key_dict, input_data)
|
||||
elif key_type == "personal_key":
|
||||
with pytest.raises(HTTPException):
|
||||
_personal_key_generation_check(user_api_key_dict, input_data)
|
||||
|
|
|
@ -1014,7 +1014,11 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
|||
with patch(
|
||||
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_litellm_usertable:
|
||||
) as mock_litellm_usertable, patch(
|
||||
"litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
|
||||
new=AsyncMock(return_value=team_obj),
|
||||
) as mock_team_obj:
|
||||
|
||||
mock_client = AsyncMock(
|
||||
return_value=LiteLLM_UserTable(
|
||||
user_id="1234", max_budget=100, user_email="1234"
|
||||
|
@ -1193,7 +1197,10 @@ async def test_create_team_member_add_team_admin(
|
|||
with patch(
|
||||
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_litellm_usertable:
|
||||
) as mock_litellm_usertable, patch(
|
||||
"litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
|
||||
new=AsyncMock(return_value=team_obj),
|
||||
) as mock_team_obj:
|
||||
mock_client = AsyncMock(
|
||||
return_value=LiteLLM_UserTable(
|
||||
user_id="1234", max_budget=100, user_email="1234"
|
||||
|
|
|
@ -413,6 +413,7 @@ const Team: React.FC<TeamProps> = ({
|
|||
selectedTeam["team_id"],
|
||||
user_role
|
||||
);
|
||||
message.success("Member added");
|
||||
console.log(`response for team create call: ${response["data"]}`);
|
||||
// Checking if the team exists in the list and updating or adding accordingly
|
||||
const foundIndex = teams.findIndex((team) => {
|
||||
|
@ -430,6 +431,7 @@ const Team: React.FC<TeamProps> = ({
|
|||
setSelectedTeam(response.data);
|
||||
}
|
||||
setIsAddMemberModalVisible(false);
|
||||
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error creating the team:", error);
|
||||
|
@ -825,6 +827,9 @@ const Team: React.FC<TeamProps> = ({
|
|||
labelCol={{ span: 8 }}
|
||||
wrapperCol={{ span: 16 }}
|
||||
labelAlign="left"
|
||||
initialValues={{
|
||||
role: "user",
|
||||
}}
|
||||
>
|
||||
<>
|
||||
<Form.Item label="Email" name="user_email" className="mb-4">
|
||||
|
@ -842,8 +847,8 @@ const Team: React.FC<TeamProps> = ({
|
|||
</Form.Item>
|
||||
<Form.Item label="Member Role" name="role" className="mb-4">
|
||||
<Select2 defaultValue="user">
|
||||
<Select2.Option value="user">user</Select2.Option>
|
||||
<Select2.Option value="admin">admin</Select2.Option>
|
||||
<Select2.Option value="user">user</Select2.Option>
|
||||
</Select2>
|
||||
</Form.Item>
|
||||
</>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue