From 8039b95aaf61d363a7f3f381d2071fe6a8e203e9 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sat, 21 Sep 2024 18:51:53 -0700 Subject: [PATCH] LiteLLM Minor Fixes & Improvements (09/21/2024) (#5819) * fix(router.py): fix error message * Litellm disable keys (#5814) * build(schema.prisma): allow blocking/unblocking keys Fixes https://github.com/BerriAI/litellm/issues/5328 * fix(key_management_endpoints.py): fix pop * feat(auth_checks.py): allow admin to enable/disable virtual keys Closes https://github.com/BerriAI/litellm/issues/5328 * docs(vertex.md): add auth section for vertex ai Addresses - https://github.com/BerriAI/litellm/issues/5768#issuecomment-2365284223 * build(model_prices_and_context_window.json): show which models support prompt_caching Closes https://github.com/BerriAI/litellm/issues/5776 * fix(router.py): allow setting default priority for requests * fix(router.py): add 'retry-after' header for concurrent request limit errors Fixes https://github.com/BerriAI/litellm/issues/5783 * fix(router.py): correctly raise and use retry-after header from azure+openai Fixes https://github.com/BerriAI/litellm/issues/5783 * fix(user_api_key_auth.py): fix valid token being none * fix(auth_checks.py): fix model dump for cache management object * fix(user_api_key_auth.py): pass prisma_client to obj * test(test_otel.py): update test for new key check * test: fix test --- docs/my-website/docs/providers/vertex.md | 85 +++++++ litellm/_redis.py | 2 +- litellm/caching.py | 99 +++++--- litellm/exceptions.py | 12 +- litellm/llms/AzureOpenAI/azure.py | 9 + litellm/llms/OpenAI/openai.py | 83 ++++++- litellm/llms/azure_text.py | 11 +- ...odel_prices_and_context_window_backup.json | 37 ++- litellm/proxy/_new_secret_config.yaml | 3 +- litellm/proxy/_types.py | 8 +- litellm/proxy/auth/auth_checks.py | 144 ++++++++++- litellm/proxy/auth/user_api_key_auth.py | 89 ++++--- .../key_management_endpoints.py | 230 +++++++++++++++++- litellm/router.py | 43 ++-- litellm/router_strategy/lowest_tpm_rpm_v2.py | 43 +++- litellm/tests/test_exceptions.py | 40 ++- litellm/tests/test_get_model_info.py | 8 + litellm/tests/test_router.py | 86 +++++-- .../test_router_max_parallel_requests.py | 99 +++++++- litellm/tests/test_user_api_key_auth.py | 4 +- litellm/types/router.py | 1 + litellm/types/utils.py | 1 + litellm/utils.py | 8 + model_prices_and_context_window.json | 37 ++- tests/otel_tests/test_otel.py | 6 +- 25 files changed, 1006 insertions(+), 182 deletions(-) diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index 8e48780b0..be31506f7 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -2118,8 +2118,93 @@ print("response from proxy", response) +## Authentication - vertex_project, vertex_location, etc. + +Set your vertex credentials via: +- dynamic params +OR +- env vars +### **Dynamic Params** + +You can set: +- `vertex_credentials` (str) - can be a json string or filepath to your vertex ai service account.json +- `vertex_location` (str) - place where vertex model is deployed (us-central1, asia-southeast1, etc.) +- `vertex_project` Optional[str] - use if vertex project different from the one in vertex_credentials + +as dynamic params for a `litellm.completion` call. + + + + +```python +from litellm import completion +import json + +## GET CREDENTIALS +file_path = 'path/to/vertex_ai_service_account.json' + +# Load the JSON file +with open(file_path, 'r') as file: + vertex_credentials = json.load(file) + +# Convert to JSON string +vertex_credentials_json = json.dumps(vertex_credentials) + + +response = completion( + model="vertex_ai/gemini-pro", + messages=[{"content": "You are a good bot.","role": "system"}, {"content": "Hello, how are you?","role": "user"}], + vertex_credentials=vertex_credentials_json, + vertex_project="my-special-project", + vertex_location="my-special-location" +) +``` + + + + +```yaml +model_list: + - model_name: gemini-1.5-pro + litellm_params: + model: gemini-1.5-pro + vertex_credentials: os.environ/VERTEX_FILE_PATH_ENV_VAR # os.environ["VERTEX_FILE_PATH_ENV_VAR"] = "/path/to/service_account.json" + vertex_project: "my-special-project" + vertex_location: "my-special-location: +``` + + + + + + + +### **Environment Variables** + +You can set: +- `GOOGLE_APPLICATION_CREDENTIALS` - store the filepath for your service_account.json in here (used by vertex sdk directly). +- VERTEXAI_LOCATION - place where vertex model is deployed (us-central1, asia-southeast1, etc.) +- VERTEXAI_PROJECT - Optional[str] - use if vertex project different from the one in vertex_credentials + +1. GOOGLE_APPLICATION_CREDENTIALS + +```bash +export GOOGLE_APPLICATION_CREDENTIALS="/path/to/service_account.json" +``` + +2. VERTEXAI_LOCATION + +```bash +export VERTEXAI_LOCATION="us-central1" # can be any vertex location +``` + +3. VERTEXAI_PROJECT + +```bash +export VERTEXAI_PROJECT="my-test-project" # ONLY use if model project is different from service account project +``` ## Extra diff --git a/litellm/_redis.py b/litellm/_redis.py index 152f7f09e..0c8c3be68 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -253,7 +253,7 @@ def get_redis_client(**env_overrides): return redis.Redis(**redis_kwargs) -def get_redis_async_client(**env_overrides): +def get_redis_async_client(**env_overrides) -> async_redis.Redis: redis_kwargs = _get_redis_client_logic(**env_overrides) if "url" in redis_kwargs and redis_kwargs["url"] is not None: args = _get_redis_url_kwargs(client=async_redis.Redis.from_url) diff --git a/litellm/caching.py b/litellm/caching.py index 0a9fef417..8501b32c1 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -207,7 +207,7 @@ class RedisCache(BaseCache): host=None, port=None, password=None, - redis_flush_size=100, + redis_flush_size: Optional[int] = 100, namespace: Optional[str] = None, startup_nodes: Optional[List] = None, # for redis-cluster **kwargs, @@ -244,7 +244,10 @@ class RedisCache(BaseCache): self.namespace = namespace # for high traffic, we store the redis results in memory and then batch write to redis self.redis_batch_writing_buffer: list = [] - self.redis_flush_size = redis_flush_size + if redis_flush_size is None: + self.redis_flush_size: int = 100 + else: + self.redis_flush_size = redis_flush_size self.redis_version = "Unknown" try: self.redis_version = self.redis_client.info()["redis_version"] @@ -317,7 +320,7 @@ class RedisCache(BaseCache): current_ttl = _redis_client.ttl(key) if current_ttl == -1: # Key has no expiration - _redis_client.expire(key, ttl) + _redis_client.expire(key, ttl) # type: ignore return result except Exception as e: ## LOGGING ## @@ -331,10 +334,13 @@ class RedisCache(BaseCache): raise e async def async_scan_iter(self, pattern: str, count: int = 100) -> list: + from redis.asyncio import Redis + start_time = time.time() try: keys = [] - _redis_client = self.init_async_client() + _redis_client: Redis = self.init_async_client() # type: ignore + async with _redis_client as redis_client: async for key in redis_client.scan_iter( match=pattern + "*", count=count @@ -374,9 +380,11 @@ class RedisCache(BaseCache): raise e async def async_set_cache(self, key, value, **kwargs): + from redis.asyncio import Redis + start_time = time.time() try: - _redis_client = self.init_async_client() + _redis_client: Redis = self.init_async_client() # type: ignore except Exception as e: end_time = time.time() _duration = end_time - start_time @@ -397,6 +405,7 @@ class RedisCache(BaseCache): str(e), value, ) + raise e key = self.check_and_fix_namespace(key=key) async with _redis_client as redis_client: @@ -405,6 +414,10 @@ class RedisCache(BaseCache): f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" ) try: + if not hasattr(redis_client, "set"): + raise Exception( + "Redis client cannot set cache. Attribute not found." + ) await redis_client.set(name=key, value=json.dumps(value), ex=ttl) print_verbose( f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" @@ -446,12 +459,15 @@ class RedisCache(BaseCache): """ Use Redis Pipelines for bulk write operations """ - _redis_client = self.init_async_client() + from redis.asyncio import Redis + + _redis_client: Redis = self.init_async_client() # type: ignore start_time = time.time() print_verbose( f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}" ) + cache_value: Any = None try: async with _redis_client as redis_client: async with redis_client.pipeline(transaction=True) as pipe: @@ -463,6 +479,7 @@ class RedisCache(BaseCache): ) json_cache_value = json.dumps(cache_value) # Set the value with a TTL if it's provided. + if ttl is not None: pipe.setex(cache_key, ttl, json_cache_value) else: @@ -511,9 +528,11 @@ class RedisCache(BaseCache): async def async_set_cache_sadd( self, key, value: List, ttl: Optional[float], **kwargs ): + from redis.asyncio import Redis + start_time = time.time() try: - _redis_client = self.init_async_client() + _redis_client: Redis = self.init_async_client() # type: ignore except Exception as e: end_time = time.time() _duration = end_time - start_time @@ -592,9 +611,11 @@ class RedisCache(BaseCache): await self.flush_cache_buffer() # logging done in here async def async_increment( - self, key, value: float, ttl: Optional[float] = None, **kwargs + self, key, value: float, ttl: Optional[int] = None, **kwargs ) -> float: - _redis_client = self.init_async_client() + from redis.asyncio import Redis + + _redis_client: Redis = self.init_async_client() # type: ignore start_time = time.time() try: async with _redis_client as redis_client: @@ -708,7 +729,9 @@ class RedisCache(BaseCache): return key_value_dict async def async_get_cache(self, key, **kwargs): - _redis_client = self.init_async_client() + from redis.asyncio import Redis + + _redis_client: Redis = self.init_async_client() # type: ignore key = self.check_and_fix_namespace(key=key) start_time = time.time() async with _redis_client as redis_client: @@ -903,6 +926,12 @@ class RedisCache(BaseCache): async def disconnect(self): await self.async_redis_conn_pool.disconnect(inuse_connections=True) + async def async_delete_cache(self, key: str): + _redis_client = self.init_async_client() + # keys is str + async with _redis_client as redis_client: + await redis_client.delete(key) + def delete_cache(self, key): self.redis_client.delete(key) @@ -1241,6 +1270,7 @@ class QdrantSemanticCache(BaseCache): get_async_httpx_client, httpxSpecialProvider, ) + from litellm.secret_managers.main import get_secret_str if collection_name is None: raise Exception("collection_name must be provided, passed None") @@ -1261,12 +1291,12 @@ class QdrantSemanticCache(BaseCache): if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith( "os.environ/" ): - qdrant_api_base = litellm.get_secret(qdrant_api_base) + qdrant_api_base = get_secret_str(qdrant_api_base) if qdrant_api_key: if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith( "os.environ/" ): - qdrant_api_key = litellm.get_secret(qdrant_api_key) + qdrant_api_key = get_secret_str(qdrant_api_key) qdrant_api_base = ( qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE") @@ -1633,7 +1663,7 @@ class S3Cache(BaseCache): s3_bucket_name, s3_region_name=None, s3_api_version=None, - s3_use_ssl=True, + s3_use_ssl: Optional[bool] = True, s3_verify=None, s3_endpoint_url=None, s3_aws_access_key_id=None, @@ -1721,7 +1751,7 @@ class S3Cache(BaseCache): Bucket=self.bucket_name, Key=key ) - if cached_response != None: + if cached_response is not None: # cached_response is in `b{} convert it to ModelResponse cached_response = ( cached_response["Body"].read().decode("utf-8") @@ -1739,7 +1769,7 @@ class S3Cache(BaseCache): ) return cached_response - except botocore.exceptions.ClientError as e: + except botocore.exceptions.ClientError as e: # type: ignore if e.response["Error"]["Code"] == "NoSuchKey": verbose_logger.debug( f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket." @@ -2081,6 +2111,15 @@ class DualCache(BaseCache): if self.redis_cache is not None: self.redis_cache.delete_cache(key) + async def async_delete_cache(self, key: str): + """ + Delete a key from the cache + """ + if self.in_memory_cache is not None: + self.in_memory_cache.delete_cache(key) + if self.redis_cache is not None: + await self.redis_cache.async_delete_cache(key) + #### LiteLLM.Completion / Embedding Cache #### class Cache: @@ -2137,7 +2176,7 @@ class Cache: s3_path: Optional[str] = None, redis_semantic_cache_use_async=False, redis_semantic_cache_embedding_model="text-embedding-ada-002", - redis_flush_size=None, + redis_flush_size: Optional[int] = None, redis_startup_nodes: Optional[List] = None, disk_cache_dir=None, qdrant_api_base: Optional[str] = None, @@ -2501,10 +2540,9 @@ class Cache: if self.ttl is not None: kwargs["ttl"] = self.ttl ## Get Cache-Controls ## - if kwargs.get("cache", None) is not None and isinstance( - kwargs.get("cache"), dict - ): - for k, v in kwargs.get("cache").items(): + _cache_kwargs = kwargs.get("cache", None) + if isinstance(_cache_kwargs, dict): + for k, v in _cache_kwargs.items(): if k == "ttl": kwargs["ttl"] = v @@ -2574,14 +2612,15 @@ class Cache: **kwargs, ) cache_list.append((cache_key, cached_data)) - if hasattr(self.cache, "async_set_cache_pipeline"): - await self.cache.async_set_cache_pipeline(cache_list=cache_list) + async_set_cache_pipeline = getattr( + self.cache, "async_set_cache_pipeline", None + ) + if async_set_cache_pipeline: + await async_set_cache_pipeline(cache_list=cache_list) else: tasks = [] for val in cache_list: - tasks.append( - self.cache.async_set_cache(cache_key, cached_data, **kwargs) - ) + tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs)) await asyncio.gather(*tasks) except Exception as e: verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") @@ -2611,13 +2650,15 @@ class Cache: await self.cache.batch_cache_write(cache_key, cached_data, **kwargs) async def ping(self): - if hasattr(self.cache, "ping"): - return await self.cache.ping() + cache_ping = getattr(self.cache, "ping") + if cache_ping: + return await cache_ping() return None async def delete_cache_keys(self, keys): - if hasattr(self.cache, "delete_cache_keys"): - return await self.cache.delete_cache_keys(keys) + cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys") + if cache_delete_cache_keys: + return await cache_delete_cache_keys(keys) return None async def disconnect(self): diff --git a/litellm/exceptions.py b/litellm/exceptions.py index dd9953a32..423ccd603 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -292,8 +292,12 @@ class RateLimitError(openai.RateLimitError): # type: ignore self.litellm_debug_info = litellm_debug_info self.max_retries = max_retries self.num_retries = num_retries + _response_headers = ( + getattr(response, "headers", None) if response is not None else None + ) self.response = httpx.Response( status_code=429, + headers=_response_headers, request=httpx.Request( method="POST", url=" https://cloud.google.com/vertex-ai/", @@ -750,8 +754,14 @@ class InvalidRequestError(openai.BadRequestError): # type: ignore self.message = message self.model = model self.llm_provider = llm_provider + self.response = httpx.Response( + status_code=400, + request=httpx.Request( + method="GET", url="https://litellm.ai" + ), # mock request object + ) super().__init__( - self.message, f"{self.model}" + message=self.message, response=self.response, body=None ) # Call the base class constructor with the parameters it needs diff --git a/litellm/llms/AzureOpenAI/azure.py b/litellm/llms/AzureOpenAI/azure.py index 914126e99..4b8f5ad22 100644 --- a/litellm/llms/AzureOpenAI/azure.py +++ b/litellm/llms/AzureOpenAI/azure.py @@ -767,6 +767,9 @@ class AzureChatCompletion(BaseLLM): except Exception as e: status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) raise AzureOpenAIError( status_code=status_code, message=str(e), headers=error_headers ) @@ -1023,6 +1026,9 @@ class AzureChatCompletion(BaseLLM): except Exception as e: status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) raise AzureOpenAIError( status_code=status_code, message=str(e), headers=error_headers ) @@ -1165,6 +1171,9 @@ class AzureChatCompletion(BaseLLM): except Exception as e: status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) raise AzureOpenAIError( status_code=status_code, message=str(e), headers=error_headers ) diff --git a/litellm/llms/OpenAI/openai.py b/litellm/llms/OpenAI/openai.py index b37af10bf..c67968a5d 100644 --- a/litellm/llms/OpenAI/openai.py +++ b/litellm/llms/OpenAI/openai.py @@ -704,7 +704,6 @@ class OpenAIChatCompletion(BaseLLM): drop_params: Optional[bool] = None, ): super().completion() - exception_mapping_worked = False try: if headers: optional_params["extra_headers"] = headers @@ -911,6 +910,9 @@ class OpenAIChatCompletion(BaseLLM): status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) raise OpenAIError( status_code=status_code, message=error_text, headers=error_headers ) @@ -1003,8 +1005,12 @@ class OpenAIChatCompletion(BaseLLM): raise e # e.message except Exception as e: + exception_response = getattr(e, "response", None) status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) + if error_headers is None and exception_response: + error_headers = getattr(exception_response, "headers", None) + raise OpenAIError( status_code=status_code, message=str(e), headers=error_headers ) @@ -1144,10 +1150,13 @@ class OpenAIChatCompletion(BaseLLM): raise e error_headers = getattr(e, "headers", None) + status_code = getattr(e, "status_code", 500) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) if response is not None and hasattr(response, "text"): - error_headers = getattr(e, "headers", None) raise OpenAIError( - status_code=500, + status_code=status_code, message=f"{str(e)}\n\nOriginal Response: {response.text}", # type: ignore headers=error_headers, ) @@ -1272,8 +1281,12 @@ class OpenAIChatCompletion(BaseLLM): ) status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) raise OpenAIError( - status_code=status_code, message=str(e), headers=error_headers + status_code=status_code, message=error_text, headers=error_headers ) def embedding( # type: ignore @@ -1352,8 +1365,12 @@ class OpenAIChatCompletion(BaseLLM): except Exception as e: status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) raise OpenAIError( - status_code=status_code, message=str(e), headers=error_headers + status_code=status_code, message=error_text, headers=error_headers ) async def aimage_generation( @@ -1774,7 +1791,15 @@ class OpenAITextCompletion(BaseLLM): ## RESPONSE OBJECT return TextCompletionResponse(**response_json) except Exception as e: - raise e + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) async def acompletion( self, @@ -1825,7 +1850,15 @@ class OpenAITextCompletion(BaseLLM): response_obj._hidden_params.original_response = json.dumps(response_json) return response_obj except Exception as e: - raise e + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) def streaming( self, @@ -1860,8 +1893,12 @@ class OpenAITextCompletion(BaseLLM): except Exception as e: status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) raise OpenAIError( - status_code=status_code, message=str(e), headers=error_headers + status_code=status_code, message=error_text, headers=error_headers ) streamwrapper = CustomStreamWrapper( completion_stream=response, @@ -1871,8 +1908,19 @@ class OpenAITextCompletion(BaseLLM): stream_options=data.get("stream_options", None), ) - for chunk in streamwrapper: - yield chunk + try: + for chunk in streamwrapper: + yield chunk + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) async def async_streaming( self, @@ -1910,8 +1958,19 @@ class OpenAITextCompletion(BaseLLM): stream_options=data.get("stream_options", None), ) - async for transformed_chunk in streamwrapper: - yield transformed_chunk + try: + async for transformed_chunk in streamwrapper: + yield transformed_chunk + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) class OpenAIFilesAPI(BaseLLM): diff --git a/litellm/llms/azure_text.py b/litellm/llms/azure_text.py index 6defd58ff..f127680d5 100644 --- a/litellm/llms/azure_text.py +++ b/litellm/llms/azure_text.py @@ -319,6 +319,9 @@ class AzureTextCompletion(BaseLLM): except Exception as e: status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) raise AzureOpenAIError( status_code=status_code, message=str(e), headers=error_headers ) @@ -332,9 +335,9 @@ class AzureTextCompletion(BaseLLM): data: dict, timeout: Any, model_response: ModelResponse, + logging_obj: Any, azure_ad_token: Optional[str] = None, client=None, # this is the AsyncAzureOpenAI - logging_obj=None, ): response = None try: @@ -395,6 +398,9 @@ class AzureTextCompletion(BaseLLM): except Exception as e: status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) raise AzureOpenAIError( status_code=status_code, message=str(e), headers=error_headers ) @@ -526,6 +532,9 @@ class AzureTextCompletion(BaseLLM): except Exception as e: status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) raise AzureOpenAIError( status_code=status_code, message=str(e), headers=error_headers ) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 8772c3100..9d5c35f3b 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1245,7 +1245,8 @@ "mode": "chat", "supports_function_calling": true, "supports_assistant_prefill": true, - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_prompt_caching": true }, "codestral/codestral-latest": { "max_tokens": 8191, @@ -1300,7 +1301,8 @@ "mode": "chat", "supports_function_calling": true, "supports_assistant_prefill": true, - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_prompt_caching": true }, "groq/llama2-70b-4096": { "max_tokens": 4096, @@ -1502,7 +1504,8 @@ "supports_function_calling": true, "supports_vision": true, "tool_use_system_prompt_tokens": 264, - "supports_assistant_prefill": true + "supports_assistant_prefill": true, + "supports_prompt_caching": true }, "claude-3-opus-20240229": { "max_tokens": 4096, @@ -1517,7 +1520,8 @@ "supports_function_calling": true, "supports_vision": true, "tool_use_system_prompt_tokens": 395, - "supports_assistant_prefill": true + "supports_assistant_prefill": true, + "supports_prompt_caching": true }, "claude-3-sonnet-20240229": { "max_tokens": 4096, @@ -1530,7 +1534,8 @@ "supports_function_calling": true, "supports_vision": true, "tool_use_system_prompt_tokens": 159, - "supports_assistant_prefill": true + "supports_assistant_prefill": true, + "supports_prompt_caching": true }, "claude-3-5-sonnet-20240620": { "max_tokens": 8192, @@ -1545,7 +1550,8 @@ "supports_function_calling": true, "supports_vision": true, "tool_use_system_prompt_tokens": 159, - "supports_assistant_prefill": true + "supports_assistant_prefill": true, + "supports_prompt_caching": true }, "text-bison": { "max_tokens": 2048, @@ -2664,6 +2670,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_response_schema": true, + "supports_prompt_caching": true, "source": "https://ai.google.dev/pricing" }, "gemini/gemini-1.5-flash": { @@ -2783,6 +2790,24 @@ "supports_response_schema": true, "source": "https://ai.google.dev/pricing" }, + "gemini/gemini-1.5-pro-001": { + "max_tokens": 8192, + "max_input_tokens": 2097152, + "max_output_tokens": 8192, + "input_cost_per_token": 0.0000035, + "input_cost_per_token_above_128k_tokens": 0.000007, + "output_cost_per_token": 0.0000105, + "output_cost_per_token_above_128k_tokens": 0.000021, + "litellm_provider": "gemini", + "mode": "chat", + "supports_system_messages": true, + "supports_function_calling": true, + "supports_vision": true, + "supports_tool_choice": true, + "supports_response_schema": true, + "supports_prompt_caching": true, + "source": "https://ai.google.dev/pricing" + }, "gemini/gemini-1.5-pro-exp-0801": { "max_tokens": 8192, "max_input_tokens": 2097152, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 1b811fe23..f130e4918 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -26,4 +26,5 @@ model_list: litellm_settings: - success_callback: ["langfuse"] \ No newline at end of file + success_callback: ["langfuse"] + cache: true \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index a8aa759fb..9b2bfbe24 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -632,6 +632,7 @@ class _GenerateKeyRequest(GenerateRequestBase): model_rpm_limit: Optional[dict] = None model_tpm_limit: Optional[dict] = None guardrails: Optional[List[str]] = None + blocked: Optional[bool] = None class GenerateKeyRequest(_GenerateKeyRequest): @@ -967,6 +968,10 @@ class BlockTeamRequest(LiteLLMBase): team_id: str # required +class BlockKeyRequest(LiteLLMBase): + key: str # required + + class AddTeamCallback(LiteLLMBase): callback_name: str callback_type: Literal["success", "failure", "success_and_failure"] @@ -1359,6 +1364,7 @@ class LiteLLM_VerificationToken(LiteLLMBase): model_spend: Dict = {} model_max_budget: Dict = {} soft_budget_cooldown: bool = False + blocked: Optional[bool] = None litellm_budget_table: Optional[dict] = None org_id: Optional[str] = None # org id for a given key @@ -1516,7 +1522,7 @@ class LiteLLM_AuditLogs(LiteLLMBase): updated_at: datetime changed_by: str changed_by_api_key: Optional[str] = None - action: Literal["created", "updated", "deleted"] + action: Literal["created", "updated", "deleted", "blocked"] table_name: Literal[ LitellmTableNames.TEAM_TABLE_NAME, LitellmTableNames.USER_TABLE_NAME, diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index a3b0179c2..70aa3ef39 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -12,6 +12,8 @@ import time from datetime import datetime from typing import TYPE_CHECKING, Any, List, Literal, Optional +from pydantic import BaseModel + import litellm from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache @@ -424,6 +426,24 @@ async def get_user_object( ) +async def _cache_management_object( + key: str, + value: BaseModel, + user_api_key_cache: DualCache, + proxy_logging_obj: Optional[ProxyLogging], +): + await user_api_key_cache.async_set_cache(key=key, value=value) + + ## UPDATE REDIS CACHE ## + if proxy_logging_obj is not None: + _value = value.model_dump_json( + exclude_unset=True, exclude={"parent_otel_span": True} + ) + await proxy_logging_obj.internal_usage_cache.async_set_cache( + key=key, value=_value + ) + + async def _cache_team_object( team_id: str, team_table: LiteLLM_TeamTableCachedObj, @@ -435,20 +455,45 @@ async def _cache_team_object( ## CACHE REFRESH TIME! team_table.last_refreshed_at = time.time() - value = team_table.model_dump_json(exclude_unset=True) - await user_api_key_cache.async_set_cache(key=key, value=value) + await _cache_management_object( + key=key, + value=team_table, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + +async def _cache_key_object( + hashed_token: str, + user_api_key_obj: UserAPIKeyAuth, + user_api_key_cache: DualCache, + proxy_logging_obj: Optional[ProxyLogging], +): + key = hashed_token + + ## CACHE REFRESH TIME! + user_api_key_obj.last_refreshed_at = time.time() + + await _cache_management_object( + key=key, + value=user_api_key_obj, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + +async def _delete_cache_key_object( + hashed_token: str, + user_api_key_cache: DualCache, + proxy_logging_obj: Optional[ProxyLogging], +): + key = hashed_token + + user_api_key_cache.delete_cache(key=key) ## UPDATE REDIS CACHE ## if proxy_logging_obj is not None: - await proxy_logging_obj.internal_usage_cache.async_set_cache( - key=key, value=value - ) - - ## UPDATE REDIS CACHE ## - if proxy_logging_obj is not None: - await proxy_logging_obj.internal_usage_cache.async_set_cache( - key=key, value=team_table - ) + await proxy_logging_obj.internal_usage_cache.async_delete_cache(key=key) @log_to_opentelemetry @@ -524,6 +569,83 @@ async def get_team_object( ) +@log_to_opentelemetry +async def get_key_object( + hashed_token: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, + check_cache_only: Optional[bool] = None, +) -> UserAPIKeyAuth: + """ + - Check if team id in proxy Team Table + - if valid, return LiteLLM_TeamTable object with defined limits + - if not, then raise an error + """ + if prisma_client is None: + raise Exception( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + + # check if in cache + key = hashed_token + cached_team_obj: Optional[UserAPIKeyAuth] = None + + ## CHECK REDIS CACHE ## + if ( + proxy_logging_obj is not None + and proxy_logging_obj.internal_usage_cache.redis_cache is not None + ): + cached_team_obj = ( + await proxy_logging_obj.internal_usage_cache.redis_cache.async_get_cache( + key=key + ) + ) + + if cached_team_obj is None: + cached_team_obj = await user_api_key_cache.async_get_cache(key=key) + + if cached_team_obj is not None: + if isinstance(cached_team_obj, dict): + return UserAPIKeyAuth(**cached_team_obj) + elif isinstance(cached_team_obj, UserAPIKeyAuth): + return cached_team_obj + + if check_cache_only: + raise Exception( + f"Key doesn't exist in cache + check_cache_only=True. key={key}." + ) + + # else, check db + try: + _valid_token: Optional[BaseModel] = await prisma_client.get_data( + token=hashed_token, + table_name="combined_view", + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + + if _valid_token is None: + raise Exception + + _response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True)) + + # save the key object to cache + await _cache_key_object( + hashed_token=hashed_token, + user_api_key_obj=_response, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + return _response + except Exception: + raise Exception( + f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call." + ) + + @log_to_opentelemetry async def get_org_object( org_id: str, diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index d49d9843e..85c252d5d 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -46,11 +46,13 @@ import litellm from litellm._logging import verbose_logger, verbose_proxy_logger from litellm.proxy._types import * from litellm.proxy.auth.auth_checks import ( + _cache_key_object, allowed_routes_check, can_key_call_model, common_checks, get_actual_routes, get_end_user_object, + get_key_object, get_org_object, get_team_object, get_user_object, @@ -525,9 +527,19 @@ async def user_api_key_auth( ### CHECK IF ADMIN ### # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead ## Check CACHE - valid_token: Optional[UserAPIKeyAuth] = user_api_key_cache.get_cache( - key=hash_token(api_key) - ) + try: + valid_token = await get_key_object( + hashed_token=hash_token(api_key), + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + check_cache_only=True, + ) + except Exception: + verbose_logger.debug("api key not found in cache.") + valid_token = None + if ( valid_token is not None and isinstance(valid_token, UserAPIKeyAuth) @@ -578,7 +590,7 @@ async def user_api_key_auth( try: is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore - except Exception as e: + except Exception: is_master_key_valid = False ## VALIDATE MASTER KEY ## @@ -602,8 +614,11 @@ async def user_api_key_auth( parent_otel_span=parent_otel_span, **end_user_params, ) - await user_api_key_cache.async_set_cache( - key=hash_token(master_key), value=_user_api_key_obj + await _cache_key_object( + hashed_token=hash_token(master_key), + user_api_key_obj=_user_api_key_obj, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, ) return _user_api_key_obj @@ -640,38 +655,31 @@ async def user_api_key_auth( _user_role = None if api_key.startswith("sk-"): api_key = hash_token(token=api_key) - valid_token: Optional[UserAPIKeyAuth] = user_api_key_cache.get_cache( # type: ignore - key=api_key - ) + if valid_token is None: - ## check db - verbose_proxy_logger.debug("api key: %s", api_key) - if prisma_client is not None: - _valid_token: Optional[BaseModel] = await prisma_client.get_data( - token=api_key, - table_name="combined_view", + try: + valid_token = await get_key_object( + hashed_token=api_key, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) + # update end-user params on valid token + # These can change per request - it's important to update them here + valid_token.end_user_id = end_user_params.get("end_user_id") + valid_token.end_user_tpm_limit = end_user_params.get( + "end_user_tpm_limit" + ) + valid_token.end_user_rpm_limit = end_user_params.get( + "end_user_rpm_limit" + ) + valid_token.allowed_model_region = end_user_params.get( + "allowed_model_region" + ) - if _valid_token is not None: - ## update cached token - valid_token = UserAPIKeyAuth( - **_valid_token.model_dump(exclude_none=True) - ) - - verbose_proxy_logger.debug("Token from db: %s", valid_token) - elif valid_token is not None and isinstance(valid_token, UserAPIKeyAuth): - verbose_proxy_logger.debug("API Key Cache Hit!") - - # update end-user params on valid token - # These can change per request - it's important to update them here - valid_token.end_user_id = end_user_params.get("end_user_id") - valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit") - valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit") - valid_token.allowed_model_region = end_user_params.get( - "allowed_model_region" - ) + except Exception: + valid_token = None user_obj: Optional[LiteLLM_UserTable] = None valid_token_dict: dict = {} @@ -689,6 +697,12 @@ async def user_api_key_auth( # 8. If token spend is under team budget # 9. If team spend is under team budget + ## base case ## key is disabled + if valid_token.blocked is True: + raise Exception( + "Key is blocked. Update via `/key/unblock` if you're admin." + ) + # Check 1. If token can call model _model_alias_map = {} model: Optional[str] = None @@ -1006,10 +1020,13 @@ async def user_api_key_auth( api_key = valid_token.token # Add hashed token to cache - await user_api_key_cache.async_set_cache( - key=api_key, - value=valid_token, + await _cache_key_object( + hashed_token=api_key, + user_api_key_obj=valid_token, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, ) + valid_token_dict = valid_token.model_dump(exclude_none=True) valid_token_dict.pop("token", None) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index d3c2c942f..cf01597db 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -25,6 +25,11 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, s import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy._types import * +from litellm.proxy.auth.auth_checks import ( + _cache_key_object, + _delete_cache_key_object, + get_key_object, +) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.utils import _duration_in_seconds @@ -302,15 +307,18 @@ async def prepare_key_update_data( data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row ): data_json: dict = data.dict(exclude_unset=True) - key = data_json.pop("key", None) - + data_json.pop("key", None) _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"] non_default_values = {} for k, v in data_json.items(): if k in _metadata_fields: continue - if v is not None and v not in ([], {}, 0): - non_default_values[k] = v + if v is not None: + if not isinstance(v, bool) and v in ([], {}, 0): + pass + else: + non_default_values[k] = v + if "duration" in non_default_values: duration = non_default_values.pop("duration") if duration and (isinstance(duration, str)) and len(duration) > 0: @@ -364,12 +372,10 @@ async def update_key_fn( """ from litellm.proxy.proxy_server import ( create_audit_log_for_update, - general_settings, litellm_proxy_admin_name, prisma_client, proxy_logging_obj, user_api_key_cache, - user_custom_key_generate, ) try: @@ -399,9 +405,11 @@ async def update_key_fn( # Delete - key from cache, since it's been updated! # key updated - a new model could have been added to this key. it should not block requests after this is done - user_api_key_cache.delete_cache(key) - hashed_token = hash_token(key) - user_api_key_cache.delete_cache(hashed_token) + await _delete_cache_key_object( + hashed_token=hash_token(key), + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True if litellm.store_audit_logs is True: @@ -434,6 +442,11 @@ async def update_key_fn( return {"key": key, **response["data"]} # update based on remaining passed in values except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.update_key_fn(): Exception occured - {}".format( + str(e) + ) + ) if isinstance(e, HTTPException): raise ProxyException( message=getattr(e, "detail", f"Authentication Error({str(e)})"), @@ -771,6 +784,7 @@ async def generate_key_helper_fn( float ] = None, # soft_budget is used to set soft Budgets Per user max_budget: Optional[float] = None, # max_budget is used to Budget Per user + blocked: Optional[bool] = None, budget_duration: Optional[str] = None, # max_budget is used to Budget Per user token: Optional[str] = None, key: Optional[ @@ -899,6 +913,7 @@ async def generate_key_helper_fn( "permissions": permissions_json, "model_max_budget": model_max_budget_json, "budget_id": budget_id, + "blocked": blocked, } if ( @@ -1047,6 +1062,7 @@ async def regenerate_key_fn( hash_token, premium_user, prisma_client, + proxy_logging_obj, user_api_key_cache, ) @@ -1124,10 +1140,18 @@ async def regenerate_key_fn( ### 3. remove existing key entry from cache ###################################################################### if key: - user_api_key_cache.delete_cache(key) + await _delete_cache_key_object( + hashed_token=hash_token(key), + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) if hashed_api_key: - user_api_key_cache.delete_cache(hashed_api_key) + await _delete_cache_key_object( + hashed_token=hash_token(key), + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) return GenerateKeyResponse( **updated_token_dict, @@ -1240,3 +1264,187 @@ async def list_keys( param=getattr(e, "param", "None"), code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) + + +@router.post( + "/key/block", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def block_key( + data: BlockKeyRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Blocks all calls from keys with this team id. + """ + from litellm.proxy.proxy_server import ( + create_audit_log_for_update, + hash_token, + litellm_proxy_admin_name, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + if prisma_client is None: + raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value)) + + if data.key.startswith("sk-"): + hashed_token = hash_token(token=data.key) + else: + hashed_token = data.key + + if litellm.store_audit_logs is True: + # make an audit log for key update + record = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed_token} + ) + if record is None: + raise ProxyException( + message=f"Key {data.key} not found", + type=ProxyErrorTypes.bad_request_error, + param="key", + code=status.HTTP_404_NOT_FOUND, + ) + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=hashed_token, + action="blocked", + updated_values="{}", + before_value=record.model_dump_json(), + ) + ) + ) + + record = await prisma_client.db.litellm_verificationtoken.update( + where={"token": hashed_token}, data={"blocked": True} # type: ignore + ) + + ## UPDATE KEY CACHE + + ### get cached object ### + key_object = await get_key_object( + hashed_token=hashed_token, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=proxy_logging_obj, + ) + + ### update cached object ### + key_object.blocked = True + + ### store cached object ### + await _cache_key_object( + hashed_token=hashed_token, + user_api_key_obj=key_object, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + return record + + +@router.post( + "/key/unblock", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def unblock_key( + data: BlockKeyRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Unblocks all calls from this key. + """ + from litellm.proxy.proxy_server import ( + create_audit_log_for_update, + hash_token, + litellm_proxy_admin_name, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + if prisma_client is None: + raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value)) + + if data.key.startswith("sk-"): + hashed_token = hash_token(token=data.key) + else: + hashed_token = data.key + + if litellm.store_audit_logs is True: + # make an audit log for key update + record = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed_token} + ) + if record is None: + raise ProxyException( + message=f"Key {data.key} not found", + type=ProxyErrorTypes.bad_request_error, + param="key", + code=status.HTTP_404_NOT_FOUND, + ) + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=hashed_token, + action="blocked", + updated_values="{}", + before_value=record.model_dump_json(), + ) + ) + ) + + record = await prisma_client.db.litellm_verificationtoken.update( + where={"token": hashed_token}, data={"blocked": False} # type: ignore + ) + + ## UPDATE KEY CACHE + + ### get cached object ### + key_object = await get_key_object( + hashed_token=hashed_token, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=proxy_logging_obj, + ) + + ### update cached object ### + key_object.blocked = False + + ### store cached object ### + await _cache_key_object( + hashed_token=hashed_token, + user_api_key_obj=key_object, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + return record diff --git a/litellm/router.py b/litellm/router.py index 0159a0b17..47b2c8b15 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -154,6 +154,7 @@ class Router: client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds ## SCHEDULER ## polling_interval: Optional[float] = None, + default_priority: Optional[int] = None, ## RELIABILITY ## num_retries: Optional[int] = None, timeout: Optional[float] = None, @@ -220,6 +221,7 @@ class Router: caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None. client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600. polling_interval: (Optional[float]): frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms. + default_priority: (Optional[int]): the default priority for a request. Only for '.scheduler_acompletion()'. Default is None. num_retries (Optional[int]): Number of retries for failed requests. Defaults to 2. timeout (Optional[float]): Timeout for requests. Defaults to None. default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}. @@ -336,6 +338,7 @@ class Router: self.scheduler = Scheduler( polling_interval=polling_interval, redis_cache=redis_cache ) + self.default_priority = default_priority self.default_deployment = None # use this to track the users default deployment, when they want to use model = * self.default_max_parallel_requests = default_max_parallel_requests self.provider_default_deployments: Dict[str, List] = {} @@ -712,12 +715,11 @@ class Router: kwargs["original_function"] = self._acompletion kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) - if kwargs.get("priority", None) is not None and isinstance( - kwargs.get("priority"), int - ): + request_priority = kwargs.get("priority") or self.default_priority + + if request_priority is not None and isinstance(request_priority, int): response = await self.schedule_acompletion(**kwargs) else: response = await self.async_function_with_fallbacks(**kwargs) @@ -3085,9 +3087,9 @@ class Router: except Exception as e: current_attempt = None original_exception = e + """ Retry Logic - """ _healthy_deployments, _all_deployments = ( await self._async_get_healthy_deployments( @@ -3105,16 +3107,6 @@ class Router: content_policy_fallbacks=content_policy_fallbacks, ) - # decides how long to sleep before retry - _timeout = self._time_to_sleep_before_retry( - e=original_exception, - remaining_retries=num_retries, - num_retries=num_retries, - healthy_deployments=_healthy_deployments, - ) - # sleeps for the length of the timeout - await asyncio.sleep(_timeout) - if ( self.retry_policy is not None or self.model_group_retry_policy is not None @@ -3128,11 +3120,19 @@ class Router: ## LOGGING if num_retries > 0: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) + else: + raise + # decides how long to sleep before retry + _timeout = self._time_to_sleep_before_retry( + e=original_exception, + remaining_retries=num_retries, + num_retries=num_retries, + healthy_deployments=_healthy_deployments, + ) + # sleeps for the length of the timeout + await asyncio.sleep(_timeout) for current_attempt in range(num_retries): - verbose_router_logger.debug( - f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}" - ) try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = await original_function(*args, **kwargs) @@ -3370,14 +3370,14 @@ class Router: if ( healthy_deployments is not None and isinstance(healthy_deployments, list) - and len(healthy_deployments) > 0 + and len(healthy_deployments) > 1 ): return 0 response_headers: Optional[httpx.Headers] = None if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore response_headers = e.response.headers # type: ignore - elif hasattr(e, "litellm_response_headers"): + if hasattr(e, "litellm_response_headers"): response_headers = e.litellm_response_headers # type: ignore if response_headers is not None: @@ -3561,7 +3561,7 @@ class Router: except Exception as e: verbose_router_logger.exception( - "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( + "litellm.router.Router::deployment_callback_on_success(): Exception occured - {}".format( str(e) ) ) @@ -5324,7 +5324,6 @@ class Router: return deployment except Exception as e: - traceback_exception = traceback.format_exc() # if router rejects call -> log to langfuse/otel/etc. if request_kwargs is not None: diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index cefebf5e7..75f9f794c 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -21,7 +21,7 @@ class LiteLLMBase(BaseModel): Implements default functions, all pydantic objects should have. """ - def json(self, **kwargs): + def json(self, **kwargs): # type: ignore try: return self.model_dump() # noqa except Exception as e: @@ -185,6 +185,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): deployment_rpm, local_result, ), + headers={"retry-after": 60}, # type: ignore request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore ), ) @@ -207,6 +208,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): deployment_rpm, result, ), + headers={"retry-after": 60}, # type: ignore request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore ), ) @@ -321,15 +323,18 @@ class LowestTPMLoggingHandler_v2(CustomLogger): model_group: str, healthy_deployments: list, tpm_keys: list, - tpm_values: list, + tpm_values: Optional[list], rpm_keys: list, - rpm_values: list, + rpm_values: Optional[list], messages: Optional[List[Dict[str, str]]] = None, input: Optional[Union[str, List]] = None, - ): + ) -> Optional[dict]: """ Common checks for get available deployment, across sync + async implementations """ + if tpm_values is None or rpm_values is None: + return None + tpm_dict = {} # {model_id: 1, ..} for idx, key in enumerate(tpm_keys): tpm_dict[tpm_keys[idx]] = tpm_values[idx] @@ -455,8 +460,12 @@ class LowestTPMLoggingHandler_v2(CustomLogger): keys=combined_tpm_rpm_keys ) # [1, 2, None, ..] - tpm_values = combined_tpm_rpm_values[: len(tpm_keys)] - rpm_values = combined_tpm_rpm_values[len(tpm_keys) :] + if combined_tpm_rpm_values is not None: + tpm_values = combined_tpm_rpm_values[: len(tpm_keys)] + rpm_values = combined_tpm_rpm_values[len(tpm_keys) :] + else: + tpm_values = None + rpm_values = None deployment = self._common_checks_available_deployment( model_group=model_group, @@ -472,7 +481,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): try: assert deployment is not None return deployment - except Exception as e: + except Exception: ### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ### deployment_dict = {} for index, _deployment in enumerate(healthy_deployments): @@ -494,7 +503,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): _deployment_tpm = float("inf") ### GET CURRENT TPM ### - current_tpm = tpm_values[index] + current_tpm = tpm_values[index] if tpm_values else 0 ### GET DEPLOYMENT TPM LIMIT ### _deployment_rpm = None @@ -512,7 +521,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): _deployment_rpm = float("inf") ### GET CURRENT RPM ### - current_rpm = rpm_values[index] + current_rpm = rpm_values[index] if rpm_values else 0 deployment_dict[id] = { "current_tpm": current_tpm, @@ -520,8 +529,16 @@ class LowestTPMLoggingHandler_v2(CustomLogger): "current_rpm": current_rpm, "rpm_limit": _deployment_rpm, } - raise ValueError( - f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}" + raise litellm.RateLimitError( + message=f"{RouterErrors.no_deployments_available.value}. 12345 Passed model={model_group}. Deployments={deployment_dict}", + llm_provider="", + model=model_group, + response=httpx.Response( + status_code=429, + content="", + headers={"retry-after": str(60)}, # type: ignore + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), ) def get_available_deployments( @@ -597,7 +614,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): _deployment_tpm = float("inf") ### GET CURRENT TPM ### - current_tpm = tpm_values[index] + current_tpm = tpm_values[index] if tpm_values else 0 ### GET DEPLOYMENT TPM LIMIT ### _deployment_rpm = None @@ -615,7 +632,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): _deployment_rpm = float("inf") ### GET CURRENT RPM ### - current_rpm = rpm_values[index] + current_rpm = rpm_values[index] if rpm_values else 0 deployment_dict[id] = { "current_tpm": current_tpm, diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py index 0d04aa8b6..e23285422 100644 --- a/litellm/tests/test_exceptions.py +++ b/litellm/tests/test_exceptions.py @@ -910,6 +910,7 @@ async def test_exception_with_headers(sync_mode, provider, model, call_type, str {"message": "litellm.proxy.proxy_server.embeddings(): Exception occured - No deployments available for selected model, Try again in 60 seconds. Passed model=text-embedding-ada-002. pre-call-checks=False, allowed_model_region=n/a, cooldown_list=[('b49cbc9314273db7181fe69b1b19993f04efb88f2c1819947c538bac08097e4c', {'Exception Received': 'litellm.RateLimitError: AzureException RateLimitError - Requests to the Embeddings_Create Operation under Azure OpenAI API version 2023-09-01-preview have exceeded call rate limit of your current OpenAI S0 pricing tier. Please retry after 9 seconds. Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit.', 'Status Code': '429'})]", "level": "ERROR", "timestamp": "2024-08-22T03:25:36.900476"} ``` """ + print(f"Received args: {locals()}") import openai if sync_mode: @@ -939,13 +940,38 @@ async def test_exception_with_headers(sync_mode, provider, model, call_type, str cooldown_time = 30.0 def _return_exception(*args, **kwargs): - from fastapi import HTTPException + import datetime - raise HTTPException( - status_code=429, - detail="Rate Limited!", - headers={"retry-after": cooldown_time}, # type: ignore - ) + from httpx import Headers, Request, Response + + kwargs = { + "request": Request("POST", "https://www.google.com"), + "message": "Error code: 429 - Rate Limit Error!", + "body": {"detail": "Rate Limit Error!"}, + "code": None, + "param": None, + "type": None, + "response": Response( + status_code=429, + headers=Headers( + { + "date": "Sat, 21 Sep 2024 22:56:53 GMT", + "server": "uvicorn", + "retry-after": "30", + "content-length": "30", + "content-type": "application/json", + } + ), + request=Request("POST", "http://0.0.0.0:9000/chat/completions"), + ), + "status_code": 429, + "request_id": None, + } + + exception = Exception() + for k, v in kwargs.items(): + setattr(exception, k, v) + raise exception with patch.object( mapped_target, @@ -975,7 +1001,7 @@ async def test_exception_with_headers(sync_mode, provider, model, call_type, str except litellm.RateLimitError as e: exception_raised = True assert e.litellm_response_headers is not None - assert e.litellm_response_headers["retry-after"] == cooldown_time + assert int(e.litellm_response_headers["retry-after"]) == cooldown_time if exception_raised is False: print(resp) diff --git a/litellm/tests/test_get_model_info.py b/litellm/tests/test_get_model_info.py index 657fdf3ba..3a923bd1e 100644 --- a/litellm/tests/test_get_model_info.py +++ b/litellm/tests/test_get_model_info.py @@ -54,3 +54,11 @@ def test_get_model_info_shows_assistant_prefill(): info = litellm.get_model_info("deepseek/deepseek-chat") print("info", info) assert info.get("supports_assistant_prefill") is True + + +def test_get_model_info_shows_supports_prompt_caching(): + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + info = litellm.get_model_info("deepseek/deepseek-chat") + print("info", info) + assert info.get("supports_prompt_caching") is True diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 5069fc2dc..76277874f 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -2215,16 +2215,39 @@ def test_router_dynamic_cooldown_correct_retry_after_time(): openai_client = openai.OpenAI(api_key="") - cooldown_time = 30.0 + cooldown_time = 30 def _return_exception(*args, **kwargs): - from fastapi import HTTPException + from httpx import Headers, Request, Response - raise HTTPException( - status_code=429, - detail="Rate Limited!", - headers={"retry-after": cooldown_time}, # type: ignore - ) + kwargs = { + "request": Request("POST", "https://www.google.com"), + "message": "Error code: 429 - Rate Limit Error!", + "body": {"detail": "Rate Limit Error!"}, + "code": None, + "param": None, + "type": None, + "response": Response( + status_code=429, + headers=Headers( + { + "date": "Sat, 21 Sep 2024 22:56:53 GMT", + "server": "uvicorn", + "retry-after": f"{cooldown_time}", + "content-length": "30", + "content-type": "application/json", + } + ), + request=Request("POST", "http://0.0.0.0:9000/chat/completions"), + ), + "status_code": 429, + "request_id": None, + } + + exception = Exception() + for k, v in kwargs.items(): + setattr(exception, k, v) + raise exception with patch.object( openai_client.embeddings.with_raw_response, @@ -2250,12 +2273,12 @@ def test_router_dynamic_cooldown_correct_retry_after_time(): print( f"new_retry_after_mock_client.call_args.kwargs: {new_retry_after_mock_client.call_args.kwargs}" ) + print( + f"new_retry_after_mock_client.call_args: {new_retry_after_mock_client.call_args[0][0]}" + ) - response_headers: httpx.Headers = new_retry_after_mock_client.call_args.kwargs[ - "response_headers" - ] - assert "retry-after" in response_headers - assert response_headers["retry-after"] == cooldown_time + response_headers: httpx.Headers = new_retry_after_mock_client.call_args[0][0] + assert int(response_headers["retry-after"]) == cooldown_time @pytest.mark.parametrize("sync_mode", [True, False]) @@ -2270,6 +2293,7 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode): ``` """ litellm.set_verbose = True + cooldown_time = 30.0 router = Router( model_list=[ { @@ -2287,20 +2311,42 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode): ], set_verbose=True, debug_level="DEBUG", + cooldown_time=cooldown_time, ) openai_client = openai.OpenAI(api_key="") - cooldown_time = 30.0 - def _return_exception(*args, **kwargs): - from fastapi import HTTPException + from httpx import Headers, Request, Response - raise HTTPException( - status_code=429, - detail="Rate Limited!", - headers={"retry-after": cooldown_time}, - ) + kwargs = { + "request": Request("POST", "https://www.google.com"), + "message": "Error code: 429 - Rate Limit Error!", + "body": {"detail": "Rate Limit Error!"}, + "code": None, + "param": None, + "type": None, + "response": Response( + status_code=429, + headers=Headers( + { + "date": "Sat, 21 Sep 2024 22:56:53 GMT", + "server": "uvicorn", + "retry-after": f"{cooldown_time}", + "content-length": "30", + "content-type": "application/json", + } + ), + request=Request("POST", "http://0.0.0.0:9000/chat/completions"), + ), + "status_code": 429, + "request_id": None, + } + + exception = Exception() + for k, v in kwargs.items(): + setattr(exception, k, v) + raise exception with patch.object( openai_client.embeddings.with_raw_response, diff --git a/litellm/tests/test_router_max_parallel_requests.py b/litellm/tests/test_router_max_parallel_requests.py index f9cac6aaf..33ca17d8b 100644 --- a/litellm/tests/test_router_max_parallel_requests.py +++ b/litellm/tests/test_router_max_parallel_requests.py @@ -1,13 +1,20 @@ # What is this? ## Unit tests for the max_parallel_requests feature on Router -import sys, os, time, inspect, asyncio, traceback +import asyncio +import inspect +import os +import sys +import time +import traceback from datetime import datetime + import pytest sys.path.insert(0, os.path.abspath("../..")) +from typing import Optional + import litellm from litellm.utils import calculate_max_parallel_requests -from typing import Optional """ - only rpm @@ -113,3 +120,91 @@ def test_setting_mpr_limits_per_model( assert mpr_client is None # raise Exception("it worked!") + + +async def _handle_router_calls(router): + import random + + pre_fill = """ + Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nunc ut finibus massa. Quisque a magna magna. Quisque neque diam, varius sit amet tellus eu, elementum fermentum sapien. Integer ut erat eget arcu rutrum blandit. Morbi a metus purus. Nulla porta, urna at finibus malesuada, velit ante suscipit orci, vitae laoreet dui ligula ut augue. Cras elementum pretium dui, nec luctus nulla aliquet ut. Nam faucibus, diam nec semper interdum, nisl nisi viverra nulla, vitae sodales elit ex a purus. Donec tristique malesuada lobortis. Donec posuere iaculis nisl, vitae accumsan libero dignissim dignissim. Suspendisse finibus leo et ex mattis tempor. Praesent at nisl vitae quam egestas lacinia. Donec in justo non erat aliquam accumsan sed vitae ex. Vivamus gravida diam vel ipsum tincidunt dignissim. + + Cras vitae efficitur tortor. Curabitur vel erat mollis, euismod diam quis, consequat nibh. Ut vel est eu nulla euismod finibus. Aliquam euismod at risus quis dignissim. Integer non auctor massa. Nullam vitae aliquet mauris. Etiam risus enim, dignissim ut volutpat eget, pulvinar ac augue. Mauris elit est, ultricies vel convallis at, rhoncus nec elit. Aenean ornare maximus orci, ut maximus felis cursus venenatis. Nulla facilisi. + + Maecenas aliquet ante massa, at ullamcorper nibh dictum quis. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Quisque id egestas justo. Suspendisse fringilla in massa in consectetur. Quisque scelerisque egestas lacus at posuere. Vestibulum dui sem, bibendum vehicula ultricies vel, blandit id nisi. Curabitur ullamcorper semper metus, vitae commodo magna. Nulla mi metus, suscipit in neque vitae, porttitor pharetra erat. Vestibulum libero velit, congue in diam non, efficitur suscipit diam. Integer arcu velit, fermentum vel tortor sit amet, venenatis rutrum felis. Donec ultricies enim sit amet iaculis mattis. + + Integer at purus posuere, malesuada tortor vitae, mattis nibh. Mauris ex quam, tincidunt et fermentum vitae, iaculis non elit. Nullam dapibus non nisl ac sagittis. Duis lacinia eros iaculis lectus consectetur vehicula. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Interdum et malesuada fames ac ante ipsum primis in faucibus. Ut cursus semper est, vel interdum turpis ultrices dictum. Suspendisse posuere lorem et accumsan ultrices. Duis sagittis bibendum consequat. Ut convallis vestibulum enim, non dapibus est porttitor et. Quisque suscipit pulvinar turpis, varius tempor turpis. Vestibulum semper dui nunc, vel vulputate elit convallis quis. Fusce aliquam enim nulla, eu congue nunc tempus eu. + + Nam vitae finibus eros, eu eleifend erat. Maecenas hendrerit magna quis molestie dictum. Ut consequat quam eu massa auctor pulvinar. Pellentesque vitae eros ornare urna accumsan tempor. Maecenas porta id quam at sodales. Donec quis accumsan leo, vel viverra nibh. Vestibulum congue blandit nulla, sed rhoncus libero eleifend ac. In risus lorem, rutrum et tincidunt a, interdum a lectus. Pellentesque aliquet pulvinar mauris, ut ultrices nibh ultricies nec. Mauris mi mauris, facilisis nec metus non, egestas luctus ligula. Quisque ac ligula at felis mollis blandit id nec risus. Nam sollicitudin lacus sed sapien fringilla ullamcorper. Etiam dui quam, posuere sit amet velit id, aliquet molestie ante. Integer cursus eget sapien fringilla elementum. Integer molestie, mi ac scelerisque ultrices, nunc purus condimentum est, in posuere quam nibh vitae velit. + """ + completion = await router.acompletion( + "gpt-4o-2024-08-06", + [ + { + "role": "user", + "content": f"{pre_fill * 3}\n\nRecite the Declaration of independence at a speed of {random.random() * 100} words per minute.", + } + ], + stream=True, + temperature=0.0, + stream_options={"include_usage": True}, + ) + + async for chunk in completion: + pass + print("done", chunk) + + +@pytest.mark.asyncio +async def test_max_parallel_requests_rpm_rate_limiting(): + """ + - make sure requests > model limits are retried successfully. + """ + from litellm import Router + + router = Router( + routing_strategy="usage-based-routing-v2", + enable_pre_call_checks=True, + model_list=[ + { + "model_name": "gpt-4o-2024-08-06", + "litellm_params": { + "model": "gpt-4o-2024-08-06", + "temperature": 0.0, + "rpm": 5, + }, + } + ], + ) + await asyncio.gather(*[_handle_router_calls(router) for _ in range(16)]) + + +@pytest.mark.asyncio +async def test_max_parallel_requests_tpm_rate_limiting_base_case(): + """ + - check error raised if defined tpm limit crossed. + """ + from litellm import Router, token_counter + + _messages = [{"role": "user", "content": "Hey, how's it going?"}] + router = Router( + routing_strategy="usage-based-routing-v2", + enable_pre_call_checks=True, + model_list=[ + { + "model_name": "gpt-4o-2024-08-06", + "litellm_params": { + "model": "gpt-4o-2024-08-06", + "temperature": 0.0, + "tpm": 1, + }, + } + ], + num_retries=0, + ) + + with pytest.raises(litellm.RateLimitError): + for _ in range(2): + await router.acompletion( + model="gpt-4o-2024-08-06", + messages=_messages, + ) diff --git a/litellm/tests/test_user_api_key_auth.py b/litellm/tests/test_user_api_key_auth.py index e7b01aa3f..6f0132312 100644 --- a/litellm/tests/test_user_api_key_auth.py +++ b/litellm/tests/test_user_api_key_auth.py @@ -113,7 +113,9 @@ async def test_check_blocked_team(): team_obj = LiteLLM_TeamTableCachedObj( team_id=_team_id, blocked=False, last_refreshed_at=time.time() ) - user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token) + hashed_token = hash_token(user_key) + print(f"STORING TOKEN UNDER KEY={hashed_token}") + user_api_key_cache.set_cache(key=hashed_token, value=valid_token) user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj) setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) diff --git a/litellm/types/router.py b/litellm/types/router.py index 67870c313..304e6fd43 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -302,6 +302,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): rpm: Optional[int] order: Optional[int] weight: Optional[int] + max_parallel_requests: Optional[int] api_key: Optional[str] api_base: Optional[str] api_version: Optional[str] diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 54a4a920a..0c92dabf5 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -86,6 +86,7 @@ class ModelInfo(TypedDict, total=False): supports_vision: Optional[bool] supports_function_calling: Optional[bool] supports_assistant_prefill: Optional[bool] + supports_prompt_caching: Optional[bool] class GenericStreamingChunk(TypedDict, total=False): diff --git a/litellm/utils.py b/litellm/utils.py index 2b97e9d71..1c9d7bde7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4762,6 +4762,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod supports_response_schema: Optional[bool] supports_vision: Optional[bool] supports_function_calling: Optional[bool] + supports_prompt_caching: Optional[bool] Raises: Exception: If the model is not mapped yet. @@ -4849,6 +4850,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod supports_response_schema=None, supports_function_calling=None, supports_assistant_prefill=None, + supports_prompt_caching=None, ) else: """ @@ -5008,6 +5010,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod supports_assistant_prefill=_model_info.get( "supports_assistant_prefill", False ), + supports_prompt_caching=_model_info.get( + "supports_prompt_caching", False + ), ) except Exception as e: raise Exception( @@ -6261,6 +6266,9 @@ def _get_response_headers(original_exception: Exception) -> Optional[httpx.Heade _response_headers: Optional[httpx.Headers] = None try: _response_headers = getattr(original_exception, "headers", None) + error_response = getattr(original_exception, "response", None) + if _response_headers is None and error_response: + _response_headers = getattr(error_response, "headers", None) except Exception: return None diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 8772c3100..9d5c35f3b 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1245,7 +1245,8 @@ "mode": "chat", "supports_function_calling": true, "supports_assistant_prefill": true, - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_prompt_caching": true }, "codestral/codestral-latest": { "max_tokens": 8191, @@ -1300,7 +1301,8 @@ "mode": "chat", "supports_function_calling": true, "supports_assistant_prefill": true, - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_prompt_caching": true }, "groq/llama2-70b-4096": { "max_tokens": 4096, @@ -1502,7 +1504,8 @@ "supports_function_calling": true, "supports_vision": true, "tool_use_system_prompt_tokens": 264, - "supports_assistant_prefill": true + "supports_assistant_prefill": true, + "supports_prompt_caching": true }, "claude-3-opus-20240229": { "max_tokens": 4096, @@ -1517,7 +1520,8 @@ "supports_function_calling": true, "supports_vision": true, "tool_use_system_prompt_tokens": 395, - "supports_assistant_prefill": true + "supports_assistant_prefill": true, + "supports_prompt_caching": true }, "claude-3-sonnet-20240229": { "max_tokens": 4096, @@ -1530,7 +1534,8 @@ "supports_function_calling": true, "supports_vision": true, "tool_use_system_prompt_tokens": 159, - "supports_assistant_prefill": true + "supports_assistant_prefill": true, + "supports_prompt_caching": true }, "claude-3-5-sonnet-20240620": { "max_tokens": 8192, @@ -1545,7 +1550,8 @@ "supports_function_calling": true, "supports_vision": true, "tool_use_system_prompt_tokens": 159, - "supports_assistant_prefill": true + "supports_assistant_prefill": true, + "supports_prompt_caching": true }, "text-bison": { "max_tokens": 2048, @@ -2664,6 +2670,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_response_schema": true, + "supports_prompt_caching": true, "source": "https://ai.google.dev/pricing" }, "gemini/gemini-1.5-flash": { @@ -2783,6 +2790,24 @@ "supports_response_schema": true, "source": "https://ai.google.dev/pricing" }, + "gemini/gemini-1.5-pro-001": { + "max_tokens": 8192, + "max_input_tokens": 2097152, + "max_output_tokens": 8192, + "input_cost_per_token": 0.0000035, + "input_cost_per_token_above_128k_tokens": 0.000007, + "output_cost_per_token": 0.0000105, + "output_cost_per_token_above_128k_tokens": 0.000021, + "litellm_provider": "gemini", + "mode": "chat", + "supports_system_messages": true, + "supports_function_calling": true, + "supports_vision": true, + "supports_tool_choice": true, + "supports_response_schema": true, + "supports_prompt_caching": true, + "source": "https://ai.google.dev/pricing" + }, "gemini/gemini-1.5-pro-exp-0801": { "max_tokens": 8192, "max_input_tokens": 2097152, diff --git a/tests/otel_tests/test_otel.py b/tests/otel_tests/test_otel.py index d0d312128..529b499d9 100644 --- a/tests/otel_tests/test_otel.py +++ b/tests/otel_tests/test_otel.py @@ -111,7 +111,11 @@ async def test_chat_completion_check_otel_spans(): print("Parent trace spans: ", parent_trace_spans) # either 5 or 6 traces depending on how many redis calls were made - assert len(parent_trace_spans) == 6 or len(parent_trace_spans) == 5 + assert ( + len(parent_trace_spans) == 6 + or len(parent_trace_spans) == 5 + or len(parent_trace_spans) == 7 + ) # 'postgres', 'redis', 'raw_gen_ai_request', 'litellm_request', 'Received Proxy Server Request' in the span assert "postgres" in parent_trace_spans