LiteLLM Minor Fixes & Improvements (09/21/2024) (#5819)

* fix(router.py): fix error message

* Litellm disable keys (#5814)

* build(schema.prisma): allow blocking/unblocking keys

Fixes https://github.com/BerriAI/litellm/issues/5328

* fix(key_management_endpoints.py): fix pop

* feat(auth_checks.py): allow admin to enable/disable virtual keys

Closes https://github.com/BerriAI/litellm/issues/5328

* docs(vertex.md): add auth section for vertex ai

Addresses - https://github.com/BerriAI/litellm/issues/5768#issuecomment-2365284223

* build(model_prices_and_context_window.json): show which models support prompt_caching

Closes https://github.com/BerriAI/litellm/issues/5776

* fix(router.py): allow setting default priority for requests

* fix(router.py): add 'retry-after' header for concurrent request limit errors

Fixes https://github.com/BerriAI/litellm/issues/5783

* fix(router.py): correctly raise and use retry-after header from azure+openai

Fixes https://github.com/BerriAI/litellm/issues/5783

* fix(user_api_key_auth.py): fix valid token being none

* fix(auth_checks.py): fix model dump for cache management object

* fix(user_api_key_auth.py): pass prisma_client to obj

* test(test_otel.py): update test for new key check

* test: fix test
This commit is contained in:
Krish Dholakia 2024-09-21 18:51:53 -07:00 committed by GitHub
parent 1ca638973f
commit 8039b95aaf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1006 additions and 182 deletions

View file

@ -2118,8 +2118,93 @@ print("response from proxy", response)
## Authentication - vertex_project, vertex_location, etc.
Set your vertex credentials via:
- dynamic params
OR
- env vars
### **Dynamic Params**
You can set:
- `vertex_credentials` (str) - can be a json string or filepath to your vertex ai service account.json
- `vertex_location` (str) - place where vertex model is deployed (us-central1, asia-southeast1, etc.)
- `vertex_project` Optional[str] - use if vertex project different from the one in vertex_credentials
as dynamic params for a `litellm.completion` call.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import json
## GET CREDENTIALS
file_path = 'path/to/vertex_ai_service_account.json'
# Load the JSON file
with open(file_path, 'r') as file:
vertex_credentials = json.load(file)
# Convert to JSON string
vertex_credentials_json = json.dumps(vertex_credentials)
response = completion(
model="vertex_ai/gemini-pro",
messages=[{"content": "You are a good bot.","role": "system"}, {"content": "Hello, how are you?","role": "user"}],
vertex_credentials=vertex_credentials_json,
vertex_project="my-special-project",
vertex_location="my-special-location"
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```yaml
model_list:
- model_name: gemini-1.5-pro
litellm_params:
model: gemini-1.5-pro
vertex_credentials: os.environ/VERTEX_FILE_PATH_ENV_VAR # os.environ["VERTEX_FILE_PATH_ENV_VAR"] = "/path/to/service_account.json"
vertex_project: "my-special-project"
vertex_location: "my-special-location:
```
</TabItem>
</Tabs>
### **Environment Variables**
You can set:
- `GOOGLE_APPLICATION_CREDENTIALS` - store the filepath for your service_account.json in here (used by vertex sdk directly).
- VERTEXAI_LOCATION - place where vertex model is deployed (us-central1, asia-southeast1, etc.)
- VERTEXAI_PROJECT - Optional[str] - use if vertex project different from the one in vertex_credentials
1. GOOGLE_APPLICATION_CREDENTIALS
```bash
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/service_account.json"
```
2. VERTEXAI_LOCATION
```bash
export VERTEXAI_LOCATION="us-central1" # can be any vertex location
```
3. VERTEXAI_PROJECT
```bash
export VERTEXAI_PROJECT="my-test-project" # ONLY use if model project is different from service account project
```
## Extra

View file

@ -253,7 +253,7 @@ def get_redis_client(**env_overrides):
return redis.Redis(**redis_kwargs)
def get_redis_async_client(**env_overrides):
def get_redis_async_client(**env_overrides) -> async_redis.Redis:
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)

View file

@ -207,7 +207,7 @@ class RedisCache(BaseCache):
host=None,
port=None,
password=None,
redis_flush_size=100,
redis_flush_size: Optional[int] = 100,
namespace: Optional[str] = None,
startup_nodes: Optional[List] = None, # for redis-cluster
**kwargs,
@ -244,7 +244,10 @@ class RedisCache(BaseCache):
self.namespace = namespace
# for high traffic, we store the redis results in memory and then batch write to redis
self.redis_batch_writing_buffer: list = []
self.redis_flush_size = redis_flush_size
if redis_flush_size is None:
self.redis_flush_size: int = 100
else:
self.redis_flush_size = redis_flush_size
self.redis_version = "Unknown"
try:
self.redis_version = self.redis_client.info()["redis_version"]
@ -317,7 +320,7 @@ class RedisCache(BaseCache):
current_ttl = _redis_client.ttl(key)
if current_ttl == -1:
# Key has no expiration
_redis_client.expire(key, ttl)
_redis_client.expire(key, ttl) # type: ignore
return result
except Exception as e:
## LOGGING ##
@ -331,10 +334,13 @@ class RedisCache(BaseCache):
raise e
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
from redis.asyncio import Redis
start_time = time.time()
try:
keys = []
_redis_client = self.init_async_client()
_redis_client: Redis = self.init_async_client() # type: ignore
async with _redis_client as redis_client:
async for key in redis_client.scan_iter(
match=pattern + "*", count=count
@ -374,9 +380,11 @@ class RedisCache(BaseCache):
raise e
async def async_set_cache(self, key, value, **kwargs):
from redis.asyncio import Redis
start_time = time.time()
try:
_redis_client = self.init_async_client()
_redis_client: Redis = self.init_async_client() # type: ignore
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
@ -397,6 +405,7 @@ class RedisCache(BaseCache):
str(e),
value,
)
raise e
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client:
@ -405,6 +414,10 @@ class RedisCache(BaseCache):
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
try:
if not hasattr(redis_client, "set"):
raise Exception(
"Redis client cannot set cache. Attribute not found."
)
await redis_client.set(name=key, value=json.dumps(value), ex=ttl)
print_verbose(
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
@ -446,12 +459,15 @@ class RedisCache(BaseCache):
"""
Use Redis Pipelines for bulk write operations
"""
_redis_client = self.init_async_client()
from redis.asyncio import Redis
_redis_client: Redis = self.init_async_client() # type: ignore
start_time = time.time()
print_verbose(
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
)
cache_value: Any = None
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
@ -463,6 +479,7 @@ class RedisCache(BaseCache):
)
json_cache_value = json.dumps(cache_value)
# Set the value with a TTL if it's provided.
if ttl is not None:
pipe.setex(cache_key, ttl, json_cache_value)
else:
@ -511,9 +528,11 @@ class RedisCache(BaseCache):
async def async_set_cache_sadd(
self, key, value: List, ttl: Optional[float], **kwargs
):
from redis.asyncio import Redis
start_time = time.time()
try:
_redis_client = self.init_async_client()
_redis_client: Redis = self.init_async_client() # type: ignore
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
@ -592,9 +611,11 @@ class RedisCache(BaseCache):
await self.flush_cache_buffer() # logging done in here
async def async_increment(
self, key, value: float, ttl: Optional[float] = None, **kwargs
self, key, value: float, ttl: Optional[int] = None, **kwargs
) -> float:
_redis_client = self.init_async_client()
from redis.asyncio import Redis
_redis_client: Redis = self.init_async_client() # type: ignore
start_time = time.time()
try:
async with _redis_client as redis_client:
@ -708,7 +729,9 @@ class RedisCache(BaseCache):
return key_value_dict
async def async_get_cache(self, key, **kwargs):
_redis_client = self.init_async_client()
from redis.asyncio import Redis
_redis_client: Redis = self.init_async_client() # type: ignore
key = self.check_and_fix_namespace(key=key)
start_time = time.time()
async with _redis_client as redis_client:
@ -903,6 +926,12 @@ class RedisCache(BaseCache):
async def disconnect(self):
await self.async_redis_conn_pool.disconnect(inuse_connections=True)
async def async_delete_cache(self, key: str):
_redis_client = self.init_async_client()
# keys is str
async with _redis_client as redis_client:
await redis_client.delete(key)
def delete_cache(self, key):
self.redis_client.delete(key)
@ -1241,6 +1270,7 @@ class QdrantSemanticCache(BaseCache):
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.secret_managers.main import get_secret_str
if collection_name is None:
raise Exception("collection_name must be provided, passed None")
@ -1261,12 +1291,12 @@ class QdrantSemanticCache(BaseCache):
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
"os.environ/"
):
qdrant_api_base = litellm.get_secret(qdrant_api_base)
qdrant_api_base = get_secret_str(qdrant_api_base)
if qdrant_api_key:
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
"os.environ/"
):
qdrant_api_key = litellm.get_secret(qdrant_api_key)
qdrant_api_key = get_secret_str(qdrant_api_key)
qdrant_api_base = (
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
@ -1633,7 +1663,7 @@ class S3Cache(BaseCache):
s3_bucket_name,
s3_region_name=None,
s3_api_version=None,
s3_use_ssl=True,
s3_use_ssl: Optional[bool] = True,
s3_verify=None,
s3_endpoint_url=None,
s3_aws_access_key_id=None,
@ -1721,7 +1751,7 @@ class S3Cache(BaseCache):
Bucket=self.bucket_name, Key=key
)
if cached_response != None:
if cached_response is not None:
# cached_response is in `b{} convert it to ModelResponse
cached_response = (
cached_response["Body"].read().decode("utf-8")
@ -1739,7 +1769,7 @@ class S3Cache(BaseCache):
)
return cached_response
except botocore.exceptions.ClientError as e:
except botocore.exceptions.ClientError as e: # type: ignore
if e.response["Error"]["Code"] == "NoSuchKey":
verbose_logger.debug(
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
@ -2081,6 +2111,15 @@ class DualCache(BaseCache):
if self.redis_cache is not None:
self.redis_cache.delete_cache(key)
async def async_delete_cache(self, key: str):
"""
Delete a key from the cache
"""
if self.in_memory_cache is not None:
self.in_memory_cache.delete_cache(key)
if self.redis_cache is not None:
await self.redis_cache.async_delete_cache(key)
#### LiteLLM.Completion / Embedding Cache ####
class Cache:
@ -2137,7 +2176,7 @@ class Cache:
s3_path: Optional[str] = None,
redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None,
redis_flush_size: Optional[int] = None,
redis_startup_nodes: Optional[List] = None,
disk_cache_dir=None,
qdrant_api_base: Optional[str] = None,
@ -2501,10 +2540,9 @@ class Cache:
if self.ttl is not None:
kwargs["ttl"] = self.ttl
## Get Cache-Controls ##
if kwargs.get("cache", None) is not None and isinstance(
kwargs.get("cache"), dict
):
for k, v in kwargs.get("cache").items():
_cache_kwargs = kwargs.get("cache", None)
if isinstance(_cache_kwargs, dict):
for k, v in _cache_kwargs.items():
if k == "ttl":
kwargs["ttl"] = v
@ -2574,14 +2612,15 @@ class Cache:
**kwargs,
)
cache_list.append((cache_key, cached_data))
if hasattr(self.cache, "async_set_cache_pipeline"):
await self.cache.async_set_cache_pipeline(cache_list=cache_list)
async_set_cache_pipeline = getattr(
self.cache, "async_set_cache_pipeline", None
)
if async_set_cache_pipeline:
await async_set_cache_pipeline(cache_list=cache_list)
else:
tasks = []
for val in cache_list:
tasks.append(
self.cache.async_set_cache(cache_key, cached_data, **kwargs)
)
tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
@ -2611,13 +2650,15 @@ class Cache:
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
async def ping(self):
if hasattr(self.cache, "ping"):
return await self.cache.ping()
cache_ping = getattr(self.cache, "ping")
if cache_ping:
return await cache_ping()
return None
async def delete_cache_keys(self, keys):
if hasattr(self.cache, "delete_cache_keys"):
return await self.cache.delete_cache_keys(keys)
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
if cache_delete_cache_keys:
return await cache_delete_cache_keys(keys)
return None
async def disconnect(self):

