mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge branch 'main' into litellm_svc_logger
This commit is contained in:
commit
2cf3133669
84 changed files with 3848 additions and 5302 deletions
|
@ -344,4 +344,4 @@ workflows:
|
|||
filters:
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
- main
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -59,3 +59,4 @@ myenv/*
|
|||
litellm/proxy/_experimental/out/404/index.html
|
||||
litellm/proxy/_experimental/out/model_hub/index.html
|
||||
litellm/proxy/_experimental/out/onboarding/index.html
|
||||
litellm/tests/log.txt
|
||||
|
|
|
@ -38,7 +38,7 @@ class MyCustomHandler(CustomLogger):
|
|||
print(f"On Async Success")
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Success")
|
||||
print(f"On Async Failure")
|
||||
|
||||
customHandler = MyCustomHandler()
|
||||
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
llmcord.py lets you and your friends chat with LLMs directly in your Discord server. It works with practically any LLM, remote or locally hosted.
|
||||
|
||||
Github: https://github.com/jakobdylanc/discord-llm-chatbot
|
|
@ -62,6 +62,23 @@ curl -X GET 'http://localhost:4000/health/services?service=slack' \
|
|||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
## Advanced - Redacting Messages from Alerts
|
||||
|
||||
By default alerts show the `messages/input` passed to the LLM. If you want to redact this from slack alerting set the following setting on your config
|
||||
|
||||
|
||||
```shell
|
||||
general_settings:
|
||||
alerting: ["slack"]
|
||||
alert_types: ["spend_reports"]
|
||||
|
||||
litellm_settings:
|
||||
redact_messages_in_exceptions: True
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
## Advanced - Opting into specific alert types
|
||||
|
||||
Set `alert_types` if you want to Opt into only specific alert types
|
||||
|
|
|
@ -14,6 +14,7 @@ Features:
|
|||
- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features)
|
||||
- ✅ [Audit Logs](#audit-logs)
|
||||
- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags)
|
||||
- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests)
|
||||
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Google Text Moderations](#content-moderation)
|
||||
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
|
||||
- ✅ [Custom Branding + Routes on Swagger Docs](#swagger-docs---custom-routes--branding)
|
||||
|
@ -204,6 +205,109 @@ curl -X GET "http://0.0.0.0:4000/spend/tags" \
|
|||
```
|
||||
|
||||
|
||||
## Enforce Required Params for LLM Requests
|
||||
Use this when you want to enforce all requests to include certain params. Example you need all requests to include the `user` and `["metadata]["generation_name"]` params.
|
||||
|
||||
**Step 1** Define all Params you want to enforce on config.yaml
|
||||
|
||||
This means `["user"]` and `["metadata]["generation_name"]` are required in all LLM Requests to LiteLLM
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
enforced_params:
|
||||
- user
|
||||
- metadata.generation_name
|
||||
```
|
||||
|
||||
Start LiteLLM Proxy
|
||||
|
||||
**Step 2 Verify if this works**
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="bad" label="Invalid Request (No `user` passed)">
|
||||
|
||||
```shell
|
||||
curl --location 'http://localhost:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-5fmYeaUEbAMpwBNT-QpxyA' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Expected Response
|
||||
|
||||
```shell
|
||||
{"error":{"message":"Authentication Error, BadRequest please pass param=user in request body. This is a required param","type":"auth_error","param":"None","code":401}}%
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="bad2" label="Invalid Request (No `metadata` passed)">
|
||||
|
||||
```shell
|
||||
curl --location 'http://localhost:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-5fmYeaUEbAMpwBNT-QpxyA' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"user": "gm",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
}
|
||||
],
|
||||
"metadata": {}
|
||||
}'
|
||||
```
|
||||
|
||||
Expected Response
|
||||
|
||||
```shell
|
||||
{"error":{"message":"Authentication Error, BadRequest please pass param=[metadata][generation_name] in request body. This is a required param","type":"auth_error","param":"None","code":401}}%
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="good" label="Valid Request">
|
||||
|
||||
```shell
|
||||
curl --location 'http://localhost:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-5fmYeaUEbAMpwBNT-QpxyA' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"user": "gm",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
}
|
||||
],
|
||||
"metadata": {"generation_name": "prod-app"}
|
||||
}'
|
||||
```
|
||||
|
||||
Expected Response
|
||||
|
||||
```shell
|
||||
{"id":"chatcmpl-9XALnHqkCBMBKrOx7Abg0hURHqYtY","choices":[{"finish_reason":"stop","index":0,"message":{"content":"Hello! How can I assist you today?","role":"assistant"}}],"created":1717691639,"model":"gpt-3.5-turbo-0125","object":"chat.completion","system_fingerprint":null,"usage":{"completion_tokens":9,"prompt_tokens":8,"total_tokens":17}}%
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -313,6 +313,18 @@ You will see `raw_request` in your Langfuse Metadata. This is the RAW CURL comma
|
|||
|
||||
|
||||
## Logging Proxy Input/Output in OpenTelemetry format
|
||||
|
||||
:::info
|
||||
|
||||
[Optional] Customize OTEL Service Name and OTEL TRACER NAME by setting the following variables in your environment
|
||||
|
||||
```shell
|
||||
OTEL_TRACER_NAME=<your-trace-name> # default="litellm"
|
||||
OTEL_SERVICE_NAME=<your-service-name>` # default="litellm"
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
<Tabs>
|
||||
|
||||
|
||||
|
|
|
@ -100,4 +100,76 @@ print(response)
|
|||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</Tabs>
|
||||
|
||||
## Advanced - Redis Caching
|
||||
|
||||
Use redis caching to do request prioritization across multiple instances of LiteLLM.
|
||||
|
||||
### SDK
|
||||
```python
|
||||
from litellm import Router
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"mock_response": "Hello world this is Macintosh!", # fakes the LLM API call
|
||||
"rpm": 1,
|
||||
},
|
||||
},
|
||||
],
|
||||
### REDIS PARAMS ###
|
||||
redis_host=os.environ["REDIS_HOST"],
|
||||
redis_password=os.environ["REDIS_PASSWORD"],
|
||||
redis_port=os.environ["REDIS_PORT"],
|
||||
)
|
||||
|
||||
try:
|
||||
_response = await router.schedule_acompletion( # 👈 ADDS TO QUEUE + POLLS + MAKES CALL
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey!"}],
|
||||
priority=0, # 👈 LOWER IS BETTER
|
||||
)
|
||||
except Exception as e:
|
||||
print("didn't make request")
|
||||
```
|
||||
|
||||
### PROXY
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo-fake-model
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
mock_response: "hello world!"
|
||||
api_key: my-good-key
|
||||
|
||||
router_settings:
|
||||
redis_host; os.environ/REDIS_HOST
|
||||
redis_password: os.environ/REDIS_PASSWORD
|
||||
redis_port: os.environ/REDIS_PORT
|
||||
```
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000s
|
||||
```
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/queue/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-D '{
|
||||
"model": "gpt-3.5-turbo-fake-model",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the meaning of the universe? 1234"
|
||||
}],
|
||||
"priority": 0 👈 SET VALUE HERE
|
||||
}'
|
||||
```
|
|
@ -1,11 +1,31 @@
|
|||
# Secret Manager
|
||||
LiteLLM supports reading secrets from Azure Key Vault and Infisical
|
||||
|
||||
- AWS Key Managemenet Service
|
||||
- AWS Secret Manager
|
||||
- [Azure Key Vault](#azure-key-vault)
|
||||
- Google Key Management Service
|
||||
- [Infisical Secret Manager](#infisical-secret-manager)
|
||||
- [.env Files](#env-files)
|
||||
|
||||
## AWS Key Management Service
|
||||
|
||||
Use AWS KMS to storing a hashed copy of your Proxy Master Key in the environment.
|
||||
|
||||
```bash
|
||||
export LITELLM_MASTER_KEY="djZ9xjVaZ..." # 👈 ENCRYPTED KEY
|
||||
export AWS_REGION_NAME="us-west-2"
|
||||
```
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
key_management_system: "aws_kms"
|
||||
key_management_settings:
|
||||
hosted_keys: ["LITELLM_MASTER_KEY"] # 👈 WHICH KEYS ARE STORED ON KMS
|
||||
```
|
||||
|
||||
[**See Decryption Code**](https://github.com/BerriAI/litellm/blob/a2da2a8f168d45648b61279d4795d647d94f90c9/litellm/utils.py#L10182)
|
||||
|
||||
## AWS Secret Manager
|
||||
|
||||
Store your proxy keys in AWS Secret Manager.
|
||||
|
|
6
docs/my-website/package-lock.json
generated
6
docs/my-website/package-lock.json
generated
|
@ -5975,9 +5975,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/caniuse-lite": {
|
||||
"version": "1.0.30001519",
|
||||
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001519.tgz",
|
||||
"integrity": "sha512-0QHgqR+Jv4bxHMp8kZ1Kn8CH55OikjKJ6JmKkZYP1F3D7w+lnFXF70nG5eNfsZS89jadi5Ywy5UCSKLAglIRkg==",
|
||||
"version": "1.0.30001629",
|
||||
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001629.tgz",
|
||||
"integrity": "sha512-c3dl911slnQhmxUIT4HhYzT7wnBK/XYpGnYLOj4nJBaRiw52Ibe7YxlDaAeRECvA786zCuExhxIUJ2K7nHMrBw==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "opencollective",
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -18,10 +18,6 @@ async def log_event(request: Request):
|
|||
|
||||
return {"message": "Request received successfully"}
|
||||
except Exception as e:
|
||||
print(f"Error processing request: {str(e)}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail="Internal Server Error")
|
||||
|
||||
|
||||
|
|
|
@ -120,6 +120,5 @@ class GenericAPILogger:
|
|||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.debug(f"Generic - {str(e)}\n{traceback.format_exc()}")
|
||||
verbose_logger.error(f"Generic - {str(e)}\n{traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -82,7 +82,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
|
|||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(traceback.format_exc())
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
|
|
|
@ -118,4 +118,4 @@ class _ENTERPRISE_BlockedUserList(CustomLogger):
|
|||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(traceback.format_exc())
|
||||
|
|
|
@ -92,7 +92,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
},
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool:
|
||||
|
|
|
@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*
|
|||
### INIT VARIABLES ###
|
||||
import threading, requests, os
|
||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.caching import Cache
|
||||
from litellm._logging import (
|
||||
set_verbose,
|
||||
|
@ -60,6 +60,7 @@ _async_failure_callback: List[Callable] = (
|
|||
pre_call_rules: List[Callable] = []
|
||||
post_call_rules: List[Callable] = []
|
||||
turn_off_message_logging: Optional[bool] = False
|
||||
redact_messages_in_exceptions: Optional[bool] = False
|
||||
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
||||
## end of callbacks #############
|
||||
|
||||
|
@ -233,6 +234,7 @@ max_end_user_budget: Optional[float] = None
|
|||
#### RELIABILITY ####
|
||||
request_timeout: float = 6000
|
||||
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
|
||||
module_level_client = HTTPHandler(timeout=request_timeout)
|
||||
num_retries: Optional[int] = None # per model endpoint
|
||||
default_fallbacks: Optional[List] = None
|
||||
fallbacks: Optional[List] = None
|
||||
|
@ -766,7 +768,7 @@ from .llms.sagemaker import SagemakerConfig
|
|||
from .llms.ollama import OllamaConfig
|
||||
from .llms.ollama_chat import OllamaChatConfig
|
||||
from .llms.maritalk import MaritTalkConfig
|
||||
from .llms.bedrock_httpx import AmazonCohereChatConfig
|
||||
from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig
|
||||
from .llms.bedrock import (
|
||||
AmazonTitanConfig,
|
||||
AmazonAI21Config,
|
||||
|
@ -808,6 +810,7 @@ from .exceptions import (
|
|||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
UnprocessableEntityError,
|
||||
InternalServerError,
|
||||
LITELLM_EXCEPTION_TYPES,
|
||||
)
|
||||
from .budget_manager import BudgetManager
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging, os, json
|
||||
from logging import Formatter
|
||||
import traceback
|
||||
|
||||
set_verbose = False
|
||||
json_logs = bool(os.getenv("JSON_LOGS", False))
|
||||
|
|
|
@ -253,7 +253,6 @@ class RedisCache(BaseCache):
|
|||
str(e),
|
||||
value,
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
|
||||
|
@ -313,7 +312,6 @@ class RedisCache(BaseCache):
|
|||
str(e),
|
||||
value,
|
||||
)
|
||||
traceback.print_exc()
|
||||
|
||||
key = self.check_and_fix_namespace(key=key)
|
||||
async with _redis_client as redis_client:
|
||||
|
@ -352,7 +350,6 @@ class RedisCache(BaseCache):
|
|||
str(e),
|
||||
value,
|
||||
)
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||
"""
|
||||
|
@ -413,7 +410,6 @@ class RedisCache(BaseCache):
|
|||
str(e),
|
||||
cache_value,
|
||||
)
|
||||
traceback.print_exc()
|
||||
|
||||
async def batch_cache_write(self, key, value, **kwargs):
|
||||
print_verbose(
|
||||
|
@ -458,7 +454,6 @@ class RedisCache(BaseCache):
|
|||
str(e),
|
||||
value,
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
async def flush_cache_buffer(self):
|
||||
|
@ -495,8 +490,9 @@ class RedisCache(BaseCache):
|
|||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
traceback.print_exc()
|
||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||
verbose_logger.error(
|
||||
"LiteLLM Caching: get() - Got exception from REDIS: ", e
|
||||
)
|
||||
|
||||
def batch_get_cache(self, key_list) -> dict:
|
||||
"""
|
||||
|
@ -646,10 +642,9 @@ class RedisCache(BaseCache):
|
|||
error=e,
|
||||
call_type="sync_ping",
|
||||
)
|
||||
print_verbose(
|
||||
verbose_logger.error(
|
||||
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
async def ping(self) -> bool:
|
||||
|
@ -683,10 +678,9 @@ class RedisCache(BaseCache):
|
|||
call_type="async_ping",
|
||||
)
|
||||
)
|
||||
print_verbose(
|
||||
verbose_logger.error(
|
||||
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
async def delete_cache_keys(self, keys):
|
||||
|
@ -1138,22 +1132,23 @@ class S3Cache(BaseCache):
|
|||
cached_response = ast.literal_eval(cached_response)
|
||||
if type(cached_response) is not dict:
|
||||
cached_response = dict(cached_response)
|
||||
print_verbose(
|
||||
verbose_logger.debug(
|
||||
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
|
||||
return cached_response
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
print_verbose(
|
||||
verbose_logger.error(
|
||||
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# NON blocking - notify users S3 is throwing an exception
|
||||
traceback.print_exc()
|
||||
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
|
||||
verbose_logger.error(
|
||||
f"S3 Caching: get_cache() - Got exception from S3: {e}"
|
||||
)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
@ -1234,8 +1229,7 @@ class DualCache(BaseCache):
|
|||
|
||||
return result
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
raise e
|
||||
|
||||
def get_cache(self, key, local_only: bool = False, **kwargs):
|
||||
|
@ -1262,7 +1256,7 @@ class DualCache(BaseCache):
|
|||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
|
||||
try:
|
||||
|
@ -1295,7 +1289,7 @@ class DualCache(BaseCache):
|
|||
print_verbose(f"async batch get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
||||
# Try to fetch from in-memory cache first
|
||||
|
@ -1328,7 +1322,7 @@ class DualCache(BaseCache):
|
|||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
async def async_batch_get_cache(
|
||||
self, keys: list, local_only: bool = False, **kwargs
|
||||
|
@ -1368,7 +1362,7 @@ class DualCache(BaseCache):
|
|||
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
print_verbose(
|
||||
|
@ -1381,8 +1375,8 @@ class DualCache(BaseCache):
|
|||
if self.redis_cache is not None and local_only == False:
|
||||
await self.redis_cache.async_set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
|
||||
async def async_batch_set_cache(
|
||||
self, cache_list: list, local_only: bool = False, **kwargs
|
||||
|
@ -1404,8 +1398,8 @@ class DualCache(BaseCache):
|
|||
cache_list=cache_list, ttl=kwargs.get("ttl", None)
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
|
||||
async def async_increment_cache(
|
||||
self, key, value: float, local_only: bool = False, **kwargs
|
||||
|
@ -1429,8 +1423,8 @@ class DualCache(BaseCache):
|
|||
|
||||
return result
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
def flush_cache(self):
|
||||
|
@ -1846,8 +1840,8 @@ class Cache:
|
|||
)
|
||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
async def async_add_cache(self, result, *args, **kwargs):
|
||||
|
@ -1864,8 +1858,8 @@ class Cache:
|
|||
)
|
||||
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
|
||||
async def async_add_cache_pipeline(self, result, *args, **kwargs):
|
||||
"""
|
||||
|
@ -1897,8 +1891,8 @@ class Cache:
|
|||
)
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
|
||||
async def batch_cache_write(self, result, *args, **kwargs):
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
|
|
|
@ -638,6 +638,7 @@ LITELLM_EXCEPTION_TYPES = [
|
|||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
OpenAIError,
|
||||
InternalServerError,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -169,6 +169,5 @@ class AISpendLogger:
|
|||
|
||||
print_verbose(f"AISpend Logging - final data object: {data}")
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"AISpend Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -178,6 +178,5 @@ class BerriSpendLogger:
|
|||
print_verbose(f"BerriSpend Logging - final data object: {data}")
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"BerriSpend Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -297,6 +297,5 @@ class ClickhouseLogger:
|
|||
# make request to endpoint with payload
|
||||
verbose_logger.debug(f"Clickhouse Logger - final response = {response}")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.debug(f"Clickhouse - {str(e)}\n{traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -115,7 +115,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
)
|
||||
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
|
||||
async def async_log_input_event(
|
||||
|
@ -130,7 +129,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
)
|
||||
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
|
||||
def log_event(
|
||||
|
@ -146,7 +144,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
end_time,
|
||||
)
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
||||
|
@ -163,6 +160,5 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
end_time,
|
||||
)
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -134,7 +134,6 @@ class DataDogLogger:
|
|||
f"Datadog Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.debug(
|
||||
f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
|
|
@ -85,6 +85,5 @@ class DyanmoDBLogger:
|
|||
)
|
||||
return response
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -112,6 +112,5 @@ class HeliconeLogger:
|
|||
)
|
||||
print_verbose(f"Helicone Logging - Error {response.text}")
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -220,9 +220,11 @@ class LangFuseLogger:
|
|||
verbose_logger.info(f"Langfuse Layer Logging - logging success")
|
||||
|
||||
return {"trace_id": trace_id, "generation_id": generation_id}
|
||||
except:
|
||||
traceback.print_exc()
|
||||
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"Langfuse Layer Error(): Exception occured - {}".format(str(e))
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
return {"trace_id": None, "generation_id": None}
|
||||
|
||||
async def _async_log_event(
|
||||
|
|
|
@ -44,7 +44,9 @@ class LangsmithLogger:
|
|||
print_verbose(
|
||||
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
|
||||
)
|
||||
langsmith_base_url = os.getenv("LANGSMITH_BASE_URL", "https://api.smith.langchain.com")
|
||||
langsmith_base_url = os.getenv(
|
||||
"LANGSMITH_BASE_URL", "https://api.smith.langchain.com"
|
||||
)
|
||||
|
||||
try:
|
||||
print_verbose(
|
||||
|
@ -89,9 +91,7 @@ class LangsmithLogger:
|
|||
}
|
||||
|
||||
url = f"{langsmith_base_url}/runs"
|
||||
print_verbose(
|
||||
f"Langsmith Logging - About to send data to {url} ..."
|
||||
)
|
||||
print_verbose(f"Langsmith Logging - About to send data to {url} ...")
|
||||
response = requests.post(
|
||||
url=url,
|
||||
json=data,
|
||||
|
@ -106,6 +106,5 @@ class LangsmithLogger:
|
|||
f"Langsmith Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"Langsmith Layer Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -171,7 +171,6 @@ class LogfireLogger:
|
|||
f"Logfire Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.debug(
|
||||
f"Logfire Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@ def parse_usage(usage):
|
|||
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
||||
}
|
||||
|
||||
|
||||
def parse_tool_calls(tool_calls):
|
||||
if tool_calls is None:
|
||||
return None
|
||||
|
@ -26,13 +27,13 @@ def parse_tool_calls(tool_calls):
|
|||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return serialized
|
||||
|
||||
|
||||
return [clean_tool_call(tool_call) for tool_call in tool_calls]
|
||||
|
||||
|
||||
|
||||
def parse_messages(input):
|
||||
|
||||
|
@ -176,6 +177,5 @@ class LunaryLogger:
|
|||
)
|
||||
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"Lunary Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -14,8 +14,11 @@ if TYPE_CHECKING:
|
|||
else:
|
||||
Span = Any
|
||||
|
||||
LITELLM_TRACER_NAME = "litellm"
|
||||
LITELLM_RESOURCE = {"service.name": "litellm"}
|
||||
|
||||
LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm")
|
||||
LITELLM_RESOURCE = {
|
||||
"service.name": os.getenv("OTEL_SERVICE_NAME", "litellm"),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -109,8 +109,8 @@ class PrometheusLogger:
|
|||
end_user_id, user_api_key, model, user_api_team, user_id
|
||||
).inc()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.debug(
|
||||
f"prometheus Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
verbose_logger.error(
|
||||
"prometheus Layer Error(): Exception occured - {}".format(str(e))
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
|
|
@ -180,6 +180,5 @@ class S3Logger:
|
|||
print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
|
||||
return response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.debug(f"s3 Layer Error - {str(e)}\n{traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -326,8 +326,8 @@ class SlackAlerting(CustomLogger):
|
|||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
if litellm.turn_off_message_logging:
|
||||
messages = "Message not logged. `litellm.turn_off_message_logging=True`."
|
||||
if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
|
||||
messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
|
||||
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
||||
if time_difference_float > self.alerting_threshold:
|
||||
|
@ -567,9 +567,12 @@ class SlackAlerting(CustomLogger):
|
|||
except:
|
||||
messages = ""
|
||||
|
||||
if litellm.turn_off_message_logging:
|
||||
if (
|
||||
litellm.turn_off_message_logging
|
||||
or litellm.redact_messages_in_exceptions
|
||||
):
|
||||
messages = (
|
||||
"Message not logged. `litellm.turn_off_message_logging=True`."
|
||||
"Message not logged. litellm.redact_messages_in_exceptions=True"
|
||||
)
|
||||
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
|
||||
else:
|
||||
|
|
|
@ -110,6 +110,5 @@ class Supabase:
|
|||
)
|
||||
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -217,6 +217,5 @@ class WeightsBiasesLogger:
|
|||
f"W&B Logging Logging - final response object: {response_obj}"
|
||||
)
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -38,6 +38,8 @@ from .prompt_templates.factory import (
|
|||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
contains_tag,
|
||||
_bedrock_converse_messages_pt,
|
||||
_bedrock_tools_pt,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from .base import BaseLLM
|
||||
|
@ -45,6 +47,11 @@ import httpx # type: ignore
|
|||
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
||||
from litellm.types.llms.bedrock import *
|
||||
import urllib.parse
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
)
|
||||
|
||||
|
||||
class AmazonCohereChatConfig:
|
||||
|
@ -118,6 +125,8 @@ class AmazonCohereChatConfig:
|
|||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
|
@ -176,6 +185,37 @@ async def make_call(
|
|||
return completion_stream
|
||||
|
||||
|
||||
def make_sync_call(
|
||||
client: Optional[HTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
):
|
||||
if client is None:
|
||||
client = HTTPHandler() # Create a new client if none provided
|
||||
|
||||
response = client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(status_code=response.status_code, message=response.read())
|
||||
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=completion_stream, # Pass the completion stream for logging
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class BedrockLLM(BaseLLM):
|
||||
"""
|
||||
Example call
|
||||
|
@ -1000,12 +1040,12 @@ class BedrockLLM(BaseLLM):
|
|||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client # type: ignore
|
||||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
response = await client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
|
@ -1069,6 +1109,745 @@ class BedrockLLM(BaseLLM):
|
|||
return super().embedding(*args, **kwargs)
|
||||
|
||||
|
||||
class AmazonConverseConfig:
|
||||
"""
|
||||
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
||||
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
|
||||
"""
|
||||
|
||||
maxTokens: Optional[int]
|
||||
stopSequences: Optional[List[str]]
|
||||
temperature: Optional[int]
|
||||
topP: Optional[int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokens: Optional[int] = None,
|
||||
stopSequences: Optional[List[str]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
supported_params = [
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
if (
|
||||
model.startswith("anthropic")
|
||||
or model.startswith("mistral")
|
||||
or model.startswith("cohere")
|
||||
):
|
||||
supported_params.append("tools")
|
||||
|
||||
if model.startswith("anthropic") or model.startswith("mistral"):
|
||||
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
supported_params.append("tool_choice")
|
||||
|
||||
return supported_params
|
||||
|
||||
def map_tool_choice_values(
|
||||
self, model: str, tool_choice: Union[str, dict], drop_params: bool
|
||||
) -> Optional[ToolChoiceValuesBlock]:
|
||||
if tool_choice == "none":
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
return None
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
tool_choice
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
elif tool_choice == "required":
|
||||
return ToolChoiceValuesBlock(any={})
|
||||
elif tool_choice == "auto":
|
||||
return ToolChoiceValuesBlock(auto={})
|
||||
elif isinstance(tool_choice, dict):
|
||||
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
specific_tool = SpecificToolChoiceBlock(
|
||||
name=tool_choice.get("function", {}).get("name", "")
|
||||
)
|
||||
return ToolChoiceValuesBlock(tool=specific_tool)
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
tool_choice
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
def get_supported_image_types(self) -> List[str]:
|
||||
return ["png", "jpeg", "gif", "webp"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
model: str,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["maxTokens"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
if isinstance(value, str):
|
||||
value = [value]
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["topP"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "tool_choice":
|
||||
_tool_choice_value = self.map_tool_choice_values(
|
||||
model=model, tool_choice=value, drop_params=drop_params # type: ignore
|
||||
)
|
||||
if _tool_choice_value is not None:
|
||||
optional_params["tool_choice"] = _tool_choice_value
|
||||
return optional_params
|
||||
|
||||
|
||||
class BedrockConverseLLM(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||
response.text, str(e)
|
||||
),
|
||||
status_code=422,
|
||||
)
|
||||
|
||||
"""
|
||||
Bedrock Response Object has optional message block
|
||||
|
||||
completion_response["output"].get("message", None)
|
||||
|
||||
A message block looks like this (Example 1):
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
(Example 2):
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
|
||||
"name": "top_song",
|
||||
"input": {
|
||||
"sign": "WZPZ"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
message: Optional[MessageBlock] = completion_response["output"]["message"]
|
||||
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||
content_str = ""
|
||||
tools: List[ChatCompletionToolCallChunk] = []
|
||||
if message is not None:
|
||||
for content in message["content"]:
|
||||
"""
|
||||
- Content is either a tool response or text
|
||||
"""
|
||||
if "text" in content:
|
||||
content_str += content["text"]
|
||||
if "toolUse" in content:
|
||||
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||
name=content["toolUse"]["name"],
|
||||
arguments=json.dumps(content["toolUse"]["input"]),
|
||||
)
|
||||
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||
id=content["toolUse"]["toolUseId"],
|
||||
type="function",
|
||||
function=_function_chunk,
|
||||
)
|
||||
tools.append(_tool_response_chunk)
|
||||
chat_completion_message["content"] = content_str
|
||||
chat_completion_message["tool_calls"] = tools
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
input_tokens = completion_response["usage"]["inputTokens"]
|
||||
output_tokens = completion_response["usage"]["outputTokens"]
|
||||
total_tokens = completion_response["usage"]["totalTokens"]
|
||||
|
||||
model_response.choices = [
|
||||
litellm.Choices(
|
||||
finish_reason=map_finish_reason(completion_response["stopReason"]),
|
||||
index=0,
|
||||
message=litellm.Message(**chat_completion_message),
|
||||
)
|
||||
]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def encode_model_id(self, model_id: str) -> str:
|
||||
"""
|
||||
Double encode the model ID to ensure it matches the expected double-encoded format.
|
||||
Args:
|
||||
model_id (str): The model ID to encode.
|
||||
Returns:
|
||||
str: The double-encoded model ID.
|
||||
"""
|
||||
return urllib.parse.quote(model_id, safe="")
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Return a boto3.Credentials object
|
||||
"""
|
||||
import boto3
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
params_to_check: List[Optional[str]] = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
_v = get_secret(param)
|
||||
if _v is not None and isinstance(_v, str):
|
||||
params_to_check[i] = _v
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
) = params_to_check
|
||||
|
||||
### CHECK STS ###
|
||||
if (
|
||||
aws_web_identity_token is not None
|
||||
and aws_role_name is not None
|
||||
and aws_session_name is not None
|
||||
):
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise BedrockError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
|
||||
return session.get_credentials()
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
# Extract the credentials from the response and convert to Session Credentials
|
||||
sts_credentials = sts_response["Credentials"]
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
credentials = Credentials(
|
||||
access_key=sts_credentials["AccessKeyId"],
|
||||
secret_key=sts_credentials["SecretAccessKey"],
|
||||
token=sts_credentials["SessionToken"],
|
||||
)
|
||||
return credentials
|
||||
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
client = boto3.Session(profile_name=aws_profile_name)
|
||||
|
||||
return client.get_credentials()
|
||||
else:
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
|
||||
return session.get_credentials()
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException as e:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream if isinstance(stream, bool) else False,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
):
|
||||
try:
|
||||
import boto3
|
||||
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
modelId = optional_params.pop("model_id", None)
|
||||
if modelId is not None:
|
||||
modelId = self.encode_model_id(model_id=modelId)
|
||||
else:
|
||||
modelId = model
|
||||
|
||||
provider = model.split(".")[0]
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if litellm_aws_region_name is not None and isinstance(
|
||||
litellm_aws_region_name, str
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if standard_aws_region_name is not None and isinstance(
|
||||
standard_aws_region_name, str
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url = ""
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_indices = []
|
||||
system_content_blocks: List[SystemContentBlock] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
_system_content_block = SystemContentBlock(text=message["content"])
|
||||
system_content_blocks.append(_system_content_block)
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
additional_request_keys = []
|
||||
additional_request_params = {}
|
||||
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
|
||||
supported_tool_call_params = ["tools", "tool_choice"]
|
||||
## TRANSFORMATION ##
|
||||
# send all model-specific params in 'additional_request_params'
|
||||
for k, v in inference_params.items():
|
||||
if (
|
||||
k not in supported_converse_params
|
||||
and k not in supported_tool_call_params
|
||||
):
|
||||
additional_request_params[k] = v
|
||||
additional_request_keys.append(k)
|
||||
for key in additional_request_keys:
|
||||
inference_params.pop(key, None)
|
||||
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages
|
||||
)
|
||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||
inference_params.pop("tools", [])
|
||||
)
|
||||
bedrock_tool_config: Optional[ToolConfigBlock] = None
|
||||
if len(bedrock_tools) > 0:
|
||||
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
|
||||
"tool_choice", None
|
||||
)
|
||||
bedrock_tool_config = ToolConfigBlock(
|
||||
tools=bedrock_tools,
|
||||
)
|
||||
if tool_choice_values is not None:
|
||||
bedrock_tool_config["toolChoice"] = tool_choice_values
|
||||
|
||||
_data: RequestObject = {
|
||||
"messages": bedrock_messages,
|
||||
"additionalModelRequestFields": additional_request_params,
|
||||
"system": system_content_blocks,
|
||||
"inferenceConfig": InferenceConfig(**inference_params),
|
||||
}
|
||||
if bedrock_tool_config is not None:
|
||||
_data["toolConfig"] = bedrock_tool_config
|
||||
data = json.dumps(_data)
|
||||
## COMPLETION CALL
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
prepped = request.prepare()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepped.url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream is True and provider != "ai21":
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=True,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
) # type: ignore
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream, # type: ignore
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
) # type: ignore
|
||||
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_sync_call,
|
||||
client=None,
|
||||
api_base=prepped.url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=streaming_response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
return streaming_response
|
||||
### COMPLETION
|
||||
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
try:
|
||||
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
|
||||
def get_response_stream_shape():
|
||||
from botocore.model import ServiceModel
|
||||
from botocore.loaders import Loader
|
||||
|
@ -1086,6 +1865,31 @@ class AWSEventStreamDecoder:
|
|||
self.model = model
|
||||
self.parser = EventStreamJSONParser()
|
||||
|
||||
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||
text = ""
|
||||
tool_str = ""
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ConverseTokenUsageBlock] = None
|
||||
if "delta" in chunk_data:
|
||||
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
|
||||
if "text" in delta_obj:
|
||||
text = delta_obj["text"]
|
||||
elif "toolUse" in delta_obj:
|
||||
tool_str = delta_obj["toolUse"]["input"]
|
||||
elif "stopReason" in chunk_data:
|
||||
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
||||
elif "usage" in chunk_data:
|
||||
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
|
||||
response = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_str=tool_str,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
return response
|
||||
|
||||
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||
text = ""
|
||||
is_finished = False
|
||||
|
@ -1098,19 +1902,8 @@ class AWSEventStreamDecoder:
|
|||
is_finished = True
|
||||
finish_reason = "stop"
|
||||
######## bedrock.anthropic mappings ###############
|
||||
elif "completion" in chunk_data: # not claude-3
|
||||
text = chunk_data["completion"] # bedrock.anthropic
|
||||
stop_reason = chunk_data.get("stop_reason", None)
|
||||
if stop_reason != None:
|
||||
is_finished = True
|
||||
finish_reason = stop_reason
|
||||
elif "delta" in chunk_data:
|
||||
if chunk_data["delta"].get("text", None) is not None:
|
||||
text = chunk_data["delta"]["text"]
|
||||
stop_reason = chunk_data["delta"].get("stop_reason", None)
|
||||
if stop_reason != None:
|
||||
is_finished = True
|
||||
finish_reason = stop_reason
|
||||
return self.converse_chunk_parser(chunk_data=chunk_data)
|
||||
######## bedrock.mistral mappings ###############
|
||||
elif "outputs" in chunk_data:
|
||||
if (
|
||||
|
@ -1137,11 +1930,11 @@ class AWSEventStreamDecoder:
|
|||
is_finished = True
|
||||
finish_reason = chunk_data["completionReason"]
|
||||
return GenericStreamingChunk(
|
||||
**{
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
text=text,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
tool_str="",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||
|
@ -1178,9 +1971,14 @@ class AWSEventStreamDecoder:
|
|||
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
||||
if response_dict["status_code"] != 200:
|
||||
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
||||
if "chunk" in parsed_response:
|
||||
chunk = parsed_response.get("chunk")
|
||||
if not chunk:
|
||||
return None
|
||||
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
||||
else:
|
||||
chunk = response_dict.get("body")
|
||||
if not chunk:
|
||||
return None
|
||||
|
||||
chunk = parsed_response.get("chunk")
|
||||
if not chunk:
|
||||
return None
|
||||
|
||||
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
||||
return chunk.decode() # type: ignore[no-any-return]
|
||||
|
|
|
@ -156,12 +156,13 @@ class HTTPHandler:
|
|||
self,
|
||||
url: str,
|
||||
data: Optional[Union[dict, str]] = None,
|
||||
json: Optional[Union[dict, str]] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, params=params, headers=headers # type: ignore
|
||||
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||
)
|
||||
response = self.client.send(req, stream=stream)
|
||||
return response
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import os, types, traceback, copy, asyncio
|
||||
import json
|
||||
from enum import Enum
|
||||
import types
|
||||
import traceback
|
||||
import copy
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
|
||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||
import litellm
|
||||
import sys, httpx
|
||||
import httpx
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
|
||||
from packaging.version import Version
|
||||
from litellm import verbose_logger
|
||||
|
||||
|
||||
class GeminiError(Exception):
|
||||
|
@ -264,7 +265,8 @@ def completion(
|
|||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
raise GeminiError(
|
||||
message=traceback.format_exc(), status_code=response.status_code
|
||||
)
|
||||
|
@ -356,7 +358,8 @@ async def async_completion(
|
|||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
raise GeminiError(
|
||||
message=traceback.format_exc(), status_code=response.status_code
|
||||
)
|
||||
|
|
|
@ -7,6 +7,7 @@ import litellm
|
|||
from litellm.types.utils import ProviderField
|
||||
import httpx, aiohttp, asyncio # type: ignore
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm import verbose_logger
|
||||
|
||||
|
||||
class OllamaError(Exception):
|
||||
|
@ -137,6 +138,7 @@ class OllamaConfig:
|
|||
)
|
||||
]
|
||||
|
||||
|
||||
def get_supported_openai_params(
|
||||
self,
|
||||
):
|
||||
|
@ -151,10 +153,12 @@ class OllamaConfig:
|
|||
"response_format",
|
||||
]
|
||||
|
||||
|
||||
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
|
||||
# and convert to jpeg if necessary.
|
||||
def _convert_image(image):
|
||||
import base64, io
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
except:
|
||||
|
@ -404,7 +408,13 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
|
|||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"LiteLLM.ollama.py::ollama_async_streaming(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
|
||||
raise e
|
||||
|
||||
|
||||
|
@ -468,7 +478,12 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
|||
)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"LiteLLM.ollama.py::ollama_acompletion(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
from itertools import chain
|
||||
import requests, types, time
|
||||
import json, uuid
|
||||
import requests
|
||||
import types
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
import traceback
|
||||
from typing import Optional
|
||||
from litellm import verbose_logger
|
||||
import litellm
|
||||
import httpx, aiohttp, asyncio
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
import httpx
|
||||
import aiohttp
|
||||
|
||||
|
||||
class OllamaError(Exception):
|
||||
|
@ -299,7 +303,10 @@ def get_ollama_response(
|
|||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
||||
"function": {
|
||||
"name": function_call["name"],
|
||||
"arguments": json.dumps(function_call["arguments"]),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
|
@ -307,7 +314,9 @@ def get_ollama_response(
|
|||
model_response["choices"][0]["message"] = message
|
||||
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||
else:
|
||||
model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
|
||||
model_response["choices"][0]["message"]["content"] = response_json["message"][
|
||||
"content"
|
||||
]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + model
|
||||
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
|
||||
|
@ -361,7 +370,10 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
|
|||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
||||
"function": {
|
||||
"name": function_call["name"],
|
||||
"arguments": json.dumps(function_call["arguments"]),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
|
@ -410,9 +422,10 @@ async def ollama_async_streaming(
|
|||
first_chunk_content = first_chunk.choices[0].delta.content or ""
|
||||
response_content = first_chunk_content + "".join(
|
||||
[
|
||||
chunk.choices[0].delta.content
|
||||
async for chunk in streamwrapper
|
||||
if chunk.choices[0].delta.content]
|
||||
chunk.choices[0].delta.content
|
||||
async for chunk in streamwrapper
|
||||
if chunk.choices[0].delta.content
|
||||
]
|
||||
)
|
||||
function_call = json.loads(response_content)
|
||||
delta = litellm.utils.Delta(
|
||||
|
@ -420,7 +433,10 @@ async def ollama_async_streaming(
|
|||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
||||
"function": {
|
||||
"name": function_call["name"],
|
||||
"arguments": json.dumps(function_call["arguments"]),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
|
@ -433,7 +449,8 @@ async def ollama_async_streaming(
|
|||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error("LiteLLM.gemini(): Exception occured - {}".format(str(e)))
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
|
||||
|
||||
async def ollama_acompletion(
|
||||
|
@ -483,7 +500,10 @@ async def ollama_acompletion(
|
|||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
||||
"function": {
|
||||
"name": function_call["name"],
|
||||
"arguments": json.dumps(function_call["arguments"]),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
|
@ -491,7 +511,9 @@ async def ollama_acompletion(
|
|||
model_response["choices"][0]["message"] = message
|
||||
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||
else:
|
||||
model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
|
||||
model_response["choices"][0]["message"]["content"] = response_json[
|
||||
"message"
|
||||
]["content"]
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama_chat/" + data["model"]
|
||||
|
@ -509,5 +531,9 @@ async def ollama_acompletion(
|
|||
)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"LiteLLM.ollama_acompletion(): Exception occured - {}".format(str(e))
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
|
||||
raise e
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import os, types, traceback, copy
|
||||
import json
|
||||
from enum import Enum
|
||||
import types
|
||||
import traceback
|
||||
import copy
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
|
||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||
import litellm
|
||||
import sys, httpx
|
||||
import httpx
|
||||
from litellm import verbose_logger
|
||||
|
||||
|
||||
class PalmError(Exception):
|
||||
|
@ -165,7 +166,10 @@ def completion(
|
|||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.llms.palm.py::completion(): Exception occured - {}".format(str(e))
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
raise PalmError(
|
||||
message=traceback.format_exc(), status_code=response.status_code
|
||||
)
|
||||
|
|
|
@ -3,14 +3,7 @@ import requests, traceback
|
|||
import json, re, xml.etree.ElementTree as ET
|
||||
from jinja2 import Template, exceptions, meta, BaseLoader
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
)
|
||||
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
|
||||
import litellm
|
||||
import litellm.types
|
||||
from litellm.types.completion import (
|
||||
|
@ -24,7 +17,7 @@ from litellm.types.completion import (
|
|||
import litellm.types.llms
|
||||
from litellm.types.llms.anthropic import *
|
||||
import uuid
|
||||
|
||||
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
|
||||
import litellm.types.llms.vertex_ai
|
||||
|
||||
|
||||
|
@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url):
|
|||
try:
|
||||
from PIL import Image
|
||||
except:
|
||||
raise Exception(
|
||||
"gemini image conversion failed please run `pip install Pillow`"
|
||||
)
|
||||
raise Exception("image conversion failed please run `pip install Pillow`")
|
||||
from io import BytesIO
|
||||
|
||||
try:
|
||||
|
@ -1613,6 +1604,380 @@ def azure_text_pt(messages: list):
|
|||
return prompt
|
||||
|
||||
|
||||
###### AMAZON BEDROCK #######
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
ToolResultContentBlock as BedrockToolResultContentBlock,
|
||||
ToolResultBlock as BedrockToolResultBlock,
|
||||
ToolConfigBlock as BedrockToolConfigBlock,
|
||||
ToolUseBlock as BedrockToolUseBlock,
|
||||
ImageSourceBlock as BedrockImageSourceBlock,
|
||||
ImageBlock as BedrockImageBlock,
|
||||
ContentBlock as BedrockContentBlock,
|
||||
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
|
||||
ToolSpecBlock as BedrockToolSpecBlock,
|
||||
ToolBlock as BedrockToolBlock,
|
||||
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
|
||||
)
|
||||
|
||||
|
||||
def get_image_details(image_url) -> Tuple[str, str]:
|
||||
try:
|
||||
import base64
|
||||
|
||||
# Send a GET request to the image URL
|
||||
response = requests.get(image_url)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
|
||||
# Check the response's content type to ensure it is an image
|
||||
content_type = response.headers.get("content-type")
|
||||
if not content_type or "image" not in content_type:
|
||||
raise ValueError(
|
||||
f"URL does not point to a valid image (content-type: {content_type})"
|
||||
)
|
||||
|
||||
# Convert the image content to base64 bytes
|
||||
base64_bytes = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
# Get mime-type
|
||||
mime_type = content_type.split("/")[
|
||||
1
|
||||
] # Extract mime-type from content-type header
|
||||
|
||||
return base64_bytes, mime_type
|
||||
|
||||
except requests.RequestException as e:
|
||||
raise Exception(f"Request failed: {e}")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
|
||||
if "base64" in image_url:
|
||||
# Case 1: Images with base64 encoding
|
||||
import base64, re
|
||||
|
||||
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||
image_metadata, img_without_base_64 = image_url.split(",")
|
||||
|
||||
# read mime_type from img_without_base_64=data:image/jpeg;base64
|
||||
# Extract MIME type using regular expression
|
||||
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
|
||||
if mime_type_match:
|
||||
mime_type = mime_type_match.group(1)
|
||||
image_format = mime_type.split("/")[1]
|
||||
else:
|
||||
mime_type = "image/jpeg"
|
||||
image_format = "jpeg"
|
||||
_blob = BedrockImageSourceBlock(bytes=img_without_base_64)
|
||||
supported_image_formats = (
|
||||
litellm.AmazonConverseConfig().get_supported_image_types()
|
||||
)
|
||||
if image_format in supported_image_formats:
|
||||
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
|
||||
else:
|
||||
# Handle the case when the image format is not supported
|
||||
raise ValueError(
|
||||
"Unsupported image format: {}. Supported formats: {}".format(
|
||||
image_format, supported_image_formats
|
||||
)
|
||||
)
|
||||
elif "https:/" in image_url:
|
||||
# Case 2: Images with direct links
|
||||
image_bytes, image_format = get_image_details(image_url)
|
||||
_blob = BedrockImageSourceBlock(bytes=image_bytes)
|
||||
supported_image_formats = (
|
||||
litellm.AmazonConverseConfig().get_supported_image_types()
|
||||
)
|
||||
if image_format in supported_image_formats:
|
||||
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
|
||||
else:
|
||||
# Handle the case when the image format is not supported
|
||||
raise ValueError(
|
||||
"Unsupported image format: {}. Supported formats: {}".format(
|
||||
image_format, supported_image_formats
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported image type. Expected either image url or base64 encoded string - \
|
||||
e.g. 'data:image/jpeg;base64,<base64-encoded-string>'"
|
||||
)
|
||||
|
||||
|
||||
def _convert_to_bedrock_tool_call_invoke(
|
||||
tool_calls: list,
|
||||
) -> List[BedrockContentBlock]:
|
||||
"""
|
||||
OpenAI tool invokes:
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"""
|
||||
"""
|
||||
Bedrock tool invokes:
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"toolUse": {
|
||||
"input": {"location": "Boston, MA", ..},
|
||||
"name": "get_current_weather",
|
||||
"toolUseId": "call_abc123"
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
"""
|
||||
- json.loads argument
|
||||
- extract name
|
||||
- extract id
|
||||
"""
|
||||
|
||||
try:
|
||||
_parts_list: List[BedrockContentBlock] = []
|
||||
for tool in tool_calls:
|
||||
if "function" in tool:
|
||||
id = tool["id"]
|
||||
name = tool["function"].get("name", "")
|
||||
arguments = tool["function"].get("arguments", "")
|
||||
arguments_dict = json.loads(arguments)
|
||||
bedrock_tool = BedrockToolUseBlock(
|
||||
input=arguments_dict, name=name, toolUseId=id
|
||||
)
|
||||
bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool)
|
||||
_parts_list.append(bedrock_content_block)
|
||||
return _parts_list
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format(
|
||||
tool_calls, str(e)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _convert_to_bedrock_tool_call_result(
|
||||
message: dict,
|
||||
) -> BedrockMessageBlock:
|
||||
"""
|
||||
OpenAI message with a tool result looks like:
|
||||
{
|
||||
"tool_call_id": "tool_1",
|
||||
"role": "tool",
|
||||
"name": "get_current_weather",
|
||||
"content": "function result goes here",
|
||||
},
|
||||
|
||||
OpenAI message with a function call result looks like:
|
||||
{
|
||||
"role": "function",
|
||||
"name": "get_current_weather",
|
||||
"content": "function result goes here",
|
||||
}
|
||||
"""
|
||||
"""
|
||||
Bedrock result looks like this:
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q",
|
||||
"content": [
|
||||
{
|
||||
"json": {
|
||||
"song": "Elemental Hotel",
|
||||
"artist": "8 Storey Hike"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
"""
|
||||
-
|
||||
"""
|
||||
content = message.get("content", "")
|
||||
name = message.get("name", "")
|
||||
id = message.get("tool_call_id", str(uuid.uuid4()))
|
||||
|
||||
tool_result_content_block = BedrockToolResultContentBlock(text=content)
|
||||
tool_result = BedrockToolResultBlock(
|
||||
content=[tool_result_content_block],
|
||||
toolUseId=id,
|
||||
)
|
||||
content_block = BedrockContentBlock(toolResult=tool_result)
|
||||
|
||||
return BedrockMessageBlock(role="user", content=[content_block])
|
||||
|
||||
|
||||
def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]:
|
||||
"""
|
||||
Converts given messages from OpenAI format to Bedrock format
|
||||
|
||||
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
|
||||
- Please ensure that function response turn comes immediately after a function call turn
|
||||
"""
|
||||
|
||||
contents: List[BedrockMessageBlock] = []
|
||||
msg_i = 0
|
||||
while msg_i < len(messages):
|
||||
user_content: List[BedrockContentBlock] = []
|
||||
init_msg_i = msg_i
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "user":
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
_parts: List[BedrockContentBlock] = []
|
||||
for element in messages[msg_i]["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
_part = BedrockContentBlock(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
_part = _process_bedrock_converse_image_block( # type: ignore
|
||||
image_url=image_url
|
||||
)
|
||||
_parts.append(BedrockContentBlock(image=_part)) # type: ignore
|
||||
user_content.extend(_parts)
|
||||
else:
|
||||
_part = BedrockContentBlock(text=messages[msg_i]["content"])
|
||||
user_content.append(_part)
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if user_content:
|
||||
contents.append(BedrockMessageBlock(role="user", content=user_content))
|
||||
assistant_content: List[BedrockContentBlock] = []
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
assistants_parts: List[BedrockContentBlock] = []
|
||||
for element in messages[msg_i]["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
assistants_part = BedrockContentBlock(text=element["text"])
|
||||
assistants_parts.append(assistants_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
assistants_part = _process_bedrock_converse_image_block( # type: ignore
|
||||
image_url=image_url
|
||||
)
|
||||
assistants_parts.append(
|
||||
BedrockContentBlock(image=assistants_part) # type: ignore
|
||||
)
|
||||
assistant_content.extend(assistants_parts)
|
||||
elif messages[msg_i].get(
|
||||
"tool_calls", []
|
||||
): # support assistant tool invoke convertion
|
||||
assistant_content.extend(
|
||||
_convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"])
|
||||
)
|
||||
else:
|
||||
assistant_text = (
|
||||
messages[msg_i].get("content") or ""
|
||||
) # either string or none
|
||||
if assistant_text:
|
||||
assistant_content.append(BedrockContentBlock(text=assistant_text))
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if assistant_content:
|
||||
contents.append(
|
||||
BedrockMessageBlock(role="assistant", content=assistant_content)
|
||||
)
|
||||
|
||||
## APPEND TOOL CALL MESSAGES ##
|
||||
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
|
||||
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
|
||||
contents.append(tool_call_result)
|
||||
msg_i += 1
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise Exception(
|
||||
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||
messages[msg_i]
|
||||
)
|
||||
)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
|
||||
"""
|
||||
OpenAI tools looks like:
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
"""
|
||||
Bedrock toolConfig looks like:
|
||||
"tools": [
|
||||
{
|
||||
"toolSpec": {
|
||||
"name": "top_song",
|
||||
"description": "Get the most popular song played on a radio station.",
|
||||
"inputSchema": {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sign": {
|
||||
"type": "string",
|
||||
"description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"sign"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
tool_block_list: List[BedrockToolBlock] = []
|
||||
for tool in tools:
|
||||
parameters = tool.get("function", {}).get("parameters", None)
|
||||
name = tool.get("function", {}).get("name", "")
|
||||
description = tool.get("function", {}).get("description", "")
|
||||
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
|
||||
tool_spec = BedrockToolSpecBlock(
|
||||
inputSchema=tool_input_schema, name=name, description=description
|
||||
)
|
||||
tool_block = BedrockToolBlock(toolSpec=tool_spec)
|
||||
tool_block_list.append(tool_block)
|
||||
|
||||
return tool_block_list
|
||||
|
||||
|
||||
# Function call template
|
||||
def function_call_prompt(messages: list, functions: list):
|
||||
function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:"""
|
||||
|
|
|
@ -12,6 +12,7 @@ from litellm.llms.prompt_templates.factory import (
|
|||
convert_to_gemini_tool_call_result,
|
||||
convert_to_gemini_tool_call_invoke,
|
||||
)
|
||||
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -297,29 +298,31 @@ def _convert_gemini_role(role: str) -> Literal["user", "model"]:
|
|||
|
||||
def _process_gemini_image(image_url: str) -> PartType:
|
||||
try:
|
||||
if ".mp4" in image_url and "gs://" in image_url:
|
||||
# Case 1: Videos with Cloud Storage URIs
|
||||
part_mime = "video/mp4"
|
||||
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
|
||||
return PartType(file_data=_file_data)
|
||||
elif ".pdf" in image_url and "gs://" in image_url:
|
||||
# Case 2: PDF's with Cloud Storage URIs
|
||||
part_mime = "application/pdf"
|
||||
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
|
||||
return PartType(file_data=_file_data)
|
||||
elif "gs://" in image_url:
|
||||
# Case 3: Images with Cloud Storage URIs
|
||||
# The supported MIME types for images include image/png and image/jpeg.
|
||||
part_mime = "image/png" if "png" in image_url else "image/jpeg"
|
||||
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
|
||||
return PartType(file_data=_file_data)
|
||||
# GCS URIs
|
||||
if "gs://" in image_url:
|
||||
# Figure out file type
|
||||
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
|
||||
extension = extension_with_dot[1:] # Ex: "png"
|
||||
|
||||
file_type = get_file_type_from_extension(extension)
|
||||
|
||||
# Validate the file type is supported by Gemini
|
||||
if not is_gemini_1_5_accepted_file_type(file_type):
|
||||
raise Exception(f"File type not supported by gemini - {file_type}")
|
||||
|
||||
mime_type = get_file_mime_type_for_file_type(file_type)
|
||||
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
|
||||
|
||||
return PartType(file_data=file_data)
|
||||
|
||||
# Direct links
|
||||
elif "https:/" in image_url:
|
||||
# Case 4: Images with direct links
|
||||
image = _load_image_from_url(image_url)
|
||||
_blob = BlobType(data=image.data, mime_type=image._mime_type)
|
||||
return PartType(inline_data=_blob)
|
||||
|
||||
# Base64 encoding
|
||||
elif "base64" in image_url:
|
||||
# Case 5: Images with base64 encoding
|
||||
import base64, re
|
||||
|
||||
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||
|
@ -426,112 +429,6 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
|
|||
return contents
|
||||
|
||||
|
||||
def _gemini_vision_convert_messages(messages: list):
|
||||
"""
|
||||
Converts given messages for GPT-4 Vision to Gemini format.
|
||||
|
||||
Args:
|
||||
messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type:
|
||||
- If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt.
|
||||
- If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
|
||||
|
||||
Raises:
|
||||
VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed.
|
||||
Exception: If any other exception occurs during the execution of the function.
|
||||
|
||||
Note:
|
||||
This function is based on the code from the 'gemini/getting-started/intro_gemini_python.ipynb' notebook in the 'generative-ai' repository on GitHub.
|
||||
The supported MIME types for images include 'image/png' and 'image/jpeg'.
|
||||
|
||||
Examples:
|
||||
>>> messages = [
|
||||
... {"content": "Hello, world!"},
|
||||
... {"content": [{"type": "text", "text": "This is a text message."}, {"type": "image_url", "image_url": "example.com/image.png"}]},
|
||||
... ]
|
||||
>>> _gemini_vision_convert_messages(messages)
|
||||
('Hello, world!This is a text message.', [<Part object>, <Part object>])
|
||||
"""
|
||||
try:
|
||||
import vertexai
|
||||
except:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||
)
|
||||
try:
|
||||
from vertexai.preview.language_models import (
|
||||
ChatModel,
|
||||
CodeChatModel,
|
||||
InputOutputTextPair,
|
||||
)
|
||||
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
|
||||
from vertexai.preview.generative_models import (
|
||||
GenerativeModel,
|
||||
Part,
|
||||
GenerationConfig,
|
||||
Image,
|
||||
)
|
||||
|
||||
# given messages for gpt-4 vision, convert them for gemini
|
||||
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
|
||||
prompt = ""
|
||||
images = []
|
||||
for message in messages:
|
||||
if isinstance(message["content"], str):
|
||||
prompt += message["content"]
|
||||
elif isinstance(message["content"], list):
|
||||
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
||||
for element in message["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
prompt += element["text"]
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
images.append(image_url)
|
||||
# processing images passed to gemini
|
||||
processed_images = []
|
||||
for img in images:
|
||||
if "gs://" in img:
|
||||
# Case 1: Images with Cloud Storage URIs
|
||||
# The supported MIME types for images include image/png and image/jpeg.
|
||||
part_mime = "image/png" if "png" in img else "image/jpeg"
|
||||
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
|
||||
processed_images.append(google_clooud_part)
|
||||
elif "https:/" in img:
|
||||
# Case 2: Images with direct links
|
||||
image = _load_image_from_url(img)
|
||||
processed_images.append(image)
|
||||
elif ".mp4" in img and "gs://" in img:
|
||||
# Case 3: Videos with Cloud Storage URIs
|
||||
part_mime = "video/mp4"
|
||||
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
|
||||
processed_images.append(google_clooud_part)
|
||||
elif "base64" in img:
|
||||
# Case 4: Images with base64 encoding
|
||||
import base64, re
|
||||
|
||||
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||
image_metadata, img_without_base_64 = img.split(",")
|
||||
|
||||
# read mime_type from img_without_base_64=data:image/jpeg;base64
|
||||
# Extract MIME type using regular expression
|
||||
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
|
||||
|
||||
if mime_type_match:
|
||||
mime_type = mime_type_match.group(1)
|
||||
else:
|
||||
mime_type = "image/jpeg"
|
||||
decoded_img = base64.b64decode(img_without_base_64)
|
||||
processed_image = Part.from_data(data=decoded_img, mime_type=mime_type)
|
||||
processed_images.append(processed_image)
|
||||
return prompt, processed_images
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
|
||||
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
|
||||
return _cache_key
|
||||
|
@ -647,9 +544,9 @@ def completion(
|
|||
|
||||
prompt = " ".join(
|
||||
[
|
||||
message["content"]
|
||||
message.get("content")
|
||||
for message in messages
|
||||
if isinstance(message["content"], str)
|
||||
if isinstance(message.get("content", None), str)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
|||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.bedrock_httpx import BedrockLLM
|
||||
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
|
||||
from .llms.vertex_httpx import VertexLLM
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
|
@ -122,6 +122,7 @@ huggingface = Huggingface()
|
|||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
bedrock_chat_completion = BedrockLLM()
|
||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||
vertex_chat_completion = VertexLLM()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
@ -364,7 +365,10 @@ async def acompletion(
|
|||
) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
|
||||
return response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.acompletion(): Exception occured - {}".format(str(e))
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
custom_llm_provider = custom_llm_provider or "openai"
|
||||
raise exception_type(
|
||||
model=model,
|
||||
|
@ -477,7 +481,10 @@ def mock_completion(
|
|||
except Exception as e:
|
||||
if isinstance(e, openai.APIError):
|
||||
raise e
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.mock_completion(): Exception occured - {}".format(str(e))
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
raise Exception("Mock completion response failed")
|
||||
|
||||
|
||||
|
@ -2100,22 +2107,40 @@ def completion(
|
|||
logging_obj=logging,
|
||||
)
|
||||
else:
|
||||
response = bedrock_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
acompletion=acompletion,
|
||||
client=client,
|
||||
)
|
||||
if model.startswith("anthropic"):
|
||||
response = bedrock_converse_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
acompletion=acompletion,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
response = bedrock_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
acompletion=acompletion,
|
||||
client=client,
|
||||
)
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
|
@ -4433,7 +4458,10 @@ async def ahealth_check(
|
|||
response = {} # args like remaining ratelimit etc.
|
||||
return response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.ahealth_check(): Exception occured - {}".format(str(e))
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
stack_trace = traceback.format_exc()
|
||||
if isinstance(stack_trace, str):
|
||||
stack_trace = stack_trace[:1000]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
from logging import Formatter
|
||||
import sys
|
||||
|
||||
|
||||
class JsonFormatter(Formatter):
|
||||
|
|
|
@ -56,8 +56,10 @@ router_settings:
|
|||
|
||||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
||||
json_logs: true
|
||||
|
||||
general_settings:
|
||||
alerting: ["email"]
|
||||
key_management_system: "aws_kms"
|
||||
key_management_settings:
|
||||
hosted_keys: ["LITELLM_MASTER_KEY"]
|
||||
|
||||
|
|
|
@ -955,6 +955,7 @@ class KeyManagementSystem(enum.Enum):
|
|||
AZURE_KEY_VAULT = "azure_key_vault"
|
||||
AWS_SECRET_MANAGER = "aws_secret_manager"
|
||||
LOCAL = "local"
|
||||
AWS_KMS = "aws_kms"
|
||||
|
||||
|
||||
class KeyManagementSettings(LiteLLMBase):
|
||||
|
|
|
@ -106,6 +106,40 @@ def common_checks(
|
|||
raise Exception(
|
||||
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
|
||||
)
|
||||
if general_settings.get("enforced_params", None) is not None:
|
||||
# Enterprise ONLY Feature
|
||||
# we already validate if user is premium_user when reading the config
|
||||
# Add an extra premium_usercheck here too, just incase
|
||||
from litellm.proxy.proxy_server import premium_user, CommonProxyErrors
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
"Trying to use `enforced_params`"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
if route in LiteLLMRoutes.openai_routes.value:
|
||||
# loop through each enforced param
|
||||
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
|
||||
for enforced_param in general_settings["enforced_params"]:
|
||||
_enforced_params = enforced_param.split(".")
|
||||
if len(_enforced_params) == 1:
|
||||
if _enforced_params[0] not in request_body:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
|
||||
)
|
||||
elif len(_enforced_params) == 2:
|
||||
# this is a scenario where user requires request['metadata']['generation_name'] to exist
|
||||
if _enforced_params[0] not in request_body:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
|
||||
)
|
||||
if _enforced_params[1] not in request_body[_enforced_params[0]]:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param"
|
||||
)
|
||||
|
||||
pass
|
||||
# 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||
if (
|
||||
litellm.max_budget > 0
|
||||
|
|
|
@ -88,7 +88,7 @@ class _PROXY_AzureContentSafety(
|
|||
verbose_proxy_logger.debug(
|
||||
"Error in Azure Content-Safety: %s", traceback.format_exc()
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise
|
||||
|
||||
result = self._compute_result(response)
|
||||
|
@ -123,7 +123,12 @@ class _PROXY_AzureContentSafety(
|
|||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
|
|
|
@ -94,7 +94,12 @@ class _PROXY_BatchRedisRequests(CustomLogger):
|
|||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
|
||||
async def async_get_cache(self, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
# What this does?
|
||||
## Checks if key is allowed to use the cache controls passed in to the completion() call
|
||||
|
||||
from typing import Optional
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
import json, traceback
|
||||
import traceback
|
||||
|
||||
|
||||
class _PROXY_CacheControlCheck(CustomLogger):
|
||||
|
@ -54,4 +54,9 @@ class _PROXY_CacheControlCheck(CustomLogger):
|
|||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.proxy.hooks.cache_control_check.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Optional
|
||||
from litellm import verbose_logger
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
import json, traceback
|
||||
import traceback
|
||||
|
||||
|
||||
class _PROXY_MaxBudgetLimiter(CustomLogger):
|
||||
|
@ -44,4 +44,9 @@ class _PROXY_MaxBudgetLimiter(CustomLogger):
|
|||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
|
|
|
@ -8,8 +8,8 @@
|
|||
# Tell us how we can improve! - Krrish & Ishaan
|
||||
|
||||
|
||||
from typing import Optional, Literal, Union
|
||||
import litellm, traceback, sys, uuid, json
|
||||
from typing import Optional, Union
|
||||
import litellm, traceback, uuid, json # noqa: E401
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -21,8 +21,8 @@ from litellm.utils import (
|
|||
ImageResponse,
|
||||
StreamingChoices,
|
||||
)
|
||||
from datetime import datetime
|
||||
import aiohttp, asyncio
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
|
||||
class _OPTIONAL_PresidioPIIMasking(CustomLogger):
|
||||
|
@ -138,7 +138,12 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
|
|||
else:
|
||||
raise Exception(f"Invalid anonymizer response: {redacted_text}")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
async def async_pre_call_hook(
|
||||
|
|
|
@ -204,7 +204,12 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
return e.detail["error"]
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
|
|
|
@ -21,7 +21,14 @@ model_list:
|
|||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
alerting: ["slack"]
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["otel"]
|
||||
store_audit_logs: true
|
||||
store_audit_logs: true
|
||||
redact_messages_in_exceptions: True
|
||||
enforced_params:
|
||||
- user
|
||||
- metadata
|
||||
- metadata.generation_name
|
||||
|
||||
|
|
|
@ -111,6 +111,7 @@ from litellm.proxy.utils import (
|
|||
encrypt_value,
|
||||
decrypt_value,
|
||||
_to_ns,
|
||||
get_error_message_str,
|
||||
)
|
||||
from litellm import (
|
||||
CreateBatchRequest,
|
||||
|
@ -120,7 +121,10 @@ from litellm import (
|
|||
CreateFileRequest,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||
load_aws_secret_manager,
|
||||
load_aws_kms,
|
||||
)
|
||||
import pydantic
|
||||
from litellm.proxy._types import *
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
|
@ -133,7 +137,10 @@ from litellm.router import (
|
|||
AssistantsTypedDict,
|
||||
)
|
||||
from litellm.router import ModelInfo as RouterModelInfo
|
||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
from litellm._logging import (
|
||||
verbose_router_logger,
|
||||
verbose_proxy_logger,
|
||||
)
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||
from litellm.proxy.auth.model_checks import (
|
||||
|
@ -1515,7 +1522,12 @@ async def user_api_key_auth(
|
|||
else:
|
||||
raise Exception()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, litellm.BudgetExceededError):
|
||||
raise ProxyException(
|
||||
message=e.message, type="auth_error", param=None, code=400
|
||||
|
@ -2782,10 +2794,12 @@ class ProxyConfig:
|
|||
load_google_kms(use_google_kms=True)
|
||||
elif (
|
||||
key_management_system
|
||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value
|
||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||
):
|
||||
### LOAD FROM AWS SECRET MANAGER ###
|
||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||
load_aws_kms(use_aws_kms=True)
|
||||
else:
|
||||
raise ValueError("Invalid Key Management System selected")
|
||||
key_management_settings = general_settings.get(
|
||||
|
@ -2819,6 +2833,7 @@ class ProxyConfig:
|
|||
master_key = general_settings.get(
|
||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||
)
|
||||
|
||||
if master_key and master_key.startswith("os.environ/"):
|
||||
master_key = litellm.get_secret(master_key)
|
||||
if not isinstance(master_key, str):
|
||||
|
@ -2909,6 +2924,16 @@ class ProxyConfig:
|
|||
)
|
||||
health_check_interval = general_settings.get("health_check_interval", 300)
|
||||
|
||||
## check if user has set a premium feature in general_settings
|
||||
if (
|
||||
general_settings.get("enforced_params") is not None
|
||||
and premium_user is not True
|
||||
):
|
||||
raise ValueError(
|
||||
"Trying to use `enforced_params`"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
router_params: dict = {
|
||||
"cache_responses": litellm.cache
|
||||
!= None, # cache if user passed in cache values
|
||||
|
@ -3522,7 +3547,12 @@ async def generate_key_helper_fn(
|
|||
)
|
||||
key_data["token_id"] = getattr(create_key_response, "token", None)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.generate_key_helper_fn(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
raise HTTPException(
|
||||
|
@ -3561,7 +3591,12 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
|
|||
else:
|
||||
raise Exception("DB not connected. prisma_client is None")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.delete_verification_token(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
return deleted_tokens
|
||||
|
||||
|
@ -3722,7 +3757,12 @@ async def async_assistants_data_generator(
|
|||
done_message = "[DONE]"
|
||||
yield f"data: {done_message}\n\n"
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.async_assistants_data_generator(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
|
@ -3732,9 +3772,6 @@ async def async_assistants_data_generator(
|
|||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
if user_debug:
|
||||
traceback.print_exc()
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
|
@ -3774,7 +3811,12 @@ async def async_data_generator(
|
|||
done_message = "[DONE]"
|
||||
yield f"data: {done_message}\n\n"
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
|
@ -3784,8 +3826,6 @@ async def async_data_generator(
|
|||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
if user_debug:
|
||||
traceback.print_exc()
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
|
@ -3846,6 +3886,18 @@ def on_backoff(details):
|
|||
verbose_proxy_logger.debug("Backing off... this was attempt # %s", details["tries"])
|
||||
|
||||
|
||||
def giveup(e):
|
||||
result = not (
|
||||
isinstance(e, ProxyException)
|
||||
and getattr(e, "message", None) is not None
|
||||
and isinstance(e.message, str)
|
||||
and "Max parallel request limit reached" in e.message
|
||||
)
|
||||
if result:
|
||||
verbose_proxy_logger.info(json.dumps({"event": "giveup", "exception": str(e)}))
|
||||
return result
|
||||
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db
|
||||
|
@ -4130,12 +4182,8 @@ def model_list(
|
|||
max_tries=litellm.num_retries or 3, # maximum number of retries
|
||||
max_time=litellm.request_timeout or 60, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
giveup=lambda e: not (
|
||||
isinstance(e, ProxyException)
|
||||
and getattr(e, "message", None) is not None
|
||||
and isinstance(e.message, str)
|
||||
and "Max parallel request limit reached" in e.message
|
||||
), # the result of the logical expression is on the second position
|
||||
giveup=giveup,
|
||||
logger=verbose_proxy_logger,
|
||||
)
|
||||
async def chat_completion(
|
||||
request: Request,
|
||||
|
@ -4144,6 +4192,7 @@ async def chat_completion(
|
|||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||
|
||||
data = {}
|
||||
try:
|
||||
body = await request.body()
|
||||
|
@ -4434,7 +4483,12 @@ async def chat_completion(
|
|||
return _chat_response
|
||||
except Exception as e:
|
||||
data["litellm_status"] = "fail" # used for alerting
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}".format(
|
||||
get_error_message_str(e=e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
|
@ -4445,8 +4499,6 @@ async def chat_completion(
|
|||
litellm_debug_info,
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
if user_debug:
|
||||
traceback.print_exc()
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
|
@ -4678,15 +4730,12 @@ async def completion(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.debug("EXCEPTION RAISED IN PROXY MAIN.PY")
|
||||
litellm_debug_info = getattr(e, "litellm_debug_info", "")
|
||||
verbose_proxy_logger.debug(
|
||||
"\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
|
||||
e,
|
||||
litellm_debug_info,
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.completion(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
traceback.print_exc()
|
||||
error_traceback = traceback.format_exc()
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
|
@ -4896,7 +4945,12 @@ async def embeddings(
|
|||
e,
|
||||
litellm_debug_info,
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.embeddings(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e)),
|
||||
|
@ -5075,7 +5129,12 @@ async def image_generation(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.image_generation(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e)),
|
||||
|
@ -5253,7 +5312,12 @@ async def audio_speech(
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.audio_speech(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
|
||||
|
@ -5442,7 +5506,12 @@ async def audio_transcriptions(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.audio_transcription(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
|
@ -5451,7 +5520,6 @@ async def audio_transcriptions(
|
|||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
|
@ -5579,7 +5647,12 @@ async def get_assistants(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.get_assistants(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
|
@ -5588,7 +5661,6 @@ async def get_assistants(
|
|||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
|
@ -5708,7 +5780,12 @@ async def create_threads(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.create_threads(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
|
@ -5717,7 +5794,6 @@ async def create_threads(
|
|||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
|
@ -5836,7 +5912,12 @@ async def get_thread(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.get_thread(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
|
@ -5845,7 +5926,6 @@ async def get_thread(
|
|||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
|
@ -5967,7 +6047,12 @@ async def add_messages(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.add_messages(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
|
@ -5976,7 +6061,6 @@ async def add_messages(
|
|||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
|
@ -6094,7 +6178,12 @@ async def get_messages(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.get_messages(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
|
@ -6103,7 +6192,6 @@ async def get_messages(
|
|||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
|
@ -6235,7 +6323,12 @@ async def run_thread(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.run_thread(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
|
@ -6244,7 +6337,6 @@ async def run_thread(
|
|||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
|
@ -6383,7 +6475,12 @@ async def create_batch(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
|
@ -6392,7 +6489,6 @@ async def create_batch(
|
|||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
|
@ -6526,7 +6622,12 @@ async def retrieve_batch(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
|
@ -6679,7 +6780,12 @@ async def create_file(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.create_file(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
|
@ -6688,7 +6794,6 @@ async def create_file(
|
|||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
|
@ -6864,7 +6969,12 @@ async def moderations(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.moderations(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e)),
|
||||
|
@ -6873,7 +6983,6 @@ async def moderations(
|
|||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
|
@ -7184,7 +7293,12 @@ async def generate_key_fn(
|
|||
|
||||
return GenerateKeyResponse(**response)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.generate_key_fn(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
|
@ -9639,7 +9753,12 @@ async def user_info(
|
|||
}
|
||||
return response_data
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.user_info(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
|
@ -9734,7 +9853,12 @@ async def user_update(data: UpdateUserRequest):
|
|||
return response
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.user_update(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
|
@ -9787,7 +9911,12 @@ async def user_request_model(request: Request):
|
|||
return {"status": "success"}
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.user_request_model(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
|
@ -9829,7 +9958,12 @@ async def user_get_requests():
|
|||
return {"requests": response}
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.user_get_requests(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
|
@ -10219,7 +10353,12 @@ async def update_end_user(
|
|||
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.update_end_user(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
||||
|
@ -10303,7 +10442,12 @@ async def delete_end_user(
|
|||
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.delete_end_user(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
||||
|
@ -11606,7 +11750,12 @@ async def add_new_model(
|
|||
return model_response
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.add_new_model(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
|
@ -11720,7 +11869,12 @@ async def update_model(
|
|||
|
||||
return model_response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.update_model(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
|
@ -13954,7 +14108,12 @@ async def update_config(config_info: ConfigYAML):
|
|||
|
||||
return {"message": "Config updated successfully"}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.update_config(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
|
@ -14427,7 +14586,12 @@ async def get_config():
|
|||
"available_callbacks": all_available_callbacks,
|
||||
}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.get_config(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
|
@ -14678,7 +14842,12 @@ async def health_services_endpoint(
|
|||
}
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.health_services_endpoint(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
|
@ -14757,7 +14926,12 @@ async def health_endpoint(
|
|||
"unhealthy_count": len(unhealthy_endpoints),
|
||||
}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.py::health_endpoint(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,8 @@ Requires:
|
|||
* `pip install boto3>=1.28.57`
|
||||
"""
|
||||
|
||||
import litellm, os
|
||||
import litellm
|
||||
import os
|
||||
from typing import Optional
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
|
@ -38,3 +39,21 @@ def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
|
|||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def load_aws_kms(use_aws_kms: Optional[bool]):
|
||||
if use_aws_kms is None or use_aws_kms is False:
|
||||
return
|
||||
try:
|
||||
import boto3
|
||||
|
||||
validate_environment()
|
||||
|
||||
# Create a Secrets Manager client
|
||||
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
||||
|
||||
litellm.secret_manager_client = kms_client
|
||||
litellm._key_management_system = KeyManagementSystem.AWS_KMS
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -2889,3 +2889,16 @@ missing_keys_html_form = """
|
|||
|
||||
def _to_ns(dt):
|
||||
return int(dt.timestamp() * 1e9)
|
||||
|
||||
def get_error_message_str(e: Exception) -> str:
|
||||
error_message = ""
|
||||
if isinstance(e, HTTPException):
|
||||
if isinstance(e.detail, str):
|
||||
error_message = e.detail
|
||||
elif isinstance(e.detail, dict):
|
||||
error_message = json.dumps(e.detail)
|
||||
else:
|
||||
error_message = str(e)
|
||||
else:
|
||||
error_message = str(e)
|
||||
return error_message
|
||||
|
|
2
litellm/py.typed
Normal file
2
litellm/py.typed
Normal file
|
@ -0,0 +1,2 @@
|
|||
# Marker file to instruct type checkers to look for inline type annotations in this package.
|
||||
# See PEP 561 for more information.
|
|
@ -220,8 +220,6 @@ class Router:
|
|||
[]
|
||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||
self.deployment_latency_map = {}
|
||||
### SCHEDULER ###
|
||||
self.scheduler = Scheduler(polling_interval=polling_interval)
|
||||
### CACHING ###
|
||||
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
||||
redis_cache = None
|
||||
|
@ -259,6 +257,10 @@ class Router:
|
|||
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
|
||||
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||
|
||||
### SCHEDULER ###
|
||||
self.scheduler = Scheduler(
|
||||
polling_interval=polling_interval, redis_cache=redis_cache
|
||||
)
|
||||
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
|
||||
self.default_max_parallel_requests = default_max_parallel_requests
|
||||
|
||||
|
@ -2096,8 +2098,8 @@ class Router:
|
|||
except Exception as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug(f"An exception occurred - {str(e)}")
|
||||
traceback.print_exc()
|
||||
verbose_router_logger.error(f"An exception occurred - {str(e)}")
|
||||
verbose_router_logger.debug(traceback.format_exc())
|
||||
raise original_exception
|
||||
|
||||
async def async_function_with_retries(self, *args, **kwargs):
|
||||
|
@ -4048,6 +4050,12 @@ class Router:
|
|||
for idx in reversed(invalid_model_indices):
|
||||
_returned_deployments.pop(idx)
|
||||
|
||||
## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2)
|
||||
if len(_returned_deployments) > 0:
|
||||
_returned_deployments = litellm.utils._get_order_filtered_deployments(
|
||||
_returned_deployments
|
||||
)
|
||||
|
||||
return _returned_deployments
|
||||
|
||||
def _common_checks_available_deployment(
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
#### What this does ####
|
||||
# picks based on response time (for streaming, this is time to first token)
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
import os, requests, random # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Union, List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
|
||||
from litellm import verbose_logger
|
||||
import traceback
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -119,7 +117,12 @@ class LowestCostLoggingHandler(CustomLogger):
|
|||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
@ -201,7 +204,12 @@ class LowestCostLoggingHandler(CustomLogger):
|
|||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
async def async_get_available_deployments(
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
#### What this does ####
|
||||
# picks based on response time (for streaming, this is time to first token)
|
||||
from pydantic import BaseModel, Extra, Field, root_validator # type: ignore
|
||||
import dotenv, os, requests, random # type: ignore
|
||||
from pydantic import BaseModel
|
||||
import random
|
||||
from typing import Optional, Union, List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
import traceback
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm import ModelResponse
|
||||
from litellm import token_counter
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
|
||||
|
||||
class LiteLLMBase(BaseModel):
|
||||
|
@ -165,7 +165,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
@ -229,7 +234,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
# do nothing if it's not a timeout error
|
||||
return
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
@ -352,7 +362,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.router_strategy.lowest_latency.py::async_log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
def get_available_deployments(
|
||||
|
|
|
@ -11,6 +11,7 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
from litellm._logging import verbose_router_logger
|
||||
from litellm.utils import print_verbose
|
||||
|
||||
|
||||
class LiteLLMBase(BaseModel):
|
||||
"""
|
||||
Implements default functions, all pydantic objects should have.
|
||||
|
@ -23,16 +24,20 @@ class LiteLLMBase(BaseModel):
|
|||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
|
||||
class RoutingArgs(LiteLLMBase):
|
||||
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
|
||||
|
||||
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
|
||||
|
||||
|
||||
class LowestTPMLoggingHandler(CustomLogger):
|
||||
test_flag: bool = False
|
||||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||
|
||||
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}):
|
||||
def __init__(
|
||||
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
|
||||
):
|
||||
self.router_cache = router_cache
|
||||
self.model_list = model_list
|
||||
self.routing_args = RoutingArgs(**routing_args)
|
||||
|
@ -72,19 +77,28 @@ class LowestTPMLoggingHandler(CustomLogger):
|
|||
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||
|
||||
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
|
||||
self.router_cache.set_cache(
|
||||
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||
|
||||
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
|
||||
self.router_cache.set_cache(
|
||||
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_router_logger.error(
|
||||
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_router_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
@ -123,19 +137,28 @@ class LowestTPMLoggingHandler(CustomLogger):
|
|||
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||
|
||||
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
|
||||
self.router_cache.set_cache(
|
||||
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||
|
||||
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
|
||||
self.router_cache.set_cache(
|
||||
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_router_logger.error(
|
||||
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_router_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
def get_available_deployments(
|
||||
|
|
|
@ -1,19 +1,19 @@
|
|||
#### What this does ####
|
||||
# identifies lowest tpm deployment
|
||||
from pydantic import BaseModel
|
||||
import dotenv, os, requests, random
|
||||
import random
|
||||
from typing import Optional, Union, List, Dict
|
||||
import datetime as datetime_og
|
||||
from datetime import datetime
|
||||
import traceback, asyncio, httpx
|
||||
import traceback
|
||||
import httpx
|
||||
import litellm
|
||||
from litellm import token_counter
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm._logging import verbose_router_logger, verbose_logger
|
||||
from litellm.utils import print_verbose, get_utc_datetime
|
||||
from litellm.types.router import RouterErrors
|
||||
|
||||
|
||||
class LiteLLMBase(BaseModel):
|
||||
"""
|
||||
Implements default functions, all pydantic objects should have.
|
||||
|
@ -22,12 +22,14 @@ class LiteLLMBase(BaseModel):
|
|||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
except Exception as e:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
|
||||
class RoutingArgs(LiteLLMBase):
|
||||
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
|
||||
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
|
||||
|
||||
|
||||
class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||
"""
|
||||
|
@ -47,7 +49,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
logged_failure: int = 0
|
||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||
|
||||
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}):
|
||||
def __init__(
|
||||
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
|
||||
):
|
||||
self.router_cache = router_cache
|
||||
self.model_list = model_list
|
||||
self.routing_args = RoutingArgs(**routing_args)
|
||||
|
@ -104,7 +108,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
)
|
||||
else:
|
||||
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||
result = self.router_cache.increment_cache(key=rpm_key, value=1, ttl=self.routing_args.ttl)
|
||||
result = self.router_cache.increment_cache(
|
||||
key=rpm_key, value=1, ttl=self.routing_args.ttl
|
||||
)
|
||||
if result is not None and result > deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
|
@ -244,12 +250,19 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
# update cache
|
||||
|
||||
## TPM
|
||||
self.router_cache.increment_cache(key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl)
|
||||
self.router_cache.increment_cache(
|
||||
key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
|
||||
)
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
@ -295,7 +308,12 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
def _common_checks_available_deployment(
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import heapq, time
|
||||
import heapq
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import enum
|
||||
from litellm.caching import DualCache
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm import print_verbose
|
||||
|
||||
|
||||
class SchedulerCacheKeys(enum.Enum):
|
||||
queue = "scheduler:queue"
|
||||
default_in_memory_ttl = 5 # cache queue in-memory for 5s when redis cache available
|
||||
|
||||
|
||||
class DefaultPriorities(enum.Enum):
|
||||
|
@ -25,18 +26,24 @@ class FlowItem(BaseModel):
|
|||
class Scheduler:
|
||||
cache: DualCache
|
||||
|
||||
def __init__(self, polling_interval: Optional[float] = None):
|
||||
def __init__(
|
||||
self,
|
||||
polling_interval: Optional[float] = None,
|
||||
redis_cache: Optional[RedisCache] = None,
|
||||
):
|
||||
"""
|
||||
polling_interval: float or null - frequency of polling queue. Default is 3ms.
|
||||
"""
|
||||
self.queue: list = []
|
||||
self.cache = DualCache()
|
||||
default_in_memory_ttl: Optional[float] = None
|
||||
if redis_cache is not None:
|
||||
# if redis-cache available frequently poll that instead of using in-memory.
|
||||
default_in_memory_ttl = SchedulerCacheKeys.default_in_memory_ttl.value
|
||||
self.cache = DualCache(
|
||||
redis_cache=redis_cache, default_in_memory_ttl=default_in_memory_ttl
|
||||
)
|
||||
self.polling_interval = polling_interval or 0.03 # default to 3ms
|
||||
|
||||
def update_variables(self, cache: Optional[DualCache] = None):
|
||||
if cache is not None:
|
||||
self.cache = cache
|
||||
|
||||
async def add_request(self, request: FlowItem):
|
||||
# We use the priority directly, as lower values indicate higher priority
|
||||
# get the queue
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -198,7 +198,11 @@ async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
|
|||
)
|
||||
assert isinstance(messages.data[0], Message)
|
||||
else:
|
||||
pytest.fail("An unexpected error occurred when running the thread")
|
||||
pytest.fail(
|
||||
"An unexpected error occurred when running the thread, {}".format(
|
||||
run
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
added_message = await litellm.a_add_message(**data)
|
||||
|
@ -226,4 +230,8 @@ async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
|
|||
)
|
||||
assert isinstance(messages.data[0], Message)
|
||||
else:
|
||||
pytest.fail("An unexpected error occurred when running the thread")
|
||||
pytest.fail(
|
||||
"An unexpected error occurred when running the thread, {}".format(
|
||||
run
|
||||
)
|
||||
)
|
||||
|
|
|
@ -243,6 +243,7 @@ def test_completion_bedrock_claude_sts_oidc_auth():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
|
||||
reason="Cannot run without being in CircleCI Runner",
|
||||
|
@ -277,7 +278,15 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
def test_bedrock_claude_3():
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"image_url",
|
||||
[
|
||||
"",
|
||||
"https://avatars.githubusercontent.com/u/29436595?v=",
|
||||
],
|
||||
)
|
||||
def test_bedrock_claude_3(image_url):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
data = {
|
||||
|
@ -294,7 +303,7 @@ def test_bedrock_claude_3():
|
|||
{
|
||||
"image_url": {
|
||||
"detail": "high",
|
||||
"url": "",
|
||||
"url": image_url,
|
||||
},
|
||||
"type": "image_url",
|
||||
},
|
||||
|
@ -313,7 +322,6 @@ def test_bedrock_claude_3():
|
|||
# Add any assertions here to check the response
|
||||
assert len(response.choices) > 0
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
@ -552,7 +560,7 @@ def test_bedrock_ptu():
|
|||
assert "url" in mock_client_post.call_args.kwargs
|
||||
assert (
|
||||
mock_client_post.call_args.kwargs["url"]
|
||||
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke"
|
||||
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/converse"
|
||||
)
|
||||
mock_client_post.assert_called_once()
|
||||
|
||||
|
|
|
@ -300,7 +300,11 @@ def test_completion_claude_3():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_claude_3_function_call():
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||
)
|
||||
def test_completion_claude_3_function_call(model):
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
|
@ -331,13 +335,14 @@ def test_completion_claude_3_function_call():
|
|||
try:
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {"name": "get_current_weather"},
|
||||
},
|
||||
drop_params=True,
|
||||
)
|
||||
|
||||
# Add any assertions, here to check response args
|
||||
|
@ -364,10 +369,11 @@ def test_completion_claude_3_function_call():
|
|||
)
|
||||
# In the second response, Claude should deduce answer from tool results
|
||||
second_response = completion(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
drop_params=True,
|
||||
)
|
||||
print(second_response)
|
||||
except Exception as e:
|
||||
|
@ -2162,6 +2168,7 @@ def test_completion_azure_key_completion_arg():
|
|||
logprobs=True,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
print("Hidden Params", response._hidden_params)
|
||||
|
@ -2534,6 +2541,7 @@ def test_replicate_custom_prompt_dict():
|
|||
"content": "what is yc write 1 paragraph",
|
||||
}
|
||||
],
|
||||
mock_response="Hello world",
|
||||
repetition_penalty=0.1,
|
||||
num_retries=3,
|
||||
)
|
||||
|
|
|
@ -76,7 +76,7 @@ def test_image_generation_azure_dall_e_3():
|
|||
)
|
||||
print(f"response: {response}")
|
||||
assert len(response.data) > 0
|
||||
except litellm.RateLimitError as e:
|
||||
except litellm.InternalServerError as e:
|
||||
pass
|
||||
except litellm.ContentPolicyViolationError:
|
||||
pass # OpenAI randomly raises these errors - skip when they occur
|
||||
|
|
|
@ -2248,3 +2248,55 @@ async def test_create_update_team(prisma_client):
|
|||
assert _team_info["budget_reset_at"] is not None and isinstance(
|
||||
_team_info["budget_reset_at"], datetime.datetime
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_enforced_params(prisma_client):
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
general_settings["enforced_params"] = [
|
||||
"user",
|
||||
"metadata",
|
||||
"metadata.generation_name",
|
||||
]
|
||||
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
request = NewUserRequest()
|
||||
key = await new_user(request)
|
||||
print(key)
|
||||
|
||||
generated_key = key.key
|
||||
bearer_token = "Bearer " + generated_key
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
# Case 1: Missing user
|
||||
async def return_body():
|
||||
return b'{"model": "gemini-pro-vision"}'
|
||||
|
||||
request.body = return_body
|
||||
try:
|
||||
await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
pytest.fail(f"This should have failed!. IT's an invalid request")
|
||||
except Exception as e:
|
||||
assert (
|
||||
"BadRequest please pass param=user in request body. This is a required param"
|
||||
in e.message
|
||||
)
|
||||
|
||||
# Case 2: Missing metadata["generation_name"]
|
||||
async def return_body_2():
|
||||
return b'{"model": "gemini-pro-vision", "user": "1234", "metadata": {}}'
|
||||
|
||||
request.body = return_body_2
|
||||
try:
|
||||
await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
pytest.fail(f"This should have failed!. IT's an invalid request")
|
||||
except Exception as e:
|
||||
assert (
|
||||
"Authentication Error, BadRequest please pass param=[metadata][generation_name] in request body"
|
||||
in e.message
|
||||
)
|
||||
|
|
|
@ -275,7 +275,7 @@ async def _deploy(lowest_latency_logger, deployment_id, tokens_used, duration):
|
|||
}
|
||||
start_time = time.time()
|
||||
response_obj = {"usage": {"total_tokens": tokens_used}}
|
||||
time.sleep(duration)
|
||||
await asyncio.sleep(duration)
|
||||
end_time = time.time()
|
||||
lowest_latency_logger.log_success_event(
|
||||
response_obj=response_obj,
|
||||
|
@ -325,6 +325,7 @@ def test_get_available_endpoints_tpm_rpm_check_async(ans_rpm):
|
|||
d1 = [(lowest_latency_logger, "1234", 50, 0.01)] * non_ans_rpm
|
||||
d2 = [(lowest_latency_logger, "5678", 50, 0.01)] * non_ans_rpm
|
||||
asyncio.run(_gather_deploy([*d1, *d2]))
|
||||
time.sleep(3)
|
||||
## CHECK WHAT'S SELECTED ##
|
||||
d_ans = lowest_latency_logger.get_available_deployments(
|
||||
model_group=model_group, healthy_deployments=model_list
|
||||
|
|
|
@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import (
|
|||
claude_2_1_pt,
|
||||
llama_2_chat_pt,
|
||||
prompt_factory,
|
||||
_bedrock_tools_pt,
|
||||
)
|
||||
|
||||
|
||||
|
@ -128,3 +129,27 @@ def test_anthropic_messages_pt():
|
|||
|
||||
|
||||
# codellama_prompt_format()
|
||||
def test_bedrock_tool_calling_pt():
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
converted_tools = _bedrock_tools_pt(tools=tools)
|
||||
|
||||
print(converted_tools)
|
||||
|
|
|
@ -38,6 +38,48 @@ def test_router_sensitive_keys():
|
|||
assert "special-key" not in str(e)
|
||||
|
||||
|
||||
def test_router_order():
|
||||
"""
|
||||
Asserts for 2 models in a model group, model with order=1 always called first
|
||||
"""
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"mock_response": "Hello world",
|
||||
"order": 1,
|
||||
},
|
||||
"model_info": {"id": "1"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": "bad-key",
|
||||
"mock_response": Exception("this is a bad key"),
|
||||
"order": 2,
|
||||
},
|
||||
"model_info": {"id": "2"},
|
||||
},
|
||||
],
|
||||
num_retries=0,
|
||||
allowed_fails=0,
|
||||
enable_pre_call_checks=True,
|
||||
)
|
||||
|
||||
for _ in range(100):
|
||||
response = router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
assert response._hidden_params["model_id"] == "1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_retries", [None, 2])
|
||||
@pytest.mark.parametrize("max_retries", [None, 4])
|
||||
def test_router_num_retries_init(num_retries, max_retries):
|
||||
|
|
|
@ -1284,18 +1284,18 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
|||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize("sync_mode", [True]) # False
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# "bedrock/cohere.command-r-plus-v1:0",
|
||||
# "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
# "anthropic.claude-instant-v1",
|
||||
# "bedrock/ai21.j2-mid",
|
||||
# "mistral.mistral-7b-instruct-v0:2",
|
||||
# "bedrock/amazon.titan-tg1-large",
|
||||
# "meta.llama3-8b-instruct-v1:0",
|
||||
"cohere.command-text-v14"
|
||||
"bedrock/cohere.command-r-plus-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-instant-v1",
|
||||
"bedrock/ai21.j2-mid",
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"bedrock/amazon.titan-tg1-large",
|
||||
"meta.llama3-8b-instruct-v1:0",
|
||||
"cohere.command-text-v14",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -186,3 +186,13 @@ def test_load_test_token_counter(model):
|
|||
total_time = end_time - start_time
|
||||
print("model={}, total test time={}".format(model, total_time))
|
||||
assert total_time < 10, f"Total encoding time > 10s, {total_time}"
|
||||
|
||||
def test_openai_token_with_image_and_text():
|
||||
model = "gpt-4o"
|
||||
full_request = {'model': 'gpt-4o', 'tools': [{'type': 'function', 'function': {'name': 'json', 'parameters': {'type': 'object', 'required': ['clause'], 'properties': {'clause': {'type': 'string'}}}, 'description': 'Respond with a JSON object.'}}], 'logprobs': False, 'messages': [{'role': 'user', 'content': [{'text': '\n Just some long text, long long text, and you know it will be longer than 7 tokens definetly.', 'type': 'text'}]}], 'tool_choice': {'type': 'function', 'function': {'name': 'json'}}, 'exclude_models': [], 'disable_fallback': False, 'exclude_providers': []}
|
||||
messages = full_request.get("messages", [])
|
||||
|
||||
token_count = token_counter(model=model, messages=messages)
|
||||
print(token_count)
|
||||
|
||||
test_openai_token_with_image_and_text()
|
267
litellm/types/files.py
Normal file
267
litellm/types/files.py
Normal file
|
@ -0,0 +1,267 @@
|
|||
from enum import Enum
|
||||
from types import MappingProxyType
|
||||
from typing import List, Set
|
||||
|
||||
"""
|
||||
Base Enums/Consts
|
||||
"""
|
||||
class FileType(Enum):
|
||||
AAC = "AAC"
|
||||
CSV = "CSV"
|
||||
DOC = "DOC"
|
||||
DOCX = "DOCX"
|
||||
FLAC = "FLAC"
|
||||
FLV = "FLV"
|
||||
GIF = "GIF"
|
||||
GOOGLE_DOC = "GOOGLE_DOC"
|
||||
GOOGLE_DRAWINGS = "GOOGLE_DRAWINGS"
|
||||
GOOGLE_SHEETS = "GOOGLE_SHEETS"
|
||||
GOOGLE_SLIDES = "GOOGLE_SLIDES"
|
||||
HEIC = "HEIC"
|
||||
HEIF = "HEIF"
|
||||
HTML = "HTML"
|
||||
JPEG = "JPEG"
|
||||
JSON = "JSON"
|
||||
M4A = "M4A"
|
||||
M4V = "M4V"
|
||||
MOV = "MOV"
|
||||
MP3 = "MP3"
|
||||
MP4 = "MP4"
|
||||
MPEG = "MPEG"
|
||||
MPEGPS = "MPEGPS"
|
||||
MPG = "MPG"
|
||||
MPA = "MPA"
|
||||
MPGA = "MPGA"
|
||||
OGG = "OGG"
|
||||
OPUS = "OPUS"
|
||||
PDF = "PDF"
|
||||
PCM = "PCM"
|
||||
PNG = "PNG"
|
||||
PPT = "PPT"
|
||||
PPTX = "PPTX"
|
||||
RTF = "RTF"
|
||||
THREE_GPP = "3GPP"
|
||||
TXT = "TXT"
|
||||
WAV = "WAV"
|
||||
WEBM = "WEBM"
|
||||
WEBP = "WEBP"
|
||||
WMV = "WMV"
|
||||
XLS = "XLS"
|
||||
XLSX = "XLSX"
|
||||
|
||||
FILE_EXTENSIONS: MappingProxyType[FileType, List[str]] = MappingProxyType({
|
||||
FileType.AAC: ["aac"],
|
||||
FileType.CSV: ["csv"],
|
||||
FileType.DOC: ["doc"],
|
||||
FileType.DOCX: ["docx"],
|
||||
FileType.FLAC: ["flac"],
|
||||
FileType.FLV: ["flv"],
|
||||
FileType.GIF: ["gif"],
|
||||
FileType.GOOGLE_DOC: ["gdoc"],
|
||||
FileType.GOOGLE_DRAWINGS: ["gdraw"],
|
||||
FileType.GOOGLE_SHEETS: ["gsheet"],
|
||||
FileType.GOOGLE_SLIDES: ["gslides"],
|
||||
FileType.HEIC: ["heic"],
|
||||
FileType.HEIF: ["heif"],
|
||||
FileType.HTML: ["html", "htm"],
|
||||
FileType.JPEG: ["jpeg", "jpg"],
|
||||
FileType.JSON: ["json"],
|
||||
FileType.M4A: ["m4a"],
|
||||
FileType.M4V: ["m4v"],
|
||||
FileType.MOV: ["mov"],
|
||||
FileType.MP3: ["mp3"],
|
||||
FileType.MP4: ["mp4"],
|
||||
FileType.MPEG: ["mpeg"],
|
||||
FileType.MPEGPS: ["mpegps"],
|
||||
FileType.MPG: ["mpg"],
|
||||
FileType.MPA: ["mpa"],
|
||||
FileType.MPGA: ["mpga"],
|
||||
FileType.OGG: ["ogg"],
|
||||
FileType.OPUS: ["opus"],
|
||||
FileType.PDF: ["pdf"],
|
||||
FileType.PCM: ["pcm"],
|
||||
FileType.PNG: ["png"],
|
||||
FileType.PPT: ["ppt"],
|
||||
FileType.PPTX: ["pptx"],
|
||||
FileType.RTF: ["rtf"],
|
||||
FileType.THREE_GPP: ["3gpp"],
|
||||
FileType.TXT: ["txt"],
|
||||
FileType.WAV: ["wav"],
|
||||
FileType.WEBM: ["webm"],
|
||||
FileType.WEBP: ["webp"],
|
||||
FileType.WMV: ["wmv"],
|
||||
FileType.XLS: ["xls"],
|
||||
FileType.XLSX: ["xlsx"],
|
||||
})
|
||||
|
||||
FILE_MIME_TYPES: MappingProxyType[FileType, str] = MappingProxyType({
|
||||
FileType.AAC: "audio/aac",
|
||||
FileType.CSV: "text/csv",
|
||||
FileType.DOC: "application/msword",
|
||||
FileType.DOCX: "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
FileType.FLAC: "audio/flac",
|
||||
FileType.FLV: "video/x-flv",
|
||||
FileType.GIF: "image/gif",
|
||||
FileType.GOOGLE_DOC: "application/vnd.google-apps.document",
|
||||
FileType.GOOGLE_DRAWINGS: "application/vnd.google-apps.drawing",
|
||||
FileType.GOOGLE_SHEETS: "application/vnd.google-apps.spreadsheet",
|
||||
FileType.GOOGLE_SLIDES: "application/vnd.google-apps.presentation",
|
||||
FileType.HEIC: "image/heic",
|
||||
FileType.HEIF: "image/heif",
|
||||
FileType.HTML: "text/html",
|
||||
FileType.JPEG: "image/jpeg",
|
||||
FileType.JSON: "application/json",
|
||||
FileType.M4A: "audio/x-m4a",
|
||||
FileType.M4V: "video/x-m4v",
|
||||
FileType.MOV: "video/quicktime",
|
||||
FileType.MP3: "audio/mpeg",
|
||||
FileType.MP4: "video/mp4",
|
||||
FileType.MPEG: "video/mpeg",
|
||||
FileType.MPEGPS: "video/mpegps",
|
||||
FileType.MPG: "video/mpg",
|
||||
FileType.MPA: "audio/m4a",
|
||||
FileType.MPGA: "audio/mpga",
|
||||
FileType.OGG: "audio/ogg",
|
||||
FileType.OPUS: "audio/opus",
|
||||
FileType.PDF: "application/pdf",
|
||||
FileType.PCM: "audio/pcm",
|
||||
FileType.PNG: "image/png",
|
||||
FileType.PPT: "application/vnd.ms-powerpoint",
|
||||
FileType.PPTX: "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
FileType.RTF: "application/rtf",
|
||||
FileType.THREE_GPP: "video/3gpp",
|
||||
FileType.TXT: "text/plain",
|
||||
FileType.WAV: "audio/wav",
|
||||
FileType.WEBM: "video/webm",
|
||||
FileType.WEBP: "image/webp",
|
||||
FileType.WMV: "video/wmv",
|
||||
FileType.XLS: "application/vnd.ms-excel",
|
||||
FileType.XLSX: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
})
|
||||
|
||||
"""
|
||||
Util Functions
|
||||
"""
|
||||
def get_file_mime_type_from_extension(extension: str) -> str:
|
||||
for file_type, extensions in FILE_EXTENSIONS.items():
|
||||
if extension in extensions:
|
||||
return FILE_MIME_TYPES[file_type]
|
||||
raise ValueError(f"Unknown mime type for extension: {extension}")
|
||||
|
||||
|
||||
def get_file_extension_from_mime_type(mime_type: str) -> str:
|
||||
for file_type, mime in FILE_MIME_TYPES.items():
|
||||
if mime == mime_type:
|
||||
return FILE_EXTENSIONS[file_type][0]
|
||||
raise ValueError(f"Unknown extension for mime type: {mime_type}")
|
||||
|
||||
|
||||
def get_file_type_from_extension(extension: str) -> FileType:
|
||||
for file_type, extensions in FILE_EXTENSIONS.items():
|
||||
if extension in extensions:
|
||||
return file_type
|
||||
|
||||
raise ValueError(f"Unknown file type for extension: {extension}")
|
||||
|
||||
|
||||
def get_file_extension_for_file_type(file_type: FileType) -> str:
|
||||
return FILE_EXTENSIONS[file_type][0]
|
||||
|
||||
def get_file_mime_type_for_file_type(file_type: FileType) -> str:
|
||||
return FILE_MIME_TYPES[file_type]
|
||||
|
||||
|
||||
"""
|
||||
FileType Type Groupings (Videos, Images, etc)
|
||||
"""
|
||||
|
||||
# Images
|
||||
IMAGE_FILE_TYPES = {
|
||||
FileType.PNG,
|
||||
FileType.JPEG,
|
||||
FileType.GIF,
|
||||
FileType.WEBP,
|
||||
FileType.HEIC,
|
||||
FileType.HEIF
|
||||
}
|
||||
|
||||
def is_image_file_type(file_type):
|
||||
return file_type in IMAGE_FILE_TYPES
|
||||
|
||||
# Videos
|
||||
VIDEO_FILE_TYPES = {
|
||||
FileType.MOV,
|
||||
FileType.MP4,
|
||||
FileType.MPEG,
|
||||
FileType.M4V,
|
||||
FileType.FLV,
|
||||
FileType.MPEGPS,
|
||||
FileType.MPG,
|
||||
FileType.WEBM,
|
||||
FileType.WMV,
|
||||
FileType.THREE_GPP
|
||||
}
|
||||
|
||||
def is_video_file_type(file_type):
|
||||
return file_type in VIDEO_FILE_TYPES
|
||||
|
||||
# Audio
|
||||
AUDIO_FILE_TYPES = {
|
||||
FileType.AAC,
|
||||
FileType.FLAC,
|
||||
FileType.MP3,
|
||||
FileType.MPA,
|
||||
FileType.MPGA,
|
||||
FileType.OPUS,
|
||||
FileType.PCM,
|
||||
FileType.WAV,
|
||||
}
|
||||
|
||||
def is_audio_file_type(file_type):
|
||||
return file_type in AUDIO_FILE_TYPES
|
||||
|
||||
# Text
|
||||
TEXT_FILE_TYPES = {
|
||||
FileType.CSV,
|
||||
FileType.HTML,
|
||||
FileType.RTF,
|
||||
FileType.TXT
|
||||
}
|
||||
|
||||
def is_text_file_type(file_type):
|
||||
return file_type in TEXT_FILE_TYPES
|
||||
|
||||
"""
|
||||
Other FileType Groupings
|
||||
"""
|
||||
# Accepted file types for GEMINI 1.5 through Vertex AI
|
||||
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#gemini-send-multimodal-samples-images-nodejs
|
||||
GEMINI_1_5_ACCEPTED_FILE_TYPES: Set[FileType] = {
|
||||
# Image
|
||||
FileType.PNG,
|
||||
FileType.JPEG,
|
||||
# Audio
|
||||
FileType.AAC,
|
||||
FileType.FLAC,
|
||||
FileType.MP3,
|
||||
FileType.MPA,
|
||||
FileType.MPGA,
|
||||
FileType.OPUS,
|
||||
FileType.PCM,
|
||||
FileType.WAV,
|
||||
# Video
|
||||
FileType.FLV,
|
||||
FileType.MOV,
|
||||
FileType.MPEG,
|
||||
FileType.MPEGPS,
|
||||
FileType.MPG,
|
||||
FileType.MP4,
|
||||
FileType.WEBM,
|
||||
FileType.WMV,
|
||||
FileType.THREE_GPP,
|
||||
# PDF
|
||||
FileType.PDF,
|
||||
}
|
||||
|
||||
def is_gemini_1_5_accepted_file_type(file_type: FileType) -> bool:
|
||||
return file_type in GEMINI_1_5_ACCEPTED_FILE_TYPES
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TypedDict, Any, Union, Optional
|
||||
from typing import TypedDict, Any, Union, Optional, Literal, List
|
||||
import json
|
||||
from typing_extensions import (
|
||||
Self,
|
||||
|
@ -11,10 +11,137 @@ from typing_extensions import (
|
|||
)
|
||||
|
||||
|
||||
class SystemContentBlock(TypedDict):
|
||||
text: str
|
||||
|
||||
|
||||
class ImageSourceBlock(TypedDict):
|
||||
bytes: Optional[str] # base 64 encoded string
|
||||
|
||||
|
||||
class ImageBlock(TypedDict):
|
||||
format: Literal["png", "jpeg", "gif", "webp"]
|
||||
source: ImageSourceBlock
|
||||
|
||||
|
||||
class ToolResultContentBlock(TypedDict, total=False):
|
||||
image: ImageBlock
|
||||
json: dict
|
||||
text: str
|
||||
|
||||
|
||||
class ToolResultBlock(TypedDict, total=False):
|
||||
content: Required[List[ToolResultContentBlock]]
|
||||
toolUseId: Required[str]
|
||||
status: Literal["success", "error"]
|
||||
|
||||
|
||||
class ToolUseBlock(TypedDict):
|
||||
input: dict
|
||||
name: str
|
||||
toolUseId: str
|
||||
|
||||
|
||||
class ContentBlock(TypedDict, total=False):
|
||||
text: str
|
||||
image: ImageBlock
|
||||
toolResult: ToolResultBlock
|
||||
toolUse: ToolUseBlock
|
||||
|
||||
|
||||
class MessageBlock(TypedDict):
|
||||
content: List[ContentBlock]
|
||||
role: Literal["user", "assistant"]
|
||||
|
||||
|
||||
class ConverseMetricsBlock(TypedDict):
|
||||
latencyMs: float # time in ms
|
||||
|
||||
|
||||
class ConverseResponseOutputBlock(TypedDict):
|
||||
message: Optional[MessageBlock]
|
||||
|
||||
|
||||
class ConverseTokenUsageBlock(TypedDict):
|
||||
inputTokens: int
|
||||
outputTokens: int
|
||||
totalTokens: int
|
||||
|
||||
|
||||
class ConverseResponseBlock(TypedDict):
|
||||
additionalModelResponseFields: dict
|
||||
metrics: ConverseMetricsBlock
|
||||
output: ConverseResponseOutputBlock
|
||||
stopReason: (
|
||||
str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered
|
||||
)
|
||||
usage: ConverseTokenUsageBlock
|
||||
|
||||
|
||||
class ToolInputSchemaBlock(TypedDict):
|
||||
json: Optional[dict]
|
||||
|
||||
|
||||
class ToolSpecBlock(TypedDict, total=False):
|
||||
inputSchema: Required[ToolInputSchemaBlock]
|
||||
name: Required[str]
|
||||
description: str
|
||||
|
||||
|
||||
class ToolBlock(TypedDict):
|
||||
toolSpec: Optional[ToolSpecBlock]
|
||||
|
||||
|
||||
class SpecificToolChoiceBlock(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
class ToolChoiceValuesBlock(TypedDict, total=False):
|
||||
any: dict
|
||||
auto: dict
|
||||
tool: SpecificToolChoiceBlock
|
||||
|
||||
|
||||
class ToolConfigBlock(TypedDict, total=False):
|
||||
tools: Required[List[ToolBlock]]
|
||||
toolChoice: Union[str, ToolChoiceValuesBlock]
|
||||
|
||||
|
||||
class InferenceConfig(TypedDict, total=False):
|
||||
maxTokens: int
|
||||
stopSequences: List[str]
|
||||
temperature: float
|
||||
topP: float
|
||||
|
||||
|
||||
class ToolBlockDeltaEvent(TypedDict):
|
||||
input: str
|
||||
|
||||
|
||||
class ContentBlockDeltaEvent(TypedDict, total=False):
|
||||
"""
|
||||
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
|
||||
"""
|
||||
|
||||
text: str
|
||||
toolUse: ToolBlockDeltaEvent
|
||||
|
||||
|
||||
class RequestObject(TypedDict, total=False):
|
||||
additionalModelRequestFields: dict
|
||||
additionalModelResponseFieldPaths: List[str]
|
||||
inferenceConfig: InferenceConfig
|
||||
messages: Required[List[MessageBlock]]
|
||||
system: List[SystemContentBlock]
|
||||
toolConfig: ToolConfigBlock
|
||||
|
||||
|
||||
class GenericStreamingChunk(TypedDict):
|
||||
text: Required[str]
|
||||
tool_str: Required[str]
|
||||
is_finished: Required[bool]
|
||||
finish_reason: Required[str]
|
||||
usage: Optional[ConverseTokenUsageBlock]
|
||||
|
||||
|
||||
class Document(TypedDict):
|
||||
|
|
|
@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False):
|
|||
extra_headers: Optional[Dict[str, str]]
|
||||
extra_body: Optional[Dict[str, str]]
|
||||
timeout: Optional[float]
|
||||
|
||||
|
||||
class ChatCompletionToolCallFunctionChunk(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionToolCallChunk(TypedDict):
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionToolCallFunctionChunk
|
||||
|
||||
|
||||
class ChatCompletionResponseMessage(TypedDict, total=False):
|
||||
content: Optional[str]
|
||||
tool_calls: List[ChatCompletionToolCallChunk]
|
||||
role: Literal["assistant"]
|
||||
|
|
186
litellm/utils.py
186
litellm/utils.py
|
@ -239,6 +239,8 @@ def map_finish_reason(
|
|||
return "length"
|
||||
elif finish_reason == "tool_use": # anthropic
|
||||
return "tool_calls"
|
||||
elif finish_reason == "content_filtered":
|
||||
return "content_filter"
|
||||
return finish_reason
|
||||
|
||||
|
||||
|
@ -1372,8 +1374,12 @@ class Logging:
|
|||
callback_func=callback,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print_verbose(
|
||||
verbose_logger.error(
|
||||
"litellm.Logging.pre_call(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while input logging with integrations {traceback.format_exc()}"
|
||||
)
|
||||
print_verbose(
|
||||
|
@ -4060,6 +4066,9 @@ def openai_token_counter(
|
|||
for c in value:
|
||||
if c["type"] == "text":
|
||||
text += c["text"]
|
||||
num_tokens += len(
|
||||
encoding.encode(c["text"], disallowed_special=())
|
||||
)
|
||||
elif c["type"] == "image_url":
|
||||
if isinstance(c["image_url"], dict):
|
||||
image_url_dict = c["image_url"]
|
||||
|
@ -5634,19 +5643,29 @@ def get_optional_params(
|
|||
optional_params["stream"] = stream
|
||||
elif "anthropic" in model:
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# anthropic params on bedrock
|
||||
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
optional_params = (
|
||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
optional_params = (
|
||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
)
|
||||
else:
|
||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
)
|
||||
else:
|
||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||
else: # bedrock httpx route
|
||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif "amazon" in model: # amazon titan llms
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
|
@ -6198,6 +6217,27 @@ def calculate_max_parallel_requests(
|
|||
return None
|
||||
|
||||
|
||||
def _get_order_filtered_deployments(healthy_deployments: List[Dict]) -> List:
|
||||
min_order = min(
|
||||
(
|
||||
deployment["litellm_params"]["order"]
|
||||
for deployment in healthy_deployments
|
||||
if "order" in deployment["litellm_params"]
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
if min_order is not None:
|
||||
filtered_deployments = [
|
||||
deployment
|
||||
for deployment in healthy_deployments
|
||||
if deployment["litellm_params"].get("order") == min_order
|
||||
]
|
||||
|
||||
return filtered_deployments
|
||||
return healthy_deployments
|
||||
|
||||
|
||||
def _get_model_region(
|
||||
custom_llm_provider: str, litellm_params: LiteLLM_Params
|
||||
) -> Optional[str]:
|
||||
|
@ -6403,20 +6443,7 @@ def get_supported_openai_params(
|
|||
- None if unmapped
|
||||
"""
|
||||
if custom_llm_provider == "bedrock":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
|
||||
elif model.startswith("anthropic"):
|
||||
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
|
||||
elif model.startswith("ai21"):
|
||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
||||
elif model.startswith("amazon"):
|
||||
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
||||
elif model.startswith("meta"):
|
||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
||||
elif model.startswith("cohere"):
|
||||
return ["stream", "temperature", "max_tokens"]
|
||||
elif model.startswith("mistral"):
|
||||
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
||||
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "ollama":
|
||||
return litellm.OllamaConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "ollama_chat":
|
||||
|
@ -8516,7 +8543,11 @@ def exception_type(
|
|||
extra_information = f"\nModel: {model}"
|
||||
if _api_base:
|
||||
extra_information += f"\nAPI Base: `{_api_base}`"
|
||||
if messages and len(messages) > 0:
|
||||
if (
|
||||
messages
|
||||
and len(messages) > 0
|
||||
and litellm.redact_messages_in_exceptions is False
|
||||
):
|
||||
extra_information += f"\nMessages: `{messages}`"
|
||||
|
||||
if _model_group is not None:
|
||||
|
@ -9803,8 +9834,7 @@ def exception_type(
|
|||
elif custom_llm_provider == "azure":
|
||||
if "Internal server error" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
status_code=500,
|
||||
raise litellm.InternalServerError(
|
||||
message=f"AzureException Internal server error - {original_exception.message}",
|
||||
llm_provider="azure",
|
||||
model=model,
|
||||
|
@ -10054,6 +10084,8 @@ def get_secret(
|
|||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
key_management_settings = litellm._key_management_settings
|
||||
args = locals()
|
||||
|
||||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
||||
|
@ -10141,13 +10173,13 @@ def get_secret(
|
|||
key_manager = "local"
|
||||
|
||||
if (
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||
or type(client).__module__ + "." + type(client).__name__
|
||||
== "azure.keyvault.secrets._client.SecretClient"
|
||||
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||
secret = client.get_secret(secret_name).value
|
||||
elif (
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS.value
|
||||
or client.__class__.__name__ == "KeyManagementServiceClient"
|
||||
):
|
||||
encrypted_secret: Any = os.getenv(secret_name)
|
||||
|
@ -10175,6 +10207,25 @@ def get_secret(
|
|||
secret = response.plaintext.decode(
|
||||
"utf-8"
|
||||
) # assumes the original value was encoded with utf-8
|
||||
elif key_manager == KeyManagementSystem.AWS_KMS.value:
|
||||
"""
|
||||
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
|
||||
"""
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception("encrypted value for AWS KMS cannot be None.")
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
|
||||
# Perform the decryption
|
||||
response = client.decrypt(**params)
|
||||
|
||||
# Extract and decode the plaintext
|
||||
plaintext = response["Plaintext"]
|
||||
secret = plaintext.decode("utf-8")
|
||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||
try:
|
||||
get_secret_value_response = client.get_secret_value(
|
||||
|
@ -10195,10 +10246,14 @@ def get_secret(
|
|||
for k, v in secret_dict.items():
|
||||
secret = v
|
||||
print_verbose(f"secret: {secret}")
|
||||
elif key_manager == "local":
|
||||
secret = os.getenv(secret_name)
|
||||
else: # assume the default is infisicial client
|
||||
secret = client.get_secret(secret_name).secret_value
|
||||
except Exception as e: # check if it's in os.environ
|
||||
print_verbose(f"An exception occurred - {str(e)}")
|
||||
verbose_logger.error(
|
||||
f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
secret = os.getenv(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
|
@ -10532,7 +10587,12 @@ class CustomStreamWrapper:
|
|||
"finish_reason": finish_reason,
|
||||
}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.CustomStreamWrapper.handle_predibase_chunk(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
def handle_huggingface_chunk(self, chunk):
|
||||
|
@ -10576,7 +10636,12 @@ class CustomStreamWrapper:
|
|||
"finish_reason": finish_reason,
|
||||
}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.CustomStreamWrapper.handle_huggingface_chunk(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
def handle_ai21_chunk(self, chunk): # fake streaming
|
||||
|
@ -10811,7 +10876,12 @@ class CustomStreamWrapper:
|
|||
"usage": usage,
|
||||
}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.CustomStreamWrapper.handle_openai_chat_completion_chunk(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
def handle_azure_text_completion_chunk(self, chunk):
|
||||
|
@ -10892,7 +10962,12 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
return ""
|
||||
except:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.CustomStreamWrapper.handle_baseten_chunk(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
return ""
|
||||
|
||||
def handle_cloudlfare_stream(self, chunk):
|
||||
|
@ -11091,7 +11166,12 @@ class CustomStreamWrapper:
|
|||
"is_finished": True,
|
||||
}
|
||||
except:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.CustomStreamWrapper.handle_clarifai_chunk(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
return ""
|
||||
|
||||
def model_response_creator(self):
|
||||
|
@ -11332,12 +11412,27 @@ class CustomStreamWrapper:
|
|||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "bedrock":
|
||||
from litellm.types.llms.bedrock import GenericStreamingChunk
|
||||
|
||||
if self.received_finish_reason is not None:
|
||||
raise StopIteration
|
||||
response_obj = self.handle_bedrock_stream(chunk)
|
||||
response_obj: GenericStreamingChunk = chunk
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
|
||||
if (
|
||||
self.stream_options
|
||||
and self.stream_options.get("include_usage", False) is True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
self.sent_stream_usage = True
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"]["inputTokens"],
|
||||
completion_tokens=response_obj["usage"]["outputTokens"],
|
||||
total_tokens=response_obj["usage"]["totalTokens"],
|
||||
)
|
||||
elif self.custom_llm_provider == "sagemaker":
|
||||
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
||||
response_obj = self.handle_sagemaker_stream(chunk)
|
||||
|
@ -11563,7 +11658,12 @@ class CustomStreamWrapper:
|
|||
tool["type"] = "function"
|
||||
model_response.choices[0].delta = Delta(**_json_delta)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.error(
|
||||
"litellm.CustomStreamWrapper.chunk_creator(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(traceback.format_exc())
|
||||
model_response.choices[0].delta = Delta()
|
||||
else:
|
||||
try:
|
||||
|
@ -11599,7 +11699,7 @@ class CustomStreamWrapper:
|
|||
and hasattr(model_response, "usage")
|
||||
and hasattr(model_response.usage, "prompt_tokens")
|
||||
):
|
||||
if self.sent_first_chunk == False:
|
||||
if self.sent_first_chunk is False:
|
||||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
|
@ -11767,6 +11867,8 @@ class CustomStreamWrapper:
|
|||
|
||||
def __next__(self):
|
||||
try:
|
||||
if self.completion_stream is None:
|
||||
self.fetch_sync_stream()
|
||||
while True:
|
||||
if (
|
||||
isinstance(self.completion_stream, str)
|
||||
|
@ -11841,6 +11943,14 @@ class CustomStreamWrapper:
|
|||
custom_llm_provider=self.custom_llm_provider,
|
||||
)
|
||||
|
||||
def fetch_sync_stream(self):
|
||||
if self.completion_stream is None and self.make_call is not None:
|
||||
# Call make_call to get the completion stream
|
||||
self.completion_stream = self.make_call(client=litellm.module_level_client)
|
||||
self._stream_iter = self.completion_stream.__iter__()
|
||||
|
||||
return self.completion_stream
|
||||
|
||||
async def fetch_stream(self):
|
||||
if self.completion_stream is None and self.make_call is not None:
|
||||
# Call make_call to get the completion stream
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "1.40.4"
|
||||
version = "1.40.5"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
packages = [
|
||||
{ include = "litellm" },
|
||||
{ include = "litellm/py.typed"},
|
||||
]
|
||||
|
||||
[tool.poetry.urls]
|
||||
homepage = "https://litellm.ai"
|
||||
|
@ -80,7 +84,7 @@ requires = ["poetry-core", "wheel"]
|
|||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.commitizen]
|
||||
version = "1.40.4"
|
||||
version = "1.40.5"
|
||||
version_files = [
|
||||
"pyproject.toml:^version"
|
||||
]
|
||||
|
|
3
ruff.toml
Normal file
3
ruff.toml
Normal file
|
@ -0,0 +1,3 @@
|
|||
ignore = ["F405"]
|
||||
extend-select = ["E501"]
|
||||
line-length = 120
|
Loading…
Add table
Add a link
Reference in a new issue