Litellm dev 12 07 2024 (#7086)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 11s

* fix(main.py): support passing max retries to azure/openai embedding integrations

Fixes https://github.com/BerriAI/litellm/issues/7003

* feat(team_endpoints.py): allow updating team model aliases

Closes https://github.com/BerriAI/litellm/issues/6956

* feat(router.py): allow specifying model id as fallback - skips any cooldown check

Allows a default model to be checked if all models in cooldown

s/o @micahjsmith

* docs(reliability.md): add fallback to specific model to docs

* fix(utils.py): new 'is_prompt_caching_valid_prompt' helper util

Allows user to identify if messages/tools have prompt caching

Related issue: https://github.com/BerriAI/litellm/issues/6784

* feat(router.py): store model id for prompt caching valid prompt

Allows routing to that model id on subsequent requests

* fix(router.py): only cache if prompt is valid prompt caching prompt

prevents storing unnecessary items in cache

* feat(router.py): support routing prompt caching enabled models to previous deployments

Closes https://github.com/BerriAI/litellm/issues/6784

* test: fix linting errors

* feat(databricks/): convert basemodel to dict and exclude none values

allow passing pydantic message to databricks

* fix(utils.py): ensure all chat completion messages are dict

* (feat) Track `custom_llm_provider` in LiteLLMSpendLogs (#7081)

* add custom_llm_provider to SpendLogsPayload

* add custom_llm_provider to SpendLogs

* add custom llm provider to SpendLogs payload

* test_spend_logs_payload

* Add MLflow to the side bar (#7031)

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* (bug fix) SpendLogs update DB catch all possible DB errors for retrying  (#7082)

* catch DB_CONNECTION_ERROR_TYPES

* fix DB retry mechanism for SpendLog updates

* use DB_CONNECTION_ERROR_TYPES in auth checks

* fix exp back off for writing SpendLogs

* use _raise_failed_update_spend_exception to ensure errors print as NON blocking

* test_update_spend_logs_multiple_batches_with_failure

* (Feat) Add StructuredOutputs support for Fireworks.AI (#7085)

* fix model cost map fireworks ai "supports_response_schema": true,

* fix supports_response_schema

* fix map openai params fireworks ai

* test_map_response_format

* test_map_response_format

* added deepinfra/Meta-Llama-3.1-405B-Instruct (#7084)

* bump: version 1.53.9 → 1.54.0

* fix deepinfra

* litellm db fixes LiteLLM_UserTable (#7089)

* ci/cd queue new release

* fix llama-3.3-70b-versatile

* refactor - use consistent file naming convention `AI21/` -> `ai21`  (#7090)

* fix refactor - use consistent file naming convention

* ci/cd run again

* fix naming structure

* fix use consistent naming (#7092)

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: Yuki Watanabe <31463517+B-Step62@users.noreply.github.com>
Co-authored-by: ali sayyah <ali.sayyah2@gmail.com>
This commit is contained in:
Krish Dholakia 2024-12-08 00:30:33 -08:00 committed by GitHub
parent 36e99ebce7
commit 0c0498dd60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 840 additions and 193 deletions

View file

@ -315,6 +315,64 @@ litellm_settings:
cooldown_time: 30 # how long to cooldown model if fails/min > allowed_fails
```
### Fallback to Specific Model ID
If all models in a group are in cooldown (e.g. rate limited), LiteLLM will fallback to the model with the specific model ID.
This skips any cooldown check for the fallback model.
1. Specify the model ID in `model_info`
```yaml
model_list:
- model_name: gpt-4
litellm_params:
model: openai/gpt-4
model_info:
id: my-specific-model-id # 👈 KEY CHANGE
- model_name: gpt-4
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
- model_name: anthropic-claude
litellm_params:
model: anthropic/claude-3-opus-20240229
api_key: os.environ/ANTHROPIC_API_KEY
```
**Note:** This will only fallback to the model with the specific model ID. If you want to fallback to another model group, you can set `fallbacks=[{"gpt-4": ["anthropic-claude"]}]`
2. Set fallbacks in config
```yaml
litellm_settings:
fallbacks: [{"gpt-4": ["my-specific-model-id"]}]
```
3. Test it!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-D '{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "ping"
}
],
"mock_testing_fallbacks": true
}'
```
Validate it works, by checking the response header `x-litellm-model-id`
```bash
x-litellm-model-id: my-specific-model-id
```
### Test Fallbacks!
Check if your fallbacks are working as expected.
@ -337,6 +395,7 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
'
```
#### **Content Policy Fallbacks**
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \

View file

@ -1130,7 +1130,7 @@ router_settings:
If a call fails after num_retries, fall back to another model group.
### Quick Start
#### Quick Start
```python
from litellm import Router
@ -1366,6 +1366,7 @@ litellm --config /path/to/config.yaml
</TabItem>
</Tabs>
### Caching
In production, we recommend using a Redis cache. For quickly testing things locally, we also support simple in-memory caching.

View file

@ -21,6 +21,7 @@ from litellm.constants import (
DEFAULT_BATCH_SIZE,
DEFAULT_FLUSH_INTERVAL_SECONDS,
ROUTER_MAX_FALLBACKS,
DEFAULT_MAX_RETRIES,
)
from litellm.types.guardrails import GuardrailItem
from litellm.proxy._types import (

View file

@ -1,3 +1,4 @@
ROUTER_MAX_FALLBACKS = 5
DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
DEFAULT_MAX_RETRIES = 2

View file

@ -1217,12 +1217,13 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str] = None,
client=None,
aembedding=None,
max_retries: Optional[int] = None,
) -> litellm.EmbeddingResponse:
super().embedding()
try:
model = model
data = {"model": model, "input": input, **optional_params}
max_retries = data.pop("max_retries", 2)
max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
## LOGGING

View file

@ -871,6 +871,7 @@ class AzureChatCompletion(BaseLLM):
optional_params: dict,
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
max_retries: Optional[int] = None,
client=None,
aembedding=None,
) -> litellm.EmbeddingResponse:
@ -879,7 +880,7 @@ class AzureChatCompletion(BaseLLM):
self._client_session = self.create_client_session()
try:
data = {"model": model, "input": input, **optional_params}
max_retries = data.pop("max_retries", 2)
max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"

View file

@ -219,6 +219,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
api_base: Optional[str] = None,
client=None,
aembedding=None,
max_retries: Optional[int] = None,
) -> litellm.EmbeddingResponse:
"""
- Separate image url from text

View file

@ -134,7 +134,7 @@ class DatabricksConfig(OpenAIGPTConfig):
new_messages = []
for idx, message in enumerate(messages):
if isinstance(message, BaseModel):
_message = message.model_dump()
_message = message.model_dump(exclude_none=True)
else:
_message = message
new_messages.append(_message)

View file

@ -77,7 +77,7 @@ from litellm.utils import (
read_config_args,
supports_httpx_timeout,
token_counter,
validate_chat_completion_user_messages,
validate_chat_completion_messages,
)
from ._logging import verbose_logger
@ -931,7 +931,7 @@ def completion( # type: ignore # noqa: PLR0915
) # support region-based pricing for bedrock
### VALIDATE USER MESSAGES ###
validate_chat_completion_user_messages(messages=messages)
messages = validate_chat_completion_messages(messages=messages)
### TIMEOUT LOGIC ###
timeout = timeout or kwargs.get("request_timeout", 600) or 600
@ -3274,6 +3274,7 @@ def embedding( # noqa: PLR0915
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
max_retries = kwargs.get("max_retries", None)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
cooldown_time = kwargs.get("cooldown_time", None)
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
@ -3422,6 +3423,7 @@ def embedding( # noqa: PLR0915
optional_params=optional_params,
client=client,
aembedding=aembedding,
max_retries=max_retries,
)
elif (
model in litellm.open_ai_embedding_models
@ -3466,6 +3468,7 @@ def embedding( # noqa: PLR0915
optional_params=optional_params,
client=client,
aembedding=aembedding,
max_retries=max_retries,
)
elif custom_llm_provider == "databricks":
api_base = (

View file

@ -723,7 +723,7 @@ class KeyRequest(LiteLLMBase):
class LiteLLM_ModelTable(LiteLLMBase):
model_aliases: Optional[str] = None # json dump the dict
model_aliases: Optional[Union[str, dict]] = None # json dump the dict
created_by: str
updated_by: str
@ -981,6 +981,7 @@ class UpdateTeamRequest(LiteLLMBase):
blocked: Optional[bool] = None
budget_duration: Optional[str] = None
tags: Optional[list] = None
model_aliases: Optional[dict] = None
class ResetTeamBudgetRequest(LiteLLMBase):
@ -1059,6 +1060,7 @@ class LiteLLM_TeamTable(TeamBase):
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None
litellm_model_table: Optional[LiteLLM_ModelTable] = None
model_config = ConfigDict(protected_namespaces=())

View file

@ -293,8 +293,13 @@ async def new_team( # noqa: PLR0915
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
complete_team_data.budget_reset_at = reset_at
team_row: LiteLLM_TeamTable = await prisma_client.insert_data( # type: ignore
data=complete_team_data.json(exclude_none=True), table_name="team"
complete_team_data_dict = complete_team_data.model_dump(exclude_none=True)
complete_team_data_dict = prisma_client.jsonify_team_object(
db_data=complete_team_data_dict
)
team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create(
data=complete_team_data_dict,
include={"litellm_model_table": True}, # type: ignore
)
## ADD TEAM ID TO USER TABLE ##
@ -340,6 +345,37 @@ async def new_team( # noqa: PLR0915
return team_row.dict()
async def _update_model_table(
data: UpdateTeamRequest,
model_id: Optional[str],
prisma_client: PrismaClient,
user_api_key_dict: UserAPIKeyAuth,
litellm_proxy_admin_name: str,
) -> Optional[str]:
"""
Upsert model table and return the model id
"""
## UPSERT MODEL TABLE
_model_id = model_id
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
litellm_modeltable = LiteLLM_ModelTable(
model_aliases=json.dumps(data.model_aliases),
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
model_dict = await prisma_client.db.litellm_modeltable.upsert(
where={"id": model_id},
data={
"update": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
"create": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
},
) # type: ignore
_model_id = model_dict.id
return _model_id
@router.post(
"/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
)
@ -370,6 +406,7 @@ async def update_team(
- blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id.
- tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing).
- organization_id: Optional[str] - The organization id of the team. Default is None. Create via `/organization/new`.
- model_aliases: Optional[dict] - Model aliases for the team. [Docs](https://docs.litellm.ai/docs/proxy/team_based_routing#create-team-with-model-alias)
Example - update team TPM Limit
@ -446,11 +483,25 @@ async def update_team(
else:
updated_kv["metadata"] = {"tags": _tags}
updated_kv = prisma_client.jsonify_object(data=updated_kv)
team_row: Optional[
LiteLLM_TeamTable
] = await prisma_client.db.litellm_teamtable.update(
where={"team_id": data.team_id}, data=updated_kv # type: ignore
if "model_aliases" in updated_kv:
updated_kv.pop("model_aliases")
_model_id = await _update_model_table(
data=data,
model_id=existing_team_row.model_id,
prisma_client=prisma_client,
user_api_key_dict=user_api_key_dict,
litellm_proxy_admin_name=litellm_proxy_admin_name,
)
if _model_id is not None:
updated_kv["model_id"] = _model_id
updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv)
team_row: Optional[LiteLLM_TeamTable] = (
await prisma_client.db.litellm_teamtable.update(
where={"team_id": data.team_id},
data=updated_kv,
include={"litellm_model_table": True}, # type: ignore
)
)
if team_row is None or team_row.team_id is None:

View file

@ -1207,173 +1207,6 @@ class PrismaClient:
except Exception:
raise
# try:
# # Try to select one row from the view
# await self.db.query_raw(
# """SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1"""
# )
# print("LiteLLM_VerificationTokenView Exists!") # noqa
# except Exception as e:
# If an error occurs, the view does not exist, so create it
# try:
# await self.db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
# print("MonthlyGlobalSpend Exists!") # noqa
# except Exception as e:
# sql_query = """
# CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
# SELECT
# DATE("startTime") AS date,
# SUM("spend") AS spend
# FROM
# "LiteLLM_SpendLogs"
# WHERE
# "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
# GROUP BY
# DATE("startTime");
# """
# await self.db.execute_raw(query=sql_query)
# print("MonthlyGlobalSpend Created!") # noqa
# try:
# await self.db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
# print("Last30dKeysBySpend Exists!") # noqa
# except Exception as e:
# sql_query = """
# CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS
# SELECT
# L."api_key",
# V."key_alias",
# V."key_name",
# SUM(L."spend") AS total_spend
# FROM
# "LiteLLM_SpendLogs" L
# LEFT JOIN
# "LiteLLM_VerificationToken" V
# ON
# L."api_key" = V."token"
# WHERE
# L."startTime" >= (CURRENT_DATE - INTERVAL '30 days')
# GROUP BY
# L."api_key", V."key_alias", V."key_name"
# ORDER BY
# total_spend DESC;
# """
# await self.db.execute_raw(query=sql_query)
# print("Last30dKeysBySpend Created!") # noqa
# try:
# await self.db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""")
# print("Last30dModelsBySpend Exists!") # noqa
# except Exception as e:
# sql_query = """
# CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS
# SELECT
# "model",
# SUM("spend") AS total_spend
# FROM
# "LiteLLM_SpendLogs"
# WHERE
# "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
# AND "model" != ''
# GROUP BY
# "model"
# ORDER BY
# total_spend DESC;
# """
# await self.db.execute_raw(query=sql_query)
# print("Last30dModelsBySpend Created!") # noqa
# try:
# await self.db.query_raw(
# """SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1"""
# )
# print("MonthlyGlobalSpendPerKey Exists!") # noqa
# except Exception as e:
# sql_query = """
# CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS
# SELECT
# DATE("startTime") AS date,
# SUM("spend") AS spend,
# api_key as api_key
# FROM
# "LiteLLM_SpendLogs"
# WHERE
# "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
# GROUP BY
# DATE("startTime"),
# api_key;
# """
# await self.db.execute_raw(query=sql_query)
# print("MonthlyGlobalSpendPerKey Created!") # noqa
# try:
# await self.db.query_raw(
# """SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1"""
# )
# print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa
# except Exception as e:
# sql_query = """
# CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS
# SELECT
# DATE("startTime") AS date,
# SUM("spend") AS spend,
# api_key as api_key,
# "user" as "user"
# FROM
# "LiteLLM_SpendLogs"
# WHERE
# "startTime" >= (CURRENT_DATE - INTERVAL '20 days')
# GROUP BY
# DATE("startTime"),
# "user",
# api_key;
# """
# await self.db.execute_raw(query=sql_query)
# print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa
# try:
# await self.db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""")
# print("DailyTagSpend Exists!") # noqa
# except Exception as e:
# sql_query = """
# CREATE OR REPLACE VIEW DailyTagSpend AS
# SELECT
# jsonb_array_elements_text(request_tags) AS individual_request_tag,
# DATE(s."startTime") AS spend_date,
# COUNT(*) AS log_count,
# SUM(spend) AS total_spend
# FROM "LiteLLM_SpendLogs" s
# GROUP BY individual_request_tag, DATE(s."startTime");
# """
# await self.db.execute_raw(query=sql_query)
# print("DailyTagSpend Created!") # noqa
# try:
# await self.db.query_raw(
# """SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1"""
# )
# print("Last30dTopEndUsersSpend Exists!") # noqa
# except Exception as e:
# sql_query = """
# CREATE VIEW "Last30dTopEndUsersSpend" AS
# SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend
# FROM "LiteLLM_SpendLogs"
# WHERE end_user <> '' AND end_user <> user
# AND "startTime" >= CURRENT_DATE - INTERVAL '30 days'
# GROUP BY end_user
# ORDER BY total_spend DESC
# LIMIT 100;
# """
# await self.db.execute_raw(query=sql_query)
# print("Last30dTopEndUsersSpend Created!") # noqa
return
@log_db_metrics
@ -1784,6 +1617,14 @@ class PrismaClient:
)
raise e
def jsonify_team_object(self, db_data: dict):
db_data = self.jsonify_object(data=db_data)
if db_data.get("members_with_roles", None) is not None and isinstance(
db_data["members_with_roles"], list
):
db_data["members_with_roles"] = json.dumps(db_data["members_with_roles"])
return db_data
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
@ -2348,7 +2189,6 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
module_name = value
instance_name = None
try:
print_verbose(f"value: {value}")
# Split the path by dots to separate module from instance
parts = value.split(".")

View file

@ -36,6 +36,7 @@ from typing import (
Tuple,
TypedDict,
Union,
cast,
)
import httpx
@ -96,6 +97,7 @@ from litellm.router_utils.router_callbacks.track_deployment_metrics import (
)
from litellm.scheduler import FlowItem, Scheduler
from litellm.types.llms.openai import (
AllMessageValues,
Assistant,
AssistantToolParam,
AsyncCursorPage,
@ -149,10 +151,12 @@ from litellm.utils import (
get_llm_provider,
get_secret,
get_utc_datetime,
is_prompt_caching_valid_prompt,
is_region_allowed,
)
from .router_utils.pattern_match_deployments import PatternMatchRouter
from .router_utils.prompt_caching_cache import PromptCachingCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -737,7 +741,9 @@ class Router:
model_client = potential_model_client
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
## only run if model group given, not model id
if model not in self.get_model_ids():
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.completion(
**{
@ -2787,8 +2793,10 @@ class Router:
*args,
**input_kwargs,
)
return response
except Exception as new_exception:
traceback.print_exc()
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
verbose_router_logger.error(
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
@ -3376,6 +3384,29 @@ class Router:
deployment_id=id,
)
## PROMPT CACHING
prompt_cache = PromptCachingCache(
cache=self.cache,
)
if (
standard_logging_object["messages"] is not None
and isinstance(standard_logging_object["messages"], list)
and deployment_name is not None
and isinstance(deployment_name, str)
):
valid_prompt = is_prompt_caching_valid_prompt(
messages=standard_logging_object["messages"], # type: ignore
tools=None,
model=deployment_name,
custom_llm_provider=None,
)
if valid_prompt:
await prompt_cache.async_add_model_id(
model_id=id,
messages=standard_logging_object["messages"], # type: ignore
tools=None,
)
return tpm_key
except Exception as e:
@ -5190,7 +5221,6 @@ class Router:
- List, if multiple models chosen
- Dict, if specific model chosen
"""
# check if aliases set on litellm model alias map
if specific_deployment is True:
return model, self._get_deployment_by_litellm_model(model=model)
@ -5302,13 +5332,6 @@ class Router:
cooldown_deployments=cooldown_deployments,
)
# filter pre-call checks
_allowed_model_region = (
request_kwargs.get("allowed_model_region")
if request_kwargs is not None
else None
)
if self.enable_pre_call_checks and messages is not None:
healthy_deployments = self._pre_call_checks(
model=model,
@ -5317,6 +5340,24 @@ class Router:
request_kwargs=request_kwargs,
)
if messages is not None and is_prompt_caching_valid_prompt(
messages=cast(List[AllMessageValues], messages),
model=model,
custom_llm_provider=None,
):
prompt_cache = PromptCachingCache(
cache=self.cache,
)
healthy_deployment = (
await prompt_cache.async_get_prompt_caching_deployment(
router=self,
messages=cast(List[AllMessageValues], messages),
tools=None,
)
)
if healthy_deployment is not None:
return healthy_deployment
# check if user wants to do tag based routing
healthy_deployments = await get_deployments_for_tag( # type: ignore
llm_router_instance=self,

View file

@ -49,6 +49,7 @@ async def run_async_fallback(
raise original_exception
error_from_fallbacks = original_exception
for mg in fallback_model_group:
if mg == original_model_group:
continue

View file

@ -0,0 +1,193 @@
"""
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
"""
import hashlib
import json
import time
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict
import litellm
from litellm import verbose_logger
from litellm.caching.caching import Cache, DualCache
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router
litellm_router = Router
Span = _Span
else:
Span = Any
litellm_router = Any
class PromptCachingCacheValue(TypedDict):
model_id: str
class PromptCachingCache:
def __init__(self, cache: DualCache):
self.cache = cache
self.in_memory_cache = InMemoryCache()
@staticmethod
def serialize_object(obj: Any) -> Any:
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
if hasattr(obj, "dict"):
# If the object is a Pydantic model, use its `dict()` method
return obj.dict()
elif isinstance(obj, dict):
# If the object is a dictionary, serialize it with sorted keys
return json.dumps(
obj, sort_keys=True, separators=(",", ":")
) # Standardize serialization
elif isinstance(obj, list):
# Serialize lists by ensuring each element is handled properly
return [PromptCachingCache.serialize_object(item) for item in obj]
elif isinstance(obj, (int, float, bool)):
return obj # Keep primitive types as-is
return str(obj)
@staticmethod
def get_prompt_caching_cache_key(
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[str]:
if messages is None and tools is None:
return None
# Use serialize_object for consistent and stable serialization
data_to_hash = {}
if messages is not None:
serialized_messages = PromptCachingCache.serialize_object(messages)
data_to_hash["messages"] = serialized_messages
if tools is not None:
serialized_tools = PromptCachingCache.serialize_object(tools)
data_to_hash["tools"] = serialized_tools
# Combine serialized data into a single string
data_to_hash_str = json.dumps(
data_to_hash,
sort_keys=True,
separators=(",", ":"),
)
# Create a hash of the serialized data for a stable cache key
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
return f"deployment:{hashed_data}:prompt_caching"
def add_model_id(
self,
model_id: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> None:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
self.cache.set_cache(
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
)
return None
async def async_add_model_id(
self,
model_id: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> None:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
await self.cache.async_set_cache(
cache_key,
PromptCachingCacheValue(model_id=model_id),
ttl=300, # store for 5 minutes
)
return None
async def async_get_model_id(
self,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[PromptCachingCacheValue]:
"""
if messages is not none
- check full messages
- check messages[:-1]
- check messages[:-2]
- check messages[:-3]
use self.cache.async_batch_get_cache(keys=potential_cache_keys])
"""
if messages is None and tools is None:
return None
# Generate potential cache keys by slicing messages
potential_cache_keys = []
if messages is not None:
full_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
messages, tools
)
potential_cache_keys.append(full_cache_key)
# Check progressively shorter message slices
for i in range(1, min(4, len(messages))):
partial_messages = messages[:-i]
partial_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
partial_messages, tools
)
potential_cache_keys.append(partial_cache_key)
# Perform batch cache lookup
cache_results = await self.cache.async_batch_get_cache(
keys=potential_cache_keys
)
if cache_results is None:
return None
# Return the first non-None cache result
for result in cache_results:
if result is not None:
return result
return None
def get_model_id(
self,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[PromptCachingCacheValue]:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
return self.cache.get_cache(cache_key)
async def async_get_prompt_caching_deployment(
self,
router: litellm_router,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[dict]:
model_id_dict = await self.async_get_model_id(
messages=messages,
tools=tools,
)
if model_id_dict is not None:
healthy_deployment_pydantic_obj = router.get_deployment(
model_id=model_id_dict["model_id"]
)
if healthy_deployment_pydantic_obj is not None:
return healthy_deployment_pydantic_obj.model_dump(exclude_none=True)
return None

View file

@ -6151,6 +6151,38 @@ from litellm.types.llms.openai import (
)
def convert_to_dict(message: Union[BaseModel, dict]) -> dict:
"""
Converts a message to a dictionary if it's a Pydantic model.
Args:
message: The message, which may be a Pydantic model or a dictionary.
Returns:
dict: The converted message.
"""
if isinstance(message, BaseModel):
return message.model_dump(exclude_none=True)
elif isinstance(message, dict):
return message
else:
raise TypeError(
f"Invalid message type: {type(message)}. Expected dict or Pydantic model."
)
def validate_chat_completion_messages(messages: List[AllMessageValues]):
"""
Ensures all messages are valid OpenAI chat completion messages.
"""
# 1. convert all messages to dict
messages = [
cast(AllMessageValues, convert_to_dict(cast(dict, m))) for m in messages
]
# 2. validate user messages
return validate_chat_completion_user_messages(messages=messages)
def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
"""
Ensures all user messages are valid OpenAI chat completion messages.
@ -6229,3 +6261,22 @@ def get_end_user_id_for_cost_tracking(
):
return None
return proxy_server_request.get("body", {}).get("user", None)
def is_prompt_caching_valid_prompt(
model: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]] = None,
custom_llm_provider: Optional[str] = None,
) -> bool:
"""
Returns true if the prompt is valid for prompt caching.
OpenAI + Anthropic providers have a minimum token count of 1024 for prompt caching.
"""
if messages is None and tools is None:
return False
if custom_llm_provider is not None and not model.startswith(custom_llm_provider):
model = custom_llm_provider + "/" + model
token_count = token_counter(messages=messages, tools=tools, model=model)
return token_count >= 1024

View file

@ -0,0 +1,67 @@
import asyncio
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import (
CustomStreamWrapper,
get_supported_openai_params,
get_optional_params,
get_optional_params_embeddings,
)
# test_example.py
from abc import ABC, abstractmethod
class BaseLLMEmbeddingTest(ABC):
"""
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
def get_base_embedding_call_args(self) -> dict:
"""Must return the base embedding call args"""
pass
@abstractmethod
def get_custom_llm_provider(self) -> litellm.LlmProviders:
"""Must return the custom llm provider"""
pass
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_embedding(self, sync_mode):
litellm.set_verbose = True
embedding_call_args = self.get_base_embedding_call_args()
if sync_mode is True:
response = litellm.embedding(
**embedding_call_args,
input=["hello", "world"],
)
print("embedding response: ", response)
else:
response = await litellm.aembedding(
**embedding_call_args,
input=["hello", "world"],
)
print("async embedding response: ", response)
def test_embedding_optional_params_max_retries(self):
embedding_call_args = self.get_base_embedding_call_args()
optional_params = get_optional_params_embeddings(
**embedding_call_args, max_retries=20
)
assert optional_params["max_retries"] == 20

View file

@ -82,6 +82,16 @@ class BaseLLMChatTest(ABC):
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
assert response.choices[0].message.content is not None
def test_pydantic_model_input(self):
litellm.set_verbose = True
from litellm import completion, Message
base_completion_call_args = self.get_base_completion_call_args()
messages = [Message(content="Hello, how are you?", role="user")]
completion(**base_completion_call_args, messages=messages)
@pytest.mark.parametrize("image_url", ["str", "dict"])
def test_pdf_handling(self, pdf_messages, image_url):
from litellm.utils import supports_pdf_input

View file

@ -8,6 +8,7 @@ sys.path.insert(
import pytest
from litellm.llms.azure.common_utils import process_azure_headers
from httpx import Headers
from base_embedding_unit_tests import BaseLLMEmbeddingTest
def test_process_azure_headers_empty():
@ -188,3 +189,15 @@ def test_process_azure_endpoint_url(api_base, model, expected_endpoint):
}
result = azure_chat_completion.create_azure_base_url(**input_args)
assert result == expected_endpoint, "Unexpected endpoint"
class TestAzureEmbedding(BaseLLMEmbeddingTest):
def get_base_embedding_call_args(self) -> dict:
return {
"model": "azure/azure-embedding-model",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
}
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.AZURE

View file

@ -161,6 +161,49 @@ async def test_litellm_anthropic_prompt_caching_tools():
)
@pytest.fixture
def anthropic_messages():
return [
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 400,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
]
@pytest.mark.asyncio()
async def test_anthropic_api_prompt_caching_basic():
litellm.set_verbose = True
@ -227,8 +270,6 @@ async def test_anthropic_api_prompt_caching_basic():
@pytest.mark.asyncio()
async def test_anthropic_api_prompt_caching_with_content_str():
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
system_message = [
{
"role": "system",
@ -546,3 +587,134 @@ async def test_litellm_anthropic_prompt_caching_system():
mock_post.assert_called_once_with(
expected_url, json=expected_json, headers=expected_headers, timeout=600.0
)
def test_is_prompt_caching_enabled(anthropic_messages):
assert litellm.utils.is_prompt_caching_valid_prompt(
messages=anthropic_messages,
tools=None,
custom_llm_provider="anthropic",
model="anthropic/claude-3-5-sonnet-20240620",
)
@pytest.mark.parametrize(
"messages, expected_model_id",
[("anthropic_messages", True), ("normal_messages", False)],
)
@pytest.mark.asyncio()
async def test_router_prompt_caching_model_stored(
messages, expected_model_id, anthropic_messages
):
"""
If a model is called with prompt caching supported, then the model id should be stored in the router cache.
"""
import asyncio
from litellm.router import Router
from litellm.router_utils.prompt_caching_cache import PromptCachingCache
router = Router(
model_list=[
{
"model_name": "claude-model",
"litellm_params": {
"model": "anthropic/claude-3-5-sonnet-20240620",
"api_key": os.environ.get("ANTHROPIC_API_KEY"),
},
"model_info": {"id": "1234"},
}
]
)
if messages == "anthropic_messages":
_messages = anthropic_messages
else:
_messages = [{"role": "user", "content": "Hello"}]
await router.acompletion(
model="claude-model",
messages=_messages,
mock_response="The sky is blue.",
)
await asyncio.sleep(1)
cache = PromptCachingCache(
cache=router.cache,
)
cached_model_id = cache.get_model_id(messages=_messages, tools=None)
if expected_model_id:
assert cached_model_id["model_id"] == "1234"
else:
assert cached_model_id is None
@pytest.mark.asyncio()
async def test_router_with_prompt_caching(anthropic_messages):
"""
if prompt caching supported model called with prompt caching valid prompt,
then 2nd call should go to the same model.
"""
from litellm.router import Router
import asyncio
from litellm.router_utils.prompt_caching_cache import PromptCachingCache
router = Router(
model_list=[
{
"model_name": "claude-model",
"litellm_params": {
"model": "anthropic/claude-3-5-sonnet-20240620",
"api_key": os.environ.get("ANTHROPIC_API_KEY"),
},
},
{
"model_name": "claude-model",
"litellm_params": {
"model": "anthropic.claude-3-5-sonnet-20241022-v2:0",
},
},
]
)
response = await router.acompletion(
messages=anthropic_messages,
model="claude-model",
mock_response="The sky is blue.",
)
print("response=", response)
initial_model_id = response._hidden_params["model_id"]
await asyncio.sleep(1)
cache = PromptCachingCache(
cache=router.cache,
)
cached_model_id = cache.get_model_id(messages=anthropic_messages, tools=None)
prompt_caching_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
messages=anthropic_messages, tools=None
)
print(f"prompt_caching_cache_key: {prompt_caching_cache_key}")
assert cached_model_id["model_id"] == initial_model_id
new_messages = anthropic_messages + [
{"role": "user", "content": "What is the weather in SF?"}
]
pc_deployment = await cache.async_get_prompt_caching_deployment(
router=router,
messages=new_messages,
tools=None,
)
assert pc_deployment is not None
response = await router.acompletion(
messages=new_messages,
model="claude-model",
mock_response="The sky is blue.",
)
print("response=", response)
assert response._hidden_params["model_id"] == initial_model_id

View file

@ -2697,3 +2697,21 @@ def test_model_group_alias(hidden):
# assert int(response_headers["x-ratelimit-remaining-requests"]) > 0
# assert response_headers["x-ratelimit-limit-tokens"] == 100500
# assert int(response_headers["x-ratelimit-remaining-tokens"]) > 0
def test_router_completion_with_model_id():
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
"model_info": {"id": "123"},
}
]
)
with patch.object(
router, "routing_strategy_pre_call_checks"
) as mock_pre_call_checks:
router.completion(model="123", messages=[{"role": "user", "content": "hi"}])
mock_pre_call_checks.assert_not_called()

View file

@ -560,3 +560,34 @@ Unit tests for router set_cooldowns
1. _set_cooldown_deployments() will cooldown a deployment after it fails 50% requests
"""
def test_router_fallbacks_with_cooldowns_and_model_id():
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo", "rpm": 1},
"model_info": {
"id": "123",
},
}
],
routing_strategy="usage-based-routing-v2",
fallbacks=[{"gpt-3.5-turbo": ["123"]}],
)
## trigger ratelimit
try:
router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
mock_response="litellm.RateLimitError",
)
except litellm.RateLimitError:
pass
router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
)

