mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Litellm dev 12 07 2024 (#7086)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 11s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 11s
* fix(main.py): support passing max retries to azure/openai embedding integrations Fixes https://github.com/BerriAI/litellm/issues/7003 * feat(team_endpoints.py): allow updating team model aliases Closes https://github.com/BerriAI/litellm/issues/6956 * feat(router.py): allow specifying model id as fallback - skips any cooldown check Allows a default model to be checked if all models in cooldown s/o @micahjsmith * docs(reliability.md): add fallback to specific model to docs * fix(utils.py): new 'is_prompt_caching_valid_prompt' helper util Allows user to identify if messages/tools have prompt caching Related issue: https://github.com/BerriAI/litellm/issues/6784 * feat(router.py): store model id for prompt caching valid prompt Allows routing to that model id on subsequent requests * fix(router.py): only cache if prompt is valid prompt caching prompt prevents storing unnecessary items in cache * feat(router.py): support routing prompt caching enabled models to previous deployments Closes https://github.com/BerriAI/litellm/issues/6784 * test: fix linting errors * feat(databricks/): convert basemodel to dict and exclude none values allow passing pydantic message to databricks * fix(utils.py): ensure all chat completion messages are dict * (feat) Track `custom_llm_provider` in LiteLLMSpendLogs (#7081) * add custom_llm_provider to SpendLogsPayload * add custom_llm_provider to SpendLogs * add custom llm provider to SpendLogs payload * test_spend_logs_payload * Add MLflow to the side bar (#7031) Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> * (bug fix) SpendLogs update DB catch all possible DB errors for retrying (#7082) * catch DB_CONNECTION_ERROR_TYPES * fix DB retry mechanism for SpendLog updates * use DB_CONNECTION_ERROR_TYPES in auth checks * fix exp back off for writing SpendLogs * use _raise_failed_update_spend_exception to ensure errors print as NON blocking * test_update_spend_logs_multiple_batches_with_failure * (Feat) Add StructuredOutputs support for Fireworks.AI (#7085) * fix model cost map fireworks ai "supports_response_schema": true, * fix supports_response_schema * fix map openai params fireworks ai * test_map_response_format * test_map_response_format * added deepinfra/Meta-Llama-3.1-405B-Instruct (#7084) * bump: version 1.53.9 → 1.54.0 * fix deepinfra * litellm db fixes LiteLLM_UserTable (#7089) * ci/cd queue new release * fix llama-3.3-70b-versatile * refactor - use consistent file naming convention `AI21/` -> `ai21` (#7090) * fix refactor - use consistent file naming convention * ci/cd run again * fix naming structure * fix use consistent naming (#7092) --------- Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: Yuki Watanabe <31463517+B-Step62@users.noreply.github.com> Co-authored-by: ali sayyah <ali.sayyah2@gmail.com>
This commit is contained in:
parent
36e99ebce7
commit
0c0498dd60
24 changed files with 840 additions and 193 deletions
|
@ -315,6 +315,64 @@ litellm_settings:
|
|||
cooldown_time: 30 # how long to cooldown model if fails/min > allowed_fails
|
||||
```
|
||||
|
||||
### Fallback to Specific Model ID
|
||||
|
||||
If all models in a group are in cooldown (e.g. rate limited), LiteLLM will fallback to the model with the specific model ID.
|
||||
|
||||
This skips any cooldown check for the fallback model.
|
||||
|
||||
1. Specify the model ID in `model_info`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/gpt-4
|
||||
model_info:
|
||||
id: my-specific-model-id # 👈 KEY CHANGE
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
- model_name: anthropic-claude
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-opus-20240229
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
```
|
||||
|
||||
**Note:** This will only fallback to the model with the specific model ID. If you want to fallback to another model group, you can set `fallbacks=[{"gpt-4": ["anthropic-claude"]}]`
|
||||
|
||||
2. Set fallbacks in config
|
||||
|
||||
```yaml
|
||||
litellm_settings:
|
||||
fallbacks: [{"gpt-4": ["my-specific-model-id"]}]
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-D '{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "ping"
|
||||
}
|
||||
],
|
||||
"mock_testing_fallbacks": true
|
||||
}'
|
||||
```
|
||||
|
||||
Validate it works, by checking the response header `x-litellm-model-id`
|
||||
|
||||
```bash
|
||||
x-litellm-model-id: my-specific-model-id
|
||||
```
|
||||
|
||||
### Test Fallbacks!
|
||||
|
||||
Check if your fallbacks are working as expected.
|
||||
|
@ -337,6 +395,7 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
|||
'
|
||||
```
|
||||
|
||||
|
||||
#### **Content Policy Fallbacks**
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
|
|
|
@ -1130,7 +1130,7 @@ router_settings:
|
|||
|
||||
If a call fails after num_retries, fall back to another model group.
|
||||
|
||||
### Quick Start
|
||||
#### Quick Start
|
||||
|
||||
```python
|
||||
from litellm import Router
|
||||
|
@ -1366,6 +1366,7 @@ litellm --config /path/to/config.yaml
|
|||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Caching
|
||||
|
||||
In production, we recommend using a Redis cache. For quickly testing things locally, we also support simple in-memory caching.
|
||||
|
|
|
@ -21,6 +21,7 @@ from litellm.constants import (
|
|||
DEFAULT_BATCH_SIZE,
|
||||
DEFAULT_FLUSH_INTERVAL_SECONDS,
|
||||
ROUTER_MAX_FALLBACKS,
|
||||
DEFAULT_MAX_RETRIES,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailItem
|
||||
from litellm.proxy._types import (
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
ROUTER_MAX_FALLBACKS = 5
|
||||
DEFAULT_BATCH_SIZE = 512
|
||||
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
||||
DEFAULT_MAX_RETRIES = 2
|
||||
|
|
|
@ -1217,12 +1217,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
max_retries: Optional[int] = None,
|
||||
) -> litellm.EmbeddingResponse:
|
||||
super().embedding()
|
||||
try:
|
||||
model = model
|
||||
data = {"model": model, "input": input, **optional_params}
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES
|
||||
if not isinstance(max_retries, int):
|
||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||
## LOGGING
|
||||
|
|
|
@ -871,6 +871,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
) -> litellm.EmbeddingResponse:
|
||||
|
@ -879,7 +880,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
self._client_session = self.create_client_session()
|
||||
try:
|
||||
data = {"model": model, "input": input, **optional_params}
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
|
|
|
@ -219,6 +219,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
|
|||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
max_retries: Optional[int] = None,
|
||||
) -> litellm.EmbeddingResponse:
|
||||
"""
|
||||
- Separate image url from text
|
||||
|
|
|
@ -134,7 +134,7 @@ class DatabricksConfig(OpenAIGPTConfig):
|
|||
new_messages = []
|
||||
for idx, message in enumerate(messages):
|
||||
if isinstance(message, BaseModel):
|
||||
_message = message.model_dump()
|
||||
_message = message.model_dump(exclude_none=True)
|
||||
else:
|
||||
_message = message
|
||||
new_messages.append(_message)
|
||||
|
|
|
@ -77,7 +77,7 @@ from litellm.utils import (
|
|||
read_config_args,
|
||||
supports_httpx_timeout,
|
||||
token_counter,
|
||||
validate_chat_completion_user_messages,
|
||||
validate_chat_completion_messages,
|
||||
)
|
||||
|
||||
from ._logging import verbose_logger
|
||||
|
@ -931,7 +931,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
) # support region-based pricing for bedrock
|
||||
|
||||
### VALIDATE USER MESSAGES ###
|
||||
validate_chat_completion_user_messages(messages=messages)
|
||||
messages = validate_chat_completion_messages(messages=messages)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -3274,6 +3274,7 @@ def embedding( # noqa: PLR0915
|
|||
client = kwargs.pop("client", None)
|
||||
rpm = kwargs.pop("rpm", None)
|
||||
tpm = kwargs.pop("tpm", None)
|
||||
max_retries = kwargs.get("max_retries", None)
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
cooldown_time = kwargs.get("cooldown_time", None)
|
||||
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
|
||||
|
@ -3422,6 +3423,7 @@ def embedding( # noqa: PLR0915
|
|||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
elif (
|
||||
model in litellm.open_ai_embedding_models
|
||||
|
@ -3466,6 +3468,7 @@ def embedding( # noqa: PLR0915
|
|||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
elif custom_llm_provider == "databricks":
|
||||
api_base = (
|
||||
|
|
|
@ -723,7 +723,7 @@ class KeyRequest(LiteLLMBase):
|
|||
|
||||
|
||||
class LiteLLM_ModelTable(LiteLLMBase):
|
||||
model_aliases: Optional[str] = None # json dump the dict
|
||||
model_aliases: Optional[Union[str, dict]] = None # json dump the dict
|
||||
created_by: str
|
||||
updated_by: str
|
||||
|
||||
|
@ -981,6 +981,7 @@ class UpdateTeamRequest(LiteLLMBase):
|
|||
blocked: Optional[bool] = None
|
||||
budget_duration: Optional[str] = None
|
||||
tags: Optional[list] = None
|
||||
model_aliases: Optional[dict] = None
|
||||
|
||||
|
||||
class ResetTeamBudgetRequest(LiteLLMBase):
|
||||
|
@ -1059,6 +1060,7 @@ class LiteLLM_TeamTable(TeamBase):
|
|||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
model_id: Optional[int] = None
|
||||
litellm_model_table: Optional[LiteLLM_ModelTable] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
|
|
@ -293,8 +293,13 @@ async def new_team( # noqa: PLR0915
|
|||
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
complete_team_data.budget_reset_at = reset_at
|
||||
|
||||
team_row: LiteLLM_TeamTable = await prisma_client.insert_data( # type: ignore
|
||||
data=complete_team_data.json(exclude_none=True), table_name="team"
|
||||
complete_team_data_dict = complete_team_data.model_dump(exclude_none=True)
|
||||
complete_team_data_dict = prisma_client.jsonify_team_object(
|
||||
db_data=complete_team_data_dict
|
||||
)
|
||||
team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create(
|
||||
data=complete_team_data_dict,
|
||||
include={"litellm_model_table": True}, # type: ignore
|
||||
)
|
||||
|
||||
## ADD TEAM ID TO USER TABLE ##
|
||||
|
@ -340,6 +345,37 @@ async def new_team( # noqa: PLR0915
|
|||
return team_row.dict()
|
||||
|
||||
|
||||
async def _update_model_table(
|
||||
data: UpdateTeamRequest,
|
||||
model_id: Optional[str],
|
||||
prisma_client: PrismaClient,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_proxy_admin_name: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Upsert model table and return the model id
|
||||
"""
|
||||
## UPSERT MODEL TABLE
|
||||
_model_id = model_id
|
||||
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
|
||||
litellm_modeltable = LiteLLM_ModelTable(
|
||||
model_aliases=json.dumps(data.model_aliases),
|
||||
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
)
|
||||
model_dict = await prisma_client.db.litellm_modeltable.upsert(
|
||||
where={"id": model_id},
|
||||
data={
|
||||
"update": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
|
||||
"create": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
|
||||
},
|
||||
) # type: ignore
|
||||
|
||||
_model_id = model_dict.id
|
||||
|
||||
return _model_id
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
|
@ -370,6 +406,7 @@ async def update_team(
|
|||
- blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id.
|
||||
- tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing).
|
||||
- organization_id: Optional[str] - The organization id of the team. Default is None. Create via `/organization/new`.
|
||||
- model_aliases: Optional[dict] - Model aliases for the team. [Docs](https://docs.litellm.ai/docs/proxy/team_based_routing#create-team-with-model-alias)
|
||||
|
||||
Example - update team TPM Limit
|
||||
|
||||
|
@ -446,11 +483,25 @@ async def update_team(
|
|||
else:
|
||||
updated_kv["metadata"] = {"tags": _tags}
|
||||
|
||||
updated_kv = prisma_client.jsonify_object(data=updated_kv)
|
||||
team_row: Optional[
|
||||
LiteLLM_TeamTable
|
||||
] = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id}, data=updated_kv # type: ignore
|
||||
if "model_aliases" in updated_kv:
|
||||
updated_kv.pop("model_aliases")
|
||||
_model_id = await _update_model_table(
|
||||
data=data,
|
||||
model_id=existing_team_row.model_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
)
|
||||
if _model_id is not None:
|
||||
updated_kv["model_id"] = _model_id
|
||||
|
||||
updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv)
|
||||
team_row: Optional[LiteLLM_TeamTable] = (
|
||||
await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id},
|
||||
data=updated_kv,
|
||||
include={"litellm_model_table": True}, # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
if team_row is None or team_row.team_id is None:
|
||||
|
|
|
@ -1207,173 +1207,6 @@ class PrismaClient:
|
|||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
# try:
|
||||
# # Try to select one row from the view
|
||||
# await self.db.query_raw(
|
||||
# """SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1"""
|
||||
# )
|
||||
# print("LiteLLM_VerificationTokenView Exists!") # noqa
|
||||
# except Exception as e:
|
||||
# If an error occurs, the view does not exist, so create it
|
||||
|
||||
# try:
|
||||
# await self.db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
|
||||
# print("MonthlyGlobalSpend Exists!") # noqa
|
||||
# except Exception as e:
|
||||
# sql_query = """
|
||||
# CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
|
||||
# SELECT
|
||||
# DATE("startTime") AS date,
|
||||
# SUM("spend") AS spend
|
||||
# FROM
|
||||
# "LiteLLM_SpendLogs"
|
||||
# WHERE
|
||||
# "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
# GROUP BY
|
||||
# DATE("startTime");
|
||||
# """
|
||||
# await self.db.execute_raw(query=sql_query)
|
||||
|
||||
# print("MonthlyGlobalSpend Created!") # noqa
|
||||
|
||||
# try:
|
||||
# await self.db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
|
||||
# print("Last30dKeysBySpend Exists!") # noqa
|
||||
# except Exception as e:
|
||||
# sql_query = """
|
||||
# CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS
|
||||
# SELECT
|
||||
# L."api_key",
|
||||
# V."key_alias",
|
||||
# V."key_name",
|
||||
# SUM(L."spend") AS total_spend
|
||||
# FROM
|
||||
# "LiteLLM_SpendLogs" L
|
||||
# LEFT JOIN
|
||||
# "LiteLLM_VerificationToken" V
|
||||
# ON
|
||||
# L."api_key" = V."token"
|
||||
# WHERE
|
||||
# L."startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
# GROUP BY
|
||||
# L."api_key", V."key_alias", V."key_name"
|
||||
# ORDER BY
|
||||
# total_spend DESC;
|
||||
# """
|
||||
# await self.db.execute_raw(query=sql_query)
|
||||
|
||||
# print("Last30dKeysBySpend Created!") # noqa
|
||||
|
||||
# try:
|
||||
# await self.db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""")
|
||||
# print("Last30dModelsBySpend Exists!") # noqa
|
||||
# except Exception as e:
|
||||
# sql_query = """
|
||||
# CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS
|
||||
# SELECT
|
||||
# "model",
|
||||
# SUM("spend") AS total_spend
|
||||
# FROM
|
||||
# "LiteLLM_SpendLogs"
|
||||
# WHERE
|
||||
# "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
# AND "model" != ''
|
||||
# GROUP BY
|
||||
# "model"
|
||||
# ORDER BY
|
||||
# total_spend DESC;
|
||||
# """
|
||||
# await self.db.execute_raw(query=sql_query)
|
||||
|
||||
# print("Last30dModelsBySpend Created!") # noqa
|
||||
# try:
|
||||
# await self.db.query_raw(
|
||||
# """SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1"""
|
||||
# )
|
||||
# print("MonthlyGlobalSpendPerKey Exists!") # noqa
|
||||
# except Exception as e:
|
||||
# sql_query = """
|
||||
# CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS
|
||||
# SELECT
|
||||
# DATE("startTime") AS date,
|
||||
# SUM("spend") AS spend,
|
||||
# api_key as api_key
|
||||
# FROM
|
||||
# "LiteLLM_SpendLogs"
|
||||
# WHERE
|
||||
# "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
# GROUP BY
|
||||
# DATE("startTime"),
|
||||
# api_key;
|
||||
# """
|
||||
# await self.db.execute_raw(query=sql_query)
|
||||
|
||||
# print("MonthlyGlobalSpendPerKey Created!") # noqa
|
||||
# try:
|
||||
# await self.db.query_raw(
|
||||
# """SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1"""
|
||||
# )
|
||||
# print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa
|
||||
# except Exception as e:
|
||||
# sql_query = """
|
||||
# CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS
|
||||
# SELECT
|
||||
# DATE("startTime") AS date,
|
||||
# SUM("spend") AS spend,
|
||||
# api_key as api_key,
|
||||
# "user" as "user"
|
||||
# FROM
|
||||
# "LiteLLM_SpendLogs"
|
||||
# WHERE
|
||||
# "startTime" >= (CURRENT_DATE - INTERVAL '20 days')
|
||||
# GROUP BY
|
||||
# DATE("startTime"),
|
||||
# "user",
|
||||
# api_key;
|
||||
# """
|
||||
# await self.db.execute_raw(query=sql_query)
|
||||
|
||||
# print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa
|
||||
|
||||
# try:
|
||||
# await self.db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""")
|
||||
# print("DailyTagSpend Exists!") # noqa
|
||||
# except Exception as e:
|
||||
# sql_query = """
|
||||
# CREATE OR REPLACE VIEW DailyTagSpend AS
|
||||
# SELECT
|
||||
# jsonb_array_elements_text(request_tags) AS individual_request_tag,
|
||||
# DATE(s."startTime") AS spend_date,
|
||||
# COUNT(*) AS log_count,
|
||||
# SUM(spend) AS total_spend
|
||||
# FROM "LiteLLM_SpendLogs" s
|
||||
# GROUP BY individual_request_tag, DATE(s."startTime");
|
||||
# """
|
||||
# await self.db.execute_raw(query=sql_query)
|
||||
|
||||
# print("DailyTagSpend Created!") # noqa
|
||||
|
||||
# try:
|
||||
# await self.db.query_raw(
|
||||
# """SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1"""
|
||||
# )
|
||||
# print("Last30dTopEndUsersSpend Exists!") # noqa
|
||||
# except Exception as e:
|
||||
# sql_query = """
|
||||
# CREATE VIEW "Last30dTopEndUsersSpend" AS
|
||||
# SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend
|
||||
# FROM "LiteLLM_SpendLogs"
|
||||
# WHERE end_user <> '' AND end_user <> user
|
||||
# AND "startTime" >= CURRENT_DATE - INTERVAL '30 days'
|
||||
# GROUP BY end_user
|
||||
# ORDER BY total_spend DESC
|
||||
# LIMIT 100;
|
||||
# """
|
||||
# await self.db.execute_raw(query=sql_query)
|
||||
|
||||
# print("Last30dTopEndUsersSpend Created!") # noqa
|
||||
|
||||
return
|
||||
|
||||
@log_db_metrics
|
||||
|
@ -1784,6 +1617,14 @@ class PrismaClient:
|
|||
)
|
||||
raise e
|
||||
|
||||
def jsonify_team_object(self, db_data: dict):
|
||||
db_data = self.jsonify_object(data=db_data)
|
||||
if db_data.get("members_with_roles", None) is not None and isinstance(
|
||||
db_data["members_with_roles"], list
|
||||
):
|
||||
db_data["members_with_roles"] = json.dumps(db_data["members_with_roles"])
|
||||
return db_data
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
|
@ -2348,7 +2189,6 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
|||
module_name = value
|
||||
instance_name = None
|
||||
try:
|
||||
print_verbose(f"value: {value}")
|
||||
# Split the path by dots to separate module from instance
|
||||
parts = value.split(".")
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ from typing import (
|
|||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
@ -96,6 +97,7 @@ from litellm.router_utils.router_callbacks.track_deployment_metrics import (
|
|||
)
|
||||
from litellm.scheduler import FlowItem, Scheduler
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
Assistant,
|
||||
AssistantToolParam,
|
||||
AsyncCursorPage,
|
||||
|
@ -149,10 +151,12 @@ from litellm.utils import (
|
|||
get_llm_provider,
|
||||
get_secret,
|
||||
get_utc_datetime,
|
||||
is_prompt_caching_valid_prompt,
|
||||
is_region_allowed,
|
||||
)
|
||||
|
||||
from .router_utils.pattern_match_deployments import PatternMatchRouter
|
||||
from .router_utils.prompt_caching_cache import PromptCachingCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
@ -737,7 +741,9 @@ class Router:
|
|||
model_client = potential_model_client
|
||||
|
||||
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
|
||||
self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
## only run if model group given, not model id
|
||||
if model not in self.get_model_ids():
|
||||
self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
|
||||
response = litellm.completion(
|
||||
**{
|
||||
|
@ -2787,8 +2793,10 @@ class Router:
|
|||
*args,
|
||||
**input_kwargs,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as new_exception:
|
||||
traceback.print_exc()
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
verbose_router_logger.error(
|
||||
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
|
||||
|
@ -3376,6 +3384,29 @@ class Router:
|
|||
deployment_id=id,
|
||||
)
|
||||
|
||||
## PROMPT CACHING
|
||||
prompt_cache = PromptCachingCache(
|
||||
cache=self.cache,
|
||||
)
|
||||
if (
|
||||
standard_logging_object["messages"] is not None
|
||||
and isinstance(standard_logging_object["messages"], list)
|
||||
and deployment_name is not None
|
||||
and isinstance(deployment_name, str)
|
||||
):
|
||||
valid_prompt = is_prompt_caching_valid_prompt(
|
||||
messages=standard_logging_object["messages"], # type: ignore
|
||||
tools=None,
|
||||
model=deployment_name,
|
||||
custom_llm_provider=None,
|
||||
)
|
||||
if valid_prompt:
|
||||
await prompt_cache.async_add_model_id(
|
||||
model_id=id,
|
||||
messages=standard_logging_object["messages"], # type: ignore
|
||||
tools=None,
|
||||
)
|
||||
|
||||
return tpm_key
|
||||
|
||||
except Exception as e:
|
||||
|
@ -5190,7 +5221,6 @@ class Router:
|
|||
- List, if multiple models chosen
|
||||
- Dict, if specific model chosen
|
||||
"""
|
||||
|
||||
# check if aliases set on litellm model alias map
|
||||
if specific_deployment is True:
|
||||
return model, self._get_deployment_by_litellm_model(model=model)
|
||||
|
@ -5302,13 +5332,6 @@ class Router:
|
|||
cooldown_deployments=cooldown_deployments,
|
||||
)
|
||||
|
||||
# filter pre-call checks
|
||||
_allowed_model_region = (
|
||||
request_kwargs.get("allowed_model_region")
|
||||
if request_kwargs is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if self.enable_pre_call_checks and messages is not None:
|
||||
healthy_deployments = self._pre_call_checks(
|
||||
model=model,
|
||||
|
@ -5317,6 +5340,24 @@ class Router:
|
|||
request_kwargs=request_kwargs,
|
||||
)
|
||||
|
||||
if messages is not None and is_prompt_caching_valid_prompt(
|
||||
messages=cast(List[AllMessageValues], messages),
|
||||
model=model,
|
||||
custom_llm_provider=None,
|
||||
):
|
||||
prompt_cache = PromptCachingCache(
|
||||
cache=self.cache,
|
||||
)
|
||||
healthy_deployment = (
|
||||
await prompt_cache.async_get_prompt_caching_deployment(
|
||||
router=self,
|
||||
messages=cast(List[AllMessageValues], messages),
|
||||
tools=None,
|
||||
)
|
||||
)
|
||||
if healthy_deployment is not None:
|
||||
return healthy_deployment
|
||||
|
||||
# check if user wants to do tag based routing
|
||||
healthy_deployments = await get_deployments_for_tag( # type: ignore
|
||||
llm_router_instance=self,
|
||||
|
|
|
@ -49,6 +49,7 @@ async def run_async_fallback(
|
|||
raise original_exception
|
||||
|
||||
error_from_fallbacks = original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
if mg == original_model_group:
|
||||
continue
|
||||
|
|
193
litellm/router_utils/prompt_caching_cache.py
Normal file
193
litellm/router_utils/prompt_caching_cache.py
Normal file
|
@ -0,0 +1,193 @@
|
|||
"""
|
||||
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching.caching import Cache, DualCache
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.router import Router
|
||||
|
||||
litellm_router = Router
|
||||
Span = _Span
|
||||
else:
|
||||
Span = Any
|
||||
litellm_router = Any
|
||||
|
||||
|
||||
class PromptCachingCacheValue(TypedDict):
|
||||
model_id: str
|
||||
|
||||
|
||||
class PromptCachingCache:
|
||||
def __init__(self, cache: DualCache):
|
||||
self.cache = cache
|
||||
self.in_memory_cache = InMemoryCache()
|
||||
|
||||
@staticmethod
|
||||
def serialize_object(obj: Any) -> Any:
|
||||
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
|
||||
if hasattr(obj, "dict"):
|
||||
# If the object is a Pydantic model, use its `dict()` method
|
||||
return obj.dict()
|
||||
elif isinstance(obj, dict):
|
||||
# If the object is a dictionary, serialize it with sorted keys
|
||||
return json.dumps(
|
||||
obj, sort_keys=True, separators=(",", ":")
|
||||
) # Standardize serialization
|
||||
|
||||
elif isinstance(obj, list):
|
||||
# Serialize lists by ensuring each element is handled properly
|
||||
return [PromptCachingCache.serialize_object(item) for item in obj]
|
||||
elif isinstance(obj, (int, float, bool)):
|
||||
return obj # Keep primitive types as-is
|
||||
return str(obj)
|
||||
|
||||
@staticmethod
|
||||
def get_prompt_caching_cache_key(
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[str]:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
# Use serialize_object for consistent and stable serialization
|
||||
data_to_hash = {}
|
||||
if messages is not None:
|
||||
serialized_messages = PromptCachingCache.serialize_object(messages)
|
||||
data_to_hash["messages"] = serialized_messages
|
||||
if tools is not None:
|
||||
serialized_tools = PromptCachingCache.serialize_object(tools)
|
||||
data_to_hash["tools"] = serialized_tools
|
||||
|
||||
# Combine serialized data into a single string
|
||||
data_to_hash_str = json.dumps(
|
||||
data_to_hash,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
|
||||
# Create a hash of the serialized data for a stable cache key
|
||||
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
|
||||
return f"deployment:{hashed_data}:prompt_caching"
|
||||
|
||||
def add_model_id(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> None:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
self.cache.set_cache(
|
||||
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_add_model_id(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> None:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
await self.cache.async_set_cache(
|
||||
cache_key,
|
||||
PromptCachingCacheValue(model_id=model_id),
|
||||
ttl=300, # store for 5 minutes
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_get_model_id(
|
||||
self,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[PromptCachingCacheValue]:
|
||||
"""
|
||||
if messages is not none
|
||||
- check full messages
|
||||
- check messages[:-1]
|
||||
- check messages[:-2]
|
||||
- check messages[:-3]
|
||||
|
||||
use self.cache.async_batch_get_cache(keys=potential_cache_keys])
|
||||
"""
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
# Generate potential cache keys by slicing messages
|
||||
|
||||
potential_cache_keys = []
|
||||
|
||||
if messages is not None:
|
||||
full_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
|
||||
messages, tools
|
||||
)
|
||||
potential_cache_keys.append(full_cache_key)
|
||||
|
||||
# Check progressively shorter message slices
|
||||
for i in range(1, min(4, len(messages))):
|
||||
partial_messages = messages[:-i]
|
||||
partial_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
|
||||
partial_messages, tools
|
||||
)
|
||||
potential_cache_keys.append(partial_cache_key)
|
||||
|
||||
# Perform batch cache lookup
|
||||
cache_results = await self.cache.async_batch_get_cache(
|
||||
keys=potential_cache_keys
|
||||
)
|
||||
|
||||
if cache_results is None:
|
||||
return None
|
||||
|
||||
# Return the first non-None cache result
|
||||
for result in cache_results:
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def get_model_id(
|
||||
self,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[PromptCachingCacheValue]:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
return self.cache.get_cache(cache_key)
|
||||
|
||||
async def async_get_prompt_caching_deployment(
|
||||
self,
|
||||
router: litellm_router,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[dict]:
|
||||
model_id_dict = await self.async_get_model_id(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if model_id_dict is not None:
|
||||
healthy_deployment_pydantic_obj = router.get_deployment(
|
||||
model_id=model_id_dict["model_id"]
|
||||
)
|
||||
if healthy_deployment_pydantic_obj is not None:
|
||||
return healthy_deployment_pydantic_obj.model_dump(exclude_none=True)
|
||||
return None
|
|
@ -6151,6 +6151,38 @@ from litellm.types.llms.openai import (
|
|||
)
|
||||
|
||||
|
||||
def convert_to_dict(message: Union[BaseModel, dict]) -> dict:
|
||||
"""
|
||||
Converts a message to a dictionary if it's a Pydantic model.
|
||||
|
||||
Args:
|
||||
message: The message, which may be a Pydantic model or a dictionary.
|
||||
|
||||
Returns:
|
||||
dict: The converted message.
|
||||
"""
|
||||
if isinstance(message, BaseModel):
|
||||
return message.model_dump(exclude_none=True)
|
||||
elif isinstance(message, dict):
|
||||
return message
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Invalid message type: {type(message)}. Expected dict or Pydantic model."
|
||||
)
|
||||
|
||||
|
||||
def validate_chat_completion_messages(messages: List[AllMessageValues]):
|
||||
"""
|
||||
Ensures all messages are valid OpenAI chat completion messages.
|
||||
"""
|
||||
# 1. convert all messages to dict
|
||||
messages = [
|
||||
cast(AllMessageValues, convert_to_dict(cast(dict, m))) for m in messages
|
||||
]
|
||||
# 2. validate user messages
|
||||
return validate_chat_completion_user_messages(messages=messages)
|
||||
|
||||
|
||||
def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
||||
"""
|
||||
Ensures all user messages are valid OpenAI chat completion messages.
|
||||
|
@ -6229,3 +6261,22 @@ def get_end_user_id_for_cost_tracking(
|
|||
):
|
||||
return None
|
||||
return proxy_server_request.get("body", {}).get("user", None)
|
||||
|
||||
|
||||
def is_prompt_caching_valid_prompt(
|
||||
model: str,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns true if the prompt is valid for prompt caching.
|
||||
|
||||
OpenAI + Anthropic providers have a minimum token count of 1024 for prompt caching.
|
||||
"""
|
||||
if messages is None and tools is None:
|
||||
return False
|
||||
if custom_llm_provider is not None and not model.startswith(custom_llm_provider):
|
||||
model = custom_llm_provider + "/" + model
|
||||
token_count = token_counter(messages=messages, tools=tools, model=model)
|
||||
return token_count >= 1024
|
||||
|
|
67
tests/llm_translation/base_embedding_unit_tests.py
Normal file
67
tests/llm_translation/base_embedding_unit_tests.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
import pytest
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.exceptions import BadRequestError
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
get_supported_openai_params,
|
||||
get_optional_params,
|
||||
get_optional_params_embeddings,
|
||||
)
|
||||
|
||||
# test_example.py
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseLLMEmbeddingTest(ABC):
|
||||
"""
|
||||
Abstract base test class that enforces a common test across all test classes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_base_embedding_call_args(self) -> dict:
|
||||
"""Must return the base embedding call args"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||
"""Must return the custom llm provider"""
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_basic_embedding(self, sync_mode):
|
||||
litellm.set_verbose = True
|
||||
embedding_call_args = self.get_base_embedding_call_args()
|
||||
if sync_mode is True:
|
||||
response = litellm.embedding(
|
||||
**embedding_call_args,
|
||||
input=["hello", "world"],
|
||||
)
|
||||
|
||||
print("embedding response: ", response)
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
**embedding_call_args,
|
||||
input=["hello", "world"],
|
||||
)
|
||||
|
||||
print("async embedding response: ", response)
|
||||
|
||||
def test_embedding_optional_params_max_retries(self):
|
||||
embedding_call_args = self.get_base_embedding_call_args()
|
||||
optional_params = get_optional_params_embeddings(
|
||||
**embedding_call_args, max_retries=20
|
||||
)
|
||||
assert optional_params["max_retries"] == 20
|
|
@ -82,6 +82,16 @@ class BaseLLMChatTest(ABC):
|
|||
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
|
||||
assert response.choices[0].message.content is not None
|
||||
|
||||
def test_pydantic_model_input(self):
|
||||
litellm.set_verbose = True
|
||||
|
||||
from litellm import completion, Message
|
||||
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
messages = [Message(content="Hello, how are you?", role="user")]
|
||||
|
||||
completion(**base_completion_call_args, messages=messages)
|
||||
|
||||
@pytest.mark.parametrize("image_url", ["str", "dict"])
|
||||
def test_pdf_handling(self, pdf_messages, image_url):
|
||||
from litellm.utils import supports_pdf_input
|
||||
|
|
|
@ -8,6 +8,7 @@ sys.path.insert(
|
|||
import pytest
|
||||
from litellm.llms.azure.common_utils import process_azure_headers
|
||||
from httpx import Headers
|
||||
from base_embedding_unit_tests import BaseLLMEmbeddingTest
|
||||
|
||||
|
||||
def test_process_azure_headers_empty():
|
||||
|
@ -188,3 +189,15 @@ def test_process_azure_endpoint_url(api_base, model, expected_endpoint):
|
|||
}
|
||||
result = azure_chat_completion.create_azure_base_url(**input_args)
|
||||
assert result == expected_endpoint, "Unexpected endpoint"
|
||||
|
||||
|
||||
class TestAzureEmbedding(BaseLLMEmbeddingTest):
|
||||
def get_base_embedding_call_args(self) -> dict:
|
||||
return {
|
||||
"model": "azure/azure-embedding-model",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
}
|
||||
|
||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||
return litellm.LlmProviders.AZURE
|
||||
|
|
|
@ -161,6 +161,49 @@ async def test_litellm_anthropic_prompt_caching_tools():
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anthropic_messages():
|
||||
return [
|
||||
# System Message
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Here is the full text of a complex legal agreement" * 400,
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
},
|
||||
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What are the key terms and conditions in this agreement?",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
|
||||
},
|
||||
# The final turn is marked with cache-control, for continuing in followups.
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What are the key terms and conditions in this agreement?",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_anthropic_api_prompt_caching_basic():
|
||||
litellm.set_verbose = True
|
||||
|
@ -227,8 +270,6 @@ async def test_anthropic_api_prompt_caching_basic():
|
|||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_anthropic_api_prompt_caching_with_content_str():
|
||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||
|
||||
system_message = [
|
||||
{
|
||||
"role": "system",
|
||||
|
@ -546,3 +587,134 @@ async def test_litellm_anthropic_prompt_caching_system():
|
|||
mock_post.assert_called_once_with(
|
||||
expected_url, json=expected_json, headers=expected_headers, timeout=600.0
|
||||
)
|
||||
|
||||
|
||||
def test_is_prompt_caching_enabled(anthropic_messages):
|
||||
assert litellm.utils.is_prompt_caching_valid_prompt(
|
||||
messages=anthropic_messages,
|
||||
tools=None,
|
||||
custom_llm_provider="anthropic",
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"messages, expected_model_id",
|
||||
[("anthropic_messages", True), ("normal_messages", False)],
|
||||
)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_router_prompt_caching_model_stored(
|
||||
messages, expected_model_id, anthropic_messages
|
||||
):
|
||||
"""
|
||||
If a model is called with prompt caching supported, then the model id should be stored in the router cache.
|
||||
"""
|
||||
import asyncio
|
||||
from litellm.router import Router
|
||||
from litellm.router_utils.prompt_caching_cache import PromptCachingCache
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "claude-model",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||
"api_key": os.environ.get("ANTHROPIC_API_KEY"),
|
||||
},
|
||||
"model_info": {"id": "1234"},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
if messages == "anthropic_messages":
|
||||
_messages = anthropic_messages
|
||||
else:
|
||||
_messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
await router.acompletion(
|
||||
model="claude-model",
|
||||
messages=_messages,
|
||||
mock_response="The sky is blue.",
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
cache = PromptCachingCache(
|
||||
cache=router.cache,
|
||||
)
|
||||
|
||||
cached_model_id = cache.get_model_id(messages=_messages, tools=None)
|
||||
|
||||
if expected_model_id:
|
||||
assert cached_model_id["model_id"] == "1234"
|
||||
else:
|
||||
assert cached_model_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_router_with_prompt_caching(anthropic_messages):
|
||||
"""
|
||||
if prompt caching supported model called with prompt caching valid prompt,
|
||||
then 2nd call should go to the same model.
|
||||
"""
|
||||
from litellm.router import Router
|
||||
import asyncio
|
||||
from litellm.router_utils.prompt_caching_cache import PromptCachingCache
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "claude-model",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||
"api_key": os.environ.get("ANTHROPIC_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "claude-model",
|
||||
"litellm_params": {
|
||||
"model": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
response = await router.acompletion(
|
||||
messages=anthropic_messages,
|
||||
model="claude-model",
|
||||
mock_response="The sky is blue.",
|
||||
)
|
||||
print("response=", response)
|
||||
|
||||
initial_model_id = response._hidden_params["model_id"]
|
||||
|
||||
await asyncio.sleep(1)
|
||||
cache = PromptCachingCache(
|
||||
cache=router.cache,
|
||||
)
|
||||
|
||||
cached_model_id = cache.get_model_id(messages=anthropic_messages, tools=None)
|
||||
|
||||
prompt_caching_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
|
||||
messages=anthropic_messages, tools=None
|
||||
)
|
||||
print(f"prompt_caching_cache_key: {prompt_caching_cache_key}")
|
||||
assert cached_model_id["model_id"] == initial_model_id
|
||||
|
||||
new_messages = anthropic_messages + [
|
||||
{"role": "user", "content": "What is the weather in SF?"}
|
||||
]
|
||||
|
||||
pc_deployment = await cache.async_get_prompt_caching_deployment(
|
||||
router=router,
|
||||
messages=new_messages,
|
||||
tools=None,
|
||||
)
|
||||
assert pc_deployment is not None
|
||||
|
||||
response = await router.acompletion(
|
||||
messages=new_messages,
|
||||
model="claude-model",
|
||||
mock_response="The sky is blue.",
|
||||
)
|
||||
print("response=", response)
|
||||
|
||||
assert response._hidden_params["model_id"] == initial_model_id
|
||||
|
|
|
@ -2697,3 +2697,21 @@ def test_model_group_alias(hidden):
|
|||
# assert int(response_headers["x-ratelimit-remaining-requests"]) > 0
|
||||
# assert response_headers["x-ratelimit-limit-tokens"] == 100500
|
||||
# assert int(response_headers["x-ratelimit-remaining-tokens"]) > 0
|
||||
|
||||
|
||||
def test_router_completion_with_model_id():
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||
"model_info": {"id": "123"},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
router, "routing_strategy_pre_call_checks"
|
||||
) as mock_pre_call_checks:
|
||||
router.completion(model="123", messages=[{"role": "user", "content": "hi"}])
|
||||
mock_pre_call_checks.assert_not_called()
|
||||
|
|
|
@ -560,3 +560,34 @@ Unit tests for router set_cooldowns
|
|||
|
||||
1. _set_cooldown_deployments() will cooldown a deployment after it fails 50% requests
|
||||
"""
|
||||
|
||||
|
||||
def test_router_fallbacks_with_cooldowns_and_model_id():
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo", "rpm": 1},
|
||||
"model_info": {
|
||||
"id": "123",
|
||||
},
|
||||
}
|
||||
],
|
||||
routing_strategy="usage-based-routing-v2",
|
||||
fallbacks=[{"gpt-3.5-turbo": ["123"]}],
|
||||
)
|
||||
|
||||
## trigger ratelimit
|
||||
try:
|
||||
router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
mock_response="litellm.RateLimitError",
|
||||
)
|
||||
except litellm.RateLimitError:
|
||||
pass
|
||||
|
||||
router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
|
|
@ -1498,3 +1498,26 @@ async def test_router_disable_fallbacks_dynamically():
|
|||
print(e)
|
||||
|
||||
mock_client.assert_not_called()
|
||||
|
||||
|
||||
def test_router_fallbacks_with_model_id():
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo", "rpm": 1},
|
||||
"model_info": {
|
||||
"id": "123",
|
||||
},
|
||||
}
|
||||
],
|
||||
routing_strategy="usage-based-routing-v2",
|
||||
fallbacks=[{"gpt-3.5-turbo": ["123"]}],
|
||||
)
|
||||
|
||||
## test model id fallback works
|
||||
router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
mock_testing_fallbacks=True,
|
||||
)
|
||||
|
|
66
tests/router_unit_tests/test_router_prompt_caching.py
Normal file
66
tests/router_unit_tests/test_router_prompt_caching.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import sys
|
||||
import os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
from datetime import datetime
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from litellm import Router
|
||||
import pytest
|
||||
import litellm
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from create_mock_standard_logging_payload import create_standard_logging_payload
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
import unittest
|
||||
from pydantic import BaseModel
|
||||
from litellm.router_utils.prompt_caching_cache import PromptCachingCache
|
||||
|
||||
|
||||
class ExampleModel(BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
|
||||
|
||||
def test_serialize_pydantic_object():
|
||||
model = ExampleModel(field1="value", field2=42)
|
||||
serialized = PromptCachingCache.serialize_object(model)
|
||||
assert serialized == {"field1": "value", "field2": 42}
|
||||
|
||||
|
||||
def test_serialize_dict():
|
||||
obj = {"b": 2, "a": 1}
|
||||
serialized = PromptCachingCache.serialize_object(obj)
|
||||
assert serialized == '{"a":1,"b":2}' # JSON string with sorted keys
|
||||
|
||||
|
||||
def test_serialize_nested_dict():
|
||||
obj = {"z": {"b": 2, "a": 1}, "x": [1, 2, {"c": 3}]}
|
||||
serialized = PromptCachingCache.serialize_object(obj)
|
||||
expected = '{"x":[1,2,{"c":3}],"z":{"a":1,"b":2}}' # JSON string with sorted keys
|
||||
assert serialized == expected
|
||||
|
||||
|
||||
def test_serialize_list():
|
||||
obj = ["item1", {"a": 1, "b": 2}, 42]
|
||||
serialized = PromptCachingCache.serialize_object(obj)
|
||||
expected = ["item1", '{"a":1,"b":2}', 42]
|
||||
assert serialized == expected
|
||||
|
||||
|
||||
def test_serialize_fallback():
|
||||
obj = 12345 # Simple non-serializable object
|
||||
serialized = PromptCachingCache.serialize_object(obj)
|
||||
assert serialized == 12345
|
||||
|
||||
|
||||
def test_serialize_non_serializable():
|
||||
class CustomClass:
|
||||
def __str__(self):
|
||||
return "custom_object"
|
||||
|
||||
obj = CustomClass()
|
||||
serialized = PromptCachingCache.serialize_object(obj)
|
||||
assert serialized == "custom_object" # Fallback to string conversion
|
Loading…
Add table
Add a link
Reference in a new issue