View file

@ -292,8 +292,12 @@ class RateLimitError(openai.RateLimitError): # type: ignore
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
_response_headers = (
getattr(response, "headers", None) if response is not None else None
)
self.response = httpx.Response(
status_code=429,
headers=_response_headers,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
@ -750,8 +754,14 @@ class InvalidRequestError(openai.BadRequestError): # type: ignore
self.message = message
self.model = model
self.llm_provider = llm_provider
self.response = httpx.Response(
status_code=400,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
super().__init__(
self.message, f"{self.model}"
message=self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs

View file

@ -767,6 +767,9 @@ class AzureChatCompletion(BaseLLM):
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
@ -1023,6 +1026,9 @@ class AzureChatCompletion(BaseLLM):
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
@ -1165,6 +1171,9 @@ class AzureChatCompletion(BaseLLM):
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)

View file

@ -704,7 +704,6 @@ class OpenAIChatCompletion(BaseLLM):
drop_params: Optional[bool] = None,
):
super().completion()
exception_mapping_worked = False
try:
if headers:
optional_params["extra_headers"] = headers
@ -911,6 +910,9 @@ class OpenAIChatCompletion(BaseLLM):
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
@ -1003,8 +1005,12 @@ class OpenAIChatCompletion(BaseLLM):
raise e
# e.message
except Exception as e:
exception_response = getattr(e, "response", None)
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
if error_headers is None and exception_response:
error_headers = getattr(exception_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
@ -1144,10 +1150,13 @@ class OpenAIChatCompletion(BaseLLM):
raise e
error_headers = getattr(e, "headers", None)
status_code = getattr(e, "status_code", 500)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
if response is not None and hasattr(response, "text"):
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=500,
status_code=status_code,
message=f"{str(e)}\n\nOriginal Response: {response.text}", # type: ignore
headers=error_headers,
)
@ -1272,8 +1281,12 @@ class OpenAIChatCompletion(BaseLLM):
)
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
status_code=status_code, message=error_text, headers=error_headers
)
def embedding( # type: ignore
@ -1352,8 +1365,12 @@ class OpenAIChatCompletion(BaseLLM):
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
status_code=status_code, message=error_text, headers=error_headers
)
async def aimage_generation(
@ -1774,7 +1791,15 @@ class OpenAITextCompletion(BaseLLM):
## RESPONSE OBJECT
return TextCompletionResponse(**response_json)
except Exception as e:
raise e
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
async def acompletion(
self,
@ -1825,7 +1850,15 @@ class OpenAITextCompletion(BaseLLM):
response_obj._hidden_params.original_response = json.dumps(response_json)
return response_obj
except Exception as e:
raise e
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
def streaming(
self,
@ -1860,8 +1893,12 @@ class OpenAITextCompletion(BaseLLM):
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
status_code=status_code, message=error_text, headers=error_headers
)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
@ -1871,8 +1908,19 @@ class OpenAITextCompletion(BaseLLM):
stream_options=data.get("stream_options", None),
)
for chunk in streamwrapper:
yield chunk
try:
for chunk in streamwrapper:
yield chunk
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
async def async_streaming(
self,
@ -1910,8 +1958,19 @@ class OpenAITextCompletion(BaseLLM):
stream_options=data.get("stream_options", None),
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
try:
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
class OpenAIFilesAPI(BaseLLM):

View file

@ -319,6 +319,9 @@ class AzureTextCompletion(BaseLLM):
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
@ -332,9 +335,9 @@ class AzureTextCompletion(BaseLLM):
data: dict,
timeout: Any,
model_response: ModelResponse,
logging_obj: Any,
azure_ad_token: Optional[str] = None,
client=None, # this is the AsyncAzureOpenAI
logging_obj=None,
):
response = None
try:
@ -395,6 +398,9 @@ class AzureTextCompletion(BaseLLM):
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
@ -526,6 +532,9 @@ class AzureTextCompletion(BaseLLM):
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)

