fix(main.py): fix retries being multiplied when using openai sdk (#7221)

* fix(main.py): fix retries being multiplied when using openai sdk

Closes https://github.com/BerriAI/litellm/pull/7130

* docs(prompt_management.md): add langfuse prompt management doc

* feat(team_endpoints.py): allow teams to add their own models

Enables teams to call their own finetuned models via the proxy

* test: add better enforcement check testing for `/model/new` now that teams can add their own models

* docs(team_model_add.md): tutorial for allowing teams to add their own models

* test: fix test
This commit is contained in:
Krish Dholakia 2024-12-14 11:56:55 -08:00 committed by GitHub
parent 8060c5c698
commit ec36353b41
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 2439 additions and 1540 deletions

View file

@ -199,6 +199,7 @@ from litellm.proxy.management_endpoints.team_callback_endpoints import (
router as team_callback_router,
)
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
from litellm.proxy.management_endpoints.team_endpoints import update_team
from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model
@ -6202,6 +6203,94 @@ async def delete_budget(
#### MODEL MANAGEMENT ####
async def _add_model_to_db(
model_params: Deployment,
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
):
# encrypt litellm params #
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
_orignal_litellm_model_name = model_params.litellm_params.model
for k, v in _litellm_params_dict.items():
encrypted_value = encrypt_value_helper(value=v)
model_params.litellm_params[k] = encrypted_value
_data: dict = {
"model_id": model_params.model_info.id,
"model_name": model_params.model_name,
"litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore
"model_info": model_params.model_info.model_dump_json( # type: ignore
exclude_none=True
),
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
if model_params.model_info.id is not None:
_data["model_id"] = model_params.model_info.id
model_response = await prisma_client.db.litellm_proxymodeltable.create(
data=_data # type: ignore
)
return model_response
async def _add_team_model_to_db(
model_params: Deployment,
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
):
"""
If 'team_id' is provided,
- generate a unique 'model_name' for the model (e.g. 'model_name_{team_id}_{uuid})
- store the model in the db with the unique 'model_name'
- store a team model alias mapping {"model_name": "model_name_{team_id}_{uuid}"}
"""
_team_id = model_params.model_info.team_id
original_model_name = model_params.model_name
if _team_id is None:
return None
unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}"
model_params.model_name = unique_model_name
## CREATE MODEL IN DB ##
model_response = await _add_model_to_db(
model_params=model_params,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
)
## CREATE MODEL ALIAS IN DB ##
await update_team(
data=UpdateTeamRequest(
team_id=_team_id,
model_aliases={original_model_name: unique_model_name},
),
user_api_key_dict=user_api_key_dict,
http_request=Request(scope={"type": "http"}),
)
return model_response
def check_if_team_id_matches_key(
team_id: Optional[str], user_api_key_dict: UserAPIKeyAuth
) -> bool:
can_make_call = True
if (
user_api_key_dict.user_role
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
):
return True
if team_id is None:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
can_make_call = False
else:
if user_api_key_dict.team_id != team_id:
can_make_call = False
return can_make_call
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
@router.post(
"/model/new",
@ -6217,8 +6306,6 @@ async def add_new_model(
try:
import base64
global prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
@ -6227,6 +6314,14 @@ async def add_new_model(
},
)
if not check_if_team_id_matches_key(
team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict
):
raise HTTPException(
status_code=403,
detail={"error": "Team ID does not match the API key's team ID"},
)
model_response = None
# update DB
if store_model_in_db is True:
@ -6234,43 +6329,35 @@ async def add_new_model(
- store model_list in db
- store keys separately
"""
# encrypt litellm params #
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
_orignal_litellm_model_name = model_params.litellm_params.model
for k, v in _litellm_params_dict.items():
encrypted_value = encrypt_value_helper(value=v)
model_params.litellm_params[k] = encrypted_value
_data: dict = {
"model_id": model_params.model_info.id,
"model_name": model_params.model_name,
"litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore
"model_info": model_params.model_info.model_dump_json( # type: ignore
exclude_none=True
),
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
if model_params.model_info.id is not None:
_data["model_id"] = model_params.model_info.id
model_response = await prisma_client.db.litellm_proxymodeltable.create(
data=_data # type: ignore
)
await proxy_config.add_deployment(
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
)
try:
_original_litellm_model_name = model_params.model_name
if model_params.model_info.team_id is None:
model_response = await _add_model_to_db(
model_params=model_params,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
)
else:
model_response = await _add_team_model_to_db(
model_params=model_params,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
)
await proxy_config.add_deployment(
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
)
# don't let failed slack alert block the /model/new response
_alerting = general_settings.get("alerting", []) or []
if "slack" in _alerting:
# send notification - new model added
await proxy_logging_obj.slack_alerting_instance.model_added_alert(
model_name=model_params.model_name,
litellm_model_name=_orignal_litellm_model_name,
litellm_model_name=_original_litellm_model_name,
passed_model_info=model_params.model_info,
)
except Exception:
pass
except Exception as e:
verbose_proxy_logger.exception(f"Exception in add_new_model: {e}")
else:
raise HTTPException(