Merge branch 'main' into litellm_redact_messages_slack_alerting

This commit is contained in:
Ishaan Jaff 2024-06-07 12:43:53 -07:00 committed by GitHub
commit d2857fc24c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
81 changed files with 3529 additions and 5173 deletions

1
.gitignore vendored
View file

@ -59,3 +59,4 @@ myenv/*
litellm/proxy/_experimental/out/404/index.html litellm/proxy/_experimental/out/404/index.html
litellm/proxy/_experimental/out/model_hub/index.html litellm/proxy/_experimental/out/model_hub/index.html
litellm/proxy/_experimental/out/onboarding/index.html litellm/proxy/_experimental/out/onboarding/index.html
litellm/tests/log.txt

View file

@ -38,7 +38,7 @@ class MyCustomHandler(CustomLogger):
print(f"On Async Success") print(f"On Async Success")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Success") print(f"On Async Failure")
customHandler = MyCustomHandler() customHandler = MyCustomHandler()

View file

@ -0,0 +1,3 @@
llmcord.py lets you and your friends chat with LLMs directly in your Discord server. It works with practically any LLM, remote or locally hosted.
Github: https://github.com/jakobdylanc/discord-llm-chatbot

View file

@ -14,6 +14,7 @@ Features:
- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features) - ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features)
- ✅ [Audit Logs](#audit-logs) - ✅ [Audit Logs](#audit-logs)
- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags) - ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags)
- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests)
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Google Text Moderations](#content-moderation) - ✅ [Content Moderation with LLM Guard, LlamaGuard, Google Text Moderations](#content-moderation)
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai) - ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
- ✅ [Custom Branding + Routes on Swagger Docs](#swagger-docs---custom-routes--branding) - ✅ [Custom Branding + Routes on Swagger Docs](#swagger-docs---custom-routes--branding)
@ -204,6 +205,109 @@ curl -X GET "http://0.0.0.0:4000/spend/tags" \
``` ```
## Enforce Required Params for LLM Requests
Use this when you want to enforce all requests to include certain params. Example you need all requests to include the `user` and `["metadata]["generation_name"]` params.
**Step 1** Define all Params you want to enforce on config.yaml
This means `["user"]` and `["metadata]["generation_name"]` are required in all LLM Requests to LiteLLM
```yaml
general_settings:
master_key: sk-1234
enforced_params:
- user
- metadata.generation_name
```
Start LiteLLM Proxy
**Step 2 Verify if this works**
<Tabs>
<TabItem value="bad" label="Invalid Request (No `user` passed)">
```shell
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-5fmYeaUEbAMpwBNT-QpxyA' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "hi"
}
]
}'
```
Expected Response
```shell
{"error":{"message":"Authentication Error, BadRequest please pass param=user in request body. This is a required param","type":"auth_error","param":"None","code":401}}%
```
</TabItem>
<TabItem value="bad2" label="Invalid Request (No `metadata` passed)">
```shell
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-5fmYeaUEbAMpwBNT-QpxyA' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-3.5-turbo",
"user": "gm",
"messages": [
{
"role": "user",
"content": "hi"
}
],
"metadata": {}
}'
```
Expected Response
```shell
{"error":{"message":"Authentication Error, BadRequest please pass param=[metadata][generation_name] in request body. This is a required param","type":"auth_error","param":"None","code":401}}%
```
</TabItem>
<TabItem value="good" label="Valid Request">
```shell
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-5fmYeaUEbAMpwBNT-QpxyA' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-3.5-turbo",
"user": "gm",
"messages": [
{
"role": "user",
"content": "hi"
}
],
"metadata": {"generation_name": "prod-app"}
}'
```
Expected Response
```shell
{"id":"chatcmpl-9XALnHqkCBMBKrOx7Abg0hURHqYtY","choices":[{"finish_reason":"stop","index":0,"message":{"content":"Hello! How can I assist you today?","role":"assistant"}}],"created":1717691639,"model":"gpt-3.5-turbo-0125","object":"chat.completion","system_fingerprint":null,"usage":{"completion_tokens":9,"prompt_tokens":8,"total_tokens":17}}%
```
</TabItem>
</Tabs>

View file

@ -313,6 +313,18 @@ You will see `raw_request` in your Langfuse Metadata. This is the RAW CURL comma
## Logging Proxy Input/Output in OpenTelemetry format ## Logging Proxy Input/Output in OpenTelemetry format
:::info
[Optional] Customize OTEL Service Name and OTEL TRACER NAME by setting the following variables in your environment
```shell
OTEL_TRACER_NAME=<your-trace-name> # default="litellm"
OTEL_SERVICE_NAME=<your-service-name>` # default="litellm"
```
:::
<Tabs> <Tabs>

View file

@ -101,3 +101,75 @@ print(response)
</TabItem> </TabItem>
</Tabs> </Tabs>
## Advanced - Redis Caching
Use redis caching to do request prioritization across multiple instances of LiteLLM.
### SDK
```python
from litellm import Router
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "Hello world this is Macintosh!", # fakes the LLM API call
"rpm": 1,
},
},
],
### REDIS PARAMS ###
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
)
try:
_response = await router.schedule_acompletion( # 👈 ADDS TO QUEUE + POLLS + MAKES CALL
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey!"}],
priority=0, # 👈 LOWER IS BETTER
)
except Exception as e:
print("didn't make request")
```
### PROXY
```yaml
model_list:
- model_name: gpt-3.5-turbo-fake-model
litellm_params:
model: gpt-3.5-turbo
mock_response: "hello world!"
api_key: my-good-key
router_settings:
redis_host; os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
```
```bash
$ litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000s
```
```bash
curl -X POST 'http://localhost:4000/queue/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-D '{
"model": "gpt-3.5-turbo-fake-model",
"messages": [
{
"role": "user",
"content": "what is the meaning of the universe? 1234"
}],
"priority": 0 👈 SET VALUE HERE
}'
```

View file

@ -1,11 +1,31 @@
# Secret Manager # Secret Manager
LiteLLM supports reading secrets from Azure Key Vault and Infisical LiteLLM supports reading secrets from Azure Key Vault and Infisical
- AWS Key Managemenet Service
- AWS Secret Manager
- [Azure Key Vault](#azure-key-vault) - [Azure Key Vault](#azure-key-vault)
- Google Key Management Service - Google Key Management Service
- [Infisical Secret Manager](#infisical-secret-manager) - [Infisical Secret Manager](#infisical-secret-manager)
- [.env Files](#env-files) - [.env Files](#env-files)
## AWS Key Management Service
Use AWS KMS to storing a hashed copy of your Proxy Master Key in the environment.
```bash
export LITELLM_MASTER_KEY="djZ9xjVaZ..." # 👈 ENCRYPTED KEY
export AWS_REGION_NAME="us-west-2"
```
```yaml
general_settings:
key_management_system: "aws_kms"
key_management_settings:
hosted_keys: ["LITELLM_MASTER_KEY"] # 👈 WHICH KEYS ARE STORED ON KMS
```
[**See Decryption Code**](https://github.com/BerriAI/litellm/blob/a2da2a8f168d45648b61279d4795d647d94f90c9/litellm/utils.py#L10182)
## AWS Secret Manager ## AWS Secret Manager
Store your proxy keys in AWS Secret Manager. Store your proxy keys in AWS Secret Manager.

View file

@ -5975,9 +5975,9 @@
} }
}, },
"node_modules/caniuse-lite": { "node_modules/caniuse-lite": {
"version": "1.0.30001519", "version": "1.0.30001629",
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001519.tgz", "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001629.tgz",
"integrity": "sha512-0QHgqR+Jv4bxHMp8kZ1Kn8CH55OikjKJ6JmKkZYP1F3D7w+lnFXF70nG5eNfsZS89jadi5Ywy5UCSKLAglIRkg==", "integrity": "sha512-c3dl911slnQhmxUIT4HhYzT7wnBK/XYpGnYLOj4nJBaRiw52Ibe7YxlDaAeRECvA786zCuExhxIUJ2K7nHMrBw==",
"funding": [ "funding": [
{ {
"type": "opencollective", "type": "opencollective",

File diff suppressed because it is too large Load diff

View file

@ -18,10 +18,6 @@ async def log_event(request: Request):
return {"message": "Request received successfully"} return {"message": "Request received successfully"}
except Exception as e: except Exception as e:
print(f"Error processing request: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail="Internal Server Error") raise HTTPException(status_code=500, detail="Internal Server Error")

View file

@ -120,6 +120,5 @@ class GenericAPILogger:
) )
return response return response
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(f"Generic - {str(e)}\n{traceback.format_exc()}")
verbose_logger.debug(f"Generic - {str(e)}\n{traceback.format_exc()}")
pass pass

View file

@ -82,7 +82,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
except HTTPException as e: except HTTPException as e:
raise e raise e
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(traceback.format_exc())
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,

View file

@ -118,4 +118,4 @@ class _ENTERPRISE_BlockedUserList(CustomLogger):
except HTTPException as e: except HTTPException as e:
raise e raise e
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(traceback.format_exc())

View file

@ -92,7 +92,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
}, },
) )
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(traceback.format_exc())
raise e raise e
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool: def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool:

View file

@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*
### INIT VARIABLES ### ### INIT VARIABLES ###
import threading, requests, os import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any, Literal from typing import Callable, List, Optional, Dict, Union, Any, Literal
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.caching import Cache from litellm.caching import Cache
from litellm._logging import ( from litellm._logging import (
set_verbose, set_verbose,
@ -234,6 +234,7 @@ max_end_user_budget: Optional[float] = None
#### RELIABILITY #### #### RELIABILITY ####
request_timeout: float = 6000 request_timeout: float = 6000
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout) module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
module_level_client = HTTPHandler(timeout=request_timeout)
num_retries: Optional[int] = None # per model endpoint num_retries: Optional[int] = None # per model endpoint
default_fallbacks: Optional[List] = None default_fallbacks: Optional[List] = None
fallbacks: Optional[List] = None fallbacks: Optional[List] = None
@ -767,7 +768,7 @@ from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig
from .llms.bedrock import ( from .llms.bedrock import (
AmazonTitanConfig, AmazonTitanConfig,
AmazonAI21Config, AmazonAI21Config,
@ -809,6 +810,7 @@ from .exceptions import (
APIConnectionError, APIConnectionError,
APIResponseValidationError, APIResponseValidationError,
UnprocessableEntityError, UnprocessableEntityError,
InternalServerError,
LITELLM_EXCEPTION_TYPES, LITELLM_EXCEPTION_TYPES,
) )
from .budget_manager import BudgetManager from .budget_manager import BudgetManager

View file

@ -1,5 +1,6 @@
import logging, os, json import logging, os, json
from logging import Formatter from logging import Formatter
import traceback
set_verbose = False set_verbose = False
json_logs = bool(os.getenv("JSON_LOGS", False)) json_logs = bool(os.getenv("JSON_LOGS", False))

View file

@ -253,7 +253,6 @@ class RedisCache(BaseCache):
str(e), str(e),
value, value,
) )
traceback.print_exc()
raise e raise e
async def async_scan_iter(self, pattern: str, count: int = 100) -> list: async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
@ -313,7 +312,6 @@ class RedisCache(BaseCache):
str(e), str(e),
value, value,
) )
traceback.print_exc()
key = self.check_and_fix_namespace(key=key) key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client: async with _redis_client as redis_client:
@ -352,7 +350,6 @@ class RedisCache(BaseCache):
str(e), str(e),
value, value,
) )
traceback.print_exc()
async def async_set_cache_pipeline(self, cache_list, ttl=None): async def async_set_cache_pipeline(self, cache_list, ttl=None):
""" """
@ -413,7 +410,6 @@ class RedisCache(BaseCache):
str(e), str(e),
cache_value, cache_value,
) )
traceback.print_exc()
async def batch_cache_write(self, key, value, **kwargs): async def batch_cache_write(self, key, value, **kwargs):
print_verbose( print_verbose(
@ -458,7 +454,6 @@ class RedisCache(BaseCache):
str(e), str(e),
value, value,
) )
traceback.print_exc()
raise e raise e
async def flush_cache_buffer(self): async def flush_cache_buffer(self):
@ -495,8 +490,9 @@ class RedisCache(BaseCache):
return self._get_cache_logic(cached_response=cached_response) return self._get_cache_logic(cached_response=cached_response)
except Exception as e: except Exception as e:
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
traceback.print_exc() verbose_logger.error(
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) "LiteLLM Caching: get() - Got exception from REDIS: ", e
)
def batch_get_cache(self, key_list) -> dict: def batch_get_cache(self, key_list) -> dict:
""" """
@ -646,10 +642,9 @@ class RedisCache(BaseCache):
error=e, error=e,
call_type="sync_ping", call_type="sync_ping",
) )
print_verbose( verbose_logger.error(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}" f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
) )
traceback.print_exc()
raise e raise e
async def ping(self) -> bool: async def ping(self) -> bool:
@ -683,10 +678,9 @@ class RedisCache(BaseCache):
call_type="async_ping", call_type="async_ping",
) )
) )
print_verbose( verbose_logger.error(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}" f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
) )
traceback.print_exc()
raise e raise e
async def delete_cache_keys(self, keys): async def delete_cache_keys(self, keys):
@ -1138,22 +1132,23 @@ class S3Cache(BaseCache):
cached_response = ast.literal_eval(cached_response) cached_response = ast.literal_eval(cached_response)
if type(cached_response) is not dict: if type(cached_response) is not dict:
cached_response = dict(cached_response) cached_response = dict(cached_response)
print_verbose( verbose_logger.debug(
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}" f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
) )
return cached_response return cached_response
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey": if e.response["Error"]["Code"] == "NoSuchKey":
print_verbose( verbose_logger.error(
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket." f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
) )
return None return None
except Exception as e: except Exception as e:
# NON blocking - notify users S3 is throwing an exception # NON blocking - notify users S3 is throwing an exception
traceback.print_exc() verbose_logger.error(
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}") f"S3 Caching: get_cache() - Got exception from S3: {e}"
)
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs) return self.get_cache(key=key, **kwargs)
@ -1234,8 +1229,7 @@ class DualCache(BaseCache):
return result return result
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
raise e raise e
def get_cache(self, key, local_only: bool = False, **kwargs): def get_cache(self, key, local_only: bool = False, **kwargs):
@ -1262,7 +1256,7 @@ class DualCache(BaseCache):
print_verbose(f"get cache: cache result: {result}") print_verbose(f"get cache: cache result: {result}")
return result return result
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(traceback.format_exc())
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs): def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
try: try:
@ -1295,7 +1289,7 @@ class DualCache(BaseCache):
print_verbose(f"async batch get cache: cache result: {result}") print_verbose(f"async batch get cache: cache result: {result}")
return result return result
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(traceback.format_exc())
async def async_get_cache(self, key, local_only: bool = False, **kwargs): async def async_get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first # Try to fetch from in-memory cache first
@ -1328,7 +1322,7 @@ class DualCache(BaseCache):
print_verbose(f"get cache: cache result: {result}") print_verbose(f"get cache: cache result: {result}")
return result return result
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(traceback.format_exc())
async def async_batch_get_cache( async def async_batch_get_cache(
self, keys: list, local_only: bool = False, **kwargs self, keys: list, local_only: bool = False, **kwargs
@ -1368,7 +1362,7 @@ class DualCache(BaseCache):
return result return result
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(traceback.format_exc())
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
print_verbose( print_verbose(
@ -1381,8 +1375,8 @@ class DualCache(BaseCache):
if self.redis_cache is not None and local_only == False: if self.redis_cache is not None and local_only == False:
await self.redis_cache.async_set_cache(key, value, **kwargs) await self.redis_cache.async_set_cache(key, value, **kwargs)
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc() verbose_logger.debug(traceback.format_exc())
async def async_batch_set_cache( async def async_batch_set_cache(
self, cache_list: list, local_only: bool = False, **kwargs self, cache_list: list, local_only: bool = False, **kwargs
@ -1404,8 +1398,8 @@ class DualCache(BaseCache):
cache_list=cache_list, ttl=kwargs.get("ttl", None) cache_list=cache_list, ttl=kwargs.get("ttl", None)
) )
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc() verbose_logger.debug(traceback.format_exc())
async def async_increment_cache( async def async_increment_cache(
self, key, value: float, local_only: bool = False, **kwargs self, key, value: float, local_only: bool = False, **kwargs
@ -1429,8 +1423,8 @@ class DualCache(BaseCache):
return result return result
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc() verbose_logger.debug(traceback.format_exc())
raise e raise e
def flush_cache(self): def flush_cache(self):
@ -1846,8 +1840,8 @@ class Cache:
) )
self.cache.set_cache(cache_key, cached_data, **kwargs) self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") verbose_logger.error(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc() verbose_logger.debug(traceback.format_exc())
pass pass
async def async_add_cache(self, result, *args, **kwargs): async def async_add_cache(self, result, *args, **kwargs):
@ -1864,8 +1858,8 @@ class Cache:
) )
await self.cache.async_set_cache(cache_key, cached_data, **kwargs) await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") verbose_logger.error(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc() verbose_logger.debug(traceback.format_exc())
async def async_add_cache_pipeline(self, result, *args, **kwargs): async def async_add_cache_pipeline(self, result, *args, **kwargs):
""" """
@ -1897,8 +1891,8 @@ class Cache:
) )
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") verbose_logger.error(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc() verbose_logger.debug(traceback.format_exc())
async def batch_cache_write(self, result, *args, **kwargs): async def batch_cache_write(self, result, *args, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic( cache_key, cached_data, kwargs = self._add_cache_logic(

View file

@ -638,6 +638,7 @@ LITELLM_EXCEPTION_TYPES = [
APIConnectionError, APIConnectionError,
APIResponseValidationError, APIResponseValidationError,
OpenAIError, OpenAIError,
InternalServerError,
] ]

View file

@ -169,6 +169,5 @@ class AISpendLogger:
print_verbose(f"AISpend Logging - final data object: {data}") print_verbose(f"AISpend Logging - final data object: {data}")
except: except:
# traceback.print_exc()
print_verbose(f"AISpend Logging Error - {traceback.format_exc()}") print_verbose(f"AISpend Logging Error - {traceback.format_exc()}")
pass pass

View file

@ -178,6 +178,5 @@ class BerriSpendLogger:
print_verbose(f"BerriSpend Logging - final data object: {data}") print_verbose(f"BerriSpend Logging - final data object: {data}")
response = requests.post(url, headers=headers, json=data) response = requests.post(url, headers=headers, json=data)
except: except:
# traceback.print_exc()
print_verbose(f"BerriSpend Logging Error - {traceback.format_exc()}") print_verbose(f"BerriSpend Logging Error - {traceback.format_exc()}")
pass pass

View file

@ -297,6 +297,5 @@ class ClickhouseLogger:
# make request to endpoint with payload # make request to endpoint with payload
verbose_logger.debug(f"Clickhouse Logger - final response = {response}") verbose_logger.debug(f"Clickhouse Logger - final response = {response}")
except Exception as e: except Exception as e:
traceback.print_exc()
verbose_logger.debug(f"Clickhouse - {str(e)}\n{traceback.format_exc()}") verbose_logger.debug(f"Clickhouse - {str(e)}\n{traceback.format_exc()}")
pass pass

View file

@ -115,7 +115,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
) )
print_verbose(f"Custom Logger - model call details: {kwargs}") print_verbose(f"Custom Logger - model call details: {kwargs}")
except: except:
traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
async def async_log_input_event( async def async_log_input_event(
@ -130,7 +129,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
) )
print_verbose(f"Custom Logger - model call details: {kwargs}") print_verbose(f"Custom Logger - model call details: {kwargs}")
except: except:
traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
def log_event( def log_event(
@ -146,7 +144,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
end_time, end_time,
) )
except: except:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass pass
@ -163,6 +160,5 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
end_time, end_time,
) )
except: except:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass pass

View file

@ -134,7 +134,6 @@ class DataDogLogger:
f"Datadog Layer Logging - final response object: {response_obj}" f"Datadog Layer Logging - final response object: {response_obj}"
) )
except Exception as e: except Exception as e:
traceback.print_exc()
verbose_logger.debug( verbose_logger.debug(
f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}" f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
) )

View file

@ -85,6 +85,5 @@ class DyanmoDBLogger:
) )
return response return response
except: except:
traceback.print_exc()
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}") print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
pass pass

View file

@ -112,6 +112,5 @@ class HeliconeLogger:
) )
print_verbose(f"Helicone Logging - Error {response.text}") print_verbose(f"Helicone Logging - Error {response.text}")
except: except:
# traceback.print_exc()
print_verbose(f"Helicone Logging Error - {traceback.format_exc()}") print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
pass pass

View file

@ -220,9 +220,11 @@ class LangFuseLogger:
verbose_logger.info(f"Langfuse Layer Logging - logging success") verbose_logger.info(f"Langfuse Layer Logging - logging success")
return {"trace_id": trace_id, "generation_id": generation_id} return {"trace_id": trace_id, "generation_id": generation_id}
except: except Exception as e:
traceback.print_exc() verbose_logger.error(
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}") "Langfuse Layer Error(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
return {"trace_id": None, "generation_id": None} return {"trace_id": None, "generation_id": None}
async def _async_log_event( async def _async_log_event(

View file

@ -44,7 +44,9 @@ class LangsmithLogger:
print_verbose( print_verbose(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}" f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
) )
langsmith_base_url = os.getenv("LANGSMITH_BASE_URL", "https://api.smith.langchain.com") langsmith_base_url = os.getenv(
"LANGSMITH_BASE_URL", "https://api.smith.langchain.com"
)
try: try:
print_verbose( print_verbose(
@ -89,9 +91,7 @@ class LangsmithLogger:
} }
url = f"{langsmith_base_url}/runs" url = f"{langsmith_base_url}/runs"
print_verbose( print_verbose(f"Langsmith Logging - About to send data to {url} ...")
f"Langsmith Logging - About to send data to {url} ..."
)
response = requests.post( response = requests.post(
url=url, url=url,
json=data, json=data,
@ -106,6 +106,5 @@ class LangsmithLogger:
f"Langsmith Layer Logging - final response object: {response_obj}" f"Langsmith Layer Logging - final response object: {response_obj}"
) )
except: except:
# traceback.print_exc()
print_verbose(f"Langsmith Layer Error - {traceback.format_exc()}") print_verbose(f"Langsmith Layer Error - {traceback.format_exc()}")
pass pass

View file

@ -171,7 +171,6 @@ class LogfireLogger:
f"Logfire Layer Logging - final response object: {response_obj}" f"Logfire Layer Logging - final response object: {response_obj}"
) )
except Exception as e: except Exception as e:
traceback.print_exc()
verbose_logger.debug( verbose_logger.debug(
f"Logfire Layer Error - {str(e)}\n{traceback.format_exc()}" f"Logfire Layer Error - {str(e)}\n{traceback.format_exc()}"
) )

View file

@ -14,6 +14,7 @@ def parse_usage(usage):
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0, "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
} }
def parse_tool_calls(tool_calls): def parse_tool_calls(tool_calls):
if tool_calls is None: if tool_calls is None:
return None return None
@ -26,7 +27,7 @@ def parse_tool_calls(tool_calls):
"function": { "function": {
"name": tool_call.function.name, "name": tool_call.function.name,
"arguments": tool_call.function.arguments, "arguments": tool_call.function.arguments,
} },
} }
return serialized return serialized
@ -176,6 +177,5 @@ class LunaryLogger:
) )
except: except:
# traceback.print_exc()
print_verbose(f"Lunary Logging Error - {traceback.format_exc()}") print_verbose(f"Lunary Logging Error - {traceback.format_exc()}")
pass pass

View file

@ -5,8 +5,11 @@ from dataclasses import dataclass
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
LITELLM_TRACER_NAME = "litellm"
LITELLM_RESOURCE = {"service.name": "litellm"} LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm")
LITELLM_RESOURCE = {
"service.name": os.getenv("OTEL_SERVICE_NAME", "litellm"),
}
@dataclass @dataclass

View file

@ -109,8 +109,8 @@ class PrometheusLogger:
end_user_id, user_api_key, model, user_api_team, user_id end_user_id, user_api_key, model, user_api_team, user_id
).inc() ).inc()
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
verbose_logger.debug( "prometheus Layer Error(): Exception occured - {}".format(str(e))
f"prometheus Layer Error - {str(e)}\n{traceback.format_exc()}"
) )
verbose_logger.debug(traceback.format_exc())
pass pass

View file

@ -180,6 +180,5 @@ class S3Logger:
print_verbose(f"s3 Layer Logging - final response object: {response_obj}") print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
return response return response
except Exception as e: except Exception as e:
traceback.print_exc()
verbose_logger.debug(f"s3 Layer Error - {str(e)}\n{traceback.format_exc()}") verbose_logger.debug(f"s3 Layer Error - {str(e)}\n{traceback.format_exc()}")
pass pass

View file

@ -110,6 +110,5 @@ class Supabase:
) )
except: except:
# traceback.print_exc()
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}") print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
pass pass

View file

@ -217,6 +217,5 @@ class WeightsBiasesLogger:
f"W&B Logging Logging - final response object: {response_obj}" f"W&B Logging Logging - final response object: {response_obj}"
) )
except: except:
# traceback.print_exc()
print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}") print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
pass pass

View file

@ -38,6 +38,8 @@ from .prompt_templates.factory import (
extract_between_tags, extract_between_tags,
parse_xml_params, parse_xml_params,
contains_tag, contains_tag,
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
) )
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM from .base import BaseLLM
@ -45,6 +47,11 @@ import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
from litellm.types.llms.bedrock import * from litellm.types.llms.bedrock import *
import urllib.parse import urllib.parse
from litellm.types.llms.openai import (
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
)
class AmazonCohereChatConfig: class AmazonCohereChatConfig:
@ -118,6 +125,8 @@ class AmazonCohereChatConfig:
"presence_penalty", "presence_penalty",
"seed", "seed",
"stop", "stop",
"tools",
"tool_choice",
] ]
def map_openai_params( def map_openai_params(
@ -176,6 +185,37 @@ async def make_call(
return completion_stream return completion_stream
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.read())
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
class BedrockLLM(BaseLLM): class BedrockLLM(BaseLLM):
""" """
Example call Example call
@ -1000,12 +1040,12 @@ class BedrockLLM(BaseLLM):
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout) timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout _params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore client = AsyncHTTPHandler(**_params) # type: ignore
else: else:
self.client = client # type: ignore client = client # type: ignore
try: try:
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore response = await client.post(api_base, headers=headers, data=data) # type: ignore
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code
@ -1069,6 +1109,745 @@ class BedrockLLM(BaseLLM):
return super().embedding(*args, **kwargs) return super().embedding(*args, **kwargs)
class AmazonConverseConfig:
"""
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
"""
maxTokens: Optional[int]
stopSequences: Optional[List[str]]
temperature: Optional[int]
topP: Optional[int]
def __init__(
self,
maxTokens: Optional[int] = None,
stopSequences: Optional[List[str]] = None,
temperature: Optional[int] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List[str]:
supported_params = [
"max_tokens",
"stream",
"stream_options",
"stop",
"temperature",
"top_p",
"extra_headers",
]
if (
model.startswith("anthropic")
or model.startswith("mistral")
or model.startswith("cohere")
):
supported_params.append("tools")
if model.startswith("anthropic") or model.startswith("mistral"):
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
supported_params.append("tool_choice")
return supported_params
def map_tool_choice_values(
self, model: str, tool_choice: Union[str, dict], drop_params: bool
) -> Optional[ToolChoiceValuesBlock]:
if tool_choice == "none":
if litellm.drop_params is True or drop_params is True:
return None
else:
raise litellm.utils.UnsupportedParamsError(
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
tool_choice
),
status_code=400,
)
elif tool_choice == "required":
return ToolChoiceValuesBlock(any={})
elif tool_choice == "auto":
return ToolChoiceValuesBlock(auto={})
elif isinstance(tool_choice, dict):
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
specific_tool = SpecificToolChoiceBlock(
name=tool_choice.get("function", {}).get("name", "")
)
return ToolChoiceValuesBlock(tool=specific_tool)
else:
raise litellm.utils.UnsupportedParamsError(
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
tool_choice
),
status_code=400,
)
def get_supported_image_types(self) -> List[str]:
return ["png", "jpeg", "gif", "webp"]
def map_openai_params(
self,
model: str,
non_default_params: dict,
optional_params: dict,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["maxTokens"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
value = [value]
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["topP"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice_value = self.map_tool_choice_values(
model=model, tool_choice=value, drop_params=drop_params # type: ignore
)
if _tool_choice_value is not None:
optional_params["tool_choice"] = _tool_choice_value
return optional_params
class BedrockConverseLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> Union[ModelResponse, CustomStreamWrapper]:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
except Exception as e:
raise BedrockError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
response.text, str(e)
),
status_code=422,
)
"""
Bedrock Response Object has optional message block
completion_response["output"].get("message", None)
A message block looks like this (Example 1):
"output": {
"message": {
"role": "assistant",
"content": [
{
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
}
]
}
},
(Example 2):
"output": {
"message": {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
"name": "top_song",
"input": {
"sign": "WZPZ"
}
}
}
]
}
}
"""
message: Optional[MessageBlock] = completion_response["output"]["message"]
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
if message is not None:
for content in message["content"]:
"""
- Content is either a tool response or text
"""
if "text" in content:
content_str += content["text"]
if "toolUse" in content:
_function_chunk = ChatCompletionToolCallFunctionChunk(
name=content["toolUse"]["name"],
arguments=json.dumps(content["toolUse"]["input"]),
)
_tool_response_chunk = ChatCompletionToolCallChunk(
id=content["toolUse"]["toolUseId"],
type="function",
function=_function_chunk,
)
tools.append(_tool_response_chunk)
chat_completion_message["content"] = content_str
chat_completion_message["tool_calls"] = tools
## CALCULATING USAGE - bedrock returns usage in the headers
input_tokens = completion_response["usage"]["inputTokens"]
output_tokens = completion_response["usage"]["outputTokens"]
total_tokens = completion_response["usage"]["totalTokens"]
model_response.choices = [
litellm.Choices(
finish_reason=map_finish_reason(completion_response["stopReason"]),
index=0,
message=litellm.Message(**chat_completion_message),
)
]
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=total_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def encode_model_id(self, model_id: str) -> str:
"""
Double encode the model ID to ensure it matches the expected double-encoded format.
Args:
model_id (str): The model ID to encode.
Returns:
str: The double-encoded model ID.
"""
return urllib.parse.quote(model_id, safe="")
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
) = params_to_check
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
session = boto3.Session(
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
return session.get_credentials()
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
from botocore.credentials import Credentials
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
return credentials
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
client=client,
api_base=api_base,
headers=headers,
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def completion(
self,
model: str,
messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
):
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## SETUP ##
stream = optional_params.pop("stream", None)
modelId = optional_params.pop("model_id", None)
if modelId is not None:
modelId = self.encode_model_id(model_id=modelId)
else:
modelId = model
provider = model.split(".")[0]
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
)
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
# Separate system prompt from rest of message
system_prompt_indices = []
system_content_blocks: List[SystemContentBlock] = []
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block = SystemContentBlock(text=message["content"])
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
inference_params = copy.deepcopy(optional_params)
additional_request_keys = []
additional_request_params = {}
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
supported_tool_call_params = ["tools", "tool_choice"]
## TRANSFORMATION ##
# send all model-specific params in 'additional_request_params'
for k, v in inference_params.items():
if (
k not in supported_converse_params
and k not in supported_tool_call_params
):
additional_request_params[k] = v
additional_request_keys.append(k)
for key in additional_request_keys:
inference_params.pop(key, None)
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", [])
)
bedrock_tool_config: Optional[ToolConfigBlock] = None
if len(bedrock_tools) > 0:
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
"tool_choice", None
)
bedrock_tool_config = ToolConfigBlock(
tools=bedrock_tools,
)
if tool_choice_values is not None:
bedrock_tool_config["toolChoice"] = tool_choice_values
_data: RequestObject = {
"messages": bedrock_messages,
"additionalModelRequestFields": additional_request_params,
"system": system_content_blocks,
"inferenceConfig": InferenceConfig(**inference_params),
}
if bedrock_tool_config is not None:
_data["toolConfig"] = bedrock_tool_config
data = json.dumps(_data)
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"headers": prepped.headers,
},
)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream is True and provider != "ai21":
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if (stream is not None and stream is True) and provider != "ai21":
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_sync_call,
client=None,
api_base=prepped.url,
headers=prepped.headers, # type: ignore
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=streaming_response,
additional_args={"complete_input_dict": data},
)
return streaming_response
### COMPLETION
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = HTTPHandler(**_params) # type: ignore
else:
client = client
try:
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
def get_response_stream_shape(): def get_response_stream_shape():
from botocore.model import ServiceModel from botocore.model import ServiceModel
from botocore.loaders import Loader from botocore.loaders import Loader
@ -1086,6 +1865,31 @@ class AWSEventStreamDecoder:
self.model = model self.model = model
self.parser = EventStreamJSONParser() self.parser = EventStreamJSONParser()
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
tool_str = ""
is_finished = False
finish_reason = ""
usage: Optional[ConverseTokenUsageBlock] = None
if "delta" in chunk_data:
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
if "text" in delta_obj:
text = delta_obj["text"]
elif "toolUse" in delta_obj:
tool_str = delta_obj["toolUse"]["input"]
elif "stopReason" in chunk_data:
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
elif "usage" in chunk_data:
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
response = GenericStreamingChunk(
text=text,
tool_str=tool_str,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
)
return response
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = "" text = ""
is_finished = False is_finished = False
@ -1098,19 +1902,8 @@ class AWSEventStreamDecoder:
is_finished = True is_finished = True
finish_reason = "stop" finish_reason = "stop"
######## bedrock.anthropic mappings ############### ######## bedrock.anthropic mappings ###############
elif "completion" in chunk_data: # not claude-3
text = chunk_data["completion"] # bedrock.anthropic
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
elif "delta" in chunk_data: elif "delta" in chunk_data:
if chunk_data["delta"].get("text", None) is not None: return self.converse_chunk_parser(chunk_data=chunk_data)
text = chunk_data["delta"]["text"]
stop_reason = chunk_data["delta"].get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
######## bedrock.mistral mappings ############### ######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data: elif "outputs" in chunk_data:
if ( if (
@ -1137,11 +1930,11 @@ class AWSEventStreamDecoder:
is_finished = True is_finished = True
finish_reason = chunk_data["completionReason"] finish_reason = chunk_data["completionReason"]
return GenericStreamingChunk( return GenericStreamingChunk(
**{ text=text,
"text": text, is_finished=is_finished,
"is_finished": is_finished, finish_reason=finish_reason,
"finish_reason": finish_reason, tool_str="",
} usage=None,
) )
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]: def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
@ -1178,9 +1971,14 @@ class AWSEventStreamDecoder:
parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200: if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}") raise ValueError(f"Bad response code, expected 200: {response_dict}")
if "chunk" in parsed_response:
chunk = parsed_response.get("chunk") chunk = parsed_response.get("chunk")
if not chunk: if not chunk:
return None return None
return chunk.get("bytes").decode() # type: ignore[no-any-return] return chunk.get("bytes").decode() # type: ignore[no-any-return]
else:
chunk = response_dict.get("body")
if not chunk:
return None
return chunk.decode() # type: ignore[no-any-return]

View file

@ -156,12 +156,13 @@ class HTTPHandler:
self, self,
url: str, url: str,
data: Optional[Union[dict, str]] = None, data: Optional[Union[dict, str]] = None,
json: Optional[Union[dict, str]] = None,
params: Optional[dict] = None, params: Optional[dict] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
stream: bool = False, stream: bool = False,
): ):
req = self.client.build_request( req = self.client.build_request(
"POST", url, data=data, params=params, headers=headers # type: ignore "POST", url, data=data, json=json, params=params, headers=headers # type: ignore
) )
response = self.client.send(req, stream=stream) response = self.client.send(req, stream=stream)
return response return response

View file

@ -1,13 +1,14 @@
import os, types, traceback, copy, asyncio import types
import json import traceback
from enum import Enum import copy
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
import sys, httpx import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
from packaging.version import Version from packaging.version import Version
from litellm import verbose_logger
class GeminiError(Exception): class GeminiError(Exception):
@ -264,7 +265,8 @@ def completion(
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
verbose_logger.debug(traceback.format_exc())
raise GeminiError( raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code message=traceback.format_exc(), status_code=response.status_code
) )
@ -356,7 +358,8 @@ async def async_completion(
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
verbose_logger.debug(traceback.format_exc())
raise GeminiError( raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code message=traceback.format_exc(), status_code=response.status_code
) )

View file

@ -7,6 +7,7 @@ import litellm
from litellm.types.utils import ProviderField from litellm.types.utils import ProviderField
import httpx, aiohttp, asyncio # type: ignore import httpx, aiohttp, asyncio # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm import verbose_logger
class OllamaError(Exception): class OllamaError(Exception):
@ -137,6 +138,7 @@ class OllamaConfig:
) )
] ]
def get_supported_openai_params( def get_supported_openai_params(
self, self,
): ):
@ -151,10 +153,12 @@ class OllamaConfig:
"response_format", "response_format",
] ]
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI # ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
# and convert to jpeg if necessary. # and convert to jpeg if necessary.
def _convert_image(image): def _convert_image(image):
import base64, io import base64, io
try: try:
from PIL import Image from PIL import Image
except: except:
@ -404,7 +408,13 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"LiteLLM.ollama.py::ollama_async_streaming(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
raise e raise e
@ -468,7 +478,12 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
) )
return model_response return model_response
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"LiteLLM.ollama.py::ollama_acompletion(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
raise e raise e

View file

@ -1,11 +1,15 @@
from itertools import chain from itertools import chain
import requests, types, time import requests
import json, uuid import types
import time
import json
import uuid
import traceback import traceback
from typing import Optional from typing import Optional
from litellm import verbose_logger
import litellm import litellm
import httpx, aiohttp, asyncio import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt import aiohttp
class OllamaError(Exception): class OllamaError(Exception):
@ -299,7 +303,10 @@ def get_ollama_response(
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])}, "function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function", "type": "function",
} }
], ],
@ -307,7 +314,9 @@ def get_ollama_response(
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"]["content"] = response_json["message"]["content"] model_response["choices"][0]["message"]["content"] = response_json["message"][
"content"
]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
@ -361,7 +370,10 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])}, "function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function", "type": "function",
} }
], ],
@ -412,7 +424,8 @@ async def ollama_async_streaming(
[ [
chunk.choices[0].delta.content chunk.choices[0].delta.content
async for chunk in streamwrapper async for chunk in streamwrapper
if chunk.choices[0].delta.content] if chunk.choices[0].delta.content
]
) )
function_call = json.loads(response_content) function_call = json.loads(response_content)
delta = litellm.utils.Delta( delta = litellm.utils.Delta(
@ -420,7 +433,10 @@ async def ollama_async_streaming(
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])}, "function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function", "type": "function",
} }
], ],
@ -433,7 +449,8 @@ async def ollama_async_streaming(
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error("LiteLLM.gemini(): Exception occured - {}".format(str(e)))
verbose_logger.debug(traceback.format_exc())
async def ollama_acompletion( async def ollama_acompletion(
@ -483,7 +500,10 @@ async def ollama_acompletion(
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])}, "function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function", "type": "function",
} }
], ],
@ -491,7 +511,9 @@ async def ollama_acompletion(
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"]["content"] = response_json["message"]["content"] model_response["choices"][0]["message"]["content"] = response_json[
"message"
]["content"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama_chat/" + data["model"] model_response["model"] = "ollama_chat/" + data["model"]
@ -509,5 +531,9 @@ async def ollama_acompletion(
) )
return model_response return model_response
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"LiteLLM.ollama_acompletion(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
raise e raise e

View file

@ -1,11 +1,12 @@
import os, types, traceback, copy import types
import json import traceback
from enum import Enum import copy
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
import sys, httpx import httpx
from litellm import verbose_logger
class PalmError(Exception): class PalmError(Exception):
@ -165,7 +166,10 @@ def completion(
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.llms.palm.py::completion(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
raise PalmError( raise PalmError(
message=traceback.format_exc(), status_code=response.status_code message=traceback.format_exc(), status_code=response.status_code
) )

View file

@ -3,14 +3,7 @@ import requests, traceback
import json, re, xml.etree.ElementTree as ET import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, meta, BaseLoader from jinja2 import Template, exceptions, meta, BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.sandbox import ImmutableSandboxedEnvironment
from typing import ( from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
Any,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
)
import litellm import litellm
import litellm.types import litellm.types
from litellm.types.completion import ( from litellm.types.completion import (
@ -24,7 +17,7 @@ from litellm.types.completion import (
import litellm.types.llms import litellm.types.llms
from litellm.types.llms.anthropic import * from litellm.types.llms.anthropic import *
import uuid import uuid
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
import litellm.types.llms.vertex_ai import litellm.types.llms.vertex_ai
@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url):
try: try:
from PIL import Image from PIL import Image
except: except:
raise Exception( raise Exception("image conversion failed please run `pip install Pillow`")
"gemini image conversion failed please run `pip install Pillow`"
)
from io import BytesIO from io import BytesIO
try: try:
@ -1613,6 +1604,380 @@ def azure_text_pt(messages: list):
return prompt return prompt
###### AMAZON BEDROCK #######
from litellm.types.llms.bedrock import (
ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultBlock as BedrockToolResultBlock,
ToolConfigBlock as BedrockToolConfigBlock,
ToolUseBlock as BedrockToolUseBlock,
ImageSourceBlock as BedrockImageSourceBlock,
ImageBlock as BedrockImageBlock,
ContentBlock as BedrockContentBlock,
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
ToolSpecBlock as BedrockToolSpecBlock,
ToolBlock as BedrockToolBlock,
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
)
def get_image_details(image_url) -> Tuple[str, str]:
try:
import base64
# Send a GET request to the image URL
response = requests.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors
# Check the response's content type to ensure it is an image
content_type = response.headers.get("content-type")
if not content_type or "image" not in content_type:
raise ValueError(
f"URL does not point to a valid image (content-type: {content_type})"
)
# Convert the image content to base64 bytes
base64_bytes = base64.b64encode(response.content).decode("utf-8")
# Get mime-type
mime_type = content_type.split("/")[
1
] # Extract mime-type from content-type header
return base64_bytes, mime_type
except requests.RequestException as e:
raise Exception(f"Request failed: {e}")
except Exception as e:
raise e
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
if "base64" in image_url:
# Case 1: Images with base64 encoding
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = image_url.split(",")
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
mime_type = mime_type_match.group(1)
image_format = mime_type.split("/")[1]
else:
mime_type = "image/jpeg"
image_format = "jpeg"
_blob = BedrockImageSourceBlock(bytes=img_without_base_64)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
if image_format in supported_image_formats:
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
elif "https:/" in image_url:
# Case 2: Images with direct links
image_bytes, image_format = get_image_details(image_url)
_blob = BedrockImageSourceBlock(bytes=image_bytes)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
if image_format in supported_image_formats:
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
else:
raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string - \
e.g. 'data:image/jpeg;base64,<base64-encoded-string>'"
)
def _convert_to_bedrock_tool_call_invoke(
tool_calls: list,
) -> List[BedrockContentBlock]:
"""
OpenAI tool invokes:
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
}
}
]
},
"""
"""
Bedrock tool invokes:
[
{
"role": "assistant",
"toolUse": {
"input": {"location": "Boston, MA", ..},
"name": "get_current_weather",
"toolUseId": "call_abc123"
}
}
]
"""
"""
- json.loads argument
- extract name
- extract id
"""
try:
_parts_list: List[BedrockContentBlock] = []
for tool in tool_calls:
if "function" in tool:
id = tool["id"]
name = tool["function"].get("name", "")
arguments = tool["function"].get("arguments", "")
arguments_dict = json.loads(arguments)
bedrock_tool = BedrockToolUseBlock(
input=arguments_dict, name=name, toolUseId=id
)
bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool)
_parts_list.append(bedrock_content_block)
return _parts_list
except Exception as e:
raise Exception(
"Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format(
tool_calls, str(e)
)
)
def _convert_to_bedrock_tool_call_result(
message: dict,
) -> BedrockMessageBlock:
"""
OpenAI message with a tool result looks like:
{
"tool_call_id": "tool_1",
"role": "tool",
"name": "get_current_weather",
"content": "function result goes here",
},
OpenAI message with a function call result looks like:
{
"role": "function",
"name": "get_current_weather",
"content": "function result goes here",
}
"""
"""
Bedrock result looks like this:
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q",
"content": [
{
"json": {
"song": "Elemental Hotel",
"artist": "8 Storey Hike"
}
}
]
}
}
]
}
"""
"""
-
"""
content = message.get("content", "")
name = message.get("name", "")
id = message.get("tool_call_id", str(uuid.uuid4()))
tool_result_content_block = BedrockToolResultContentBlock(text=content)
tool_result = BedrockToolResultBlock(
content=[tool_result_content_block],
toolUseId=id,
)
content_block = BedrockContentBlock(toolResult=tool_result)
return BedrockMessageBlock(role="user", content=[content_block])
def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]:
"""
Converts given messages from OpenAI format to Bedrock format
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
- Please ensure that function response turn comes immediately after a function call turn
"""
contents: List[BedrockMessageBlock] = []
msg_i = 0
while msg_i < len(messages):
user_content: List[BedrockContentBlock] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "user":
if isinstance(messages[msg_i]["content"], list):
_parts: List[BedrockContentBlock] = []
for element in messages[msg_i]["content"]:
if isinstance(element, dict):
if element["type"] == "text":
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
_parts.append(BedrockContentBlock(image=_part)) # type: ignore
user_content.extend(_parts)
else:
_part = BedrockContentBlock(text=messages[msg_i]["content"])
user_content.append(_part)
msg_i += 1
if user_content:
contents.append(BedrockMessageBlock(role="user", content=user_content))
assistant_content: List[BedrockContentBlock] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
if isinstance(messages[msg_i]["content"], list):
assistants_parts: List[BedrockContentBlock] = []
for element in messages[msg_i]["content"]:
if isinstance(element, dict):
if element["type"] == "text":
assistants_part = BedrockContentBlock(text=element["text"])
assistants_parts.append(assistants_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
assistants_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
assistants_parts.append(
BedrockContentBlock(image=assistants_part) # type: ignore
)
assistant_content.extend(assistants_parts)
elif messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion
assistant_content.extend(
_convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"])
)
else:
assistant_text = (
messages[msg_i].get("content") or ""
) # either string or none
if assistant_text:
assistant_content.append(BedrockContentBlock(text=assistant_text))
msg_i += 1
if assistant_content:
contents.append(
BedrockMessageBlock(role="assistant", content=assistant_content)
)
## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
contents.append(tool_call_result)
msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
return contents
def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
"""
OpenAI tools looks like:
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
}
]
"""
"""
Bedrock toolConfig looks like:
"tools": [
{
"toolSpec": {
"name": "top_song",
"description": "Get the most popular song played on a radio station.",
"inputSchema": {
"json": {
"type": "object",
"properties": {
"sign": {
"type": "string",
"description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP."
}
},
"required": [
"sign"
]
}
}
}
}
]
"""
tool_block_list: List[BedrockToolBlock] = []
for tool in tools:
parameters = tool.get("function", {}).get("parameters", None)
name = tool.get("function", {}).get("name", "")
description = tool.get("function", {}).get("description", "")
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
tool_spec = BedrockToolSpecBlock(
inputSchema=tool_input_schema, name=name, description=description
)
tool_block = BedrockToolBlock(toolSpec=tool_spec)
tool_block_list.append(tool_block)
return tool_block_list
# Function call template # Function call template
def function_call_prompt(messages: list, functions: list): def function_call_prompt(messages: list, functions: list):
function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:""" function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:"""

View file

@ -647,9 +647,9 @@ def completion(
prompt = " ".join( prompt = " ".join(
[ [
message["content"] message.get("content")
for message in messages for message in messages
if isinstance(message["content"], str) if isinstance(message.get("content", None), str)
] ]
) )

View file

@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
@ -122,6 +122,7 @@ huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion() predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion() triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM() bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -364,7 +365,10 @@ async def acompletion(
) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls) ) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
return response return response
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.acompletion(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
custom_llm_provider = custom_llm_provider or "openai" custom_llm_provider = custom_llm_provider or "openai"
raise exception_type( raise exception_type(
model=model, model=model,
@ -477,7 +481,10 @@ def mock_completion(
except Exception as e: except Exception as e:
if isinstance(e, openai.APIError): if isinstance(e, openai.APIError):
raise e raise e
traceback.print_exc() verbose_logger.error(
"litellm.mock_completion(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
raise Exception("Mock completion response failed") raise Exception("Mock completion response failed")
@ -2096,6 +2103,24 @@ def completion(
custom_llm_provider="bedrock", custom_llm_provider="bedrock",
logging_obj=logging, logging_obj=logging,
) )
else:
if model.startswith("anthropic"):
response = bedrock_converse_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
client=client,
)
else: else:
response = bedrock_chat_completion.completion( response = bedrock_chat_completion.completion(
model=model, model=model,
@ -4430,7 +4455,10 @@ async def ahealth_check(
response = {} # args like remaining ratelimit etc. response = {} # args like remaining ratelimit etc.
return response return response
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.ahealth_check(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
stack_trace = traceback.format_exc() stack_trace = traceback.format_exc()
if isinstance(stack_trace, str): if isinstance(stack_trace, str):
stack_trace = stack_trace[:1000] stack_trace = stack_trace[:1000]

View file

@ -1,6 +1,7 @@
import json import json
import logging import logging
from logging import Formatter from logging import Formatter
import sys
class JsonFormatter(Formatter): class JsonFormatter(Formatter):

View file

@ -56,8 +56,10 @@ router_settings:
litellm_settings: litellm_settings:
success_callback: ["langfuse"] success_callback: ["langfuse"]
json_logs: true
general_settings: general_settings:
alerting: ["email"] alerting: ["email"]
key_management_system: "aws_kms"
key_management_settings:
hosted_keys: ["LITELLM_MASTER_KEY"]

View file

@ -946,6 +946,7 @@ class KeyManagementSystem(enum.Enum):
AZURE_KEY_VAULT = "azure_key_vault" AZURE_KEY_VAULT = "azure_key_vault"
AWS_SECRET_MANAGER = "aws_secret_manager" AWS_SECRET_MANAGER = "aws_secret_manager"
LOCAL = "local" LOCAL = "local"
AWS_KMS = "aws_kms"
class KeyManagementSettings(LiteLLMBase): class KeyManagementSettings(LiteLLMBase):

View file

@ -97,6 +97,40 @@ def common_checks(
raise Exception( raise Exception(
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
) )
if general_settings.get("enforced_params", None) is not None:
# Enterprise ONLY Feature
# we already validate if user is premium_user when reading the config
# Add an extra premium_usercheck here too, just incase
from litellm.proxy.proxy_server import premium_user, CommonProxyErrors
if premium_user is not True:
raise ValueError(
"Trying to use `enforced_params`"
+ CommonProxyErrors.not_premium_user.value
)
if route in LiteLLMRoutes.openai_routes.value:
# loop through each enforced param
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
for enforced_param in general_settings["enforced_params"]:
_enforced_params = enforced_param.split(".")
if len(_enforced_params) == 1:
if _enforced_params[0] not in request_body:
raise ValueError(
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
)
elif len(_enforced_params) == 2:
# this is a scenario where user requires request['metadata']['generation_name'] to exist
if _enforced_params[0] not in request_body:
raise ValueError(
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
)
if _enforced_params[1] not in request_body[_enforced_params[0]]:
raise ValueError(
f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param"
)
pass
# 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget # 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
if ( if (
litellm.max_budget > 0 litellm.max_budget > 0

View file

@ -88,7 +88,7 @@ class _PROXY_AzureContentSafety(
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Error in Azure Content-Safety: %s", traceback.format_exc() "Error in Azure Content-Safety: %s", traceback.format_exc()
) )
traceback.print_exc() verbose_proxy_logger.debug(traceback.format_exc())
raise raise
result = self._compute_result(response) result = self._compute_result(response)
@ -123,7 +123,12 @@ class _PROXY_AzureContentSafety(
except HTTPException as e: except HTTPException as e:
raise e raise e
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,

View file

@ -94,7 +94,12 @@ class _PROXY_BatchRedisRequests(CustomLogger):
except HTTPException as e: except HTTPException as e:
raise e raise e
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_get_cache(self, *args, **kwargs): async def async_get_cache(self, *args, **kwargs):
""" """

View file

@ -1,13 +1,13 @@
# What this does? # What this does?
## Checks if key is allowed to use the cache controls passed in to the completion() call ## Checks if key is allowed to use the cache controls passed in to the completion() call
from typing import Optional
import litellm import litellm
from litellm import verbose_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
import json, traceback import traceback
class _PROXY_CacheControlCheck(CustomLogger): class _PROXY_CacheControlCheck(CustomLogger):
@ -54,4 +54,9 @@ class _PROXY_CacheControlCheck(CustomLogger):
except HTTPException as e: except HTTPException as e:
raise e raise e
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.proxy.hooks.cache_control_check.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())

View file

@ -1,10 +1,10 @@
from typing import Optional from litellm import verbose_logger
import litellm import litellm
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
import json, traceback import traceback
class _PROXY_MaxBudgetLimiter(CustomLogger): class _PROXY_MaxBudgetLimiter(CustomLogger):
@ -44,4 +44,9 @@ class _PROXY_MaxBudgetLimiter(CustomLogger):
except HTTPException as e: except HTTPException as e:
raise e raise e
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())

View file

@ -8,8 +8,8 @@
# Tell us how we can improve! - Krrish & Ishaan # Tell us how we can improve! - Krrish & Ishaan
from typing import Optional, Literal, Union from typing import Optional, Union
import litellm, traceback, sys, uuid, json import litellm, traceback, uuid, json # noqa: E401
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -21,8 +21,8 @@ from litellm.utils import (
ImageResponse, ImageResponse,
StreamingChoices, StreamingChoices,
) )
from datetime import datetime import aiohttp
import aiohttp, asyncio import asyncio
class _OPTIONAL_PresidioPIIMasking(CustomLogger): class _OPTIONAL_PresidioPIIMasking(CustomLogger):
@ -138,7 +138,12 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
else: else:
raise Exception(f"Invalid anonymizer response: {redacted_text}") raise Exception(f"Invalid anonymizer response: {redacted_text}")
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e raise e
async def async_pre_call_hook( async def async_pre_call_hook(

View file

@ -204,7 +204,12 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
return e.detail["error"] return e.detail["error"]
raise e raise e
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_moderation_hook( async def async_moderation_hook(
self, self,

View file

@ -27,3 +27,8 @@ litellm_settings:
callbacks: ["otel"] callbacks: ["otel"]
store_audit_logs: true store_audit_logs: true
redact_messages_in_exceptions: True redact_messages_in_exceptions: True
enforced_params:
- user
- metadata
- metadata.generation_name

View file

@ -103,6 +103,7 @@ from litellm.proxy.utils import (
update_spend, update_spend,
encrypt_value, encrypt_value,
decrypt_value, decrypt_value,
get_error_message_str,
) )
from litellm import ( from litellm import (
CreateBatchRequest, CreateBatchRequest,
@ -112,7 +113,10 @@ from litellm import (
CreateFileRequest, CreateFileRequest,
) )
from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_secret_manager,
load_aws_kms,
)
import pydantic import pydantic
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache from litellm.caching import DualCache, RedisCache
@ -125,7 +129,10 @@ from litellm.router import (
AssistantsTypedDict, AssistantsTypedDict,
) )
from litellm.router import ModelInfo as RouterModelInfo from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm._logging import (
verbose_router_logger,
verbose_proxy_logger,
)
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck from litellm.proxy.auth.litellm_license import LicenseCheck
from litellm.proxy.auth.model_checks import ( from litellm.proxy.auth.model_checks import (
@ -1471,7 +1478,12 @@ async def user_api_key_auth(
else: else:
raise Exception() raise Exception()
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, litellm.BudgetExceededError): if isinstance(e, litellm.BudgetExceededError):
raise ProxyException( raise ProxyException(
message=e.message, type="auth_error", param=None, code=400 message=e.message, type="auth_error", param=None, code=400
@ -2736,10 +2748,12 @@ class ProxyConfig:
load_google_kms(use_google_kms=True) load_google_kms(use_google_kms=True)
elif ( elif (
key_management_system key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
): ):
### LOAD FROM AWS SECRET MANAGER ### ### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True) load_aws_secret_manager(use_aws_secret_manager=True)
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
else: else:
raise ValueError("Invalid Key Management System selected") raise ValueError("Invalid Key Management System selected")
key_management_settings = general_settings.get( key_management_settings = general_settings.get(
@ -2773,6 +2787,7 @@ class ProxyConfig:
master_key = general_settings.get( master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) "master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
) )
if master_key and master_key.startswith("os.environ/"): if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key) master_key = litellm.get_secret(master_key)
if not isinstance(master_key, str): if not isinstance(master_key, str):
@ -2863,6 +2878,16 @@ class ProxyConfig:
) )
health_check_interval = general_settings.get("health_check_interval", 300) health_check_interval = general_settings.get("health_check_interval", 300)
## check if user has set a premium feature in general_settings
if (
general_settings.get("enforced_params") is not None
and premium_user is not True
):
raise ValueError(
"Trying to use `enforced_params`"
+ CommonProxyErrors.not_premium_user.value
)
router_params: dict = { router_params: dict = {
"cache_responses": litellm.cache "cache_responses": litellm.cache
!= None, # cache if user passed in cache values != None, # cache if user passed in cache values
@ -3476,7 +3501,12 @@ async def generate_key_helper_fn(
) )
key_data["token_id"] = getattr(create_key_response, "token", None) key_data["token_id"] = getattr(create_key_response, "token", None)
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.generate_key_helper_fn(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise e raise e
raise HTTPException( raise HTTPException(
@ -3515,7 +3545,12 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
else: else:
raise Exception("DB not connected. prisma_client is None") raise Exception("DB not connected. prisma_client is None")
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.delete_verification_token(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e raise e
return deleted_tokens return deleted_tokens
@ -3676,7 +3711,12 @@ async def async_assistants_data_generator(
done_message = "[DONE]" done_message = "[DONE]"
yield f"data: {done_message}\n\n" yield f"data: {done_message}\n\n"
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.async_assistants_data_generator(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
original_exception=e, original_exception=e,
@ -3686,9 +3726,6 @@ async def async_assistants_data_generator(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
) )
router_model_names = llm_router.model_names if llm_router is not None else [] router_model_names = llm_router.model_names if llm_router is not None else []
if user_debug:
traceback.print_exc()
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise e raise e
else: else:
@ -3728,7 +3765,12 @@ async def async_data_generator(
done_message = "[DONE]" done_message = "[DONE]"
yield f"data: {done_message}\n\n" yield f"data: {done_message}\n\n"
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
original_exception=e, original_exception=e,
@ -3738,8 +3780,6 @@ async def async_data_generator(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
) )
router_model_names = llm_router.model_names if llm_router is not None else [] router_model_names = llm_router.model_names if llm_router is not None else []
if user_debug:
traceback.print_exc()
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise e raise e
@ -3800,6 +3840,18 @@ def on_backoff(details):
verbose_proxy_logger.debug("Backing off... this was attempt # %s", details["tries"]) verbose_proxy_logger.debug("Backing off... this was attempt # %s", details["tries"])
def giveup(e):
result = not (
isinstance(e, ProxyException)
and getattr(e, "message", None) is not None
and isinstance(e.message, str)
and "Max parallel request limit reached" in e.message
)
if result:
verbose_proxy_logger.info(json.dumps({"event": "giveup", "exception": str(e)}))
return result
@router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db
@ -4084,12 +4136,8 @@ def model_list(
max_tries=litellm.num_retries or 3, # maximum number of retries max_tries=litellm.num_retries or 3, # maximum number of retries
max_time=litellm.request_timeout or 60, # maximum total time to retry for max_time=litellm.request_timeout or 60, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
giveup=lambda e: not ( giveup=giveup,
isinstance(e, ProxyException) logger=verbose_proxy_logger,
and getattr(e, "message", None) is not None
and isinstance(e.message, str)
and "Max parallel request limit reached" in e.message
), # the result of the logical expression is on the second position
) )
async def chat_completion( async def chat_completion(
request: Request, request: Request,
@ -4098,6 +4146,7 @@ async def chat_completion(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
global general_settings, user_debug, proxy_logging_obj, llm_model_list global general_settings, user_debug, proxy_logging_obj, llm_model_list
data = {} data = {}
try: try:
body = await request.body() body = await request.body()
@ -4386,7 +4435,12 @@ async def chat_completion(
return _chat_response return _chat_response
except Exception as e: except Exception as e:
data["litellm_status"] = "fail" # used for alerting data["litellm_status"] = "fail" # used for alerting
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}".format(
get_error_message_str(e=e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
@ -4397,8 +4451,6 @@ async def chat_completion(
litellm_debug_info, litellm_debug_info,
) )
router_model_names = llm_router.model_names if llm_router is not None else [] router_model_names = llm_router.model_names if llm_router is not None else []
if user_debug:
traceback.print_exc()
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
@ -4630,15 +4682,12 @@ async def completion(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
verbose_proxy_logger.debug("EXCEPTION RAISED IN PROXY MAIN.PY") verbose_proxy_logger.error(
litellm_debug_info = getattr(e, "litellm_debug_info", "") "litellm.proxy.proxy_server.completion(): Exception occured - {}".format(
verbose_proxy_logger.debug( str(e)
"\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
e,
litellm_debug_info,
) )
traceback.print_exc() )
error_traceback = traceback.format_exc() verbose_proxy_logger.debug(traceback.format_exc())
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
@ -4848,7 +4897,12 @@ async def embeddings(
e, e,
litellm_debug_info, litellm_debug_info,
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.embeddings(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e)), message=getattr(e, "message", str(e)),
@ -5027,7 +5081,12 @@ async def image_generation(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.image_generation(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e)), message=getattr(e, "message", str(e)),
@ -5205,7 +5264,12 @@ async def audio_speech(
) )
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.audio_speech(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e raise e
@ -5394,7 +5458,12 @@ async def audio_transcriptions(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.audio_transcription(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e.detail)), message=getattr(e, "message", str(e.detail)),
@ -5403,7 +5472,6 @@ async def audio_transcriptions(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
) )
else: else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
@ -5531,7 +5599,12 @@ async def get_assistants(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.get_assistants(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e.detail)), message=getattr(e, "message", str(e.detail)),
@ -5540,7 +5613,6 @@ async def get_assistants(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
) )
else: else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
@ -5660,7 +5732,12 @@ async def create_threads(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.create_threads(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e.detail)), message=getattr(e, "message", str(e.detail)),
@ -5669,7 +5746,6 @@ async def create_threads(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
) )
else: else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
@ -5788,7 +5864,12 @@ async def get_thread(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.get_thread(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e.detail)), message=getattr(e, "message", str(e.detail)),
@ -5797,7 +5878,6 @@ async def get_thread(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
) )
else: else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
@ -5919,7 +5999,12 @@ async def add_messages(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.add_messages(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e.detail)), message=getattr(e, "message", str(e.detail)),
@ -5928,7 +6013,6 @@ async def add_messages(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
) )
else: else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
@ -6046,7 +6130,12 @@ async def get_messages(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.get_messages(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e.detail)), message=getattr(e, "message", str(e.detail)),
@ -6055,7 +6144,6 @@ async def get_messages(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
) )
else: else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
@ -6187,7 +6275,12 @@ async def run_thread(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.run_thread(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e.detail)), message=getattr(e, "message", str(e.detail)),
@ -6196,7 +6289,6 @@ async def run_thread(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
) )
else: else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
@ -6335,7 +6427,12 @@ async def create_batch(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e.detail)), message=getattr(e, "message", str(e.detail)),
@ -6344,7 +6441,6 @@ async def create_batch(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
) )
else: else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
@ -6478,7 +6574,12 @@ async def retrieve_batch(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e.detail)), message=getattr(e, "message", str(e.detail)),
@ -6631,7 +6732,12 @@ async def create_file(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.create_file(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e.detail)), message=getattr(e, "message", str(e.detail)),
@ -6640,7 +6746,6 @@ async def create_file(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
) )
else: else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
@ -6816,7 +6921,12 @@ async def moderations(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.moderations(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e)), message=getattr(e, "message", str(e)),
@ -6825,7 +6935,6 @@ async def moderations(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
) )
else: else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
@ -7136,7 +7245,12 @@ async def generate_key_fn(
return GenerateKeyResponse(**response) return GenerateKeyResponse(**response)
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.generate_key_fn(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"), message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -9591,7 +9705,12 @@ async def user_info(
} }
return response_data return response_data
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_info(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"), message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -9686,7 +9805,12 @@ async def user_update(data: UpdateUserRequest):
return response return response
# update based on remaining passed in values # update based on remaining passed in values
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_update(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"), message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -9739,7 +9863,12 @@ async def user_request_model(request: Request):
return {"status": "success"} return {"status": "success"}
# update based on remaining passed in values # update based on remaining passed in values
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_request_model(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"), message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -9781,7 +9910,12 @@ async def user_get_requests():
return {"requests": response} return {"requests": response}
# update based on remaining passed in values # update based on remaining passed in values
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_get_requests(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"), message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -10171,7 +10305,12 @@ async def update_end_user(
# update based on remaining passed in values # update based on remaining passed in values
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.update_end_user(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Internal Server Error({str(e)})"), message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
@ -10255,7 +10394,12 @@ async def delete_end_user(
# update based on remaining passed in values # update based on remaining passed in values
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.delete_end_user(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Internal Server Error({str(e)})"), message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
@ -11558,7 +11702,12 @@ async def add_new_model(
return model_response return model_response
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.add_new_model(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"), message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -11672,7 +11821,12 @@ async def update_model(
return model_response return model_response
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.update_model(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"), message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -13906,7 +14060,12 @@ async def update_config(config_info: ConfigYAML):
return {"message": "Config updated successfully"} return {"message": "Config updated successfully"}
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.update_config(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"), message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -14379,7 +14538,12 @@ async def get_config():
"available_callbacks": all_available_callbacks, "available_callbacks": all_available_callbacks,
} }
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.get_config(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"), message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -14630,7 +14794,12 @@ async def health_services_endpoint(
} }
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.health_services_endpoint(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"), message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -14709,7 +14878,12 @@ async def health_endpoint(
"unhealthy_count": len(unhealthy_endpoints), "unhealthy_count": len(unhealthy_endpoints),
} }
except Exception as e: except Exception as e:
traceback.print_exc() verbose_proxy_logger.error(
"litellm.proxy.proxy_server.py::health_endpoint(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e raise e

View file

@ -8,7 +8,8 @@ Requires:
* `pip install boto3>=1.28.57` * `pip install boto3>=1.28.57`
""" """
import litellm, os import litellm
import os
from typing import Optional from typing import Optional
from litellm.proxy._types import KeyManagementSystem from litellm.proxy._types import KeyManagementSystem
@ -38,3 +39,21 @@ def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
except Exception as e: except Exception as e:
raise e raise e
def load_aws_kms(use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False:
return
try:
import boto3
validate_environment()
# Create a Secrets Manager client
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
litellm.secret_manager_client = kms_client
litellm._key_management_system = KeyManagementSystem.AWS_KMS
except Exception as e:
raise e

View file

@ -2837,3 +2837,17 @@ missing_keys_html_form = """
</body> </body>
</html> </html>
""" """
def get_error_message_str(e: Exception) -> str:
error_message = ""
if isinstance(e, HTTPException):
if isinstance(e.detail, str):
error_message = e.detail
elif isinstance(e.detail, dict):
error_message = json.dumps(e.detail)
else:
error_message = str(e)
else:
error_message = str(e)
return error_message

2
litellm/py.typed Normal file
View file

@ -0,0 +1,2 @@
# Marker file to instruct type checkers to look for inline type annotations in this package.
# See PEP 561 for more information.

View file

@ -220,8 +220,6 @@ class Router:
[] []
) # names of models under litellm_params. ex. azure/chatgpt-v-2 ) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {} self.deployment_latency_map = {}
### SCHEDULER ###
self.scheduler = Scheduler(polling_interval=polling_interval)
### CACHING ### ### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
redis_cache = None redis_cache = None
@ -259,6 +257,10 @@ class Router:
redis_cache=redis_cache, in_memory_cache=InMemoryCache() redis_cache=redis_cache, in_memory_cache=InMemoryCache()
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. ) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
### SCHEDULER ###
self.scheduler = Scheduler(
polling_interval=polling_interval, redis_cache=redis_cache
)
self.default_deployment = None # use this to track the users default deployment, when they want to use model = * self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
self.default_max_parallel_requests = default_max_parallel_requests self.default_max_parallel_requests = default_max_parallel_requests
@ -2096,8 +2098,8 @@ class Router:
except Exception as e: except Exception as e:
raise e raise e
except Exception as e: except Exception as e:
verbose_router_logger.debug(f"An exception occurred - {str(e)}") verbose_router_logger.error(f"An exception occurred - {str(e)}")
traceback.print_exc() verbose_router_logger.debug(traceback.format_exc())
raise original_exception raise original_exception
async def async_function_with_retries(self, *args, **kwargs): async def async_function_with_retries(self, *args, **kwargs):
@ -4048,6 +4050,12 @@ class Router:
for idx in reversed(invalid_model_indices): for idx in reversed(invalid_model_indices):
_returned_deployments.pop(idx) _returned_deployments.pop(idx)
## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2)
if len(_returned_deployments) > 0:
_returned_deployments = litellm.utils._get_order_filtered_deployments(
_returned_deployments
)
return _returned_deployments return _returned_deployments
def _common_checks_available_deployment( def _common_checks_available_deployment(

View file

@ -1,11 +1,9 @@
#### What this does #### #### What this does ####
# picks based on response time (for streaming, this is time to first token) # picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel
import os, requests, random # type: ignore
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import random from litellm import verbose_logger
import traceback import traceback
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -119,7 +117,12 @@ class LowestCostLoggingHandler(CustomLogger):
if self.test_flag: if self.test_flag:
self.logged_success += 1 self.logged_success += 1
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -201,7 +204,12 @@ class LowestCostLoggingHandler(CustomLogger):
if self.test_flag: if self.test_flag:
self.logged_success += 1 self.logged_success += 1
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass pass
async def async_get_available_deployments( async def async_get_available_deployments(

View file

@ -1,16 +1,16 @@
#### What this does #### #### What this does ####
# picks based on response time (for streaming, this is time to first token) # picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator # type: ignore from pydantic import BaseModel
import dotenv, os, requests, random # type: ignore import random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import random
import traceback import traceback
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm import ModelResponse from litellm import ModelResponse
from litellm import token_counter from litellm import token_counter
import litellm import litellm
from litellm import verbose_logger
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
@ -165,7 +165,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
if self.test_flag: if self.test_flag:
self.logged_success += 1 self.logged_success += 1
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
@ -229,7 +234,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
# do nothing if it's not a timeout error # do nothing if it's not a timeout error
return return
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -352,7 +362,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
if self.test_flag: if self.test_flag:
self.logged_success += 1 self.logged_success += 1
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.router_strategy.lowest_latency.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass pass
def get_available_deployments( def get_available_deployments(

View file

@ -11,6 +11,7 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose from litellm.utils import print_verbose
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
""" """
Implements default functions, all pydantic objects should have. Implements default functions, all pydantic objects should have.
@ -23,16 +24,20 @@ class LiteLLMBase(BaseModel):
# if using pydantic v1 # if using pydantic v1
return self.dict() return self.dict()
class RoutingArgs(LiteLLMBase): class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key) ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler(CustomLogger): class LowestTPMLoggingHandler(CustomLogger):
test_flag: bool = False test_flag: bool = False
logged_success: int = 0 logged_success: int = 0
logged_failure: int = 0 logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}): def __init__(
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
):
self.router_cache = router_cache self.router_cache = router_cache
self.model_list = model_list self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args) self.routing_args = RoutingArgs(**routing_args)
@ -72,19 +77,28 @@ class LowestTPMLoggingHandler(CustomLogger):
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl) self.router_cache.set_cache(
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
## RPM ## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1 request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl) self.router_cache.set_cache(
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:
self.logged_success += 1 self.logged_success += 1
except Exception as e: except Exception as e:
traceback.print_exc() verbose_router_logger.error(
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
verbose_router_logger.debug(traceback.format_exc())
pass pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -123,19 +137,28 @@ class LowestTPMLoggingHandler(CustomLogger):
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl) self.router_cache.set_cache(
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
## RPM ## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1 request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl) self.router_cache.set_cache(
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:
self.logged_success += 1 self.logged_success += 1
except Exception as e: except Exception as e:
traceback.print_exc() verbose_router_logger.error(
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
verbose_router_logger.debug(traceback.format_exc())
pass pass
def get_available_deployments( def get_available_deployments(

View file

@ -1,19 +1,19 @@
#### What this does #### #### What this does ####
# identifies lowest tpm deployment # identifies lowest tpm deployment
from pydantic import BaseModel from pydantic import BaseModel
import dotenv, os, requests, random import random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
import datetime as datetime_og import traceback
from datetime import datetime import httpx
import traceback, asyncio, httpx
import litellm import litellm
from litellm import token_counter from litellm import token_counter
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger, verbose_logger
from litellm.utils import print_verbose, get_utc_datetime from litellm.utils import print_verbose, get_utc_datetime
from litellm.types.router import RouterErrors from litellm.types.router import RouterErrors
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
""" """
Implements default functions, all pydantic objects should have. Implements default functions, all pydantic objects should have.
@ -22,13 +22,15 @@ class LiteLLMBase(BaseModel):
def json(self, **kwargs): def json(self, **kwargs):
try: try:
return self.model_dump() # noqa return self.model_dump() # noqa
except: except Exception as e:
# if using pydantic v1 # if using pydantic v1
return self.dict() return self.dict()
class RoutingArgs(LiteLLMBase): class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key) ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler_v2(CustomLogger): class LowestTPMLoggingHandler_v2(CustomLogger):
""" """
Updated version of TPM/RPM Logging. Updated version of TPM/RPM Logging.
@ -47,7 +49,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
logged_failure: int = 0 logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}): def __init__(
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
):
self.router_cache = router_cache self.router_cache = router_cache
self.model_list = model_list self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args) self.routing_args = RoutingArgs(**routing_args)
@ -104,7 +108,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
) )
else: else:
# if local result below limit, check redis ## prevent unnecessary redis checks # if local result below limit, check redis ## prevent unnecessary redis checks
result = self.router_cache.increment_cache(key=rpm_key, value=1, ttl=self.routing_args.ttl) result = self.router_cache.increment_cache(
key=rpm_key, value=1, ttl=self.routing_args.ttl
)
if result is not None and result > deployment_rpm: if result is not None and result > deployment_rpm:
raise litellm.RateLimitError( raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format( message="Deployment over defined rpm limit={}. current usage={}".format(
@ -244,12 +250,19 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# update cache # update cache
## TPM ## TPM
self.router_cache.increment_cache(key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl) self.router_cache.increment_cache(
key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
)
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:
self.logged_success += 1 self.logged_success += 1
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -295,7 +308,12 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
if self.test_flag: if self.test_flag:
self.logged_success += 1 self.logged_success += 1
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass pass
def _common_checks_available_deployment( def _common_checks_available_deployment(

View file

@ -1,13 +1,14 @@
import heapq, time import heapq
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
import enum import enum
from litellm.caching import DualCache from litellm.caching import DualCache, RedisCache
from litellm import print_verbose from litellm import print_verbose
class SchedulerCacheKeys(enum.Enum): class SchedulerCacheKeys(enum.Enum):
queue = "scheduler:queue" queue = "scheduler:queue"
default_in_memory_ttl = 5 # cache queue in-memory for 5s when redis cache available
class DefaultPriorities(enum.Enum): class DefaultPriorities(enum.Enum):
@ -25,18 +26,24 @@ class FlowItem(BaseModel):
class Scheduler: class Scheduler:
cache: DualCache cache: DualCache
def __init__(self, polling_interval: Optional[float] = None): def __init__(
self,
polling_interval: Optional[float] = None,
redis_cache: Optional[RedisCache] = None,
):
""" """
polling_interval: float or null - frequency of polling queue. Default is 3ms. polling_interval: float or null - frequency of polling queue. Default is 3ms.
""" """
self.queue: list = [] self.queue: list = []
self.cache = DualCache() default_in_memory_ttl: Optional[float] = None
if redis_cache is not None:
# if redis-cache available frequently poll that instead of using in-memory.
default_in_memory_ttl = SchedulerCacheKeys.default_in_memory_ttl.value
self.cache = DualCache(
redis_cache=redis_cache, default_in_memory_ttl=default_in_memory_ttl
)
self.polling_interval = polling_interval or 0.03 # default to 3ms self.polling_interval = polling_interval or 0.03 # default to 3ms
def update_variables(self, cache: Optional[DualCache] = None):
if cache is not None:
self.cache = cache
async def add_request(self, request: FlowItem): async def add_request(self, request: FlowItem):
# We use the priority directly, as lower values indicate higher priority # We use the priority directly, as lower values indicate higher priority
# get the queue # get the queue

File diff suppressed because it is too large Load diff

View file

@ -198,7 +198,11 @@ async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
) )
assert isinstance(messages.data[0], Message) assert isinstance(messages.data[0], Message)
else: else:
pytest.fail("An unexpected error occurred when running the thread") pytest.fail(
"An unexpected error occurred when running the thread, {}".format(
run
)
)
else: else:
added_message = await litellm.a_add_message(**data) added_message = await litellm.a_add_message(**data)
@ -226,4 +230,8 @@ async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
) )
assert isinstance(messages.data[0], Message) assert isinstance(messages.data[0], Message)
else: else:
pytest.fail("An unexpected error occurred when running the thread") pytest.fail(
"An unexpected error occurred when running the thread, {}".format(
run
)
)

View file

@ -243,6 +243,7 @@ def test_completion_bedrock_claude_sts_oidc_auth():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.skipif( @pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None, os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
reason="Cannot run without being in CircleCI Runner", reason="Cannot run without being in CircleCI Runner",
@ -277,7 +278,15 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_bedrock_claude_3():
@pytest.mark.parametrize(
"image_url",
[
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAL0AAAC9CAMAAADRCYwCAAAAh1BMVEX///8AAAD8/Pz5+fkEBAT39/cJCQn09PRNTU3y8vIMDAwzMzPe3t7v7+8QEBCOjo7FxcXR0dHn5+elpaWGhoYYGBivr686OjocHBy0tLQtLS1TU1PY2Ni6urpaWlpERER3d3ecnJxoaGiUlJRiYmIlJSU4ODhBQUFycnKAgIDBwcFnZ2chISE7EjuwAAAI/UlEQVR4nO1caXfiOgz1bhJIyAJhX1JoSzv8/9/3LNlpYd4rhX6o4/N8Z2lKM2cURZau5JsQEhERERERERERERERERERERHx/wBjhDPC3OGN8+Cc5JeMuheaETSdO8vZFyCScHtmz2CsktoeMn7rLM1u3h0PMAEhyYX7v/Q9wQvoGdB0hlbzm45lEq/wd6y6G9aezvBk9AXwp1r3LHJIRsh6s2maxaJpmvqgvkC7WFS3loUnaFJtKRVUCEoV/RpCnHRvAsesVQ1hw+vd7Mpo+424tLs72NplkvQgcdrsvXkW/zJWqH/fA0FT84M/xnQJt4to3+ZLuanbM6X5lfXKHosO9COgREqpCR5i86pf2zPS7j9tTj+9nO7bQz3+xGEyGW9zqgQ1tyQ/VsxEDvce/4dcUPNb5OD9yXvR4Z2QisuP0xiGWPnemgugU5q/troHhGEjIF5sTOyW648aC0TssuaaCEsYEIkGzjWXOp3A0vVsf6kgRyqaDk+T7DIVWrb58b2tT5xpUucKwodOD/5LbrZC1ws6YSaBZJ/8xlh+XZSYXaMJ2ezNqjB3IPXuehPcx2U6b4t1dS/xNdFzguUt8ie7arnPeyCZroxLHzGgGdqVcspwafizPWEXBee+9G1OaufGdvNng/9C+gwgZ3PH3r87G6zXTZ5D5De2G2DeFoANXfbACkT+fxBQ22YFsTTJF9hjFVO6VbqxZXko4WJ8s52P4PnuxO5KRzu0/hlix1ySt8iXjgaQ+4IHPA9nVzNkdduM9LFT/Aacj4FtKrHA7iAw602Vnht6R8Vq1IOS+wNMKLYqayAYfRuufQPGeGb7sZogQQoLZrGPgZ6KoYn70Iw30O92BNEDpvwouCFn6wH2uS+EhRb3WF/HObZk3HuxfRQM3Y/Of/VH0n4MKNHZDiZvO9+m/ABALfkOcuar/7nOo7B95ACGVAFaz4jMiJwJhdaHBkySmzlGTu82gr6FSTik2kJvLnY9nOd/D90qcH268m3I/cgI1xg1maE5CuZYaWLH+UHANCIck0yt7Mx5zBm5vVHXHwChsZ35kKqUpmo5Svq5/fzfAI5g2vDtFPYo1HiEA85QrDeGm9g//LG7K0scO3sdpj2CBDgCa+0OFs0bkvVgnnM/QBDwllOMm+cN7vMSHlB7Uu4haHKaTwgGkv8tlK+hP8fzmFuK/RQTpaLPWvbd58yWIo66HHM0OsPoPhVqmtaEVL7N+wYcTLTbb0DLdgp23Eyy2VYJ2N7bkLFAAibtoLPe5sLt6Oa2bvU+zyeMa8wrixO0gRTn9tO9NCSThTLGqcqtsDvphlfmx/cPBZVvw24jg1LE2lPuEo35Mhi58U0I/Ga8n5w+NS8i34MAQLos5B1u0xL1ZvCVYVRw/Fs2q53KLaXJMWwOZZ/4MPYV19bAHmgGDKB6f01xoeJKFbl63q9J34KdaVNPJWztQyRkzA3KNs1AdAEDowMxh10emXTCx75CkurtbY/ZpdNDGdsn2UcHKHsQ8Ai3WZi48IfkvtjOhsLpuIRSKZTX9FA4o+0d6o/zOWqQzVJMynL9NsxhSJOaourq6nBVQBueMSyubsX2xHrmuABZN2Ns9jr5nwLFlLF/2R6atjW/67Yd11YQ1Z+kA9Zk9dPTM/o6dVo6HHVgC0JR8oUfmI93T9u3gvTG94bAH02Y5xeqRcjuwnKCK6Q2+ajl8KXJ3GSh22P3Zfx6S+n008ROhJn+JRIUVu6o7OXl8w1SeyhuqNDwNI7SjbK08QrqPxS95jy4G7nCXVq6G3HNu0LtK5J0e226CfC005WKK9sVvfxI0eUbcnzutfhWe3rpZHM0nZ/ny/N8tanKYlQ6VEW5Xuym8yV1zZX58vwGhZp/5tFfhybZabdbrQYOs8F+xEhmPsb0/nki6kIyVvzZzUASiOrTfF+Sj9bXC7DoJxeiV8tjQL6loSd0yCx7YyB6rPdLx31U2qCG3F/oXIuDuqd6LFO+4DNIJuxFZqSsU0ea88avovFnWKRYFYRQDfCfcGaBCLn4M4A1ntJ5E57vicwqq2enaZEF5nokCYu9TbKqCC5yCDfL+GhLxT4w4xEJs+anqgou8DOY2q8FMryjb2MehC1dRJ9s4g9NXeTwPkWON4RH+FhIe0AWR/S9ekvQ+t70XHeimGF78LzuU7d7PwrswdIG2VpgF8C53qVQsTDtBJc4CdnkQPbnZY9mbPdDFra3PCXBBQ5QBn2aQqtyhvlyYM4Hb2/mdhsxCUen04GZVvIJZw5PAamMOmjzq8Q+dzAKLXDQ3RUZItWsg4t7W2DP+JDrJDymoMH7E5zQtuEpG03GTIjGCW3LQqOYEsXgFc78x76NeRwY6SNM+IfQoh6myJKRBIcLYxZcwscJ/gI2isTBty2Po9IkYzP0/SS4hGlxRjFAG5z1Jt1LckiB57yWvo35EaolbvA+6fBa24xodL2YjsPpTnj3JgJOqhcgOeLVsYYwoK0wjY+m1D3rGc40CukkaHnkEjarlXrF1B9M6ECQ6Ow0V7R7N4G3LfOHAXtymoyXOb4QhaYHJ/gNBJUkxclpSs7DNcgWWDDmM7Ke5MJpGuioe7w5EOvfTunUKRzOh7G2ylL+6ynHrD54oQO3//cN3yVO+5qMVsPZq0CZIOx4TlcJ8+Vz7V5waL+7WekzUpRFMTnnTlSCq3X5usi8qmIleW/rit1+oQZn1WGSU/sKBYEqMNh1mBOc6PhK8yCfKHdUNQk8o/G19ZPTs5MYfai+DLs5vmee37zEyyH48WW3XA6Xw6+Az8lMhci7N/KleToo7PtTKm+RA887Kqc6E9dyqL/QPTugzMHLbLZtJKqKLFfzVWRNJ63c+95uWT/F7R0U5dDVvuS409AJXhJvD0EwWaWdW8UN11u/7+umaYjT8mJtzZwP/MD4r57fihiHlC5fylHfaqnJdro+Dr7DajvO+vi2EwyD70s8nCH71nzIO1l5Zl+v1DMCb5ebvCMkGHvobXy/hPumGLyX0218/3RyD1GRLOuf9u/OGQyDmto32yMiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIv7GP8YjWPR/czH2AAAAAElFTkSuQmCC",
"https://avatars.githubusercontent.com/u/29436595?v=",
],
)
def test_bedrock_claude_3(image_url):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
data = { data = {
@ -294,7 +303,7 @@ def test_bedrock_claude_3():
{ {
"image_url": { "image_url": {
"detail": "high", "detail": "high",
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAL0AAAC9CAMAAADRCYwCAAAAh1BMVEX///8AAAD8/Pz5+fkEBAT39/cJCQn09PRNTU3y8vIMDAwzMzPe3t7v7+8QEBCOjo7FxcXR0dHn5+elpaWGhoYYGBivr686OjocHBy0tLQtLS1TU1PY2Ni6urpaWlpERER3d3ecnJxoaGiUlJRiYmIlJSU4ODhBQUFycnKAgIDBwcFnZ2chISE7EjuwAAAI/UlEQVR4nO1caXfiOgz1bhJIyAJhX1JoSzv8/9/3LNlpYd4rhX6o4/N8Z2lKM2cURZau5JsQEhERERERERERERERERERERHx/wBjhDPC3OGN8+Cc5JeMuheaETSdO8vZFyCScHtmz2CsktoeMn7rLM1u3h0PMAEhyYX7v/Q9wQvoGdB0hlbzm45lEq/wd6y6G9aezvBk9AXwp1r3LHJIRsh6s2maxaJpmvqgvkC7WFS3loUnaFJtKRVUCEoV/RpCnHRvAsesVQ1hw+vd7Mpo+424tLs72NplkvQgcdrsvXkW/zJWqH/fA0FT84M/xnQJt4to3+ZLuanbM6X5lfXKHosO9COgREqpCR5i86pf2zPS7j9tTj+9nO7bQz3+xGEyGW9zqgQ1tyQ/VsxEDvce/4dcUPNb5OD9yXvR4Z2QisuP0xiGWPnemgugU5q/troHhGEjIF5sTOyW648aC0TssuaaCEsYEIkGzjWXOp3A0vVsf6kgRyqaDk+T7DIVWrb58b2tT5xpUucKwodOD/5LbrZC1ws6YSaBZJ/8xlh+XZSYXaMJ2ezNqjB3IPXuehPcx2U6b4t1dS/xNdFzguUt8ie7arnPeyCZroxLHzGgGdqVcspwafizPWEXBee+9G1OaufGdvNng/9C+gwgZ3PH3r87G6zXTZ5D5De2G2DeFoANXfbACkT+fxBQ22YFsTTJF9hjFVO6VbqxZXko4WJ8s52P4PnuxO5KRzu0/hlix1ySt8iXjgaQ+4IHPA9nVzNkdduM9LFT/Aacj4FtKrHA7iAw602Vnht6R8Vq1IOS+wNMKLYqayAYfRuufQPGeGb7sZogQQoLZrGPgZ6KoYn70Iw30O92BNEDpvwouCFn6wH2uS+EhRb3WF/HObZk3HuxfRQM3Y/Of/VH0n4MKNHZDiZvO9+m/ABALfkOcuar/7nOo7B95ACGVAFaz4jMiJwJhdaHBkySmzlGTu82gr6FSTik2kJvLnY9nOd/D90qcH268m3I/cgI1xg1maE5CuZYaWLH+UHANCIck0yt7Mx5zBm5vVHXHwChsZ35kKqUpmo5Svq5/fzfAI5g2vDtFPYo1HiEA85QrDeGm9g//LG7K0scO3sdpj2CBDgCa+0OFs0bkvVgnnM/QBDwllOMm+cN7vMSHlB7Uu4haHKaTwgGkv8tlK+hP8fzmFuK/RQTpaLPWvbd58yWIo66HHM0OsPoPhVqmtaEVL7N+wYcTLTbb0DLdgp23Eyy2VYJ2N7bkLFAAibtoLPe5sLt6Oa2bvU+zyeMa8wrixO0gRTn9tO9NCSThTLGqcqtsDvphlfmx/cPBZVvw24jg1LE2lPuEo35Mhi58U0I/Ga8n5w+NS8i34MAQLos5B1u0xL1ZvCVYVRw/Fs2q53KLaXJMWwOZZ/4MPYV19bAHmgGDKB6f01xoeJKFbl63q9J34KdaVNPJWztQyRkzA3KNs1AdAEDowMxh10emXTCx75CkurtbY/ZpdNDGdsn2UcHKHsQ8Ai3WZi48IfkvtjOhsLpuIRSKZTX9FA4o+0d6o/zOWqQzVJMynL9NsxhSJOaourq6nBVQBueMSyubsX2xHrmuABZN2Ns9jr5nwLFlLF/2R6atjW/67Yd11YQ1Z+kA9Zk9dPTM/o6dVo6HHVgC0JR8oUfmI93T9u3gvTG94bAH02Y5xeqRcjuwnKCK6Q2+ajl8KXJ3GSh22P3Zfx6S+n008ROhJn+JRIUVu6o7OXl8w1SeyhuqNDwNI7SjbK08QrqPxS95jy4G7nCXVq6G3HNu0LtK5J0e226CfC005WKK9sVvfxI0eUbcnzutfhWe3rpZHM0nZ/ny/N8tanKYlQ6VEW5Xuym8yV1zZX58vwGhZp/5tFfhybZabdbrQYOs8F+xEhmPsb0/nki6kIyVvzZzUASiOrTfF+Sj9bXC7DoJxeiV8tjQL6loSd0yCx7YyB6rPdLx31U2qCG3F/oXIuDuqd6LFO+4DNIJuxFZqSsU0ea88avovFnWKRYFYRQDfCfcGaBCLn4M4A1ntJ5E57vicwqq2enaZEF5nokCYu9TbKqCC5yCDfL+GhLxT4w4xEJs+anqgou8DOY2q8FMryjb2MehC1dRJ9s4g9NXeTwPkWON4RH+FhIe0AWR/S9ekvQ+t70XHeimGF78LzuU7d7PwrswdIG2VpgF8C53qVQsTDtBJc4CdnkQPbnZY9mbPdDFra3PCXBBQ5QBn2aQqtyhvlyYM4Hb2/mdhsxCUen04GZVvIJZw5PAamMOmjzq8Q+dzAKLXDQ3RUZItWsg4t7W2DP+JDrJDymoMH7E5zQtuEpG03GTIjGCW3LQqOYEsXgFc78x76NeRwY6SNM+IfQoh6myJKRBIcLYxZcwscJ/gI2isTBty2Po9IkYzP0/SS4hGlxRjFAG5z1Jt1LckiB57yWvo35EaolbvA+6fBa24xodL2YjsPpTnj3JgJOqhcgOeLVsYYwoK0wjY+m1D3rGc40CukkaHnkEjarlXrF1B9M6ECQ6Ow0V7R7N4G3LfOHAXtymoyXOb4QhaYHJ/gNBJUkxclpSs7DNcgWWDDmM7Ke5MJpGuioe7w5EOvfTunUKRzOh7G2ylL+6ynHrD54oQO3//cN3yVO+5qMVsPZq0CZIOx4TlcJ8+Vz7V5waL+7WekzUpRFMTnnTlSCq3X5usi8qmIleW/rit1+oQZn1WGSU/sKBYEqMNh1mBOc6PhK8yCfKHdUNQk8o/G19ZPTs5MYfai+DLs5vmee37zEyyH48WW3XA6Xw6+Az8lMhci7N/KleToo7PtTKm+RA887Kqc6E9dyqL/QPTugzMHLbLZtJKqKLFfzVWRNJ63c+95uWT/F7R0U5dDVvuS409AJXhJvD0EwWaWdW8UN11u/7+umaYjT8mJtzZwP/MD4r57fihiHlC5fylHfaqnJdro+Dr7DajvO+vi2EwyD70s8nCH71nzIO1l5Zl+v1DMCb5ebvCMkGHvobXy/hPumGLyX0218/3RyD1GRLOuf9u/OGQyDmto32yMiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIv7GP8YjWPR/czH2AAAAAElFTkSuQmCC", "url": image_url,
}, },
"type": "image_url", "type": "image_url",
}, },
@ -313,7 +322,6 @@ def test_bedrock_claude_3():
# Add any assertions here to check the response # Add any assertions here to check the response
assert len(response.choices) > 0 assert len(response.choices) > 0
assert len(response.choices[0].message.content) > 0 assert len(response.choices[0].message.content) > 0
except RateLimitError: except RateLimitError:
pass pass
except Exception as e: except Exception as e:
@ -552,7 +560,7 @@ def test_bedrock_ptu():
assert "url" in mock_client_post.call_args.kwargs assert "url" in mock_client_post.call_args.kwargs
assert ( assert (
mock_client_post.call_args.kwargs["url"] mock_client_post.call_args.kwargs["url"]
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke" == "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/converse"
) )
mock_client_post.assert_called_once() mock_client_post.assert_called_once()

View file

@ -300,7 +300,11 @@ def test_completion_claude_3():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_claude_3_function_call(): @pytest.mark.parametrize(
"model",
["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"],
)
def test_completion_claude_3_function_call(model):
litellm.set_verbose = True litellm.set_verbose = True
tools = [ tools = [
{ {
@ -331,13 +335,14 @@ def test_completion_claude_3_function_call():
try: try:
# test without max tokens # test without max tokens
response = completion( response = completion(
model="anthropic/claude-3-opus-20240229", model=model,
messages=messages, messages=messages,
tools=tools, tools=tools,
tool_choice={ tool_choice={
"type": "function", "type": "function",
"function": {"name": "get_current_weather"}, "function": {"name": "get_current_weather"},
}, },
drop_params=True,
) )
# Add any assertions, here to check response args # Add any assertions, here to check response args
@ -364,10 +369,11 @@ def test_completion_claude_3_function_call():
) )
# In the second response, Claude should deduce answer from tool results # In the second response, Claude should deduce answer from tool results
second_response = completion( second_response = completion(
model="anthropic/claude-3-opus-20240229", model=model,
messages=messages, messages=messages,
tools=tools, tools=tools,
tool_choice="auto", tool_choice="auto",
drop_params=True,
) )
print(second_response) print(second_response)
except Exception as e: except Exception as e:
@ -2163,6 +2169,7 @@ def test_completion_azure_key_completion_arg():
logprobs=True, logprobs=True,
max_tokens=10, max_tokens=10,
) )
print(f"response: {response}") print(f"response: {response}")
print("Hidden Params", response._hidden_params) print("Hidden Params", response._hidden_params)
@ -2538,6 +2545,7 @@ def test_replicate_custom_prompt_dict():
"content": "what is yc write 1 paragraph", "content": "what is yc write 1 paragraph",
} }
], ],
mock_response="Hello world",
repetition_penalty=0.1, repetition_penalty=0.1,
num_retries=3, num_retries=3,
) )

View file

@ -76,7 +76,7 @@ def test_image_generation_azure_dall_e_3():
) )
print(f"response: {response}") print(f"response: {response}")
assert len(response.data) > 0 assert len(response.data) > 0
except litellm.RateLimitError as e: except litellm.InternalServerError as e:
pass pass
except litellm.ContentPolicyViolationError: except litellm.ContentPolicyViolationError:
pass # OpenAI randomly raises these errors - skip when they occur pass # OpenAI randomly raises these errors - skip when they occur

View file

@ -2248,3 +2248,55 @@ async def test_create_update_team(prisma_client):
assert _team_info["budget_reset_at"] is not None and isinstance( assert _team_info["budget_reset_at"] is not None and isinstance(
_team_info["budget_reset_at"], datetime.datetime _team_info["budget_reset_at"], datetime.datetime
) )
@pytest.mark.asyncio()
async def test_enforced_params(prisma_client):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
from litellm.proxy.proxy_server import general_settings
general_settings["enforced_params"] = [
"user",
"metadata",
"metadata.generation_name",
]
await litellm.proxy.proxy_server.prisma_client.connect()
request = NewUserRequest()
key = await new_user(request)
print(key)
generated_key = key.key
bearer_token = "Bearer " + generated_key
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# Case 1: Missing user
async def return_body():
return b'{"model": "gemini-pro-vision"}'
request.body = return_body
try:
await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail(f"This should have failed!. IT's an invalid request")
except Exception as e:
assert (
"BadRequest please pass param=user in request body. This is a required param"
in e.message
)
# Case 2: Missing metadata["generation_name"]
async def return_body_2():
return b'{"model": "gemini-pro-vision", "user": "1234", "metadata": {}}'
request.body = return_body_2
try:
await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail(f"This should have failed!. IT's an invalid request")
except Exception as e:
assert (
"Authentication Error, BadRequest please pass param=[metadata][generation_name] in request body"
in e.message
)

View file

@ -275,7 +275,7 @@ async def _deploy(lowest_latency_logger, deployment_id, tokens_used, duration):
} }
start_time = time.time() start_time = time.time()
response_obj = {"usage": {"total_tokens": tokens_used}} response_obj = {"usage": {"total_tokens": tokens_used}}
time.sleep(duration) await asyncio.sleep(duration)
end_time = time.time() end_time = time.time()
lowest_latency_logger.log_success_event( lowest_latency_logger.log_success_event(
response_obj=response_obj, response_obj=response_obj,
@ -325,6 +325,7 @@ def test_get_available_endpoints_tpm_rpm_check_async(ans_rpm):
d1 = [(lowest_latency_logger, "1234", 50, 0.01)] * non_ans_rpm d1 = [(lowest_latency_logger, "1234", 50, 0.01)] * non_ans_rpm
d2 = [(lowest_latency_logger, "5678", 50, 0.01)] * non_ans_rpm d2 = [(lowest_latency_logger, "5678", 50, 0.01)] * non_ans_rpm
asyncio.run(_gather_deploy([*d1, *d2])) asyncio.run(_gather_deploy([*d1, *d2]))
time.sleep(3)
## CHECK WHAT'S SELECTED ## ## CHECK WHAT'S SELECTED ##
d_ans = lowest_latency_logger.get_available_deployments( d_ans = lowest_latency_logger.get_available_deployments(
model_group=model_group, healthy_deployments=model_list model_group=model_group, healthy_deployments=model_list

View file

@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import (
claude_2_1_pt, claude_2_1_pt,
llama_2_chat_pt, llama_2_chat_pt,
prompt_factory, prompt_factory,
_bedrock_tools_pt,
) )
@ -128,3 +129,27 @@ def test_anthropic_messages_pt():
# codellama_prompt_format() # codellama_prompt_format()
def test_bedrock_tool_calling_pt():
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
converted_tools = _bedrock_tools_pt(tools=tools)
print(converted_tools)

View file

@ -38,6 +38,48 @@ def test_router_sensitive_keys():
assert "special-key" not in str(e) assert "special-key" not in str(e)
def test_router_order():
"""
Asserts for 2 models in a model group, model with order=1 always called first
"""
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
"mock_response": "Hello world",
"order": 1,
},
"model_info": {"id": "1"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-4o",
"api_key": "bad-key",
"mock_response": Exception("this is a bad key"),
"order": 2,
},
"model_info": {"id": "2"},
},
],
num_retries=0,
allowed_fails=0,
enable_pre_call_checks=True,
)
for _ in range(100):
response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
assert isinstance(response, litellm.ModelResponse)
assert response._hidden_params["model_id"] == "1"
@pytest.mark.parametrize("num_retries", [None, 2]) @pytest.mark.parametrize("num_retries", [None, 2])
@pytest.mark.parametrize("max_retries", [None, 4]) @pytest.mark.parametrize("max_retries", [None, 4])
def test_router_num_retries_init(num_retries, max_retries): def test_router_num_retries_init(num_retries, max_retries):

View file

@ -1284,18 +1284,18 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True]) # False
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
# "bedrock/cohere.command-r-plus-v1:0", "bedrock/cohere.command-r-plus-v1:0",
# "anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0",
# "anthropic.claude-instant-v1", "anthropic.claude-instant-v1",
# "bedrock/ai21.j2-mid", "bedrock/ai21.j2-mid",
# "mistral.mistral-7b-instruct-v0:2", "mistral.mistral-7b-instruct-v0:2",
# "bedrock/amazon.titan-tg1-large", "bedrock/amazon.titan-tg1-large",
# "meta.llama3-8b-instruct-v1:0", "meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14" "cohere.command-text-v14",
], ],
) )
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -186,3 +186,13 @@ def test_load_test_token_counter(model):
total_time = end_time - start_time total_time = end_time - start_time
print("model={}, total test time={}".format(model, total_time)) print("model={}, total test time={}".format(model, total_time))
assert total_time < 10, f"Total encoding time > 10s, {total_time}" assert total_time < 10, f"Total encoding time > 10s, {total_time}"
def test_openai_token_with_image_and_text():
model = "gpt-4o"
full_request = {'model': 'gpt-4o', 'tools': [{'type': 'function', 'function': {'name': 'json', 'parameters': {'type': 'object', 'required': ['clause'], 'properties': {'clause': {'type': 'string'}}}, 'description': 'Respond with a JSON object.'}}], 'logprobs': False, 'messages': [{'role': 'user', 'content': [{'text': '\n Just some long text, long long text, and you know it will be longer than 7 tokens definetly.', 'type': 'text'}]}], 'tool_choice': {'type': 'function', 'function': {'name': 'json'}}, 'exclude_models': [], 'disable_fallback': False, 'exclude_providers': []}
messages = full_request.get("messages", [])
token_count = token_counter(model=model, messages=messages)
print(token_count)
test_openai_token_with_image_and_text()

View file

@ -1,4 +1,4 @@
from typing import TypedDict, Any, Union, Optional from typing import TypedDict, Any, Union, Optional, Literal, List
import json import json
from typing_extensions import ( from typing_extensions import (
Self, Self,
@ -11,10 +11,137 @@ from typing_extensions import (
) )
class SystemContentBlock(TypedDict):
text: str
class ImageSourceBlock(TypedDict):
bytes: Optional[str] # base 64 encoded string
class ImageBlock(TypedDict):
format: Literal["png", "jpeg", "gif", "webp"]
source: ImageSourceBlock
class ToolResultContentBlock(TypedDict, total=False):
image: ImageBlock
json: dict
text: str
class ToolResultBlock(TypedDict, total=False):
content: Required[List[ToolResultContentBlock]]
toolUseId: Required[str]
status: Literal["success", "error"]
class ToolUseBlock(TypedDict):
input: dict
name: str
toolUseId: str
class ContentBlock(TypedDict, total=False):
text: str
image: ImageBlock
toolResult: ToolResultBlock
toolUse: ToolUseBlock
class MessageBlock(TypedDict):
content: List[ContentBlock]
role: Literal["user", "assistant"]
class ConverseMetricsBlock(TypedDict):
latencyMs: float # time in ms
class ConverseResponseOutputBlock(TypedDict):
message: Optional[MessageBlock]
class ConverseTokenUsageBlock(TypedDict):
inputTokens: int
outputTokens: int
totalTokens: int
class ConverseResponseBlock(TypedDict):
additionalModelResponseFields: dict
metrics: ConverseMetricsBlock
output: ConverseResponseOutputBlock
stopReason: (
str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered
)
usage: ConverseTokenUsageBlock
class ToolInputSchemaBlock(TypedDict):
json: Optional[dict]
class ToolSpecBlock(TypedDict, total=False):
inputSchema: Required[ToolInputSchemaBlock]
name: Required[str]
description: str
class ToolBlock(TypedDict):
toolSpec: Optional[ToolSpecBlock]
class SpecificToolChoiceBlock(TypedDict):
name: str
class ToolChoiceValuesBlock(TypedDict, total=False):
any: dict
auto: dict
tool: SpecificToolChoiceBlock
class ToolConfigBlock(TypedDict, total=False):
tools: Required[List[ToolBlock]]
toolChoice: Union[str, ToolChoiceValuesBlock]
class InferenceConfig(TypedDict, total=False):
maxTokens: int
stopSequences: List[str]
temperature: float
topP: float
class ToolBlockDeltaEvent(TypedDict):
input: str
class ContentBlockDeltaEvent(TypedDict, total=False):
"""
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
"""
text: str
toolUse: ToolBlockDeltaEvent
class RequestObject(TypedDict, total=False):
additionalModelRequestFields: dict
additionalModelResponseFieldPaths: List[str]
inferenceConfig: InferenceConfig
messages: Required[List[MessageBlock]]
system: List[SystemContentBlock]
toolConfig: ToolConfigBlock
class GenericStreamingChunk(TypedDict): class GenericStreamingChunk(TypedDict):
text: Required[str] text: Required[str]
tool_str: Required[str]
is_finished: Required[bool] is_finished: Required[bool]
finish_reason: Required[str] finish_reason: Required[str]
usage: Optional[ConverseTokenUsageBlock]
class Document(TypedDict): class Document(TypedDict):

View file

@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False):
extra_headers: Optional[Dict[str, str]] extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]] extra_body: Optional[Dict[str, str]]
timeout: Optional[float] timeout: Optional[float]
class ChatCompletionToolCallFunctionChunk(TypedDict):
name: str
arguments: str
class ChatCompletionToolCallChunk(TypedDict):
id: str
type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk
class ChatCompletionResponseMessage(TypedDict, total=False):
content: Optional[str]
tool_calls: List[ChatCompletionToolCallChunk]
role: Literal["assistant"]

View file

@ -239,6 +239,8 @@ def map_finish_reason(
return "length" return "length"
elif finish_reason == "tool_use": # anthropic elif finish_reason == "tool_use": # anthropic
return "tool_calls" return "tool_calls"
elif finish_reason == "content_filtered":
return "content_filter"
return finish_reason return finish_reason
@ -1372,8 +1374,12 @@ class Logging:
callback_func=callback, callback_func=callback,
) )
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
print_verbose( "litellm.Logging.pre_call(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while input logging with integrations {traceback.format_exc()}" f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while input logging with integrations {traceback.format_exc()}"
) )
print_verbose( print_verbose(
@ -4060,6 +4066,9 @@ def openai_token_counter(
for c in value: for c in value:
if c["type"] == "text": if c["type"] == "text":
text += c["text"] text += c["text"]
num_tokens += len(
encoding.encode(c["text"], disallowed_special=())
)
elif c["type"] == "image_url": elif c["type"] == "image_url":
if isinstance(c["image_url"], dict): if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"] image_url_dict = c["image_url"]
@ -5632,8 +5641,7 @@ def get_optional_params(
optional_params["stream"] = stream optional_params["stream"] = stream
elif "anthropic" in model: elif "anthropic" in model:
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# anthropic params on bedrock if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
if model.startswith("anthropic.claude-3"): if model.startswith("anthropic.claude-3"):
optional_params = ( optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params( litellm.AmazonAnthropicClaude3Config().map_openai_params(
@ -5646,6 +5654,17 @@ def get_optional_params(
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
) )
else: # bedrock httpx route
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif "amazon" in model: # amazon titan llms elif "amazon" in model: # amazon titan llms
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
@ -6196,6 +6215,27 @@ def calculate_max_parallel_requests(
return None return None
def _get_order_filtered_deployments(healthy_deployments: List[Dict]) -> List:
min_order = min(
(
deployment["litellm_params"]["order"]
for deployment in healthy_deployments
if "order" in deployment["litellm_params"]
),
default=None,
)
if min_order is not None:
filtered_deployments = [
deployment
for deployment in healthy_deployments
if deployment["litellm_params"].get("order") == min_order
]
return filtered_deployments
return healthy_deployments
def _get_model_region( def _get_model_region(
custom_llm_provider: str, litellm_params: LiteLLM_Params custom_llm_provider: str, litellm_params: LiteLLM_Params
) -> Optional[str]: ) -> Optional[str]:
@ -6401,20 +6441,7 @@ def get_supported_openai_params(
- None if unmapped - None if unmapped
""" """
if custom_llm_provider == "bedrock": if custom_llm_provider == "bedrock":
if model.startswith("anthropic.claude-3"): return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
elif model.startswith("anthropic"):
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
elif model.startswith("ai21"):
return ["max_tokens", "temperature", "top_p", "stream"]
elif model.startswith("amazon"):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
elif model.startswith("meta"):
return ["max_tokens", "temperature", "top_p", "stream"]
elif model.startswith("cohere"):
return ["stream", "temperature", "max_tokens"]
elif model.startswith("mistral"):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_supported_openai_params() return litellm.OllamaConfig().get_supported_openai_params()
elif custom_llm_provider == "ollama_chat": elif custom_llm_provider == "ollama_chat":
@ -9805,8 +9832,7 @@ def exception_type(
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
if "Internal server error" in error_str: if "Internal server error" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise APIError( raise litellm.InternalServerError(
status_code=500,
message=f"AzureException Internal server error - {original_exception.message}", message=f"AzureException Internal server error - {original_exception.message}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
@ -10056,6 +10082,8 @@ def get_secret(
): ):
key_management_system = litellm._key_management_system key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings key_management_settings = litellm._key_management_settings
args = locals()
if secret_name.startswith("os.environ/"): if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "") secret_name = secret_name.replace("os.environ/", "")
@ -10143,13 +10171,13 @@ def get_secret(
key_manager = "local" key_manager = "local"
if ( if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
or type(client).__module__ + "." + type(client).__name__ or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient" == "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = client.get_secret(secret_name).value secret = client.get_secret(secret_name).value
elif ( elif (
key_manager == KeyManagementSystem.GOOGLE_KMS key_manager == KeyManagementSystem.GOOGLE_KMS.value
or client.__class__.__name__ == "KeyManagementServiceClient" or client.__class__.__name__ == "KeyManagementServiceClient"
): ):
encrypted_secret: Any = os.getenv(secret_name) encrypted_secret: Any = os.getenv(secret_name)
@ -10177,6 +10205,25 @@ def get_secret(
secret = response.plaintext.decode( secret = response.plaintext.decode(
"utf-8" "utf-8"
) # assumes the original value was encoded with utf-8 ) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value:
"""
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
"""
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception("encrypted value for AWS KMS cannot be None.")
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try: try:
get_secret_value_response = client.get_secret_value( get_secret_value_response = client.get_secret_value(
@ -10197,10 +10244,14 @@ def get_secret(
for k, v in secret_dict.items(): for k, v in secret_dict.items():
secret = v secret = v
print_verbose(f"secret: {secret}") print_verbose(f"secret: {secret}")
elif key_manager == "local":
secret = os.getenv(secret_name)
else: # assume the default is infisicial client else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ except Exception as e: # check if it's in os.environ
print_verbose(f"An exception occurred - {str(e)}") verbose_logger.error(
f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}"
)
secret = os.getenv(secret_name) secret = os.getenv(secret_name)
try: try:
secret_value_as_bool = ast.literal_eval(secret) secret_value_as_bool = ast.literal_eval(secret)
@ -10534,7 +10585,12 @@ class CustomStreamWrapper:
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.CustomStreamWrapper.handle_predibase_chunk(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
raise e raise e
def handle_huggingface_chunk(self, chunk): def handle_huggingface_chunk(self, chunk):
@ -10578,7 +10634,12 @@ class CustomStreamWrapper:
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.CustomStreamWrapper.handle_huggingface_chunk(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
raise e raise e
def handle_ai21_chunk(self, chunk): # fake streaming def handle_ai21_chunk(self, chunk): # fake streaming
@ -10813,7 +10874,12 @@ class CustomStreamWrapper:
"usage": usage, "usage": usage,
} }
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.CustomStreamWrapper.handle_openai_chat_completion_chunk(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
raise e raise e
def handle_azure_text_completion_chunk(self, chunk): def handle_azure_text_completion_chunk(self, chunk):
@ -10894,7 +10960,12 @@ class CustomStreamWrapper:
else: else:
return "" return ""
except: except:
traceback.print_exc() verbose_logger.error(
"litellm.CustomStreamWrapper.handle_baseten_chunk(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
return "" return ""
def handle_cloudlfare_stream(self, chunk): def handle_cloudlfare_stream(self, chunk):
@ -11093,7 +11164,12 @@ class CustomStreamWrapper:
"is_finished": True, "is_finished": True,
} }
except: except:
traceback.print_exc() verbose_logger.error(
"litellm.CustomStreamWrapper.handle_clarifai_chunk(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
return "" return ""
def model_response_creator(self): def model_response_creator(self):
@ -11334,12 +11410,27 @@ class CustomStreamWrapper:
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock": elif self.custom_llm_provider == "bedrock":
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None: if self.received_finish_reason is not None:
raise StopIteration raise StopIteration
response_obj = self.handle_bedrock_stream(chunk) response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
elif self.custom_llm_provider == "sagemaker": elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk) response_obj = self.handle_sagemaker_stream(chunk)
@ -11565,7 +11656,12 @@ class CustomStreamWrapper:
tool["type"] = "function" tool["type"] = "function"
model_response.choices[0].delta = Delta(**_json_delta) model_response.choices[0].delta = Delta(**_json_delta)
except Exception as e: except Exception as e:
traceback.print_exc() verbose_logger.error(
"litellm.CustomStreamWrapper.chunk_creator(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
model_response.choices[0].delta = Delta() model_response.choices[0].delta = Delta()
else: else:
try: try:
@ -11601,7 +11697,7 @@ class CustomStreamWrapper:
and hasattr(model_response, "usage") and hasattr(model_response, "usage")
and hasattr(model_response.usage, "prompt_tokens") and hasattr(model_response.usage, "prompt_tokens")
): ):
if self.sent_first_chunk == False: if self.sent_first_chunk is False:
completion_obj["role"] = "assistant" completion_obj["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj) model_response.choices[0].delta = Delta(**completion_obj)
@ -11769,6 +11865,8 @@ class CustomStreamWrapper:
def __next__(self): def __next__(self):
try: try:
if self.completion_stream is None:
self.fetch_sync_stream()
while True: while True:
if ( if (
isinstance(self.completion_stream, str) isinstance(self.completion_stream, str)
@ -11843,6 +11941,14 @@ class CustomStreamWrapper:
custom_llm_provider=self.custom_llm_provider, custom_llm_provider=self.custom_llm_provider,
) )
def fetch_sync_stream(self):
if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream
self.completion_stream = self.make_call(client=litellm.module_level_client)
self._stream_iter = self.completion_stream.__iter__()
return self.completion_stream
async def fetch_stream(self): async def fetch_stream(self):
if self.completion_stream is None and self.make_call is not None: if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream # Call make_call to get the completion stream

View file

@ -1,10 +1,14 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.40.4" version = "1.40.5"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
readme = "README.md" readme = "README.md"
packages = [
{ include = "litellm" },
{ include = "litellm/py.typed"},
]
[tool.poetry.urls] [tool.poetry.urls]
homepage = "https://litellm.ai" homepage = "https://litellm.ai"
@ -80,7 +84,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.40.4" version = "1.40.5"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

3
ruff.toml Normal file
View file

@ -0,0 +1,3 @@
ignore = ["F405"]
extend-select = ["E501"]
line-length = 120