mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(main.py): fix retries being multiplied when using openai sdk (#7221)
* fix(main.py): fix retries being multiplied when using openai sdk Closes https://github.com/BerriAI/litellm/pull/7130 * docs(prompt_management.md): add langfuse prompt management doc * feat(team_endpoints.py): allow teams to add their own models Enables teams to call their own finetuned models via the proxy * test: add better enforcement check testing for `/model/new` now that teams can add their own models * docs(team_model_add.md): tutorial for allowing teams to add their own models * test: fix test
This commit is contained in:
parent
8060c5c698
commit
ec36353b41
16 changed files with 2439 additions and 1540 deletions
|
@ -199,6 +199,7 @@ from litellm.proxy.management_endpoints.team_callback_endpoints import (
|
|||
router as team_callback_router,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
||||
from litellm.proxy.management_endpoints.team_endpoints import update_team
|
||||
from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router
|
||||
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model
|
||||
|
@ -6202,6 +6203,94 @@ async def delete_budget(
|
|||
#### MODEL MANAGEMENT ####
|
||||
|
||||
|
||||
async def _add_model_to_db(
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
# encrypt litellm params #
|
||||
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
|
||||
_orignal_litellm_model_name = model_params.litellm_params.model
|
||||
for k, v in _litellm_params_dict.items():
|
||||
encrypted_value = encrypt_value_helper(value=v)
|
||||
model_params.litellm_params[k] = encrypted_value
|
||||
_data: dict = {
|
||||
"model_id": model_params.model_info.id,
|
||||
"model_name": model_params.model_name,
|
||||
"litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore
|
||||
"model_info": model_params.model_info.model_dump_json( # type: ignore
|
||||
exclude_none=True
|
||||
),
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
if model_params.model_info.id is not None:
|
||||
_data["model_id"] = model_params.model_info.id
|
||||
model_response = await prisma_client.db.litellm_proxymodeltable.create(
|
||||
data=_data # type: ignore
|
||||
)
|
||||
return model_response
|
||||
|
||||
|
||||
async def _add_team_model_to_db(
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
"""
|
||||
If 'team_id' is provided,
|
||||
|
||||
- generate a unique 'model_name' for the model (e.g. 'model_name_{team_id}_{uuid})
|
||||
- store the model in the db with the unique 'model_name'
|
||||
- store a team model alias mapping {"model_name": "model_name_{team_id}_{uuid}"}
|
||||
"""
|
||||
_team_id = model_params.model_info.team_id
|
||||
original_model_name = model_params.model_name
|
||||
if _team_id is None:
|
||||
return None
|
||||
|
||||
unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}"
|
||||
|
||||
model_params.model_name = unique_model_name
|
||||
|
||||
## CREATE MODEL IN DB ##
|
||||
model_response = await _add_model_to_db(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
## CREATE MODEL ALIAS IN DB ##
|
||||
await update_team(
|
||||
data=UpdateTeamRequest(
|
||||
team_id=_team_id,
|
||||
model_aliases={original_model_name: unique_model_name},
|
||||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
http_request=Request(scope={"type": "http"}),
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
def check_if_team_id_matches_key(
|
||||
team_id: Optional[str], user_api_key_dict: UserAPIKeyAuth
|
||||
) -> bool:
|
||||
can_make_call = True
|
||||
if (
|
||||
user_api_key_dict.user_role
|
||||
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
):
|
||||
return True
|
||||
if team_id is None:
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
can_make_call = False
|
||||
else:
|
||||
if user_api_key_dict.team_id != team_id:
|
||||
can_make_call = False
|
||||
return can_make_call
|
||||
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||
@router.post(
|
||||
"/model/new",
|
||||
|
@ -6217,8 +6306,6 @@ async def add_new_model(
|
|||
try:
|
||||
import base64
|
||||
|
||||
global prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
|
@ -6227,6 +6314,14 @@ async def add_new_model(
|
|||
},
|
||||
)
|
||||
|
||||
if not check_if_team_id_matches_key(
|
||||
team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": "Team ID does not match the API key's team ID"},
|
||||
)
|
||||
|
||||
model_response = None
|
||||
# update DB
|
||||
if store_model_in_db is True:
|
||||
|
@ -6234,43 +6329,35 @@ async def add_new_model(
|
|||
- store model_list in db
|
||||
- store keys separately
|
||||
"""
|
||||
# encrypt litellm params #
|
||||
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
|
||||
_orignal_litellm_model_name = model_params.litellm_params.model
|
||||
for k, v in _litellm_params_dict.items():
|
||||
encrypted_value = encrypt_value_helper(value=v)
|
||||
model_params.litellm_params[k] = encrypted_value
|
||||
_data: dict = {
|
||||
"model_id": model_params.model_info.id,
|
||||
"model_name": model_params.model_name,
|
||||
"litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore
|
||||
"model_info": model_params.model_info.model_dump_json( # type: ignore
|
||||
exclude_none=True
|
||||
),
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
if model_params.model_info.id is not None:
|
||||
_data["model_id"] = model_params.model_info.id
|
||||
model_response = await prisma_client.db.litellm_proxymodeltable.create(
|
||||
data=_data # type: ignore
|
||||
)
|
||||
|
||||
await proxy_config.add_deployment(
|
||||
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
try:
|
||||
_original_litellm_model_name = model_params.model_name
|
||||
if model_params.model_info.team_id is None:
|
||||
model_response = await _add_model_to_db(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
else:
|
||||
model_response = await _add_team_model_to_db(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
await proxy_config.add_deployment(
|
||||
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# don't let failed slack alert block the /model/new response
|
||||
_alerting = general_settings.get("alerting", []) or []
|
||||
if "slack" in _alerting:
|
||||
# send notification - new model added
|
||||
await proxy_logging_obj.slack_alerting_instance.model_added_alert(
|
||||
model_name=model_params.model_name,
|
||||
litellm_model_name=_orignal_litellm_model_name,
|
||||
litellm_model_name=_original_litellm_model_name,
|
||||
passed_model_info=model_params.model_info,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Exception in add_new_model: {e}")
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue