mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(main.py): fix retries being multiplied when using openai sdk (#7221)
* fix(main.py): fix retries being multiplied when using openai sdk Closes https://github.com/BerriAI/litellm/pull/7130 * docs(prompt_management.md): add langfuse prompt management doc * feat(team_endpoints.py): allow teams to add their own models Enables teams to call their own finetuned models via the proxy * test: add better enforcement check testing for `/model/new` now that teams can add their own models * docs(team_model_add.md): tutorial for allowing teams to add their own models * test: fix test
This commit is contained in:
parent
8060c5c698
commit
ec36353b41
16 changed files with 2439 additions and 1540 deletions
83
docs/my-website/docs/proxy/prompt_management.md
Normal file
83
docs/my-website/docs/proxy/prompt_management.md
Normal file
|
@ -0,0 +1,83 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
|
||||
# Prompt Management
|
||||
|
||||
LiteLLM supports using [Langfuse](https://langfuse.com/docs/prompts/get-started) for prompt management on the proxy.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Add Langfuse as a 'callback' in your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["langfuse"] # 👈 KEY CHANGE
|
||||
```
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
litellm-proxy --config config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "THIS WILL BE IGNORED"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"langfuse_prompt_id": "value",
|
||||
"langfuse_prompt_variables": { # [OPTIONAL]
|
||||
"key": "value"
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## What is 'langfuse_prompt_id'?
|
||||
|
||||
- `langfuse_prompt_id`: The ID of the prompt that will be used for the request.
|
||||
|
||||
<Image img={require('../../img/langfuse_prompt_id.png')} />
|
||||
|
||||
## What will the formatted prompt look like?
|
||||
|
||||
### `/chat/completions` messages
|
||||
|
||||
The message will be added to the start of the prompt.
|
||||
|
||||
- if the Langfuse prompt is a list, it will be added to the start of the messages list (assuming it's an OpenAI compatible message).
|
||||
|
||||
- if the Langfuse prompt is a string, it will be added as a system message.
|
||||
|
||||
```python
|
||||
if isinstance(compiled_prompt, list):
|
||||
data["messages"] = compiled_prompt + data["messages"]
|
||||
else:
|
||||
data["messages"] = [
|
||||
{"role": "system", "content": compiled_prompt}
|
||||
] + data["messages"]
|
||||
```
|
||||
|
||||
### `/completions` messages
|
||||
|
||||
The message will be added to the start of the prompt.
|
||||
|
||||
```python
|
||||
data["prompt"] = compiled_prompt + "\n" + data["prompt"]
|
||||
```
|
77
docs/my-website/docs/proxy/team_model_add.md
Normal file
77
docs/my-website/docs/proxy/team_model_add.md
Normal file
|
@ -0,0 +1,77 @@
|
|||
# Allow Teams to Add Models
|
||||
|
||||
Allow team to add a their own models/key for that project - so any OpenAI call they make uses their OpenAI key.
|
||||
|
||||
Useful for teams that want to call their own finetuned models.
|
||||
|
||||
## Specify Team ID in `/model/add` endpoint
|
||||
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/model/new' \
|
||||
-H 'Authorization: Bearer sk-******2ql3-sm28WU0tTAmA' \ # 👈 Team API Key (has same 'team_id' as below)
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model_name": "my-team-model", # 👈 Call LiteLLM with this model name
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"custom_llm_provider": "openai",
|
||||
"api_key": "******ccb07",
|
||||
"api_base": "https://my-endpoint-sweden-berri992.openai.azure.com",
|
||||
"api_version": "2023-12-01-preview"
|
||||
},
|
||||
"model_info": {
|
||||
"team_id": "e59e2671-a064-436a-a0fa-16ae96e5a0a1" # 👈 Specify the team ID it belongs to
|
||||
}
|
||||
}'
|
||||
|
||||
```
|
||||
|
||||
## Test it!
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-******2ql3-sm28WU0tTAmA' \ # 👈 Team API Key
|
||||
-d '{
|
||||
"model": "my-team-model", # 👈 team model name
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today?"
|
||||
}
|
||||
]
|
||||
}'
|
||||
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
### 'model_name' not found
|
||||
|
||||
Check if model alias exists in team table.
|
||||
|
||||
```bash
|
||||
curl -L -X GET 'http://localhost:4000/team/info?team_id=e59e2671-a064-436a-a0fa-16ae96e5a0a1' \
|
||||
-H 'Authorization: Bearer sk-******2ql3-sm28WU0tTAmA' \
|
||||
```
|
||||
|
||||
**Expected Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
{
|
||||
"team_id": "e59e2671-a064-436a-a0fa-16ae96e5a0a1",
|
||||
"team_info": {
|
||||
...,
|
||||
"litellm_model_table": {
|
||||
"model_aliases": {
|
||||
"my-team-model": # 👈 public model name "model_name_e59e2671-a064-436a-a0fa-16ae96e5a0a1_e81c9286-2195-4bd9-81e1-cf393788a1a0" 👈 internally generated model name (used to ensure uniqueness)
|
||||
},
|
||||
"created_by": "default_user_id",
|
||||
"updated_by": "default_user_id"
|
||||
}
|
||||
},
|
||||
}
|
||||
```
|
||||
|
BIN
docs/my-website/img/langfuse_prompt_id.png
Normal file
BIN
docs/my-website/img/langfuse_prompt_id.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 93 KiB |
3531
docs/my-website/package-lock.json
generated
3531
docs/my-website/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
@ -89,6 +89,13 @@ const sidebars = {
|
|||
"proxy/custom_sso"
|
||||
],
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
label: "Team Management",
|
||||
items: [
|
||||
"proxy/team_model_add"
|
||||
],
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
label: "Spend Tracking + Budgets",
|
||||
|
@ -127,6 +134,7 @@ const sidebars = {
|
|||
"oidc"
|
||||
]
|
||||
},
|
||||
"proxy/prompt_management",
|
||||
"proxy/caching",
|
||||
"proxy/call_hooks",
|
||||
"proxy/rules",
|
||||
|
|
|
@ -2969,6 +2969,9 @@ def completion_with_retries(*args, **kwargs):
|
|||
)
|
||||
|
||||
num_retries = kwargs.pop("num_retries", 3)
|
||||
# reset retries in .completion()
|
||||
kwargs["max_retries"] = 0
|
||||
kwargs["num_retries"] = 0
|
||||
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop(
|
||||
"retry_strategy", "constant_retry"
|
||||
) # type: ignore
|
||||
|
@ -2999,6 +3002,8 @@ async def acompletion_with_retries(*args, **kwargs):
|
|||
)
|
||||
|
||||
num_retries = kwargs.pop("num_retries", 3)
|
||||
kwargs["max_retries"] = 0
|
||||
kwargs["num_retries"] = 0
|
||||
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
|
||||
original_function = kwargs.pop("original_function", completion)
|
||||
if retry_strategy == "exponential_backoff_retry":
|
||||
|
|
|
@ -366,6 +366,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
self_managed_routes = [
|
||||
"/team/member_add",
|
||||
"/team/member_delete",
|
||||
"/model/new",
|
||||
] # routes that manage their own allowed/disallowed logic
|
||||
|
||||
## Org Admin Routes ##
|
||||
|
|
|
@ -363,13 +363,18 @@ async def _update_model_table(
|
|||
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
)
|
||||
model_dict = await prisma_client.db.litellm_modeltable.upsert(
|
||||
where={"id": model_id},
|
||||
data={
|
||||
"update": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
|
||||
"create": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
|
||||
},
|
||||
) # type: ignore
|
||||
if model_id is None:
|
||||
model_dict = await prisma_client.db.litellm_modeltable.create(
|
||||
data={**litellm_modeltable.json(exclude_none=True)} # type: ignore
|
||||
)
|
||||
else:
|
||||
model_dict = await prisma_client.db.litellm_modeltable.upsert(
|
||||
where={"id": model_id},
|
||||
data={
|
||||
"update": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
|
||||
"create": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
|
||||
},
|
||||
) # type: ignore
|
||||
|
||||
_model_id = model_dict.id
|
||||
|
||||
|
|
|
@ -199,6 +199,7 @@ from litellm.proxy.management_endpoints.team_callback_endpoints import (
|
|||
router as team_callback_router,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
||||
from litellm.proxy.management_endpoints.team_endpoints import update_team
|
||||
from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router
|
||||
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model
|
||||
|
@ -6202,6 +6203,94 @@ async def delete_budget(
|
|||
#### MODEL MANAGEMENT ####
|
||||
|
||||
|
||||
async def _add_model_to_db(
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
# encrypt litellm params #
|
||||
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
|
||||
_orignal_litellm_model_name = model_params.litellm_params.model
|
||||
for k, v in _litellm_params_dict.items():
|
||||
encrypted_value = encrypt_value_helper(value=v)
|
||||
model_params.litellm_params[k] = encrypted_value
|
||||
_data: dict = {
|
||||
"model_id": model_params.model_info.id,
|
||||
"model_name": model_params.model_name,
|
||||
"litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore
|
||||
"model_info": model_params.model_info.model_dump_json( # type: ignore
|
||||
exclude_none=True
|
||||
),
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
if model_params.model_info.id is not None:
|
||||
_data["model_id"] = model_params.model_info.id
|
||||
model_response = await prisma_client.db.litellm_proxymodeltable.create(
|
||||
data=_data # type: ignore
|
||||
)
|
||||
return model_response
|
||||
|
||||
|
||||
async def _add_team_model_to_db(
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
"""
|
||||
If 'team_id' is provided,
|
||||
|
||||
- generate a unique 'model_name' for the model (e.g. 'model_name_{team_id}_{uuid})
|
||||
- store the model in the db with the unique 'model_name'
|
||||
- store a team model alias mapping {"model_name": "model_name_{team_id}_{uuid}"}
|
||||
"""
|
||||
_team_id = model_params.model_info.team_id
|
||||
original_model_name = model_params.model_name
|
||||
if _team_id is None:
|
||||
return None
|
||||
|
||||
unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}"
|
||||
|
||||
model_params.model_name = unique_model_name
|
||||
|
||||
## CREATE MODEL IN DB ##
|
||||
model_response = await _add_model_to_db(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
## CREATE MODEL ALIAS IN DB ##
|
||||
await update_team(
|
||||
data=UpdateTeamRequest(
|
||||
team_id=_team_id,
|
||||
model_aliases={original_model_name: unique_model_name},
|
||||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
http_request=Request(scope={"type": "http"}),
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
def check_if_team_id_matches_key(
|
||||
team_id: Optional[str], user_api_key_dict: UserAPIKeyAuth
|
||||
) -> bool:
|
||||
can_make_call = True
|
||||
if (
|
||||
user_api_key_dict.user_role
|
||||
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
):
|
||||
return True
|
||||
if team_id is None:
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
can_make_call = False
|
||||
else:
|
||||
if user_api_key_dict.team_id != team_id:
|
||||
can_make_call = False
|
||||
return can_make_call
|
||||
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||
@router.post(
|
||||
"/model/new",
|
||||
|
@ -6217,8 +6306,6 @@ async def add_new_model(
|
|||
try:
|
||||
import base64
|
||||
|
||||
global prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
|
@ -6227,6 +6314,14 @@ async def add_new_model(
|
|||
},
|
||||
)
|
||||
|
||||
if not check_if_team_id_matches_key(
|
||||
team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": "Team ID does not match the API key's team ID"},
|
||||
)
|
||||
|
||||
model_response = None
|
||||
# update DB
|
||||
if store_model_in_db is True:
|
||||
|
@ -6234,43 +6329,35 @@ async def add_new_model(
|
|||
- store model_list in db
|
||||
- store keys separately
|
||||
"""
|
||||
# encrypt litellm params #
|
||||
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
|
||||
_orignal_litellm_model_name = model_params.litellm_params.model
|
||||
for k, v in _litellm_params_dict.items():
|
||||
encrypted_value = encrypt_value_helper(value=v)
|
||||
model_params.litellm_params[k] = encrypted_value
|
||||
_data: dict = {
|
||||
"model_id": model_params.model_info.id,
|
||||
"model_name": model_params.model_name,
|
||||
"litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore
|
||||
"model_info": model_params.model_info.model_dump_json( # type: ignore
|
||||
exclude_none=True
|
||||
),
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
if model_params.model_info.id is not None:
|
||||
_data["model_id"] = model_params.model_info.id
|
||||
model_response = await prisma_client.db.litellm_proxymodeltable.create(
|
||||
data=_data # type: ignore
|
||||
)
|
||||
|
||||
await proxy_config.add_deployment(
|
||||
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
try:
|
||||
_original_litellm_model_name = model_params.model_name
|
||||
if model_params.model_info.team_id is None:
|
||||
model_response = await _add_model_to_db(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
else:
|
||||
model_response = await _add_team_model_to_db(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
await proxy_config.add_deployment(
|
||||
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# don't let failed slack alert block the /model/new response
|
||||
_alerting = general_settings.get("alerting", []) or []
|
||||
if "slack" in _alerting:
|
||||
# send notification - new model added
|
||||
await proxy_logging_obj.slack_alerting_instance.model_added_alert(
|
||||
model_name=model_params.model_name,
|
||||
litellm_model_name=_orignal_litellm_model_name,
|
||||
litellm_model_name=_original_litellm_model_name,
|
||||
passed_model_info=model_params.model_info,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Exception in add_new_model: {e}")
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
|
@ -1485,7 +1485,8 @@ class PrismaClient:
|
|||
elif table_name == "team":
|
||||
if query_type == "find_unique":
|
||||
response = await self.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id} # type: ignore
|
||||
where={"team_id": team_id}, # type: ignore
|
||||
include={"litellm_model_table": True}, # type: ignore
|
||||
)
|
||||
elif query_type == "find_all" and reset_at is not None:
|
||||
response = await self.db.litellm_teamtable.find_many(
|
||||
|
|
|
@ -103,6 +103,7 @@ class ModelInfo(BaseModel):
|
|||
None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking
|
||||
)
|
||||
tier: Optional[Literal["free", "paid"]] = None
|
||||
team_id: Optional[str] = None # the team id that this model belongs to
|
||||
|
||||
def __init__(self, id: Optional[Union[str, int]] = None, **params):
|
||||
if id is None:
|
||||
|
|
26
package-lock.json
generated
26
package-lock.json
generated
|
@ -5,6 +5,7 @@
|
|||
"packages": {
|
||||
"": {
|
||||
"dependencies": {
|
||||
"prism-react-renderer": "^2.4.1",
|
||||
"prisma": "^5.17.0",
|
||||
"react-copy-to-clipboard": "^5.1.0"
|
||||
},
|
||||
|
@ -52,6 +53,11 @@
|
|||
"@prisma/debug": "5.17.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/prismjs": {
|
||||
"version": "1.26.5",
|
||||
"resolved": "https://registry.npmjs.org/@types/prismjs/-/prismjs-1.26.5.tgz",
|
||||
"integrity": "sha512-AUZTa7hQ2KY5L7AmtSiqxlhWxb4ina0yd8hNbl4TWuqnv/pFP0nDMb3YrfSBf4hJVGLh2YEIBfKaBW/9UEl6IQ=="
|
||||
},
|
||||
"node_modules/@types/prop-types": {
|
||||
"version": "15.7.12",
|
||||
"resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.12.tgz",
|
||||
|
@ -77,6 +83,14 @@
|
|||
"@types/react": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/clsx": {
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz",
|
||||
"integrity": "sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==",
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/copy-to-clipboard": {
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/copy-to-clipboard/-/copy-to-clipboard-3.3.3.tgz",
|
||||
|
@ -115,6 +129,18 @@
|
|||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/prism-react-renderer": {
|
||||
"version": "2.4.1",
|
||||
"resolved": "https://registry.npmjs.org/prism-react-renderer/-/prism-react-renderer-2.4.1.tgz",
|
||||
"integrity": "sha512-ey8Ls/+Di31eqzUxC46h8MksNuGx/n0AAC8uKpwFau4RPDYLuE3EXTp8N8G2vX2N7UC/+IXeNUnlWBGGcAG+Ig==",
|
||||
"dependencies": {
|
||||
"@types/prismjs": "^1.26.0",
|
||||
"clsx": "^2.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=16.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/prisma": {
|
||||
"version": "5.17.0",
|
||||
"resolved": "https://registry.npmjs.org/prisma/-/prisma-5.17.0.tgz",
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
{
|
||||
"dependencies": {
|
||||
"prism-react-renderer": "^2.4.1",
|
||||
"prisma": "^5.17.0",
|
||||
"react-copy-to-clipboard": "^5.1.0"
|
||||
},
|
||||
|
|
|
@ -107,6 +107,49 @@ async def test_add_new_model(prisma_client):
|
|||
assert _new_model_in_db is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"team_id, key_team_id, user_role, expected_result",
|
||||
[
|
||||
("1234", "1234", LitellmUserRoles.PROXY_ADMIN.value, True),
|
||||
(
|
||||
"1234",
|
||||
"1235",
|
||||
LitellmUserRoles.PROXY_ADMIN.value,
|
||||
True,
|
||||
), # proxy admin can add models for any team
|
||||
(None, "1234", LitellmUserRoles.PROXY_ADMIN.value, True),
|
||||
(None, None, LitellmUserRoles.PROXY_ADMIN.value, True),
|
||||
(
|
||||
"1234",
|
||||
"1234",
|
||||
LitellmUserRoles.INTERNAL_USER.value,
|
||||
True,
|
||||
), # internal users can add models for their team
|
||||
("1234", "1235", LitellmUserRoles.INTERNAL_USER.value, False),
|
||||
(None, "1234", LitellmUserRoles.INTERNAL_USER.value, False),
|
||||
(
|
||||
None,
|
||||
None,
|
||||
LitellmUserRoles.INTERNAL_USER.value,
|
||||
False,
|
||||
), # internal users cannot add models by default
|
||||
],
|
||||
)
|
||||
def test_can_add_model(team_id, key_team_id, user_role, expected_result):
|
||||
from litellm.proxy.proxy_server import check_if_team_id_matches_key
|
||||
|
||||
args = {
|
||||
"team_id": team_id,
|
||||
"user_api_key_dict": UserAPIKeyAuth(
|
||||
user_role=user_role,
|
||||
api_key="sk-1234",
|
||||
team_id=key_team_id,
|
||||
),
|
||||
}
|
||||
|
||||
assert check_if_team_id_matches_key(**args) is expected_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="new feature, tests passing locally")
|
||||
async def test_add_update_model(prisma_client):
|
||||
|
|
|
@ -1252,6 +1252,7 @@ def test_completion_cost_anthropic_prompt_caching():
|
|||
assert cost_1 > cost_2
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
|
|
|
@ -11,7 +11,7 @@ sys.path.insert(
|
|||
import pytest
|
||||
import openai
|
||||
import litellm
|
||||
from litellm import completion_with_retries, completion
|
||||
from litellm import completion_with_retries, completion, acompletion_with_retries
|
||||
from litellm import (
|
||||
AuthenticationError,
|
||||
BadRequestError,
|
||||
|
@ -113,3 +113,36 @@ async def test_completion_with_retry_policy_no_error(sync_mode):
|
|||
await completion(**data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_with_retries(sync_mode):
|
||||
"""
|
||||
If completion_with_retries is called with num_retries=3, and max_retries=0, then litellm.completion should receive num_retries , max_retries=0
|
||||
"""
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
if sync_mode:
|
||||
target_function = "completion"
|
||||
else:
|
||||
target_function = "acompletion"
|
||||
|
||||
with patch.object(litellm, target_function) as mock_completion:
|
||||
if sync_mode:
|
||||
completion_with_retries(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"gm": "vibe", "role": "user"}],
|
||||
num_retries=3,
|
||||
original_function=mock_completion,
|
||||
)
|
||||
else:
|
||||
await acompletion_with_retries(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"gm": "vibe", "role": "user"}],
|
||||
num_retries=3,
|
||||
original_function=mock_completion,
|
||||
)
|
||||
mock_completion.assert_called_once()
|
||||
assert mock_completion.call_args.kwargs["num_retries"] == 0
|
||||
assert mock_completion.call_args.kwargs["max_retries"] == 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue