mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm dev 12 07 2024 (#7086)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 11s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 11s
* fix(main.py): support passing max retries to azure/openai embedding integrations Fixes https://github.com/BerriAI/litellm/issues/7003 * feat(team_endpoints.py): allow updating team model aliases Closes https://github.com/BerriAI/litellm/issues/6956 * feat(router.py): allow specifying model id as fallback - skips any cooldown check Allows a default model to be checked if all models in cooldown s/o @micahjsmith * docs(reliability.md): add fallback to specific model to docs * fix(utils.py): new 'is_prompt_caching_valid_prompt' helper util Allows user to identify if messages/tools have prompt caching Related issue: https://github.com/BerriAI/litellm/issues/6784 * feat(router.py): store model id for prompt caching valid prompt Allows routing to that model id on subsequent requests * fix(router.py): only cache if prompt is valid prompt caching prompt prevents storing unnecessary items in cache * feat(router.py): support routing prompt caching enabled models to previous deployments Closes https://github.com/BerriAI/litellm/issues/6784 * test: fix linting errors * feat(databricks/): convert basemodel to dict and exclude none values allow passing pydantic message to databricks * fix(utils.py): ensure all chat completion messages are dict * (feat) Track `custom_llm_provider` in LiteLLMSpendLogs (#7081) * add custom_llm_provider to SpendLogsPayload * add custom_llm_provider to SpendLogs * add custom llm provider to SpendLogs payload * test_spend_logs_payload * Add MLflow to the side bar (#7031) Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> * (bug fix) SpendLogs update DB catch all possible DB errors for retrying (#7082) * catch DB_CONNECTION_ERROR_TYPES * fix DB retry mechanism for SpendLog updates * use DB_CONNECTION_ERROR_TYPES in auth checks * fix exp back off for writing SpendLogs * use _raise_failed_update_spend_exception to ensure errors print as NON blocking * test_update_spend_logs_multiple_batches_with_failure * (Feat) Add StructuredOutputs support for Fireworks.AI (#7085) * fix model cost map fireworks ai "supports_response_schema": true, * fix supports_response_schema * fix map openai params fireworks ai * test_map_response_format * test_map_response_format * added deepinfra/Meta-Llama-3.1-405B-Instruct (#7084) * bump: version 1.53.9 → 1.54.0 * fix deepinfra * litellm db fixes LiteLLM_UserTable (#7089) * ci/cd queue new release * fix llama-3.3-70b-versatile * refactor - use consistent file naming convention `AI21/` -> `ai21` (#7090) * fix refactor - use consistent file naming convention * ci/cd run again * fix naming structure * fix use consistent naming (#7092) --------- Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: Yuki Watanabe <31463517+B-Step62@users.noreply.github.com> Co-authored-by: ali sayyah <ali.sayyah2@gmail.com>
This commit is contained in:
parent
36e99ebce7
commit
0c0498dd60
24 changed files with 840 additions and 193 deletions
|
@ -315,6 +315,64 @@ litellm_settings:
|
||||||
cooldown_time: 30 # how long to cooldown model if fails/min > allowed_fails
|
cooldown_time: 30 # how long to cooldown model if fails/min > allowed_fails
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Fallback to Specific Model ID
|
||||||
|
|
||||||
|
If all models in a group are in cooldown (e.g. rate limited), LiteLLM will fallback to the model with the specific model ID.
|
||||||
|
|
||||||
|
This skips any cooldown check for the fallback model.
|
||||||
|
|
||||||
|
1. Specify the model ID in `model_info`
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-4
|
||||||
|
litellm_params:
|
||||||
|
model: openai/gpt-4
|
||||||
|
model_info:
|
||||||
|
id: my-specific-model-id # 👈 KEY CHANGE
|
||||||
|
- model_name: gpt-4
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
- model_name: anthropic-claude
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/claude-3-opus-20240229
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** This will only fallback to the model with the specific model ID. If you want to fallback to another model group, you can set `fallbacks=[{"gpt-4": ["anthropic-claude"]}]`
|
||||||
|
|
||||||
|
2. Set fallbacks in config
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
fallbacks: [{"gpt-4": ["my-specific-model-id"]}]
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-D '{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "ping"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"mock_testing_fallbacks": true
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Validate it works, by checking the response header `x-litellm-model-id`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
x-litellm-model-id: my-specific-model-id
|
||||||
|
```
|
||||||
|
|
||||||
### Test Fallbacks!
|
### Test Fallbacks!
|
||||||
|
|
||||||
Check if your fallbacks are working as expected.
|
Check if your fallbacks are working as expected.
|
||||||
|
@ -337,6 +395,7 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
'
|
'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
#### **Content Policy Fallbacks**
|
#### **Content Policy Fallbacks**
|
||||||
```bash
|
```bash
|
||||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
|
|
@ -1130,7 +1130,7 @@ router_settings:
|
||||||
|
|
||||||
If a call fails after num_retries, fall back to another model group.
|
If a call fails after num_retries, fall back to another model group.
|
||||||
|
|
||||||
### Quick Start
|
#### Quick Start
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
@ -1366,6 +1366,7 @@ litellm --config /path/to/config.yaml
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
### Caching
|
### Caching
|
||||||
|
|
||||||
In production, we recommend using a Redis cache. For quickly testing things locally, we also support simple in-memory caching.
|
In production, we recommend using a Redis cache. For quickly testing things locally, we also support simple in-memory caching.
|
||||||
|
|
|
@ -21,6 +21,7 @@ from litellm.constants import (
|
||||||
DEFAULT_BATCH_SIZE,
|
DEFAULT_BATCH_SIZE,
|
||||||
DEFAULT_FLUSH_INTERVAL_SECONDS,
|
DEFAULT_FLUSH_INTERVAL_SECONDS,
|
||||||
ROUTER_MAX_FALLBACKS,
|
ROUTER_MAX_FALLBACKS,
|
||||||
|
DEFAULT_MAX_RETRIES,
|
||||||
)
|
)
|
||||||
from litellm.types.guardrails import GuardrailItem
|
from litellm.types.guardrails import GuardrailItem
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
ROUTER_MAX_FALLBACKS = 5
|
ROUTER_MAX_FALLBACKS = 5
|
||||||
DEFAULT_BATCH_SIZE = 512
|
DEFAULT_BATCH_SIZE = 512
|
||||||
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
||||||
|
DEFAULT_MAX_RETRIES = 2
|
||||||
|
|
|
@ -1217,12 +1217,13 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
aembedding=None,
|
aembedding=None,
|
||||||
|
max_retries: Optional[int] = None,
|
||||||
) -> litellm.EmbeddingResponse:
|
) -> litellm.EmbeddingResponse:
|
||||||
super().embedding()
|
super().embedding()
|
||||||
try:
|
try:
|
||||||
model = model
|
model = model
|
||||||
data = {"model": model, "input": input, **optional_params}
|
data = {"model": model, "input": input, **optional_params}
|
||||||
max_retries = data.pop("max_retries", 2)
|
max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -871,6 +871,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
|
max_retries: Optional[int] = None,
|
||||||
client=None,
|
client=None,
|
||||||
aembedding=None,
|
aembedding=None,
|
||||||
) -> litellm.EmbeddingResponse:
|
) -> litellm.EmbeddingResponse:
|
||||||
|
@ -879,7 +880,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
self._client_session = self.create_client_session()
|
self._client_session = self.create_client_session()
|
||||||
try:
|
try:
|
||||||
data = {"model": model, "input": input, **optional_params}
|
data = {"model": model, "input": input, **optional_params}
|
||||||
max_retries = data.pop("max_retries", 2)
|
max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
status_code=422, message="max retries must be an int"
|
status_code=422, message="max retries must be an int"
|
||||||
|
|
|
@ -219,6 +219,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
aembedding=None,
|
aembedding=None,
|
||||||
|
max_retries: Optional[int] = None,
|
||||||
) -> litellm.EmbeddingResponse:
|
) -> litellm.EmbeddingResponse:
|
||||||
"""
|
"""
|
||||||
- Separate image url from text
|
- Separate image url from text
|
||||||
|
|
|
@ -134,7 +134,7 @@ class DatabricksConfig(OpenAIGPTConfig):
|
||||||
new_messages = []
|
new_messages = []
|
||||||
for idx, message in enumerate(messages):
|
for idx, message in enumerate(messages):
|
||||||
if isinstance(message, BaseModel):
|
if isinstance(message, BaseModel):
|
||||||
_message = message.model_dump()
|
_message = message.model_dump(exclude_none=True)
|
||||||
else:
|
else:
|
||||||
_message = message
|
_message = message
|
||||||
new_messages.append(_message)
|
new_messages.append(_message)
|
||||||
|
|
|
@ -77,7 +77,7 @@ from litellm.utils import (
|
||||||
read_config_args,
|
read_config_args,
|
||||||
supports_httpx_timeout,
|
supports_httpx_timeout,
|
||||||
token_counter,
|
token_counter,
|
||||||
validate_chat_completion_user_messages,
|
validate_chat_completion_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
|
@ -931,7 +931,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
) # support region-based pricing for bedrock
|
) # support region-based pricing for bedrock
|
||||||
|
|
||||||
### VALIDATE USER MESSAGES ###
|
### VALIDATE USER MESSAGES ###
|
||||||
validate_chat_completion_user_messages(messages=messages)
|
messages = validate_chat_completion_messages(messages=messages)
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
@ -3274,6 +3274,7 @@ def embedding( # noqa: PLR0915
|
||||||
client = kwargs.pop("client", None)
|
client = kwargs.pop("client", None)
|
||||||
rpm = kwargs.pop("rpm", None)
|
rpm = kwargs.pop("rpm", None)
|
||||||
tpm = kwargs.pop("tpm", None)
|
tpm = kwargs.pop("tpm", None)
|
||||||
|
max_retries = kwargs.get("max_retries", None)
|
||||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||||
cooldown_time = kwargs.get("cooldown_time", None)
|
cooldown_time = kwargs.get("cooldown_time", None)
|
||||||
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
|
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
|
||||||
|
@ -3422,6 +3423,7 @@ def embedding( # noqa: PLR0915
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
client=client,
|
client=client,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
|
max_retries=max_retries,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
model in litellm.open_ai_embedding_models
|
model in litellm.open_ai_embedding_models
|
||||||
|
@ -3466,6 +3468,7 @@ def embedding( # noqa: PLR0915
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
client=client,
|
client=client,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
|
max_retries=max_retries,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "databricks":
|
elif custom_llm_provider == "databricks":
|
||||||
api_base = (
|
api_base = (
|
||||||
|
|
|
@ -723,7 +723,7 @@ class KeyRequest(LiteLLMBase):
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_ModelTable(LiteLLMBase):
|
class LiteLLM_ModelTable(LiteLLMBase):
|
||||||
model_aliases: Optional[str] = None # json dump the dict
|
model_aliases: Optional[Union[str, dict]] = None # json dump the dict
|
||||||
created_by: str
|
created_by: str
|
||||||
updated_by: str
|
updated_by: str
|
||||||
|
|
||||||
|
@ -981,6 +981,7 @@ class UpdateTeamRequest(LiteLLMBase):
|
||||||
blocked: Optional[bool] = None
|
blocked: Optional[bool] = None
|
||||||
budget_duration: Optional[str] = None
|
budget_duration: Optional[str] = None
|
||||||
tags: Optional[list] = None
|
tags: Optional[list] = None
|
||||||
|
model_aliases: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class ResetTeamBudgetRequest(LiteLLMBase):
|
class ResetTeamBudgetRequest(LiteLLMBase):
|
||||||
|
@ -1059,6 +1060,7 @@ class LiteLLM_TeamTable(TeamBase):
|
||||||
budget_duration: Optional[str] = None
|
budget_duration: Optional[str] = None
|
||||||
budget_reset_at: Optional[datetime] = None
|
budget_reset_at: Optional[datetime] = None
|
||||||
model_id: Optional[int] = None
|
model_id: Optional[int] = None
|
||||||
|
litellm_model_table: Optional[LiteLLM_ModelTable] = None
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
|
@ -293,8 +293,13 @@ async def new_team( # noqa: PLR0915
|
||||||
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||||
complete_team_data.budget_reset_at = reset_at
|
complete_team_data.budget_reset_at = reset_at
|
||||||
|
|
||||||
team_row: LiteLLM_TeamTable = await prisma_client.insert_data( # type: ignore
|
complete_team_data_dict = complete_team_data.model_dump(exclude_none=True)
|
||||||
data=complete_team_data.json(exclude_none=True), table_name="team"
|
complete_team_data_dict = prisma_client.jsonify_team_object(
|
||||||
|
db_data=complete_team_data_dict
|
||||||
|
)
|
||||||
|
team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create(
|
||||||
|
data=complete_team_data_dict,
|
||||||
|
include={"litellm_model_table": True}, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
## ADD TEAM ID TO USER TABLE ##
|
## ADD TEAM ID TO USER TABLE ##
|
||||||
|
@ -340,6 +345,37 @@ async def new_team( # noqa: PLR0915
|
||||||
return team_row.dict()
|
return team_row.dict()
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_model_table(
|
||||||
|
data: UpdateTeamRequest,
|
||||||
|
model_id: Optional[str],
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
litellm_proxy_admin_name: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Upsert model table and return the model id
|
||||||
|
"""
|
||||||
|
## UPSERT MODEL TABLE
|
||||||
|
_model_id = model_id
|
||||||
|
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
|
||||||
|
litellm_modeltable = LiteLLM_ModelTable(
|
||||||
|
model_aliases=json.dumps(data.model_aliases),
|
||||||
|
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
)
|
||||||
|
model_dict = await prisma_client.db.litellm_modeltable.upsert(
|
||||||
|
where={"id": model_id},
|
||||||
|
data={
|
||||||
|
"update": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
|
||||||
|
"create": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
|
||||||
|
},
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
_model_id = model_dict.id
|
||||||
|
|
||||||
|
return _model_id
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
"/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||||
)
|
)
|
||||||
|
@ -370,6 +406,7 @@ async def update_team(
|
||||||
- blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id.
|
- blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id.
|
||||||
- tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing).
|
- tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing).
|
||||||
- organization_id: Optional[str] - The organization id of the team. Default is None. Create via `/organization/new`.
|
- organization_id: Optional[str] - The organization id of the team. Default is None. Create via `/organization/new`.
|
||||||
|
- model_aliases: Optional[dict] - Model aliases for the team. [Docs](https://docs.litellm.ai/docs/proxy/team_based_routing#create-team-with-model-alias)
|
||||||
|
|
||||||
Example - update team TPM Limit
|
Example - update team TPM Limit
|
||||||
|
|
||||||
|
@ -446,11 +483,25 @@ async def update_team(
|
||||||
else:
|
else:
|
||||||
updated_kv["metadata"] = {"tags": _tags}
|
updated_kv["metadata"] = {"tags": _tags}
|
||||||
|
|
||||||
updated_kv = prisma_client.jsonify_object(data=updated_kv)
|
if "model_aliases" in updated_kv:
|
||||||
team_row: Optional[
|
updated_kv.pop("model_aliases")
|
||||||
LiteLLM_TeamTable
|
_model_id = await _update_model_table(
|
||||||
] = await prisma_client.db.litellm_teamtable.update(
|
data=data,
|
||||||
where={"team_id": data.team_id}, data=updated_kv # type: ignore
|
model_id=existing_team_row.model_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||||
|
)
|
||||||
|
if _model_id is not None:
|
||||||
|
updated_kv["model_id"] = _model_id
|
||||||
|
|
||||||
|
updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv)
|
||||||
|
team_row: Optional[LiteLLM_TeamTable] = (
|
||||||
|
await prisma_client.db.litellm_teamtable.update(
|
||||||
|
where={"team_id": data.team_id},
|
||||||
|
data=updated_kv,
|
||||||
|
include={"litellm_model_table": True}, # type: ignore
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if team_row is None or team_row.team_id is None:
|
if team_row is None or team_row.team_id is None:
|
||||||
|
|
|
@ -1207,173 +1207,6 @@ class PrismaClient:
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# try:
|
|
||||||
# # Try to select one row from the view
|
|
||||||
# await self.db.query_raw(
|
|
||||||
# """SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1"""
|
|
||||||
# )
|
|
||||||
# print("LiteLLM_VerificationTokenView Exists!") # noqa
|
|
||||||
# except Exception as e:
|
|
||||||
# If an error occurs, the view does not exist, so create it
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# await self.db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
|
|
||||||
# print("MonthlyGlobalSpend Exists!") # noqa
|
|
||||||
# except Exception as e:
|
|
||||||
# sql_query = """
|
|
||||||
# CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
|
|
||||||
# SELECT
|
|
||||||
# DATE("startTime") AS date,
|
|
||||||
# SUM("spend") AS spend
|
|
||||||
# FROM
|
|
||||||
# "LiteLLM_SpendLogs"
|
|
||||||
# WHERE
|
|
||||||
# "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
|
||||||
# GROUP BY
|
|
||||||
# DATE("startTime");
|
|
||||||
# """
|
|
||||||
# await self.db.execute_raw(query=sql_query)
|
|
||||||
|
|
||||||
# print("MonthlyGlobalSpend Created!") # noqa
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# await self.db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
|
|
||||||
# print("Last30dKeysBySpend Exists!") # noqa
|
|
||||||
# except Exception as e:
|
|
||||||
# sql_query = """
|
|
||||||
# CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS
|
|
||||||
# SELECT
|
|
||||||
# L."api_key",
|
|
||||||
# V."key_alias",
|
|
||||||
# V."key_name",
|
|
||||||
# SUM(L."spend") AS total_spend
|
|
||||||
# FROM
|
|
||||||
# "LiteLLM_SpendLogs" L
|
|
||||||
# LEFT JOIN
|
|
||||||
# "LiteLLM_VerificationToken" V
|
|
||||||
# ON
|
|
||||||
# L."api_key" = V."token"
|
|
||||||
# WHERE
|
|
||||||
# L."startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
|
||||||
# GROUP BY
|
|
||||||
# L."api_key", V."key_alias", V."key_name"
|
|
||||||
# ORDER BY
|
|
||||||
# total_spend DESC;
|
|
||||||
# """
|
|
||||||
# await self.db.execute_raw(query=sql_query)
|
|
||||||
|
|
||||||
# print("Last30dKeysBySpend Created!") # noqa
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# await self.db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""")
|
|
||||||
# print("Last30dModelsBySpend Exists!") # noqa
|
|
||||||
# except Exception as e:
|
|
||||||
# sql_query = """
|
|
||||||
# CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS
|
|
||||||
# SELECT
|
|
||||||
# "model",
|
|
||||||
# SUM("spend") AS total_spend
|
|
||||||
# FROM
|
|
||||||
# "LiteLLM_SpendLogs"
|
|
||||||
# WHERE
|
|
||||||
# "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
|
||||||
# AND "model" != ''
|
|
||||||
# GROUP BY
|
|
||||||
# "model"
|
|
||||||
# ORDER BY
|
|
||||||
# total_spend DESC;
|
|
||||||
# """
|
|
||||||
# await self.db.execute_raw(query=sql_query)
|
|
||||||
|
|
||||||
# print("Last30dModelsBySpend Created!") # noqa
|
|
||||||
# try:
|
|
||||||
# await self.db.query_raw(
|
|
||||||
# """SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1"""
|
|
||||||
# )
|
|
||||||
# print("MonthlyGlobalSpendPerKey Exists!") # noqa
|
|
||||||
# except Exception as e:
|
|
||||||
# sql_query = """
|
|
||||||
# CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS
|
|
||||||
# SELECT
|
|
||||||
# DATE("startTime") AS date,
|
|
||||||
# SUM("spend") AS spend,
|
|
||||||
# api_key as api_key
|
|
||||||
# FROM
|
|
||||||
# "LiteLLM_SpendLogs"
|
|
||||||
# WHERE
|
|
||||||
# "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
|
||||||
# GROUP BY
|
|
||||||
# DATE("startTime"),
|
|
||||||
# api_key;
|
|
||||||
# """
|
|
||||||
# await self.db.execute_raw(query=sql_query)
|
|
||||||
|
|
||||||
# print("MonthlyGlobalSpendPerKey Created!") # noqa
|
|
||||||
# try:
|
|
||||||
# await self.db.query_raw(
|
|
||||||
# """SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1"""
|
|
||||||
# )
|
|
||||||
# print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa
|
|
||||||
# except Exception as e:
|
|
||||||
# sql_query = """
|
|
||||||
# CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS
|
|
||||||
# SELECT
|
|
||||||
# DATE("startTime") AS date,
|
|
||||||
# SUM("spend") AS spend,
|
|
||||||
# api_key as api_key,
|
|
||||||
# "user" as "user"
|
|
||||||
# FROM
|
|
||||||
# "LiteLLM_SpendLogs"
|
|
||||||
# WHERE
|
|
||||||
# "startTime" >= (CURRENT_DATE - INTERVAL '20 days')
|
|
||||||
# GROUP BY
|
|
||||||
# DATE("startTime"),
|
|
||||||
# "user",
|
|
||||||
# api_key;
|
|
||||||
# """
|
|
||||||
# await self.db.execute_raw(query=sql_query)
|
|
||||||
|
|
||||||
# print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# await self.db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""")
|
|
||||||
# print("DailyTagSpend Exists!") # noqa
|
|
||||||
# except Exception as e:
|
|
||||||
# sql_query = """
|
|
||||||
# CREATE OR REPLACE VIEW DailyTagSpend AS
|
|
||||||
# SELECT
|
|
||||||
# jsonb_array_elements_text(request_tags) AS individual_request_tag,
|
|
||||||
# DATE(s."startTime") AS spend_date,
|
|
||||||
# COUNT(*) AS log_count,
|
|
||||||
# SUM(spend) AS total_spend
|
|
||||||
# FROM "LiteLLM_SpendLogs" s
|
|
||||||
# GROUP BY individual_request_tag, DATE(s."startTime");
|
|
||||||
# """
|
|
||||||
# await self.db.execute_raw(query=sql_query)
|
|
||||||
|
|
||||||
# print("DailyTagSpend Created!") # noqa
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# await self.db.query_raw(
|
|
||||||
# """SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1"""
|
|
||||||
# )
|
|
||||||
# print("Last30dTopEndUsersSpend Exists!") # noqa
|
|
||||||
# except Exception as e:
|
|
||||||
# sql_query = """
|
|
||||||
# CREATE VIEW "Last30dTopEndUsersSpend" AS
|
|
||||||
# SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend
|
|
||||||
# FROM "LiteLLM_SpendLogs"
|
|
||||||
# WHERE end_user <> '' AND end_user <> user
|
|
||||||
# AND "startTime" >= CURRENT_DATE - INTERVAL '30 days'
|
|
||||||
# GROUP BY end_user
|
|
||||||
# ORDER BY total_spend DESC
|
|
||||||
# LIMIT 100;
|
|
||||||
# """
|
|
||||||
# await self.db.execute_raw(query=sql_query)
|
|
||||||
|
|
||||||
# print("Last30dTopEndUsersSpend Created!") # noqa
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@log_db_metrics
|
@log_db_metrics
|
||||||
|
@ -1784,6 +1617,14 @@ class PrismaClient:
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def jsonify_team_object(self, db_data: dict):
|
||||||
|
db_data = self.jsonify_object(data=db_data)
|
||||||
|
if db_data.get("members_with_roles", None) is not None and isinstance(
|
||||||
|
db_data["members_with_roles"], list
|
||||||
|
):
|
||||||
|
db_data["members_with_roles"] = json.dumps(db_data["members_with_roles"])
|
||||||
|
return db_data
|
||||||
|
|
||||||
# Define a retrying strategy with exponential backoff
|
# Define a retrying strategy with exponential backoff
|
||||||
@backoff.on_exception(
|
@backoff.on_exception(
|
||||||
backoff.expo,
|
backoff.expo,
|
||||||
|
@ -2348,7 +2189,6 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||||
module_name = value
|
module_name = value
|
||||||
instance_name = None
|
instance_name = None
|
||||||
try:
|
try:
|
||||||
print_verbose(f"value: {value}")
|
|
||||||
# Split the path by dots to separate module from instance
|
# Split the path by dots to separate module from instance
|
||||||
parts = value.split(".")
|
parts = value.split(".")
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ from typing import (
|
||||||
Tuple,
|
Tuple,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -96,6 +97,7 @@ from litellm.router_utils.router_callbacks.track_deployment_metrics import (
|
||||||
)
|
)
|
||||||
from litellm.scheduler import FlowItem, Scheduler
|
from litellm.scheduler import FlowItem, Scheduler
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
Assistant,
|
Assistant,
|
||||||
AssistantToolParam,
|
AssistantToolParam,
|
||||||
AsyncCursorPage,
|
AsyncCursorPage,
|
||||||
|
@ -149,10 +151,12 @@ from litellm.utils import (
|
||||||
get_llm_provider,
|
get_llm_provider,
|
||||||
get_secret,
|
get_secret,
|
||||||
get_utc_datetime,
|
get_utc_datetime,
|
||||||
|
is_prompt_caching_valid_prompt,
|
||||||
is_region_allowed,
|
is_region_allowed,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .router_utils.pattern_match_deployments import PatternMatchRouter
|
from .router_utils.pattern_match_deployments import PatternMatchRouter
|
||||||
|
from .router_utils.prompt_caching_cache import PromptCachingCache
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -737,6 +741,8 @@ class Router:
|
||||||
model_client = potential_model_client
|
model_client = potential_model_client
|
||||||
|
|
||||||
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
|
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
|
||||||
|
## only run if model group given, not model id
|
||||||
|
if model not in self.get_model_ids():
|
||||||
self.routing_strategy_pre_call_checks(deployment=deployment)
|
self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||||
|
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
|
@ -2787,8 +2793,10 @@ class Router:
|
||||||
*args,
|
*args,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as new_exception:
|
except Exception as new_exception:
|
||||||
|
traceback.print_exc()
|
||||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||||
verbose_router_logger.error(
|
verbose_router_logger.error(
|
||||||
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
|
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
|
||||||
|
@ -3376,6 +3384,29 @@ class Router:
|
||||||
deployment_id=id,
|
deployment_id=id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## PROMPT CACHING
|
||||||
|
prompt_cache = PromptCachingCache(
|
||||||
|
cache=self.cache,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
standard_logging_object["messages"] is not None
|
||||||
|
and isinstance(standard_logging_object["messages"], list)
|
||||||
|
and deployment_name is not None
|
||||||
|
and isinstance(deployment_name, str)
|
||||||
|
):
|
||||||
|
valid_prompt = is_prompt_caching_valid_prompt(
|
||||||
|
messages=standard_logging_object["messages"], # type: ignore
|
||||||
|
tools=None,
|
||||||
|
model=deployment_name,
|
||||||
|
custom_llm_provider=None,
|
||||||
|
)
|
||||||
|
if valid_prompt:
|
||||||
|
await prompt_cache.async_add_model_id(
|
||||||
|
model_id=id,
|
||||||
|
messages=standard_logging_object["messages"], # type: ignore
|
||||||
|
tools=None,
|
||||||
|
)
|
||||||
|
|
||||||
return tpm_key
|
return tpm_key
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -5190,7 +5221,6 @@ class Router:
|
||||||
- List, if multiple models chosen
|
- List, if multiple models chosen
|
||||||
- Dict, if specific model chosen
|
- Dict, if specific model chosen
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# check if aliases set on litellm model alias map
|
# check if aliases set on litellm model alias map
|
||||||
if specific_deployment is True:
|
if specific_deployment is True:
|
||||||
return model, self._get_deployment_by_litellm_model(model=model)
|
return model, self._get_deployment_by_litellm_model(model=model)
|
||||||
|
@ -5302,13 +5332,6 @@ class Router:
|
||||||
cooldown_deployments=cooldown_deployments,
|
cooldown_deployments=cooldown_deployments,
|
||||||
)
|
)
|
||||||
|
|
||||||
# filter pre-call checks
|
|
||||||
_allowed_model_region = (
|
|
||||||
request_kwargs.get("allowed_model_region")
|
|
||||||
if request_kwargs is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.enable_pre_call_checks and messages is not None:
|
if self.enable_pre_call_checks and messages is not None:
|
||||||
healthy_deployments = self._pre_call_checks(
|
healthy_deployments = self._pre_call_checks(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -5317,6 +5340,24 @@ class Router:
|
||||||
request_kwargs=request_kwargs,
|
request_kwargs=request_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if messages is not None and is_prompt_caching_valid_prompt(
|
||||||
|
messages=cast(List[AllMessageValues], messages),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=None,
|
||||||
|
):
|
||||||
|
prompt_cache = PromptCachingCache(
|
||||||
|
cache=self.cache,
|
||||||
|
)
|
||||||
|
healthy_deployment = (
|
||||||
|
await prompt_cache.async_get_prompt_caching_deployment(
|
||||||
|
router=self,
|
||||||
|
messages=cast(List[AllMessageValues], messages),
|
||||||
|
tools=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if healthy_deployment is not None:
|
||||||
|
return healthy_deployment
|
||||||
|
|
||||||
# check if user wants to do tag based routing
|
# check if user wants to do tag based routing
|
||||||
healthy_deployments = await get_deployments_for_tag( # type: ignore
|
healthy_deployments = await get_deployments_for_tag( # type: ignore
|
||||||
llm_router_instance=self,
|
llm_router_instance=self,
|
||||||
|
|
|
@ -49,6 +49,7 @@ async def run_async_fallback(
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
error_from_fallbacks = original_exception
|
error_from_fallbacks = original_exception
|
||||||
|
|
||||||
for mg in fallback_model_group:
|
for mg in fallback_model_group:
|
||||||
if mg == original_model_group:
|
if mg == original_model_group:
|
||||||
continue
|
continue
|
||||||
|
|
193
litellm/router_utils/prompt_caching_cache.py
Normal file
193
litellm/router_utils/prompt_caching_cache.py
Normal file
|
@ -0,0 +1,193 @@
|
||||||
|
"""
|
||||||
|
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import verbose_logger
|
||||||
|
from litellm.caching.caching import Cache, DualCache
|
||||||
|
from litellm.caching.in_memory_cache import InMemoryCache
|
||||||
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
|
from litellm.router import Router
|
||||||
|
|
||||||
|
litellm_router = Router
|
||||||
|
Span = _Span
|
||||||
|
else:
|
||||||
|
Span = Any
|
||||||
|
litellm_router = Any
|
||||||
|
|
||||||
|
|
||||||
|
class PromptCachingCacheValue(TypedDict):
|
||||||
|
model_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class PromptCachingCache:
|
||||||
|
def __init__(self, cache: DualCache):
|
||||||
|
self.cache = cache
|
||||||
|
self.in_memory_cache = InMemoryCache()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def serialize_object(obj: Any) -> Any:
|
||||||
|
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
|
||||||
|
if hasattr(obj, "dict"):
|
||||||
|
# If the object is a Pydantic model, use its `dict()` method
|
||||||
|
return obj.dict()
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
# If the object is a dictionary, serialize it with sorted keys
|
||||||
|
return json.dumps(
|
||||||
|
obj, sort_keys=True, separators=(",", ":")
|
||||||
|
) # Standardize serialization
|
||||||
|
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
# Serialize lists by ensuring each element is handled properly
|
||||||
|
return [PromptCachingCache.serialize_object(item) for item in obj]
|
||||||
|
elif isinstance(obj, (int, float, bool)):
|
||||||
|
return obj # Keep primitive types as-is
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_prompt_caching_cache_key(
|
||||||
|
messages: Optional[List[AllMessageValues]],
|
||||||
|
tools: Optional[List[ChatCompletionToolParam]],
|
||||||
|
) -> Optional[str]:
|
||||||
|
if messages is None and tools is None:
|
||||||
|
return None
|
||||||
|
# Use serialize_object for consistent and stable serialization
|
||||||
|
data_to_hash = {}
|
||||||
|
if messages is not None:
|
||||||
|
serialized_messages = PromptCachingCache.serialize_object(messages)
|
||||||
|
data_to_hash["messages"] = serialized_messages
|
||||||
|
if tools is not None:
|
||||||
|
serialized_tools = PromptCachingCache.serialize_object(tools)
|
||||||
|
data_to_hash["tools"] = serialized_tools
|
||||||
|
|
||||||
|
# Combine serialized data into a single string
|
||||||
|
data_to_hash_str = json.dumps(
|
||||||
|
data_to_hash,
|
||||||
|
sort_keys=True,
|
||||||
|
separators=(",", ":"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a hash of the serialized data for a stable cache key
|
||||||
|
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
|
||||||
|
return f"deployment:{hashed_data}:prompt_caching"
|
||||||
|
|
||||||
|
def add_model_id(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: Optional[List[AllMessageValues]],
|
||||||
|
tools: Optional[List[ChatCompletionToolParam]],
|
||||||
|
) -> None:
|
||||||
|
if messages is None and tools is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||||
|
self.cache.set_cache(
|
||||||
|
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def async_add_model_id(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: Optional[List[AllMessageValues]],
|
||||||
|
tools: Optional[List[ChatCompletionToolParam]],
|
||||||
|
) -> None:
|
||||||
|
if messages is None and tools is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||||
|
await self.cache.async_set_cache(
|
||||||
|
cache_key,
|
||||||
|
PromptCachingCacheValue(model_id=model_id),
|
||||||
|
ttl=300, # store for 5 minutes
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def async_get_model_id(
|
||||||
|
self,
|
||||||
|
messages: Optional[List[AllMessageValues]],
|
||||||
|
tools: Optional[List[ChatCompletionToolParam]],
|
||||||
|
) -> Optional[PromptCachingCacheValue]:
|
||||||
|
"""
|
||||||
|
if messages is not none
|
||||||
|
- check full messages
|
||||||
|
- check messages[:-1]
|
||||||
|
- check messages[:-2]
|
||||||
|
- check messages[:-3]
|
||||||
|
|
||||||
|
use self.cache.async_batch_get_cache(keys=potential_cache_keys])
|
||||||
|
"""
|
||||||
|
if messages is None and tools is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Generate potential cache keys by slicing messages
|
||||||
|
|
||||||
|
potential_cache_keys = []
|
||||||
|
|
||||||
|
if messages is not None:
|
||||||
|
full_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
|
||||||
|
messages, tools
|
||||||
|
)
|
||||||
|
potential_cache_keys.append(full_cache_key)
|
||||||
|
|
||||||
|
# Check progressively shorter message slices
|
||||||
|
for i in range(1, min(4, len(messages))):
|
||||||
|
partial_messages = messages[:-i]
|
||||||
|
partial_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
|
||||||
|
partial_messages, tools
|
||||||
|
)
|
||||||
|
potential_cache_keys.append(partial_cache_key)
|
||||||
|
|
||||||
|
# Perform batch cache lookup
|
||||||
|
cache_results = await self.cache.async_batch_get_cache(
|
||||||
|
keys=potential_cache_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache_results is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Return the first non-None cache result
|
||||||
|
for result in cache_results:
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_model_id(
|
||||||
|
self,
|
||||||
|
messages: Optional[List[AllMessageValues]],
|
||||||
|
tools: Optional[List[ChatCompletionToolParam]],
|
||||||
|
) -> Optional[PromptCachingCacheValue]:
|
||||||
|
if messages is None and tools is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||||
|
return self.cache.get_cache(cache_key)
|
||||||
|
|
||||||
|
async def async_get_prompt_caching_deployment(
|
||||||
|
self,
|
||||||
|
router: litellm_router,
|
||||||
|
messages: Optional[List[AllMessageValues]],
|
||||||
|
tools: Optional[List[ChatCompletionToolParam]],
|
||||||
|
) -> Optional[dict]:
|
||||||
|
model_id_dict = await self.async_get_model_id(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_id_dict is not None:
|
||||||
|
healthy_deployment_pydantic_obj = router.get_deployment(
|
||||||
|
model_id=model_id_dict["model_id"]
|
||||||
|
)
|
||||||
|
if healthy_deployment_pydantic_obj is not None:
|
||||||
|
return healthy_deployment_pydantic_obj.model_dump(exclude_none=True)
|
||||||
|
return None
|
|
@ -6151,6 +6151,38 @@ from litellm.types.llms.openai import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_dict(message: Union[BaseModel, dict]) -> dict:
|
||||||
|
"""
|
||||||
|
Converts a message to a dictionary if it's a Pydantic model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message, which may be a Pydantic model or a dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The converted message.
|
||||||
|
"""
|
||||||
|
if isinstance(message, BaseModel):
|
||||||
|
return message.model_dump(exclude_none=True)
|
||||||
|
elif isinstance(message, dict):
|
||||||
|
return message
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Invalid message type: {type(message)}. Expected dict or Pydantic model."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_chat_completion_messages(messages: List[AllMessageValues]):
|
||||||
|
"""
|
||||||
|
Ensures all messages are valid OpenAI chat completion messages.
|
||||||
|
"""
|
||||||
|
# 1. convert all messages to dict
|
||||||
|
messages = [
|
||||||
|
cast(AllMessageValues, convert_to_dict(cast(dict, m))) for m in messages
|
||||||
|
]
|
||||||
|
# 2. validate user messages
|
||||||
|
return validate_chat_completion_user_messages(messages=messages)
|
||||||
|
|
||||||
|
|
||||||
def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
||||||
"""
|
"""
|
||||||
Ensures all user messages are valid OpenAI chat completion messages.
|
Ensures all user messages are valid OpenAI chat completion messages.
|
||||||
|
@ -6229,3 +6261,22 @@ def get_end_user_id_for_cost_tracking(
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
return proxy_server_request.get("body", {}).get("user", None)
|
return proxy_server_request.get("body", {}).get("user", None)
|
||||||
|
|
||||||
|
|
||||||
|
def is_prompt_caching_valid_prompt(
|
||||||
|
model: str,
|
||||||
|
messages: Optional[List[AllMessageValues]],
|
||||||
|
tools: Optional[List[ChatCompletionToolParam]] = None,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Returns true if the prompt is valid for prompt caching.
|
||||||
|
|
||||||
|
OpenAI + Anthropic providers have a minimum token count of 1024 for prompt caching.
|
||||||
|
"""
|
||||||
|
if messages is None and tools is None:
|
||||||
|
return False
|
||||||
|
if custom_llm_provider is not None and not model.startswith(custom_llm_provider):
|
||||||
|
model = custom_llm_provider + "/" + model
|
||||||
|
token_count = token_counter(messages=messages, tools=tools, model=model)
|
||||||
|
return token_count >= 1024
|
||||||
|
|
67
tests/llm_translation/base_embedding_unit_tests.py
Normal file
67
tests/llm_translation/base_embedding_unit_tests.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm.exceptions import BadRequestError
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.utils import (
|
||||||
|
CustomStreamWrapper,
|
||||||
|
get_supported_openai_params,
|
||||||
|
get_optional_params,
|
||||||
|
get_optional_params_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# test_example.py
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMEmbeddingTest(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base test class that enforces a common test across all test classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_base_embedding_call_args(self) -> dict:
|
||||||
|
"""Must return the base embedding call args"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||||
|
"""Must return the custom llm provider"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
async def test_basic_embedding(self, sync_mode):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
embedding_call_args = self.get_base_embedding_call_args()
|
||||||
|
if sync_mode is True:
|
||||||
|
response = litellm.embedding(
|
||||||
|
**embedding_call_args,
|
||||||
|
input=["hello", "world"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("embedding response: ", response)
|
||||||
|
else:
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
**embedding_call_args,
|
||||||
|
input=["hello", "world"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("async embedding response: ", response)
|
||||||
|
|
||||||
|
def test_embedding_optional_params_max_retries(self):
|
||||||
|
embedding_call_args = self.get_base_embedding_call_args()
|
||||||
|
optional_params = get_optional_params_embeddings(
|
||||||
|
**embedding_call_args, max_retries=20
|
||||||
|
)
|
||||||
|
assert optional_params["max_retries"] == 20
|
|
@ -82,6 +82,16 @@ class BaseLLMChatTest(ABC):
|
||||||
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
|
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
|
||||||
assert response.choices[0].message.content is not None
|
assert response.choices[0].message.content is not None
|
||||||
|
|
||||||
|
def test_pydantic_model_input(self):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
from litellm import completion, Message
|
||||||
|
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
messages = [Message(content="Hello, how are you?", role="user")]
|
||||||
|
|
||||||
|
completion(**base_completion_call_args, messages=messages)
|
||||||
|
|
||||||
@pytest.mark.parametrize("image_url", ["str", "dict"])
|
@pytest.mark.parametrize("image_url", ["str", "dict"])
|
||||||
def test_pdf_handling(self, pdf_messages, image_url):
|
def test_pdf_handling(self, pdf_messages, image_url):
|
||||||
from litellm.utils import supports_pdf_input
|
from litellm.utils import supports_pdf_input
|
||||||
|
|
|
@ -8,6 +8,7 @@ sys.path.insert(
|
||||||
import pytest
|
import pytest
|
||||||
from litellm.llms.azure.common_utils import process_azure_headers
|
from litellm.llms.azure.common_utils import process_azure_headers
|
||||||
from httpx import Headers
|
from httpx import Headers
|
||||||
|
from base_embedding_unit_tests import BaseLLMEmbeddingTest
|
||||||
|
|
||||||
|
|
||||||
def test_process_azure_headers_empty():
|
def test_process_azure_headers_empty():
|
||||||
|
@ -188,3 +189,15 @@ def test_process_azure_endpoint_url(api_base, model, expected_endpoint):
|
||||||
}
|
}
|
||||||
result = azure_chat_completion.create_azure_base_url(**input_args)
|
result = azure_chat_completion.create_azure_base_url(**input_args)
|
||||||
assert result == expected_endpoint, "Unexpected endpoint"
|
assert result == expected_endpoint, "Unexpected endpoint"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureEmbedding(BaseLLMEmbeddingTest):
|
||||||
|
def get_base_embedding_call_args(self) -> dict:
|
||||||
|
return {
|
||||||
|
"model": "azure/azure-embedding-model",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||||
|
return litellm.LlmProviders.AZURE
|
||||||
|
|
|
@ -161,6 +161,49 @@ async def test_litellm_anthropic_prompt_caching_tools():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def anthropic_messages():
|
||||||
|
return [
|
||||||
|
# System Message
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement" * 400,
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
|
||||||
|
},
|
||||||
|
# The final turn is marked with cache-control, for continuing in followups.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
async def test_anthropic_api_prompt_caching_basic():
|
async def test_anthropic_api_prompt_caching_basic():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -227,8 +270,6 @@ async def test_anthropic_api_prompt_caching_basic():
|
||||||
|
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
async def test_anthropic_api_prompt_caching_with_content_str():
|
async def test_anthropic_api_prompt_caching_with_content_str():
|
||||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
|
||||||
|
|
||||||
system_message = [
|
system_message = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -546,3 +587,134 @@ async def test_litellm_anthropic_prompt_caching_system():
|
||||||
mock_post.assert_called_once_with(
|
mock_post.assert_called_once_with(
|
||||||
expected_url, json=expected_json, headers=expected_headers, timeout=600.0
|
expected_url, json=expected_json, headers=expected_headers, timeout=600.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_prompt_caching_enabled(anthropic_messages):
|
||||||
|
assert litellm.utils.is_prompt_caching_valid_prompt(
|
||||||
|
messages=anthropic_messages,
|
||||||
|
tools=None,
|
||||||
|
custom_llm_provider="anthropic",
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"messages, expected_model_id",
|
||||||
|
[("anthropic_messages", True), ("normal_messages", False)],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_router_prompt_caching_model_stored(
|
||||||
|
messages, expected_model_id, anthropic_messages
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
If a model is called with prompt caching supported, then the model id should be stored in the router cache.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from litellm.router import Router
|
||||||
|
from litellm.router_utils.prompt_caching_cache import PromptCachingCache
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "claude-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
"api_key": os.environ.get("ANTHROPIC_API_KEY"),
|
||||||
|
},
|
||||||
|
"model_info": {"id": "1234"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if messages == "anthropic_messages":
|
||||||
|
_messages = anthropic_messages
|
||||||
|
else:
|
||||||
|
_messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
await router.acompletion(
|
||||||
|
model="claude-model",
|
||||||
|
messages=_messages,
|
||||||
|
mock_response="The sky is blue.",
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
cache = PromptCachingCache(
|
||||||
|
cache=router.cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
cached_model_id = cache.get_model_id(messages=_messages, tools=None)
|
||||||
|
|
||||||
|
if expected_model_id:
|
||||||
|
assert cached_model_id["model_id"] == "1234"
|
||||||
|
else:
|
||||||
|
assert cached_model_id is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_router_with_prompt_caching(anthropic_messages):
|
||||||
|
"""
|
||||||
|
if prompt caching supported model called with prompt caching valid prompt,
|
||||||
|
then 2nd call should go to the same model.
|
||||||
|
"""
|
||||||
|
from litellm.router import Router
|
||||||
|
import asyncio
|
||||||
|
from litellm.router_utils.prompt_caching_cache import PromptCachingCache
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "claude-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
"api_key": os.environ.get("ANTHROPIC_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "claude-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await router.acompletion(
|
||||||
|
messages=anthropic_messages,
|
||||||
|
model="claude-model",
|
||||||
|
mock_response="The sky is blue.",
|
||||||
|
)
|
||||||
|
print("response=", response)
|
||||||
|
|
||||||
|
initial_model_id = response._hidden_params["model_id"]
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
cache = PromptCachingCache(
|
||||||
|
cache=router.cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
cached_model_id = cache.get_model_id(messages=anthropic_messages, tools=None)
|
||||||
|
|
||||||
|
prompt_caching_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
|
||||||
|
messages=anthropic_messages, tools=None
|
||||||
|
)
|
||||||
|
print(f"prompt_caching_cache_key: {prompt_caching_cache_key}")
|
||||||
|
assert cached_model_id["model_id"] == initial_model_id
|
||||||
|
|
||||||
|
new_messages = anthropic_messages + [
|
||||||
|
{"role": "user", "content": "What is the weather in SF?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
pc_deployment = await cache.async_get_prompt_caching_deployment(
|
||||||
|
router=router,
|
||||||
|
messages=new_messages,
|
||||||
|
tools=None,
|
||||||
|
)
|
||||||
|
assert pc_deployment is not None
|
||||||
|
|
||||||
|
response = await router.acompletion(
|
||||||
|
messages=new_messages,
|
||||||
|
model="claude-model",
|
||||||
|
mock_response="The sky is blue.",
|
||||||
|
)
|
||||||
|
print("response=", response)
|
||||||
|
|
||||||
|
assert response._hidden_params["model_id"] == initial_model_id
|
||||||
|
|
|
@ -2697,3 +2697,21 @@ def test_model_group_alias(hidden):
|
||||||
# assert int(response_headers["x-ratelimit-remaining-requests"]) > 0
|
# assert int(response_headers["x-ratelimit-remaining-requests"]) > 0
|
||||||
# assert response_headers["x-ratelimit-limit-tokens"] == 100500
|
# assert response_headers["x-ratelimit-limit-tokens"] == 100500
|
||||||
# assert int(response_headers["x-ratelimit-remaining-tokens"]) > 0
|
# assert int(response_headers["x-ratelimit-remaining-tokens"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_completion_with_model_id():
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||||
|
"model_info": {"id": "123"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router, "routing_strategy_pre_call_checks"
|
||||||
|
) as mock_pre_call_checks:
|
||||||
|
router.completion(model="123", messages=[{"role": "user", "content": "hi"}])
|
||||||
|
mock_pre_call_checks.assert_not_called()
|
||||||
|
|
|
@ -560,3 +560,34 @@ Unit tests for router set_cooldowns
|
||||||
|
|
||||||
1. _set_cooldown_deployments() will cooldown a deployment after it fails 50% requests
|
1. _set_cooldown_deployments() will cooldown a deployment after it fails 50% requests
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_fallbacks_with_cooldowns_and_model_id():
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "gpt-3.5-turbo", "rpm": 1},
|
||||||
|
"model_info": {
|
||||||
|
"id": "123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
routing_strategy="usage-based-routing-v2",
|
||||||
|
fallbacks=[{"gpt-3.5-turbo": ["123"]}],
|
||||||
|
)
|
||||||
|
|
||||||
|
## trigger ratelimit
|
||||||
|
try:
|
||||||
|
router.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
mock_response="litellm.RateLimitError",
|
||||||
|
)
|
||||||
|
except litellm.RateLimitError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
router.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
)
|
||||||
|
|
|
@ -1498,3 +1498,26 @@ async def test_router_disable_fallbacks_dynamically():
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
mock_client.assert_not_called()
|
mock_client.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_fallbacks_with_model_id():
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "gpt-3.5-turbo", "rpm": 1},
|
||||||
|
"model_info": {
|
||||||
|
"id": "123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
routing_strategy="usage-based-routing-v2",
|
||||||
|
fallbacks=[{"gpt-3.5-turbo": ["123"]}],
|
||||||
|
)
|
||||||
|
|
||||||
|
## test model id fallback works
|
||||||
|
router.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
mock_testing_fallbacks=True,
|
||||||
|
)
|
||||||
|
|
66
tests/router_unit_tests/test_router_prompt_caching.py
Normal file
66
tests/router_unit_tests/test_router_prompt_caching.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import Request
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from litellm import Router
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
from create_mock_standard_logging_payload import create_standard_logging_payload
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
import unittest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from litellm.router_utils.prompt_caching_cache import PromptCachingCache
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleModel(BaseModel):
|
||||||
|
field1: str
|
||||||
|
field2: int
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_pydantic_object():
|
||||||
|
model = ExampleModel(field1="value", field2=42)
|
||||||
|
serialized = PromptCachingCache.serialize_object(model)
|
||||||
|
assert serialized == {"field1": "value", "field2": 42}
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_dict():
|
||||||
|
obj = {"b": 2, "a": 1}
|
||||||
|
serialized = PromptCachingCache.serialize_object(obj)
|
||||||
|
assert serialized == '{"a":1,"b":2}' # JSON string with sorted keys
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_nested_dict():
|
||||||
|
obj = {"z": {"b": 2, "a": 1}, "x": [1, 2, {"c": 3}]}
|
||||||
|
serialized = PromptCachingCache.serialize_object(obj)
|
||||||
|
expected = '{"x":[1,2,{"c":3}],"z":{"a":1,"b":2}}' # JSON string with sorted keys
|
||||||
|
assert serialized == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_list():
|
||||||
|
obj = ["item1", {"a": 1, "b": 2}, 42]
|
||||||
|
serialized = PromptCachingCache.serialize_object(obj)
|
||||||
|
expected = ["item1", '{"a":1,"b":2}', 42]
|
||||||
|
assert serialized == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_fallback():
|
||||||
|
obj = 12345 # Simple non-serializable object
|
||||||
|
serialized = PromptCachingCache.serialize_object(obj)
|
||||||
|
assert serialized == 12345
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_non_serializable():
|
||||||
|
class CustomClass:
|
||||||
|
def __str__(self):
|
||||||
|
return "custom_object"
|
||||||
|
|
||||||
|
obj = CustomClass()
|
||||||
|
serialized = PromptCachingCache.serialize_object(obj)
|
||||||
|
assert serialized == "custom_object" # Fallback to string conversion
|
Loading…
Add table
Add a link
Reference in a new issue