View file

@ -1245,7 +1245,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_prompt_caching": true
},
"codestral/codestral-latest": {
"max_tokens": 8191,
@ -1300,7 +1301,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_prompt_caching": true
},
"groq/llama2-70b-4096": {
"max_tokens": 4096,
@ -1502,7 +1504,8 @@
"supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 264,
"supports_assistant_prefill": true
"supports_assistant_prefill": true,
"supports_prompt_caching": true
},
"claude-3-opus-20240229": {
"max_tokens": 4096,
@ -1517,7 +1520,8 @@
"supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 395,
"supports_assistant_prefill": true
"supports_assistant_prefill": true,
"supports_prompt_caching": true
},
"claude-3-sonnet-20240229": {
"max_tokens": 4096,
@ -1530,7 +1534,8 @@
"supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 159,
"supports_assistant_prefill": true
"supports_assistant_prefill": true,
"supports_prompt_caching": true
},
"claude-3-5-sonnet-20240620": {
"max_tokens": 8192,
@ -1545,7 +1550,8 @@
"supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 159,
"supports_assistant_prefill": true
"supports_assistant_prefill": true,
"supports_prompt_caching": true
},
"text-bison": {
"max_tokens": 2048,
@ -2664,6 +2670,7 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash": {
@ -2783,6 +2790,24 @@
"supports_response_schema": true,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-001": {
"max_tokens": 8192,
"max_input_tokens": 2097152,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0000035,
"input_cost_per_token_above_128k_tokens": 0.000007,
"output_cost_per_token": 0.0000105,
"output_cost_per_token_above_128k_tokens": 0.000021,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-exp-0801": {
"max_tokens": 8192,
"max_input_tokens": 2097152,

View file

@ -26,4 +26,5 @@ model_list:
litellm_settings:
success_callback: ["langfuse"]
success_callback: ["langfuse"]
cache: true

View file

@ -632,6 +632,7 @@ class _GenerateKeyRequest(GenerateRequestBase):
model_rpm_limit: Optional[dict] = None
model_tpm_limit: Optional[dict] = None
guardrails: Optional[List[str]] = None
blocked: Optional[bool] = None
class GenerateKeyRequest(_GenerateKeyRequest):
@ -967,6 +968,10 @@ class BlockTeamRequest(LiteLLMBase):
team_id: str # required
class BlockKeyRequest(LiteLLMBase):
key: str # required
class AddTeamCallback(LiteLLMBase):
callback_name: str
callback_type: Literal["success", "failure", "success_and_failure"]
@ -1359,6 +1364,7 @@ class LiteLLM_VerificationToken(LiteLLMBase):
model_spend: Dict = {}
model_max_budget: Dict = {}
soft_budget_cooldown: bool = False
blocked: Optional[bool] = None
litellm_budget_table: Optional[dict] = None
org_id: Optional[str] = None # org id for a given key
@ -1516,7 +1522,7 @@ class LiteLLM_AuditLogs(LiteLLMBase):
updated_at: datetime
changed_by: str
changed_by_api_key: Optional[str] = None
action: Literal["created", "updated", "deleted"]
action: Literal["created", "updated", "deleted", "blocked"]
table_name: Literal[
LitellmTableNames.TEAM_TABLE_NAME,
LitellmTableNames.USER_TABLE_NAME,

View file

@ -12,6 +12,8 @@ import time
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Literal, Optional
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
@ -424,6 +426,24 @@ async def get_user_object(
)
async def _cache_management_object(
key: str,
value: BaseModel,
user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging],
):
await user_api_key_cache.async_set_cache(key=key, value=value)
## UPDATE REDIS CACHE ##
if proxy_logging_obj is not None:
_value = value.model_dump_json(
exclude_unset=True, exclude={"parent_otel_span": True}
)
await proxy_logging_obj.internal_usage_cache.async_set_cache(
key=key, value=_value
)
async def _cache_team_object(
team_id: str,
team_table: LiteLLM_TeamTableCachedObj,
@ -435,20 +455,45 @@ async def _cache_team_object(
## CACHE REFRESH TIME!
team_table.last_refreshed_at = time.time()
value = team_table.model_dump_json(exclude_unset=True)
await user_api_key_cache.async_set_cache(key=key, value=value)
await _cache_management_object(
key=key,
value=team_table,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
async def _cache_key_object(
hashed_token: str,
user_api_key_obj: UserAPIKeyAuth,
user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging],
):
key = hashed_token
## CACHE REFRESH TIME!
user_api_key_obj.last_refreshed_at = time.time()
await _cache_management_object(
key=key,
value=user_api_key_obj,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
async def _delete_cache_key_object(
hashed_token: str,
user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging],
):
key = hashed_token
user_api_key_cache.delete_cache(key=key)
## UPDATE REDIS CACHE ##
if proxy_logging_obj is not None:
await proxy_logging_obj.internal_usage_cache.async_set_cache(
key=key, value=value
)
## UPDATE REDIS CACHE ##
if proxy_logging_obj is not None:
await proxy_logging_obj.internal_usage_cache.async_set_cache(
key=key, value=team_table
)
await proxy_logging_obj.internal_usage_cache.async_delete_cache(key=key)
@log_to_opentelemetry
@ -524,6 +569,83 @@ async def get_team_object(
)
@log_to_opentelemetry
async def get_key_object(
hashed_token: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
check_cache_only: Optional[bool] = None,
) -> UserAPIKeyAuth:
"""
- Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
key = hashed_token
cached_team_obj: Optional[UserAPIKeyAuth] = None
## CHECK REDIS CACHE ##
if (
proxy_logging_obj is not None
and proxy_logging_obj.internal_usage_cache.redis_cache is not None
):
cached_team_obj = (
await proxy_logging_obj.internal_usage_cache.redis_cache.async_get_cache(
key=key
)
)
if cached_team_obj is None:
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
if cached_team_obj is not None:
if isinstance(cached_team_obj, dict):
return UserAPIKeyAuth(**cached_team_obj)
elif isinstance(cached_team_obj, UserAPIKeyAuth):
return cached_team_obj
if check_cache_only:
raise Exception(
f"Key doesn't exist in cache + check_cache_only=True. key={key}."
)
# else, check db
try:
_valid_token: Optional[BaseModel] = await prisma_client.get_data(
token=hashed_token,
table_name="combined_view",
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if _valid_token is None:
raise Exception
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
# save the key object to cache
await _cache_key_object(
hashed_token=hashed_token,
user_api_key_obj=_response,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return _response
except Exception:
raise Exception(
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
)
@log_to_opentelemetry
async def get_org_object(
org_id: str,

View file

@ -46,11 +46,13 @@ import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
_cache_key_object,
allowed_routes_check,
can_key_call_model,
common_checks,
get_actual_routes,
get_end_user_object,
get_key_object,
get_org_object,
get_team_object,
get_user_object,
@ -525,9 +527,19 @@ async def user_api_key_auth(
### CHECK IF ADMIN ###
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
## Check CACHE
valid_token: Optional[UserAPIKeyAuth] = user_api_key_cache.get_cache(
key=hash_token(api_key)
)
try:
valid_token = await get_key_object(
hashed_token=hash_token(api_key),
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
check_cache_only=True,
)
except Exception:
verbose_logger.debug("api key not found in cache.")
valid_token = None
if (
valid_token is not None
and isinstance(valid_token, UserAPIKeyAuth)
@ -578,7 +590,7 @@ async def user_api_key_auth(
try:
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
except Exception as e:
except Exception:
is_master_key_valid = False
## VALIDATE MASTER KEY ##
@ -602,8 +614,11 @@ async def user_api_key_auth(
parent_otel_span=parent_otel_span,
**end_user_params,
)
await user_api_key_cache.async_set_cache(
key=hash_token(master_key), value=_user_api_key_obj
await _cache_key_object(
hashed_token=hash_token(master_key),
user_api_key_obj=_user_api_key_obj,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return _user_api_key_obj
@ -640,38 +655,31 @@ async def user_api_key_auth(
_user_role = None
if api_key.startswith("sk-"):
api_key = hash_token(token=api_key)
valid_token: Optional[UserAPIKeyAuth] = user_api_key_cache.get_cache( # type: ignore
key=api_key
)
if valid_token is None:
## check db
verbose_proxy_logger.debug("api key: %s", api_key)
if prisma_client is not None:
_valid_token: Optional[BaseModel] = await prisma_client.get_data(
token=api_key,
table_name="combined_view",
try:
valid_token = await get_key_object(
hashed_token=api_key,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# update end-user params on valid token
# These can change per request - it's important to update them here
valid_token.end_user_id = end_user_params.get("end_user_id")
valid_token.end_user_tpm_limit = end_user_params.get(
"end_user_tpm_limit"
)
valid_token.end_user_rpm_limit = end_user_params.get(
"end_user_rpm_limit"
)
valid_token.allowed_model_region = end_user_params.get(
"allowed_model_region"
)
if _valid_token is not None:
## update cached token
valid_token = UserAPIKeyAuth(
**_valid_token.model_dump(exclude_none=True)
)
verbose_proxy_logger.debug("Token from db: %s", valid_token)
elif valid_token is not None and isinstance(valid_token, UserAPIKeyAuth):
verbose_proxy_logger.debug("API Key Cache Hit!")
# update end-user params on valid token
# These can change per request - it's important to update them here
valid_token.end_user_id = end_user_params.get("end_user_id")
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
valid_token.allowed_model_region = end_user_params.get(
"allowed_model_region"
)
except Exception:
valid_token = None
user_obj: Optional[LiteLLM_UserTable] = None
valid_token_dict: dict = {}
@ -689,6 +697,12 @@ async def user_api_key_auth(
# 8. If token spend is under team budget
# 9. If team spend is under team budget
## base case ## key is disabled
if valid_token.blocked is True:
raise Exception(
"Key is blocked. Update via `/key/unblock` if you're admin."
)
# Check 1. If token can call model
_model_alias_map = {}
model: Optional[str] = None
@ -1006,10 +1020,13 @@ async def user_api_key_auth(
api_key = valid_token.token
# Add hashed token to cache
await user_api_key_cache.async_set_cache(
key=api_key,
value=valid_token,
await _cache_key_object(
hashed_token=api_key,
user_api_key_obj=valid_token,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
valid_token_dict = valid_token.model_dump(exclude_none=True)
valid_token_dict.pop("token", None)

View file

@ -25,6 +25,11 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, s
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
_cache_key_object,
_delete_cache_key_object,
get_key_object,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import _duration_in_seconds
@ -302,15 +307,18 @@ async def prepare_key_update_data(
data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row
):
data_json: dict = data.dict(exclude_unset=True)
key = data_json.pop("key", None)
data_json.pop("key", None)
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
non_default_values = {}
for k, v in data_json.items():
if k in _metadata_fields:
continue
if v is not None and v not in ([], {}, 0):
non_default_values[k] = v
if v is not None:
if not isinstance(v, bool) and v in ([], {}, 0):
pass
else:
non_default_values[k] = v
if "duration" in non_default_values:
duration = non_default_values.pop("duration")
if duration and (isinstance(duration, str)) and len(duration) > 0:
@ -364,12 +372,10 @@ async def update_key_fn(
"""
from litellm.proxy.proxy_server import (
create_audit_log_for_update,
general_settings,
litellm_proxy_admin_name,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
user_custom_key_generate,
)
try:
@ -399,9 +405,11 @@ async def update_key_fn(
# Delete - key from cache, since it's been updated!
# key updated - a new model could have been added to this key. it should not block requests after this is done
user_api_key_cache.delete_cache(key)
hashed_token = hash_token(key)
user_api_key_cache.delete_cache(hashed_token)
await _delete_cache_key_object(
hashed_token=hash_token(key),
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
@ -434,6 +442,11 @@ async def update_key_fn(
return {"key": key, **response["data"]}
# update based on remaining passed in values
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.update_key_fn(): Exception occured - {}".format(
str(e)
)
)
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -771,6 +784,7 @@ async def generate_key_helper_fn(
float
] = None, # soft_budget is used to set soft Budgets Per user
max_budget: Optional[float] = None, # max_budget is used to Budget Per user
blocked: Optional[bool] = None,
budget_duration: Optional[str] = None, # max_budget is used to Budget Per user
token: Optional[str] = None,
key: Optional[
@ -899,6 +913,7 @@ async def generate_key_helper_fn(
"permissions": permissions_json,
"model_max_budget": model_max_budget_json,
"budget_id": budget_id,
"blocked": blocked,
}
if (
@ -1047,6 +1062,7 @@ async def regenerate_key_fn(
hash_token,
premium_user,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
@ -1124,10 +1140,18 @@ async def regenerate_key_fn(
### 3. remove existing key entry from cache
######################################################################
if key:
user_api_key_cache.delete_cache(key)
await _delete_cache_key_object(
hashed_token=hash_token(key),
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
if hashed_api_key:
user_api_key_cache.delete_cache(hashed_api_key)
await _delete_cache_key_object(
hashed_token=hash_token(key),
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return GenerateKeyResponse(
**updated_token_dict,
@ -1240,3 +1264,187 @@ async def list_keys(
param=getattr(e, "param", "None"),
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@router.post(
"/key/block", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
)
@management_endpoint_wrapper
async def block_key(
data: BlockKeyRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
litellm_changed_by: Optional[str] = Header(
None,
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
),
):
"""
Blocks all calls from keys with this team id.
"""
from litellm.proxy.proxy_server import (
create_audit_log_for_update,
hash_token,
litellm_proxy_admin_name,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if prisma_client is None:
raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value))
if data.key.startswith("sk-"):
hashed_token = hash_token(token=data.key)
else:
hashed_token = data.key
if litellm.store_audit_logs is True:
# make an audit log for key update
record = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token}
)
if record is None:
raise ProxyException(
message=f"Key {data.key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=hashed_token,
action="blocked",
updated_values="{}",
before_value=record.model_dump_json(),
)
)
)
record = await prisma_client.db.litellm_verificationtoken.update(
where={"token": hashed_token}, data={"blocked": True} # type: ignore
)
## UPDATE KEY CACHE
### get cached object ###
key_object = await get_key_object(
hashed_token=hashed_token,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=proxy_logging_obj,
)
### update cached object ###
key_object.blocked = True
### store cached object ###
await _cache_key_object(
hashed_token=hashed_token,
user_api_key_obj=key_object,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return record
@router.post(
"/key/unblock", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
)
@management_endpoint_wrapper
async def unblock_key(
data: BlockKeyRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
litellm_changed_by: Optional[str] = Header(
None,
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
),
):
"""
Unblocks all calls from this key.
"""
from litellm.proxy.proxy_server import (
create_audit_log_for_update,
hash_token,
litellm_proxy_admin_name,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if prisma_client is None:
raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value))
if data.key.startswith("sk-"):
hashed_token = hash_token(token=data.key)
else:
hashed_token = data.key
if litellm.store_audit_logs is True:
# make an audit log for key update
record = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token}
)
if record is None:
raise ProxyException(
message=f"Key {data.key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=hashed_token,
action="blocked",
updated_values="{}",
before_value=record.model_dump_json(),
)
)
)
record = await prisma_client.db.litellm_verificationtoken.update(
where={"token": hashed_token}, data={"blocked": False} # type: ignore
)
## UPDATE KEY CACHE
### get cached object ###
key_object = await get_key_object(
hashed_token=hashed_token,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=proxy_logging_obj,
)
### update cached object ###
key_object.blocked = False
### store cached object ###
await _cache_key_object(
hashed_token=hashed_token,
user_api_key_obj=key_object,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return record

View file

@ -154,6 +154,7 @@ class Router:
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
## SCHEDULER ##
polling_interval: Optional[float] = None,
default_priority: Optional[int] = None,
## RELIABILITY ##
num_retries: Optional[int] = None,
timeout: Optional[float] = None,
@ -220,6 +221,7 @@ class Router:
caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None.
client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600.
polling_interval: (Optional[float]): frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms.
default_priority: (Optional[int]): the default priority for a request. Only for '.scheduler_acompletion()'. Default is None.
num_retries (Optional[int]): Number of retries for failed requests. Defaults to 2.
timeout (Optional[float]): Timeout for requests. Defaults to None.
default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}.
@ -336,6 +338,7 @@ class Router:
self.scheduler = Scheduler(
polling_interval=polling_interval, redis_cache=redis_cache
)
self.default_priority = default_priority
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
self.default_max_parallel_requests = default_max_parallel_requests
self.provider_default_deployments: Dict[str, List] = {}
@ -712,12 +715,11 @@ class Router:
kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
if kwargs.get("priority", None) is not None and isinstance(
kwargs.get("priority"), int
):
request_priority = kwargs.get("priority") or self.default_priority
if request_priority is not None and isinstance(request_priority, int):
response = await self.schedule_acompletion(**kwargs)
else:
response = await self.async_function_with_fallbacks(**kwargs)
@ -3085,9 +3087,9 @@ class Router:
except Exception as e:
current_attempt = None
original_exception = e
"""
Retry Logic
"""
_healthy_deployments, _all_deployments = (
await self._async_get_healthy_deployments(
@ -3105,16 +3107,6 @@ class Router:
content_policy_fallbacks=content_policy_fallbacks,
)
# decides how long to sleep before retry
_timeout = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
)
# sleeps for the length of the timeout
await asyncio.sleep(_timeout)
if (
self.retry_policy is not None
or self.model_group_retry_policy is not None
@ -3128,11 +3120,19 @@ class Router:
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
else:
raise
# decides how long to sleep before retry
_timeout = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
)
# sleeps for the length of the timeout
await asyncio.sleep(_timeout)
for current_attempt in range(num_retries):
verbose_router_logger.debug(
f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}"
)
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = await original_function(*args, **kwargs)
@ -3370,14 +3370,14 @@ class Router:
if (
healthy_deployments is not None
and isinstance(healthy_deployments, list)
and len(healthy_deployments) > 0
and len(healthy_deployments) > 1
):
return 0
response_headers: Optional[httpx.Headers] = None
if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
response_headers = e.response.headers # type: ignore
elif hasattr(e, "litellm_response_headers"):
if hasattr(e, "litellm_response_headers"):
response_headers = e.litellm_response_headers # type: ignore
if response_headers is not None:
@ -3561,7 +3561,7 @@ class Router:
except Exception as e:
verbose_router_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
"litellm.router.Router::deployment_callback_on_success(): Exception occured - {}".format(
str(e)
)
)
@ -5324,7 +5324,6 @@ class Router:
return deployment
except Exception as e:
traceback_exception = traceback.format_exc()
# if router rejects call -> log to langfuse/otel/etc.
if request_kwargs is not None:

View file

@ -21,7 +21,7 @@ class LiteLLMBase(BaseModel):
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except Exception as e:
@ -185,6 +185,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
deployment_rpm,
local_result,
),
headers={"retry-after": 60}, # type: ignore
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
@ -207,6 +208,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
deployment_rpm,
result,
),
headers={"retry-after": 60}, # type: ignore
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
@ -321,15 +323,18 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
model_group: str,
healthy_deployments: list,
tpm_keys: list,
tpm_values: list,
tpm_values: Optional[list],
rpm_keys: list,
rpm_values: list,
rpm_values: Optional[list],
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
):
) -> Optional[dict]:
"""
Common checks for get available deployment, across sync + async implementations
"""
if tpm_values is None or rpm_values is None:
return None
tpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(tpm_keys):
tpm_dict[tpm_keys[idx]] = tpm_values[idx]
@ -455,8 +460,12 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
keys=combined_tpm_rpm_keys
) # [1, 2, None, ..]
tpm_values = combined_tpm_rpm_values[: len(tpm_keys)]
rpm_values = combined_tpm_rpm_values[len(tpm_keys) :]
if combined_tpm_rpm_values is not None:
tpm_values = combined_tpm_rpm_values[: len(tpm_keys)]
rpm_values = combined_tpm_rpm_values[len(tpm_keys) :]
else:
tpm_values = None
rpm_values = None
deployment = self._common_checks_available_deployment(
model_group=model_group,
@ -472,7 +481,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
try:
assert deployment is not None
return deployment
except Exception as e:
except Exception:
### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
deployment_dict = {}
for index, _deployment in enumerate(healthy_deployments):
@ -494,7 +503,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
_deployment_tpm = float("inf")
### GET CURRENT TPM ###
current_tpm = tpm_values[index]
current_tpm = tpm_values[index] if tpm_values else 0
### GET DEPLOYMENT TPM LIMIT ###
_deployment_rpm = None
@ -512,7 +521,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
_deployment_rpm = float("inf")
### GET CURRENT RPM ###
current_rpm = rpm_values[index]
current_rpm = rpm_values[index] if rpm_values else 0
deployment_dict[id] = {
"current_tpm": current_tpm,
@ -520,8 +529,16 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
"current_rpm": current_rpm,
"rpm_limit": _deployment_rpm,
}
raise ValueError(
f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}"
raise litellm.RateLimitError(
message=f"{RouterErrors.no_deployments_available.value}. 12345 Passed model={model_group}. Deployments={deployment_dict}",
llm_provider="",
model=model_group,
response=httpx.Response(
status_code=429,
content="",
headers={"retry-after": str(60)}, # type: ignore
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
def get_available_deployments(
@ -597,7 +614,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
_deployment_tpm = float("inf")
### GET CURRENT TPM ###
current_tpm = tpm_values[index]
current_tpm = tpm_values[index] if tpm_values else 0
### GET DEPLOYMENT TPM LIMIT ###
_deployment_rpm = None
@ -615,7 +632,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
_deployment_rpm = float("inf")
### GET CURRENT RPM ###
current_rpm = rpm_values[index]
current_rpm = rpm_values[index] if rpm_values else 0
deployment_dict[id] = {
"current_tpm": current_tpm,

View file

@ -910,6 +910,7 @@ async def test_exception_with_headers(sync_mode, provider, model, call_type, str
{"message": "litellm.proxy.proxy_server.embeddings(): Exception occured - No deployments available for selected model, Try again in 60 seconds. Passed model=text-embedding-ada-002. pre-call-checks=False, allowed_model_region=n/a, cooldown_list=[('b49cbc9314273db7181fe69b1b19993f04efb88f2c1819947c538bac08097e4c', {'Exception Received': 'litellm.RateLimitError: AzureException RateLimitError - Requests to the Embeddings_Create Operation under Azure OpenAI API version 2023-09-01-preview have exceeded call rate limit of your current OpenAI S0 pricing tier. Please retry after 9 seconds. Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit.', 'Status Code': '429'})]", "level": "ERROR", "timestamp": "2024-08-22T03:25:36.900476"}
```
"""
print(f"Received args: {locals()}")
import openai
if sync_mode:
@ -939,13 +940,38 @@ async def test_exception_with_headers(sync_mode, provider, model, call_type, str
cooldown_time = 30.0
def _return_exception(*args, **kwargs):
from fastapi import HTTPException
import datetime
raise HTTPException(
status_code=429,
detail="Rate Limited!",
headers={"retry-after": cooldown_time}, # type: ignore
)
from httpx import Headers, Request, Response
kwargs = {
"request": Request("POST", "https://www.google.com"),
"message": "Error code: 429 - Rate Limit Error!",
"body": {"detail": "Rate Limit Error!"},
"code": None,
"param": None,
"type": None,
"response": Response(
status_code=429,
headers=Headers(
{
"date": "Sat, 21 Sep 2024 22:56:53 GMT",
"server": "uvicorn",
"retry-after": "30",
"content-length": "30",
"content-type": "application/json",
}
),
request=Request("POST", "http://0.0.0.0:9000/chat/completions"),
),
"status_code": 429,
"request_id": None,
}
exception = Exception()
for k, v in kwargs.items():
setattr(exception, k, v)
raise exception
with patch.object(
mapped_target,
@ -975,7 +1001,7 @@ async def test_exception_with_headers(sync_mode, provider, model, call_type, str
except litellm.RateLimitError as e:
exception_raised = True
assert e.litellm_response_headers is not None
assert e.litellm_response_headers["retry-after"] == cooldown_time
assert int(e.litellm_response_headers["retry-after"]) == cooldown_time
if exception_raised is False:
print(resp)

View file

@ -54,3 +54,11 @@ def test_get_model_info_shows_assistant_prefill():
info = litellm.get_model_info("deepseek/deepseek-chat")
print("info", info)
assert info.get("supports_assistant_prefill") is True
def test_get_model_info_shows_supports_prompt_caching():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
info = litellm.get_model_info("deepseek/deepseek-chat")
print("info", info)
assert info.get("supports_prompt_caching") is True

View file

@ -2215,16 +2215,39 @@ def test_router_dynamic_cooldown_correct_retry_after_time():
openai_client = openai.OpenAI(api_key="")
cooldown_time = 30.0
cooldown_time = 30
def _return_exception(*args, **kwargs):
from fastapi import HTTPException
from httpx import Headers, Request, Response
raise HTTPException(
status_code=429,
detail="Rate Limited!",
headers={"retry-after": cooldown_time}, # type: ignore
)
kwargs = {
"request": Request("POST", "https://www.google.com"),
"message": "Error code: 429 - Rate Limit Error!",
"body": {"detail": "Rate Limit Error!"},
"code": None,
"param": None,
"type": None,
"response": Response(
status_code=429,
headers=Headers(
{
"date": "Sat, 21 Sep 2024 22:56:53 GMT",
"server": "uvicorn",
"retry-after": f"{cooldown_time}",
"content-length": "30",
"content-type": "application/json",
}
),
request=Request("POST", "http://0.0.0.0:9000/chat/completions"),
),
"status_code": 429,
"request_id": None,
}
exception = Exception()
for k, v in kwargs.items():
setattr(exception, k, v)
raise exception
with patch.object(
openai_client.embeddings.with_raw_response,
@ -2250,12 +2273,12 @@ def test_router_dynamic_cooldown_correct_retry_after_time():
print(
f"new_retry_after_mock_client.call_args.kwargs: {new_retry_after_mock_client.call_args.kwargs}"
)
print(
f"new_retry_after_mock_client.call_args: {new_retry_after_mock_client.call_args[0][0]}"
)
response_headers: httpx.Headers = new_retry_after_mock_client.call_args.kwargs[
"response_headers"
]
assert "retry-after" in response_headers
assert response_headers["retry-after"] == cooldown_time
response_headers: httpx.Headers = new_retry_after_mock_client.call_args[0][0]
assert int(response_headers["retry-after"]) == cooldown_time
@pytest.mark.parametrize("sync_mode", [True, False])
@ -2270,6 +2293,7 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
```
"""
litellm.set_verbose = True
cooldown_time = 30.0
router = Router(
model_list=[
{
@ -2287,20 +2311,42 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
],
set_verbose=True,
debug_level="DEBUG",
cooldown_time=cooldown_time,
)
openai_client = openai.OpenAI(api_key="")
cooldown_time = 30.0
def _return_exception(*args, **kwargs):
from fastapi import HTTPException
from httpx import Headers, Request, Response
raise HTTPException(
status_code=429,
detail="Rate Limited!",
headers={"retry-after": cooldown_time},
)
kwargs = {
"request": Request("POST", "https://www.google.com"),
"message": "Error code: 429 - Rate Limit Error!",
"body": {"detail": "Rate Limit Error!"},
"code": None,
"param": None,
"type": None,
"response": Response(
status_code=429,
headers=Headers(
{
"date": "Sat, 21 Sep 2024 22:56:53 GMT",
"server": "uvicorn",
"retry-after": f"{cooldown_time}",
"content-length": "30",
"content-type": "application/json",
}
),
request=Request("POST", "http://0.0.0.0:9000/chat/completions"),
),
"status_code": 429,
"request_id": None,
}
exception = Exception()
for k, v in kwargs.items():
setattr(exception, k, v)
raise exception
with patch.object(
openai_client.embeddings.with_raw_response,

View file

@ -1,13 +1,20 @@
# What is this?
## Unit tests for the max_parallel_requests feature on Router
import sys, os, time, inspect, asyncio, traceback
import asyncio
import inspect
import os
import sys
import time
import traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional
import litellm
from litellm.utils import calculate_max_parallel_requests
from typing import Optional
"""
- only rpm
@ -113,3 +120,91 @@ def test_setting_mpr_limits_per_model(
assert mpr_client is None
# raise Exception("it worked!")
async def _handle_router_calls(router):
import random
pre_fill = """
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nunc ut finibus massa. Quisque a magna magna. Quisque neque diam, varius sit amet tellus eu, elementum fermentum sapien. Integer ut erat eget arcu rutrum blandit. Morbi a metus purus. Nulla porta, urna at finibus malesuada, velit ante suscipit orci, vitae laoreet dui ligula ut augue. Cras elementum pretium dui, nec luctus nulla aliquet ut. Nam faucibus, diam nec semper interdum, nisl nisi viverra nulla, vitae sodales elit ex a purus. Donec tristique malesuada lobortis. Donec posuere iaculis nisl, vitae accumsan libero dignissim dignissim. Suspendisse finibus leo et ex mattis tempor. Praesent at nisl vitae quam egestas lacinia. Donec in justo non erat aliquam accumsan sed vitae ex. Vivamus gravida diam vel ipsum tincidunt dignissim.
Cras vitae efficitur tortor. Curabitur vel erat mollis, euismod diam quis, consequat nibh. Ut vel est eu nulla euismod finibus. Aliquam euismod at risus quis dignissim. Integer non auctor massa. Nullam vitae aliquet mauris. Etiam risus enim, dignissim ut volutpat eget, pulvinar ac augue. Mauris elit est, ultricies vel convallis at, rhoncus nec elit. Aenean ornare maximus orci, ut maximus felis cursus venenatis. Nulla facilisi.
Maecenas aliquet ante massa, at ullamcorper nibh dictum quis. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Quisque id egestas justo. Suspendisse fringilla in massa in consectetur. Quisque scelerisque egestas lacus at posuere. Vestibulum dui sem, bibendum vehicula ultricies vel, blandit id nisi. Curabitur ullamcorper semper metus, vitae commodo magna. Nulla mi metus, suscipit in neque vitae, porttitor pharetra erat. Vestibulum libero velit, congue in diam non, efficitur suscipit diam. Integer arcu velit, fermentum vel tortor sit amet, venenatis rutrum felis. Donec ultricies enim sit amet iaculis mattis.
Integer at purus posuere, malesuada tortor vitae, mattis nibh. Mauris ex quam, tincidunt et fermentum vitae, iaculis non elit. Nullam dapibus non nisl ac sagittis. Duis lacinia eros iaculis lectus consectetur vehicula. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Interdum et malesuada fames ac ante ipsum primis in faucibus. Ut cursus semper est, vel interdum turpis ultrices dictum. Suspendisse posuere lorem et accumsan ultrices. Duis sagittis bibendum consequat. Ut convallis vestibulum enim, non dapibus est porttitor et. Quisque suscipit pulvinar turpis, varius tempor turpis. Vestibulum semper dui nunc, vel vulputate elit convallis quis. Fusce aliquam enim nulla, eu congue nunc tempus eu.
Nam vitae finibus eros, eu eleifend erat. Maecenas hendrerit magna quis molestie dictum. Ut consequat quam eu massa auctor pulvinar. Pellentesque vitae eros ornare urna accumsan tempor. Maecenas porta id quam at sodales. Donec quis accumsan leo, vel viverra nibh. Vestibulum congue blandit nulla, sed rhoncus libero eleifend ac. In risus lorem, rutrum et tincidunt a, interdum a lectus. Pellentesque aliquet pulvinar mauris, ut ultrices nibh ultricies nec. Mauris mi mauris, facilisis nec metus non, egestas luctus ligula. Quisque ac ligula at felis mollis blandit id nec risus. Nam sollicitudin lacus sed sapien fringilla ullamcorper. Etiam dui quam, posuere sit amet velit id, aliquet molestie ante. Integer cursus eget sapien fringilla elementum. Integer molestie, mi ac scelerisque ultrices, nunc purus condimentum est, in posuere quam nibh vitae velit.
"""
completion = await router.acompletion(
"gpt-4o-2024-08-06",
[
{
"role": "user",
"content": f"{pre_fill * 3}\n\nRecite the Declaration of independence at a speed of {random.random() * 100} words per minute.",
}
],
stream=True,
temperature=0.0,
stream_options={"include_usage": True},
)
async for chunk in completion:
pass
print("done", chunk)
@pytest.mark.asyncio
async def test_max_parallel_requests_rpm_rate_limiting():
"""
- make sure requests > model limits are retried successfully.
"""
from litellm import Router
router = Router(
routing_strategy="usage-based-routing-v2",
enable_pre_call_checks=True,
model_list=[
{
"model_name": "gpt-4o-2024-08-06",
"litellm_params": {
"model": "gpt-4o-2024-08-06",
"temperature": 0.0,
"rpm": 5,
},
}
],
)
await asyncio.gather(*[_handle_router_calls(router) for _ in range(16)])
@pytest.mark.asyncio
async def test_max_parallel_requests_tpm_rate_limiting_base_case():
"""
- check error raised if defined tpm limit crossed.
"""
from litellm import Router, token_counter
_messages = [{"role": "user", "content": "Hey, how's it going?"}]
router = Router(
routing_strategy="usage-based-routing-v2",
enable_pre_call_checks=True,
model_list=[
{
"model_name": "gpt-4o-2024-08-06",
"litellm_params": {
"model": "gpt-4o-2024-08-06",
"temperature": 0.0,
"tpm": 1,
},
}
],
num_retries=0,
)
with pytest.raises(litellm.RateLimitError):
for _ in range(2):
await router.acompletion(
model="gpt-4o-2024-08-06",
messages=_messages,
)

View file

@ -113,7 +113,9 @@ async def test_check_blocked_team():
team_obj = LiteLLM_TeamTableCachedObj(
team_id=_team_id, blocked=False, last_refreshed_at=time.time()
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
hashed_token = hash_token(user_key)
print(f"STORING TOKEN UNDER KEY={hashed_token}")
user_api_key_cache.set_cache(key=hashed_token, value=valid_token)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)

View file

@ -302,6 +302,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
rpm: Optional[int]
order: Optional[int]
weight: Optional[int]
max_parallel_requests: Optional[int]
api_key: Optional[str]
api_base: Optional[str]
api_version: Optional[str]

View file

@ -86,6 +86,7 @@ class ModelInfo(TypedDict, total=False):
supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
supports_assistant_prefill: Optional[bool]
supports_prompt_caching: Optional[bool]
class GenericStreamingChunk(TypedDict, total=False):

View file

@ -4762,6 +4762,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
supports_response_schema: Optional[bool]
supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
supports_prompt_caching: Optional[bool]
Raises:
Exception: If the model is not mapped yet.
@ -4849,6 +4850,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
supports_response_schema=None,
supports_function_calling=None,
supports_assistant_prefill=None,
supports_prompt_caching=None,
)
else:
"""
@ -5008,6 +5010,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
supports_assistant_prefill=_model_info.get(
"supports_assistant_prefill", False
),
supports_prompt_caching=_model_info.get(
"supports_prompt_caching", False
),
)
except Exception as e:
raise Exception(
@ -6261,6 +6266,9 @@ def _get_response_headers(original_exception: Exception) -> Optional[httpx.Heade
_response_headers: Optional[httpx.Headers] = None
try:
_response_headers = getattr(original_exception, "headers", None)
error_response = getattr(original_exception, "response", None)
if _response_headers is None and error_response:
_response_headers = getattr(error_response, "headers", None)
except Exception:
return None

View file

@ -1245,7 +1245,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_prompt_caching": true
},
"codestral/codestral-latest": {
"max_tokens": 8191,
@ -1300,7 +1301,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_prompt_caching": true
},
"groq/llama2-70b-4096": {
"max_tokens": 4096,
@ -1502,7 +1504,8 @@
"supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 264,
"supports_assistant_prefill": true
"supports_assistant_prefill": true,
"supports_prompt_caching": true
},
"claude-3-opus-20240229": {
"max_tokens": 4096,
@ -1517,7 +1520,8 @@
"supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 395,
"supports_assistant_prefill": true
"supports_assistant_prefill": true,
"supports_prompt_caching": true
},
"claude-3-sonnet-20240229": {
"max_tokens": 4096,
@ -1530,7 +1534,8 @@
"supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 159,
"supports_assistant_prefill": true
"supports_assistant_prefill": true,
"supports_prompt_caching": true
},
"claude-3-5-sonnet-20240620": {
"max_tokens": 8192,
@ -1545,7 +1550,8 @@
"supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 159,
"supports_assistant_prefill": true
"supports_assistant_prefill": true,
"supports_prompt_caching": true
},
"text-bison": {
"max_tokens": 2048,
@ -2664,6 +2670,7 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash": {
@ -2783,6 +2790,24 @@
"supports_response_schema": true,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-001": {
"max_tokens": 8192,
"max_input_tokens": 2097152,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0000035,
"input_cost_per_token_above_128k_tokens": 0.000007,
"output_cost_per_token": 0.0000105,
"output_cost_per_token_above_128k_tokens": 0.000021,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-exp-0801": {
"max_tokens": 8192,
"max_input_tokens": 2097152,

View file

@ -111,7 +111,11 @@ async def test_chat_completion_check_otel_spans():
print("Parent trace spans: ", parent_trace_spans)
# either 5 or 6 traces depending on how many redis calls were made
assert len(parent_trace_spans) == 6 or len(parent_trace_spans) == 5
assert (
len(parent_trace_spans) == 6
or len(parent_trace_spans) == 5
or len(parent_trace_spans) == 7
)
# 'postgres', 'redis', 'raw_gen_ai_request', 'litellm_request', 'Received Proxy Server Request' in the span
assert "postgres" in parent_trace_spans