View file

@ -1498,3 +1498,26 @@ async def test_router_disable_fallbacks_dynamically():
print(e)
mock_client.assert_not_called()
def test_router_fallbacks_with_model_id():
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo", "rpm": 1},
"model_info": {
"id": "123",
},
}
],
routing_strategy="usage-based-routing-v2",
fallbacks=[{"gpt-3.5-turbo": ["123"]}],
)
## test model id fallback works
router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
mock_testing_fallbacks=True,
)

View file

@ -0,0 +1,66 @@
import sys
import os
import traceback
from dotenv import load_dotenv
from fastapi import Request
from datetime import datetime
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm import Router
import pytest
import litellm
from unittest.mock import patch, MagicMock, AsyncMock
from create_mock_standard_logging_payload import create_standard_logging_payload
from litellm.types.utils import StandardLoggingPayload
import unittest
from pydantic import BaseModel
from litellm.router_utils.prompt_caching_cache import PromptCachingCache
class ExampleModel(BaseModel):
field1: str
field2: int
def test_serialize_pydantic_object():
model = ExampleModel(field1="value", field2=42)
serialized = PromptCachingCache.serialize_object(model)
assert serialized == {"field1": "value", "field2": 42}
def test_serialize_dict():
obj = {"b": 2, "a": 1}
serialized = PromptCachingCache.serialize_object(obj)
assert serialized == '{"a":1,"b":2}' # JSON string with sorted keys
def test_serialize_nested_dict():
obj = {"z": {"b": 2, "a": 1}, "x": [1, 2, {"c": 3}]}
serialized = PromptCachingCache.serialize_object(obj)
expected = '{"x":[1,2,{"c":3}],"z":{"a":1,"b":2}}' # JSON string with sorted keys
assert serialized == expected
def test_serialize_list():
obj = ["item1", {"a": 1, "b": 2}, 42]
serialized = PromptCachingCache.serialize_object(obj)
expected = ["item1", '{"a":1,"b":2}', 42]
assert serialized == expected
def test_serialize_fallback():
obj = 12345 # Simple non-serializable object
serialized = PromptCachingCache.serialize_object(obj)
assert serialized == 12345
def test_serialize_non_serializable():
class CustomClass:
def __str__(self):
return "custom_object"
obj = CustomClass()
serialized = PromptCachingCache.serialize_object(obj)
assert serialized == "custom_object" # Fallback to string conversion