fix(main.py): fix retries being multiplied when using openai sdk (#7221)

* fix(main.py): fix retries being multiplied when using openai sdk

Closes https://github.com/BerriAI/litellm/pull/7130

* docs(prompt_management.md): add langfuse prompt management doc

* feat(team_endpoints.py): allow teams to add their own models

Enables teams to call their own finetuned models via the proxy

* test: add better enforcement check testing for `/model/new` now that teams can add their own models

* docs(team_model_add.md): tutorial for allowing teams to add their own models

* test: fix test
This commit is contained in:
Krish Dholakia 2024-12-14 11:56:55 -08:00 committed by GitHub
parent 8060c5c698
commit ec36353b41
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 2439 additions and 1540 deletions

View file

@ -0,0 +1,83 @@
import Image from '@theme/IdealImage';
# Prompt Management
LiteLLM supports using [Langfuse](https://langfuse.com/docs/prompts/get-started) for prompt management on the proxy.
## Quick Start
1. Add Langfuse as a 'callback' in your config.yaml
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
litellm_settings:
callbacks: ["langfuse"] # 👈 KEY CHANGE
```
2. Start the proxy
```bash
litellm-proxy --config config.yaml
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "THIS WILL BE IGNORED"
}
],
"metadata": {
"langfuse_prompt_id": "value",
"langfuse_prompt_variables": { # [OPTIONAL]
"key": "value"
}
}
}'
```
## What is 'langfuse_prompt_id'?
- `langfuse_prompt_id`: The ID of the prompt that will be used for the request.
<Image img={require('../../img/langfuse_prompt_id.png')} />
## What will the formatted prompt look like?
### `/chat/completions` messages
The message will be added to the start of the prompt.
- if the Langfuse prompt is a list, it will be added to the start of the messages list (assuming it's an OpenAI compatible message).
- if the Langfuse prompt is a string, it will be added as a system message.
```python
if isinstance(compiled_prompt, list):
data["messages"] = compiled_prompt + data["messages"]
else:
data["messages"] = [
{"role": "system", "content": compiled_prompt}
] + data["messages"]
```
### `/completions` messages
The message will be added to the start of the prompt.
```python
data["prompt"] = compiled_prompt + "\n" + data["prompt"]
```

View file

@ -0,0 +1,77 @@
# Allow Teams to Add Models
Allow team to add a their own models/key for that project - so any OpenAI call they make uses their OpenAI key.
Useful for teams that want to call their own finetuned models.
## Specify Team ID in `/model/add` endpoint
```bash
curl -L -X POST 'http://0.0.0.0:4000/model/new' \
-H 'Authorization: Bearer sk-******2ql3-sm28WU0tTAmA' \ # 👈 Team API Key (has same 'team_id' as below)
-H 'Content-Type: application/json' \
-d '{
"model_name": "my-team-model", # 👈 Call LiteLLM with this model name
"litellm_params": {
"model": "openai/gpt-4o",
"custom_llm_provider": "openai",
"api_key": "******ccb07",
"api_base": "https://my-endpoint-sweden-berri992.openai.azure.com",
"api_version": "2023-12-01-preview"
},
"model_info": {
"team_id": "e59e2671-a064-436a-a0fa-16ae96e5a0a1" # 👈 Specify the team ID it belongs to
}
}'
```
## Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-******2ql3-sm28WU0tTAmA' \ # 👈 Team API Key
-d '{
"model": "my-team-model", # 👈 team model name
"messages": [
{
"role": "user",
"content": "What's the weather like in Boston today?"
}
]
}'
```
## Debugging
### 'model_name' not found
Check if model alias exists in team table.
```bash
curl -L -X GET 'http://localhost:4000/team/info?team_id=e59e2671-a064-436a-a0fa-16ae96e5a0a1' \
-H 'Authorization: Bearer sk-******2ql3-sm28WU0tTAmA' \
```
**Expected Response:**
```json
{
{
"team_id": "e59e2671-a064-436a-a0fa-16ae96e5a0a1",
"team_info": {
...,
"litellm_model_table": {
"model_aliases": {
"my-team-model": # 👈 public model name "model_name_e59e2671-a064-436a-a0fa-16ae96e5a0a1_e81c9286-2195-4bd9-81e1-cf393788a1a0" 👈 internally generated model name (used to ensure uniqueness)
},
"created_by": "default_user_id",
"updated_by": "default_user_id"
}
},
}
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 93 KiB

File diff suppressed because it is too large Load diff

View file

@ -89,6 +89,13 @@ const sidebars = {
"proxy/custom_sso"
],
},
{
type: "category",
label: "Team Management",
items: [
"proxy/team_model_add"
],
},
{
type: "category",
label: "Spend Tracking + Budgets",
@ -127,6 +134,7 @@ const sidebars = {
"oidc"
]
},
"proxy/prompt_management",
"proxy/caching",
"proxy/call_hooks",
"proxy/rules",

View file

@ -2969,6 +2969,9 @@ def completion_with_retries(*args, **kwargs):
)
num_retries = kwargs.pop("num_retries", 3)
# reset retries in .completion()
kwargs["max_retries"] = 0
kwargs["num_retries"] = 0
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop(
"retry_strategy", "constant_retry"
) # type: ignore
@ -2999,6 +3002,8 @@ async def acompletion_with_retries(*args, **kwargs):
)
num_retries = kwargs.pop("num_retries", 3)
kwargs["max_retries"] = 0
kwargs["num_retries"] = 0
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "exponential_backoff_retry":

View file

@ -366,6 +366,7 @@ class LiteLLMRoutes(enum.Enum):
self_managed_routes = [
"/team/member_add",
"/team/member_delete",
"/model/new",
] # routes that manage their own allowed/disallowed logic
## Org Admin Routes ##

View file

@ -363,6 +363,11 @@ async def _update_model_table(
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
if model_id is None:
model_dict = await prisma_client.db.litellm_modeltable.create(
data={**litellm_modeltable.json(exclude_none=True)} # type: ignore
)
else:
model_dict = await prisma_client.db.litellm_modeltable.upsert(
where={"id": model_id},
data={

View file

@ -199,6 +199,7 @@ from litellm.proxy.management_endpoints.team_callback_endpoints import (
router as team_callback_router,
)
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
from litellm.proxy.management_endpoints.team_endpoints import update_team
from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model
@ -6202,38 +6203,11 @@ async def delete_budget(
#### MODEL MANAGEMENT ####
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
@router.post(
"/model/new",
description="Allows adding new models to the model list in the config.yaml",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
)
async def add_new_model(
async def _add_model_to_db(
model_params: Deployment,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
):
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj
try:
import base64
global prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
},
)
model_response = None
# update DB
if store_model_in_db is True:
"""
- store model_list in db
- store keys separately
"""
# encrypt litellm params #
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
_orignal_litellm_model_name = model_params.litellm_params.model
@ -6255,22 +6229,135 @@ async def add_new_model(
model_response = await prisma_client.db.litellm_proxymodeltable.create(
data=_data # type: ignore
)
return model_response
async def _add_team_model_to_db(
model_params: Deployment,
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
):
"""
If 'team_id' is provided,
- generate a unique 'model_name' for the model (e.g. 'model_name_{team_id}_{uuid})
- store the model in the db with the unique 'model_name'
- store a team model alias mapping {"model_name": "model_name_{team_id}_{uuid}"}
"""
_team_id = model_params.model_info.team_id
original_model_name = model_params.model_name
if _team_id is None:
return None
unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}"
model_params.model_name = unique_model_name
## CREATE MODEL IN DB ##
model_response = await _add_model_to_db(
model_params=model_params,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
)
## CREATE MODEL ALIAS IN DB ##
await update_team(
data=UpdateTeamRequest(
team_id=_team_id,
model_aliases={original_model_name: unique_model_name},
),
user_api_key_dict=user_api_key_dict,
http_request=Request(scope={"type": "http"}),
)
return model_response
def check_if_team_id_matches_key(
team_id: Optional[str], user_api_key_dict: UserAPIKeyAuth
) -> bool:
can_make_call = True
if (
user_api_key_dict.user_role
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
):
return True
if team_id is None:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
can_make_call = False
else:
if user_api_key_dict.team_id != team_id:
can_make_call = False
return can_make_call
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
@router.post(
"/model/new",
description="Allows adding new models to the model list in the config.yaml",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
)
async def add_new_model(
model_params: Deployment,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj
try:
import base64
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
},
)
if not check_if_team_id_matches_key(
team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict
):
raise HTTPException(
status_code=403,
detail={"error": "Team ID does not match the API key's team ID"},
)
model_response = None
# update DB
if store_model_in_db is True:
"""
- store model_list in db
- store keys separately
"""
try:
_original_litellm_model_name = model_params.model_name
if model_params.model_info.team_id is None:
model_response = await _add_model_to_db(
model_params=model_params,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
)
else:
model_response = await _add_team_model_to_db(
model_params=model_params,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
)
await proxy_config.add_deployment(
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
)
try:
# don't let failed slack alert block the /model/new response
_alerting = general_settings.get("alerting", []) or []
if "slack" in _alerting:
# send notification - new model added
await proxy_logging_obj.slack_alerting_instance.model_added_alert(
model_name=model_params.model_name,
litellm_model_name=_orignal_litellm_model_name,
litellm_model_name=_original_litellm_model_name,
passed_model_info=model_params.model_info,
)
except Exception:
pass
except Exception as e:
verbose_proxy_logger.exception(f"Exception in add_new_model: {e}")
else:
raise HTTPException(

View file

@ -1485,7 +1485,8 @@ class PrismaClient:
elif table_name == "team":
if query_type == "find_unique":
response = await self.db.litellm_teamtable.find_unique(
where={"team_id": team_id} # type: ignore
where={"team_id": team_id}, # type: ignore
include={"litellm_model_table": True}, # type: ignore
)
elif query_type == "find_all" and reset_at is not None:
response = await self.db.litellm_teamtable.find_many(

View file

@ -103,6 +103,7 @@ class ModelInfo(BaseModel):
None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking
)
tier: Optional[Literal["free", "paid"]] = None
team_id: Optional[str] = None # the team id that this model belongs to
def __init__(self, id: Optional[Union[str, int]] = None, **params):
if id is None:

26
package-lock.json generated
View file

@ -5,6 +5,7 @@
"packages": {
"": {
"dependencies": {
"prism-react-renderer": "^2.4.1",
"prisma": "^5.17.0",
"react-copy-to-clipboard": "^5.1.0"
},
@ -52,6 +53,11 @@
"@prisma/debug": "5.17.0"
}
},
"node_modules/@types/prismjs": {
"version": "1.26.5",
"resolved": "https://registry.npmjs.org/@types/prismjs/-/prismjs-1.26.5.tgz",
"integrity": "sha512-AUZTa7hQ2KY5L7AmtSiqxlhWxb4ina0yd8hNbl4TWuqnv/pFP0nDMb3YrfSBf4hJVGLh2YEIBfKaBW/9UEl6IQ=="
},
"node_modules/@types/prop-types": {
"version": "15.7.12",
"resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.12.tgz",
@ -77,6 +83,14 @@
"@types/react": "*"
}
},
"node_modules/clsx": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz",
"integrity": "sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==",
"engines": {
"node": ">=6"
}
},
"node_modules/copy-to-clipboard": {
"version": "3.3.3",
"resolved": "https://registry.npmjs.org/copy-to-clipboard/-/copy-to-clipboard-3.3.3.tgz",
@ -115,6 +129,18 @@
"node": ">=0.10.0"
}
},
"node_modules/prism-react-renderer": {
"version": "2.4.1",
"resolved": "https://registry.npmjs.org/prism-react-renderer/-/prism-react-renderer-2.4.1.tgz",
"integrity": "sha512-ey8Ls/+Di31eqzUxC46h8MksNuGx/n0AAC8uKpwFau4RPDYLuE3EXTp8N8G2vX2N7UC/+IXeNUnlWBGGcAG+Ig==",
"dependencies": {
"@types/prismjs": "^1.26.0",
"clsx": "^2.0.0"
},
"peerDependencies": {
"react": ">=16.0.0"
}
},
"node_modules/prisma": {
"version": "5.17.0",
"resolved": "https://registry.npmjs.org/prisma/-/prisma-5.17.0.tgz",

View file

@ -1,5 +1,6 @@
{
"dependencies": {
"prism-react-renderer": "^2.4.1",
"prisma": "^5.17.0",
"react-copy-to-clipboard": "^5.1.0"
},

View file

@ -107,6 +107,49 @@ async def test_add_new_model(prisma_client):
assert _new_model_in_db is not None
@pytest.mark.parametrize(
"team_id, key_team_id, user_role, expected_result",
[
("1234", "1234", LitellmUserRoles.PROXY_ADMIN.value, True),
(
"1234",
"1235",
LitellmUserRoles.PROXY_ADMIN.value,
True,
), # proxy admin can add models for any team
(None, "1234", LitellmUserRoles.PROXY_ADMIN.value, True),
(None, None, LitellmUserRoles.PROXY_ADMIN.value, True),
(
"1234",
"1234",
LitellmUserRoles.INTERNAL_USER.value,
True,
), # internal users can add models for their team
("1234", "1235", LitellmUserRoles.INTERNAL_USER.value, False),
(None, "1234", LitellmUserRoles.INTERNAL_USER.value, False),
(
None,
None,
LitellmUserRoles.INTERNAL_USER.value,
False,
), # internal users cannot add models by default
],
)
def test_can_add_model(team_id, key_team_id, user_role, expected_result):
from litellm.proxy.proxy_server import check_if_team_id_matches_key
args = {
"team_id": team_id,
"user_api_key_dict": UserAPIKeyAuth(
user_role=user_role,
api_key="sk-1234",
team_id=key_team_id,
),
}
assert check_if_team_id_matches_key(**args) is expected_result
@pytest.mark.asyncio
@pytest.mark.skip(reason="new feature, tests passing locally")
async def test_add_update_model(prisma_client):

View file

@ -1252,6 +1252,7 @@ def test_completion_cost_anthropic_prompt_caching():
assert cost_1 > cost_2
@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.parametrize(
"model",
[

View file

@ -11,7 +11,7 @@ sys.path.insert(
import pytest
import openai
import litellm
from litellm import completion_with_retries, completion
from litellm import completion_with_retries, completion, acompletion_with_retries
from litellm import (
AuthenticationError,
BadRequestError,
@ -113,3 +113,36 @@ async def test_completion_with_retry_policy_no_error(sync_mode):
await completion(**data)
except Exception as e:
print(e)
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_with_retries(sync_mode):
"""
If completion_with_retries is called with num_retries=3, and max_retries=0, then litellm.completion should receive num_retries , max_retries=0
"""
from unittest.mock import patch, MagicMock, AsyncMock
if sync_mode:
target_function = "completion"
else:
target_function = "acompletion"
with patch.object(litellm, target_function) as mock_completion:
if sync_mode:
completion_with_retries(
model="gpt-3.5-turbo",
messages=[{"gm": "vibe", "role": "user"}],
num_retries=3,
original_function=mock_completion,
)
else:
await acompletion_with_retries(
model="gpt-3.5-turbo",
messages=[{"gm": "vibe", "role": "user"}],
num_retries=3,
original_function=mock_completion,
)
mock_completion.assert_called_once()
assert mock_completion.call_args.kwargs["num_retries"] == 0
assert mock_completion.call_args.kwargs["max_retries"] == 0