Merge branch 'main' into litellm_svc_logger

This commit is contained in:
Ishaan Jaff 2024-06-07 14:01:54 -07:00 committed by GitHub
commit 2cf3133669
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
84 changed files with 3848 additions and 5302 deletions

View file

@ -344,4 +344,4 @@ workflows:
filters:
branches:
only:
- main
- main

1
.gitignore vendored
View file

@ -59,3 +59,4 @@ myenv/*
litellm/proxy/_experimental/out/404/index.html
litellm/proxy/_experimental/out/model_hub/index.html
litellm/proxy/_experimental/out/onboarding/index.html
litellm/tests/log.txt

View file

@ -38,7 +38,7 @@ class MyCustomHandler(CustomLogger):
print(f"On Async Success")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Success")
print(f"On Async Failure")
customHandler = MyCustomHandler()

View file

@ -0,0 +1,3 @@
llmcord.py lets you and your friends chat with LLMs directly in your Discord server. It works with practically any LLM, remote or locally hosted.
Github: https://github.com/jakobdylanc/discord-llm-chatbot

View file

@ -62,6 +62,23 @@ curl -X GET 'http://localhost:4000/health/services?service=slack' \
-H 'Authorization: Bearer sk-1234'
```
## Advanced - Redacting Messages from Alerts
By default alerts show the `messages/input` passed to the LLM. If you want to redact this from slack alerting set the following setting on your config
```shell
general_settings:
alerting: ["slack"]
alert_types: ["spend_reports"]
litellm_settings:
redact_messages_in_exceptions: True
```
## Advanced - Opting into specific alert types
Set `alert_types` if you want to Opt into only specific alert types

View file

@ -14,6 +14,7 @@ Features:
- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features)
- ✅ [Audit Logs](#audit-logs)
- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags)
- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests)
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Google Text Moderations](#content-moderation)
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
- ✅ [Custom Branding + Routes on Swagger Docs](#swagger-docs---custom-routes--branding)
@ -204,6 +205,109 @@ curl -X GET "http://0.0.0.0:4000/spend/tags" \
```
## Enforce Required Params for LLM Requests
Use this when you want to enforce all requests to include certain params. Example you need all requests to include the `user` and `["metadata]["generation_name"]` params.
**Step 1** Define all Params you want to enforce on config.yaml
This means `["user"]` and `["metadata]["generation_name"]` are required in all LLM Requests to LiteLLM
```yaml
general_settings:
master_key: sk-1234
enforced_params:
- user
- metadata.generation_name
```
Start LiteLLM Proxy
**Step 2 Verify if this works**
<Tabs>
<TabItem value="bad" label="Invalid Request (No `user` passed)">
```shell
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-5fmYeaUEbAMpwBNT-QpxyA' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "hi"
}
]
}'
```
Expected Response
```shell
{"error":{"message":"Authentication Error, BadRequest please pass param=user in request body. This is a required param","type":"auth_error","param":"None","code":401}}%
```
</TabItem>
<TabItem value="bad2" label="Invalid Request (No `metadata` passed)">
```shell
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-5fmYeaUEbAMpwBNT-QpxyA' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-3.5-turbo",
"user": "gm",
"messages": [
{
"role": "user",
"content": "hi"
}
],
"metadata": {}
}'
```
Expected Response
```shell
{"error":{"message":"Authentication Error, BadRequest please pass param=[metadata][generation_name] in request body. This is a required param","type":"auth_error","param":"None","code":401}}%
```
</TabItem>
<TabItem value="good" label="Valid Request">
```shell
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-5fmYeaUEbAMpwBNT-QpxyA' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-3.5-turbo",
"user": "gm",
"messages": [
{
"role": "user",
"content": "hi"
}
],
"metadata": {"generation_name": "prod-app"}
}'
```
Expected Response
```shell
{"id":"chatcmpl-9XALnHqkCBMBKrOx7Abg0hURHqYtY","choices":[{"finish_reason":"stop","index":0,"message":{"content":"Hello! How can I assist you today?","role":"assistant"}}],"created":1717691639,"model":"gpt-3.5-turbo-0125","object":"chat.completion","system_fingerprint":null,"usage":{"completion_tokens":9,"prompt_tokens":8,"total_tokens":17}}%
```
</TabItem>
</Tabs>

View file

@ -313,6 +313,18 @@ You will see `raw_request` in your Langfuse Metadata. This is the RAW CURL comma
## Logging Proxy Input/Output in OpenTelemetry format
:::info
[Optional] Customize OTEL Service Name and OTEL TRACER NAME by setting the following variables in your environment
```shell
OTEL_TRACER_NAME=<your-trace-name> # default="litellm"
OTEL_SERVICE_NAME=<your-service-name>` # default="litellm"
```
:::
<Tabs>

View file

@ -100,4 +100,76 @@ print(response)
```
</TabItem>
</Tabs>
</Tabs>
## Advanced - Redis Caching
Use redis caching to do request prioritization across multiple instances of LiteLLM.
### SDK
```python
from litellm import Router
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "Hello world this is Macintosh!", # fakes the LLM API call
"rpm": 1,
},
},
],
### REDIS PARAMS ###
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
)
try:
_response = await router.schedule_acompletion( # 👈 ADDS TO QUEUE + POLLS + MAKES CALL
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey!"}],
priority=0, # 👈 LOWER IS BETTER
)
except Exception as e:
print("didn't make request")
```
### PROXY
```yaml
model_list:
- model_name: gpt-3.5-turbo-fake-model
litellm_params:
model: gpt-3.5-turbo
mock_response: "hello world!"
api_key: my-good-key
router_settings:
redis_host; os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
```
```bash
$ litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000s
```
```bash
curl -X POST 'http://localhost:4000/queue/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-D '{
"model": "gpt-3.5-turbo-fake-model",
"messages": [
{
"role": "user",
"content": "what is the meaning of the universe? 1234"
}],
"priority": 0 👈 SET VALUE HERE
}'
```

View file

@ -1,11 +1,31 @@
# Secret Manager
LiteLLM supports reading secrets from Azure Key Vault and Infisical
- AWS Key Managemenet Service
- AWS Secret Manager
- [Azure Key Vault](#azure-key-vault)
- Google Key Management Service
- [Infisical Secret Manager](#infisical-secret-manager)
- [.env Files](#env-files)
## AWS Key Management Service
Use AWS KMS to storing a hashed copy of your Proxy Master Key in the environment.
```bash
export LITELLM_MASTER_KEY="djZ9xjVaZ..." # 👈 ENCRYPTED KEY
export AWS_REGION_NAME="us-west-2"
```
```yaml
general_settings:
key_management_system: "aws_kms"
key_management_settings:
hosted_keys: ["LITELLM_MASTER_KEY"] # 👈 WHICH KEYS ARE STORED ON KMS
```
[**See Decryption Code**](https://github.com/BerriAI/litellm/blob/a2da2a8f168d45648b61279d4795d647d94f90c9/litellm/utils.py#L10182)
## AWS Secret Manager
Store your proxy keys in AWS Secret Manager.

View file

@ -5975,9 +5975,9 @@
}
},
"node_modules/caniuse-lite": {
"version": "1.0.30001519",
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001519.tgz",
"integrity": "sha512-0QHgqR+Jv4bxHMp8kZ1Kn8CH55OikjKJ6JmKkZYP1F3D7w+lnFXF70nG5eNfsZS89jadi5Ywy5UCSKLAglIRkg==",
"version": "1.0.30001629",
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001629.tgz",
"integrity": "sha512-c3dl911slnQhmxUIT4HhYzT7wnBK/XYpGnYLOj4nJBaRiw52Ibe7YxlDaAeRECvA786zCuExhxIUJ2K7nHMrBw==",
"funding": [
{
"type": "opencollective",

File diff suppressed because it is too large Load diff

View file

@ -18,10 +18,6 @@ async def log_event(request: Request):
return {"message": "Request received successfully"}
except Exception as e:
print(f"Error processing request: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail="Internal Server Error")

View file

@ -120,6 +120,5 @@ class GenericAPILogger:
)
return response
except Exception as e:
traceback.print_exc()
verbose_logger.debug(f"Generic - {str(e)}\n{traceback.format_exc()}")
verbose_logger.error(f"Generic - {str(e)}\n{traceback.format_exc()}")
pass

View file

@ -82,7 +82,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(traceback.format_exc())
async def async_post_call_success_hook(
self,

View file

@ -118,4 +118,4 @@ class _ENTERPRISE_BlockedUserList(CustomLogger):
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(traceback.format_exc())

View file

@ -92,7 +92,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
},
)
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(traceback.format_exc())
raise e
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool:

View file

@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*
### INIT VARIABLES ###
import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any, Literal
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.caching import Cache
from litellm._logging import (
set_verbose,
@ -60,6 +60,7 @@ _async_failure_callback: List[Callable] = (
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False
redact_messages_in_exceptions: Optional[bool] = False
store_audit_logs = False # Enterprise feature, allow users to see audit logs
## end of callbacks #############
@ -233,6 +234,7 @@ max_end_user_budget: Optional[float] = None
#### RELIABILITY ####
request_timeout: float = 6000
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
module_level_client = HTTPHandler(timeout=request_timeout)
num_retries: Optional[int] = None # per model endpoint
default_fallbacks: Optional[List] = None
fallbacks: Optional[List] = None
@ -766,7 +768,7 @@ from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig
from .llms.bedrock import (
AmazonTitanConfig,
AmazonAI21Config,
@ -808,6 +810,7 @@ from .exceptions import (
APIConnectionError,
APIResponseValidationError,
UnprocessableEntityError,
InternalServerError,
LITELLM_EXCEPTION_TYPES,
)
from .budget_manager import BudgetManager

View file

@ -1,5 +1,6 @@
import logging, os, json
from logging import Formatter
import traceback
set_verbose = False
json_logs = bool(os.getenv("JSON_LOGS", False))

View file

@ -253,7 +253,6 @@ class RedisCache(BaseCache):
str(e),
value,
)
traceback.print_exc()
raise e
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
@ -313,7 +312,6 @@ class RedisCache(BaseCache):
str(e),
value,
)
traceback.print_exc()
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client:
@ -352,7 +350,6 @@ class RedisCache(BaseCache):
str(e),
value,
)
traceback.print_exc()
async def async_set_cache_pipeline(self, cache_list, ttl=None):
"""
@ -413,7 +410,6 @@ class RedisCache(BaseCache):
str(e),
cache_value,
)
traceback.print_exc()
async def batch_cache_write(self, key, value, **kwargs):
print_verbose(
@ -458,7 +454,6 @@ class RedisCache(BaseCache):
str(e),
value,
)
traceback.print_exc()
raise e
async def flush_cache_buffer(self):
@ -495,8 +490,9 @@ class RedisCache(BaseCache):
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
verbose_logger.error(
"LiteLLM Caching: get() - Got exception from REDIS: ", e
)
def batch_get_cache(self, key_list) -> dict:
"""
@ -646,10 +642,9 @@ class RedisCache(BaseCache):
error=e,
call_type="sync_ping",
)
print_verbose(
verbose_logger.error(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
)
traceback.print_exc()
raise e
async def ping(self) -> bool:
@ -683,10 +678,9 @@ class RedisCache(BaseCache):
call_type="async_ping",
)
)
print_verbose(
verbose_logger.error(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
)
traceback.print_exc()
raise e
async def delete_cache_keys(self, keys):
@ -1138,22 +1132,23 @@ class S3Cache(BaseCache):
cached_response = ast.literal_eval(cached_response)
if type(cached_response) is not dict:
cached_response = dict(cached_response)
print_verbose(
verbose_logger.debug(
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
)
return cached_response
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
print_verbose(
verbose_logger.error(
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
)
return None
except Exception as e:
# NON blocking - notify users S3 is throwing an exception
traceback.print_exc()
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
verbose_logger.error(
f"S3 Caching: get_cache() - Got exception from S3: {e}"
)
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
@ -1234,8 +1229,7 @@ class DualCache(BaseCache):
return result
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
raise e
def get_cache(self, key, local_only: bool = False, **kwargs):
@ -1262,7 +1256,7 @@ class DualCache(BaseCache):
print_verbose(f"get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
verbose_logger.error(traceback.format_exc())
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
try:
@ -1295,7 +1289,7 @@ class DualCache(BaseCache):
print_verbose(f"async batch get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
verbose_logger.error(traceback.format_exc())
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first
@ -1328,7 +1322,7 @@ class DualCache(BaseCache):
print_verbose(f"get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
verbose_logger.error(traceback.format_exc())
async def async_batch_get_cache(
self, keys: list, local_only: bool = False, **kwargs
@ -1368,7 +1362,7 @@ class DualCache(BaseCache):
return result
except Exception as e:
traceback.print_exc()
verbose_logger.error(traceback.format_exc())
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
print_verbose(
@ -1381,8 +1375,8 @@ class DualCache(BaseCache):
if self.redis_cache is not None and local_only == False:
await self.redis_cache.async_set_cache(key, value, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
verbose_logger.debug(traceback.format_exc())
async def async_batch_set_cache(
self, cache_list: list, local_only: bool = False, **kwargs
@ -1404,8 +1398,8 @@ class DualCache(BaseCache):
cache_list=cache_list, ttl=kwargs.get("ttl", None)
)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
verbose_logger.debug(traceback.format_exc())
async def async_increment_cache(
self, key, value: float, local_only: bool = False, **kwargs
@ -1429,8 +1423,8 @@ class DualCache(BaseCache):
return result
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
verbose_logger.debug(traceback.format_exc())
raise e
def flush_cache(self):
@ -1846,8 +1840,8 @@ class Cache:
)
self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
verbose_logger.error(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
verbose_logger.debug(traceback.format_exc())
pass
async def async_add_cache(self, result, *args, **kwargs):
@ -1864,8 +1858,8 @@ class Cache:
)
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
verbose_logger.error(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
verbose_logger.debug(traceback.format_exc())
async def async_add_cache_pipeline(self, result, *args, **kwargs):
"""
@ -1897,8 +1891,8 @@ class Cache:
)
await asyncio.gather(*tasks)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
verbose_logger.error(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
verbose_logger.debug(traceback.format_exc())
async def batch_cache_write(self, result, *args, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic(

View file

@ -638,6 +638,7 @@ LITELLM_EXCEPTION_TYPES = [
APIConnectionError,
APIResponseValidationError,
OpenAIError,
InternalServerError,
]

View file

@ -169,6 +169,5 @@ class AISpendLogger:
print_verbose(f"AISpend Logging - final data object: {data}")
except:
# traceback.print_exc()
print_verbose(f"AISpend Logging Error - {traceback.format_exc()}")
pass

View file

@ -178,6 +178,5 @@ class BerriSpendLogger:
print_verbose(f"BerriSpend Logging - final data object: {data}")
response = requests.post(url, headers=headers, json=data)
except:
# traceback.print_exc()
print_verbose(f"BerriSpend Logging Error - {traceback.format_exc()}")
pass

View file

@ -297,6 +297,5 @@ class ClickhouseLogger:
# make request to endpoint with payload
verbose_logger.debug(f"Clickhouse Logger - final response = {response}")
except Exception as e:
traceback.print_exc()
verbose_logger.debug(f"Clickhouse - {str(e)}\n{traceback.format_exc()}")
pass

View file

@ -115,7 +115,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
)
print_verbose(f"Custom Logger - model call details: {kwargs}")
except:
traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
async def async_log_input_event(
@ -130,7 +129,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
)
print_verbose(f"Custom Logger - model call details: {kwargs}")
except:
traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
def log_event(
@ -146,7 +144,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
end_time,
)
except:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass
@ -163,6 +160,5 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
end_time,
)
except:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass

View file

@ -134,7 +134,6 @@ class DataDogLogger:
f"Datadog Layer Logging - final response object: {response_obj}"
)
except Exception as e:
traceback.print_exc()
verbose_logger.debug(
f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
)

View file

@ -85,6 +85,5 @@ class DyanmoDBLogger:
)
return response
except:
traceback.print_exc()
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
pass

View file

@ -112,6 +112,5 @@ class HeliconeLogger:
)
print_verbose(f"Helicone Logging - Error {response.text}")
except:
# traceback.print_exc()
print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
pass

View file

@ -220,9 +220,11 @@ class LangFuseLogger:
verbose_logger.info(f"Langfuse Layer Logging - logging success")
return {"trace_id": trace_id, "generation_id": generation_id}
except:
traceback.print_exc()
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
except Exception as e:
verbose_logger.error(
"Langfuse Layer Error(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
return {"trace_id": None, "generation_id": None}
async def _async_log_event(

View file

@ -44,7 +44,9 @@ class LangsmithLogger:
print_verbose(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
)
langsmith_base_url = os.getenv("LANGSMITH_BASE_URL", "https://api.smith.langchain.com")
langsmith_base_url = os.getenv(
"LANGSMITH_BASE_URL", "https://api.smith.langchain.com"
)
try:
print_verbose(
@ -89,9 +91,7 @@ class LangsmithLogger:
}
url = f"{langsmith_base_url}/runs"
print_verbose(
f"Langsmith Logging - About to send data to {url} ..."
)
print_verbose(f"Langsmith Logging - About to send data to {url} ...")
response = requests.post(
url=url,
json=data,
@ -106,6 +106,5 @@ class LangsmithLogger:
f"Langsmith Layer Logging - final response object: {response_obj}"
)
except:
# traceback.print_exc()
print_verbose(f"Langsmith Layer Error - {traceback.format_exc()}")
pass

View file

@ -171,7 +171,6 @@ class LogfireLogger:
f"Logfire Layer Logging - final response object: {response_obj}"
)
except Exception as e:
traceback.print_exc()
verbose_logger.debug(
f"Logfire Layer Error - {str(e)}\n{traceback.format_exc()}"
)

View file

@ -14,6 +14,7 @@ def parse_usage(usage):
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
}
def parse_tool_calls(tool_calls):
if tool_calls is None:
return None
@ -26,13 +27,13 @@ def parse_tool_calls(tool_calls):
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
}
},
}
return serialized
return [clean_tool_call(tool_call) for tool_call in tool_calls]
def parse_messages(input):
@ -176,6 +177,5 @@ class LunaryLogger:
)
except:
# traceback.print_exc()
print_verbose(f"Lunary Logging Error - {traceback.format_exc()}")
pass

View file

@ -14,8 +14,11 @@ if TYPE_CHECKING:
else:
Span = Any
LITELLM_TRACER_NAME = "litellm"
LITELLM_RESOURCE = {"service.name": "litellm"}
LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm")
LITELLM_RESOURCE = {
"service.name": os.getenv("OTEL_SERVICE_NAME", "litellm"),
}
@dataclass

View file

@ -109,8 +109,8 @@ class PrometheusLogger:
end_user_id, user_api_key, model, user_api_team, user_id
).inc()
except Exception as e:
traceback.print_exc()
verbose_logger.debug(
f"prometheus Layer Error - {str(e)}\n{traceback.format_exc()}"
verbose_logger.error(
"prometheus Layer Error(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
pass

View file

@ -180,6 +180,5 @@ class S3Logger:
print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
return response
except Exception as e:
traceback.print_exc()
verbose_logger.debug(f"s3 Layer Error - {str(e)}\n{traceback.format_exc()}")
pass

View file

@ -326,8 +326,8 @@ class SlackAlerting(CustomLogger):
end_time=end_time,
)
)
if litellm.turn_off_message_logging:
messages = "Message not logged. `litellm.turn_off_message_logging=True`."
if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
if time_difference_float > self.alerting_threshold:
@ -567,9 +567,12 @@ class SlackAlerting(CustomLogger):
except:
messages = ""
if litellm.turn_off_message_logging:
if (
litellm.turn_off_message_logging
or litellm.redact_messages_in_exceptions
):
messages = (
"Message not logged. `litellm.turn_off_message_logging=True`."
"Message not logged. litellm.redact_messages_in_exceptions=True"
)
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
else:

View file

@ -110,6 +110,5 @@ class Supabase:
)
except:
# traceback.print_exc()
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
pass

View file

@ -217,6 +217,5 @@ class WeightsBiasesLogger:
f"W&B Logging Logging - final response object: {response_obj}"
)
except:
# traceback.print_exc()
print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
pass

View file

@ -38,6 +38,8 @@ from .prompt_templates.factory import (
extract_between_tags,
parse_xml_params,
contains_tag,
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
@ -45,6 +47,11 @@ import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
from litellm.types.llms.bedrock import *
import urllib.parse
from litellm.types.llms.openai import (
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
)
class AmazonCohereChatConfig:
@ -118,6 +125,8 @@ class AmazonCohereChatConfig:
"presence_penalty",
"seed",
"stop",
"tools",
"tool_choice",
]
def map_openai_params(
@ -176,6 +185,37 @@ async def make_call(
return completion_stream
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.read())
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
class BedrockLLM(BaseLLM):
"""
Example call
@ -1000,12 +1040,12 @@ class BedrockLLM(BaseLLM):
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
client = client # type: ignore
try:
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
response = await client.post(api_base, headers=headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
@ -1069,6 +1109,745 @@ class BedrockLLM(BaseLLM):
return super().embedding(*args, **kwargs)
class AmazonConverseConfig:
"""
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
"""
maxTokens: Optional[int]
stopSequences: Optional[List[str]]
temperature: Optional[int]
topP: Optional[int]
def __init__(
self,
maxTokens: Optional[int] = None,
stopSequences: Optional[List[str]] = None,
temperature: Optional[int] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List[str]:
supported_params = [
"max_tokens",
"stream",
"stream_options",
"stop",
"temperature",
"top_p",
"extra_headers",
]
if (
model.startswith("anthropic")
or model.startswith("mistral")
or model.startswith("cohere")
):
supported_params.append("tools")
if model.startswith("anthropic") or model.startswith("mistral"):
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
supported_params.append("tool_choice")
return supported_params
def map_tool_choice_values(
self, model: str, tool_choice: Union[str, dict], drop_params: bool
) -> Optional[ToolChoiceValuesBlock]:
if tool_choice == "none":
if litellm.drop_params is True or drop_params is True:
return None
else:
raise litellm.utils.UnsupportedParamsError(
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
tool_choice
),
status_code=400,
)
elif tool_choice == "required":
return ToolChoiceValuesBlock(any={})
elif tool_choice == "auto":
return ToolChoiceValuesBlock(auto={})
elif isinstance(tool_choice, dict):
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
specific_tool = SpecificToolChoiceBlock(
name=tool_choice.get("function", {}).get("name", "")
)
return ToolChoiceValuesBlock(tool=specific_tool)
else:
raise litellm.utils.UnsupportedParamsError(
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
tool_choice
),
status_code=400,
)
def get_supported_image_types(self) -> List[str]:
return ["png", "jpeg", "gif", "webp"]
def map_openai_params(
self,
model: str,
non_default_params: dict,
optional_params: dict,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["maxTokens"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
value = [value]
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["topP"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice_value = self.map_tool_choice_values(
model=model, tool_choice=value, drop_params=drop_params # type: ignore
)
if _tool_choice_value is not None:
optional_params["tool_choice"] = _tool_choice_value
return optional_params
class BedrockConverseLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> Union[ModelResponse, CustomStreamWrapper]:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
except Exception as e:
raise BedrockError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
response.text, str(e)
),
status_code=422,
)
"""
Bedrock Response Object has optional message block
completion_response["output"].get("message", None)
A message block looks like this (Example 1):
"output": {
"message": {
"role": "assistant",
"content": [
{
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
}
]
}
},
(Example 2):
"output": {
"message": {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
"name": "top_song",
"input": {
"sign": "WZPZ"
}
}
}
]
}
}
"""
message: Optional[MessageBlock] = completion_response["output"]["message"]
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
if message is not None:
for content in message["content"]:
"""
- Content is either a tool response or text
"""
if "text" in content:
content_str += content["text"]
if "toolUse" in content:
_function_chunk = ChatCompletionToolCallFunctionChunk(
name=content["toolUse"]["name"],
arguments=json.dumps(content["toolUse"]["input"]),
)
_tool_response_chunk = ChatCompletionToolCallChunk(
id=content["toolUse"]["toolUseId"],
type="function",
function=_function_chunk,
)
tools.append(_tool_response_chunk)
chat_completion_message["content"] = content_str
chat_completion_message["tool_calls"] = tools
## CALCULATING USAGE - bedrock returns usage in the headers
input_tokens = completion_response["usage"]["inputTokens"]
output_tokens = completion_response["usage"]["outputTokens"]
total_tokens = completion_response["usage"]["totalTokens"]
model_response.choices = [
litellm.Choices(
finish_reason=map_finish_reason(completion_response["stopReason"]),
index=0,
message=litellm.Message(**chat_completion_message),
)
]
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=total_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def encode_model_id(self, model_id: str) -> str:
"""
Double encode the model ID to ensure it matches the expected double-encoded format.
Args:
model_id (str): The model ID to encode.
Returns:
str: The double-encoded model ID.
"""
return urllib.parse.quote(model_id, safe="")
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
) = params_to_check
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
session = boto3.Session(
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
return session.get_credentials()
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
from botocore.credentials import Credentials
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
return credentials
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
client=client,
api_base=api_base,
headers=headers,
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def completion(
self,
model: str,
messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
):
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## SETUP ##
stream = optional_params.pop("stream", None)
modelId = optional_params.pop("model_id", None)
if modelId is not None:
modelId = self.encode_model_id(model_id=modelId)
else:
modelId = model
provider = model.split(".")[0]
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
)
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
# Separate system prompt from rest of message
system_prompt_indices = []
system_content_blocks: List[SystemContentBlock] = []
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block = SystemContentBlock(text=message["content"])
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
inference_params = copy.deepcopy(optional_params)
additional_request_keys = []
additional_request_params = {}
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
supported_tool_call_params = ["tools", "tool_choice"]
## TRANSFORMATION ##
# send all model-specific params in 'additional_request_params'
for k, v in inference_params.items():
if (
k not in supported_converse_params
and k not in supported_tool_call_params
):
additional_request_params[k] = v
additional_request_keys.append(k)
for key in additional_request_keys:
inference_params.pop(key, None)
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", [])
)
bedrock_tool_config: Optional[ToolConfigBlock] = None
if len(bedrock_tools) > 0:
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
"tool_choice", None
)
bedrock_tool_config = ToolConfigBlock(
tools=bedrock_tools,
)
if tool_choice_values is not None:
bedrock_tool_config["toolChoice"] = tool_choice_values
_data: RequestObject = {
"messages": bedrock_messages,
"additionalModelRequestFields": additional_request_params,
"system": system_content_blocks,
"inferenceConfig": InferenceConfig(**inference_params),
}
if bedrock_tool_config is not None:
_data["toolConfig"] = bedrock_tool_config
data = json.dumps(_data)
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"headers": prepped.headers,
},
)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream is True and provider != "ai21":
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if (stream is not None and stream is True) and provider != "ai21":
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_sync_call,
client=None,
api_base=prepped.url,
headers=prepped.headers, # type: ignore
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=streaming_response,
additional_args={"complete_input_dict": data},
)
return streaming_response
### COMPLETION
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = HTTPHandler(**_params) # type: ignore
else:
client = client
try:
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
def get_response_stream_shape():
from botocore.model import ServiceModel
from botocore.loaders import Loader
@ -1086,6 +1865,31 @@ class AWSEventStreamDecoder:
self.model = model
self.parser = EventStreamJSONParser()
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
tool_str = ""
is_finished = False
finish_reason = ""
usage: Optional[ConverseTokenUsageBlock] = None
if "delta" in chunk_data:
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
if "text" in delta_obj:
text = delta_obj["text"]
elif "toolUse" in delta_obj:
tool_str = delta_obj["toolUse"]["input"]
elif "stopReason" in chunk_data:
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
elif "usage" in chunk_data:
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
response = GenericStreamingChunk(
text=text,
tool_str=tool_str,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
)
return response
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
is_finished = False
@ -1098,19 +1902,8 @@ class AWSEventStreamDecoder:
is_finished = True
finish_reason = "stop"
######## bedrock.anthropic mappings ###############
elif "completion" in chunk_data: # not claude-3
text = chunk_data["completion"] # bedrock.anthropic
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
elif "delta" in chunk_data:
if chunk_data["delta"].get("text", None) is not None:
text = chunk_data["delta"]["text"]
stop_reason = chunk_data["delta"].get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
return self.converse_chunk_parser(chunk_data=chunk_data)
######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data:
if (
@ -1137,11 +1930,11 @@ class AWSEventStreamDecoder:
is_finished = True
finish_reason = chunk_data["completionReason"]
return GenericStreamingChunk(
**{
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
text=text,
is_finished=is_finished,
finish_reason=finish_reason,
tool_str="",
usage=None,
)
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
@ -1178,9 +1971,14 @@ class AWSEventStreamDecoder:
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
if "chunk" in parsed_response:
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]
else:
chunk = response_dict.get("body")
if not chunk:
return None
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]
return chunk.decode() # type: ignore[no-any-return]

View file

@ -156,12 +156,13 @@ class HTTPHandler:
self,
url: str,
data: Optional[Union[dict, str]] = None,
json: Optional[Union[dict, str]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
req = self.client.build_request(
"POST", url, data=data, params=params, headers=headers # type: ignore
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
)
response = self.client.send(req, stream=stream)
return response

View file

@ -1,13 +1,14 @@
import os, types, traceback, copy, asyncio
import json
from enum import Enum
import types
import traceback
import copy
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import sys, httpx
import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
from packaging.version import Version
from litellm import verbose_logger
class GeminiError(Exception):
@ -264,7 +265,8 @@ def completion(
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
traceback.print_exc()
verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
verbose_logger.debug(traceback.format_exc())
raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code
)
@ -356,7 +358,8 @@ async def async_completion(
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
traceback.print_exc()
verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
verbose_logger.debug(traceback.format_exc())
raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code
)

View file

@ -7,6 +7,7 @@ import litellm
from litellm.types.utils import ProviderField
import httpx, aiohttp, asyncio # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm import verbose_logger
class OllamaError(Exception):
@ -137,6 +138,7 @@ class OllamaConfig:
)
]
def get_supported_openai_params(
self,
):
@ -151,10 +153,12 @@ class OllamaConfig:
"response_format",
]
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
# and convert to jpeg if necessary.
def _convert_image(image):
import base64, io
try:
from PIL import Image
except:
@ -404,7 +408,13 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"LiteLLM.ollama.py::ollama_async_streaming(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
raise e
@ -468,7 +478,12 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
)
return model_response
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"LiteLLM.ollama.py::ollama_acompletion(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
raise e

View file

@ -1,11 +1,15 @@
from itertools import chain
import requests, types, time
import json, uuid
import requests
import types
import time
import json
import uuid
import traceback
from typing import Optional
from litellm import verbose_logger
import litellm
import httpx, aiohttp, asyncio
from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx
import aiohttp
class OllamaError(Exception):
@ -299,7 +303,10 @@ def get_ollama_response(
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
@ -307,7 +314,9 @@ def get_ollama_response(
model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls"
else:
model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
model_response["choices"][0]["message"]["content"] = response_json["message"][
"content"
]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
@ -361,7 +370,10 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
@ -410,9 +422,10 @@ async def ollama_async_streaming(
first_chunk_content = first_chunk.choices[0].delta.content or ""
response_content = first_chunk_content + "".join(
[
chunk.choices[0].delta.content
async for chunk in streamwrapper
if chunk.choices[0].delta.content]
chunk.choices[0].delta.content
async for chunk in streamwrapper
if chunk.choices[0].delta.content
]
)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
@ -420,7 +433,10 @@ async def ollama_async_streaming(
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
@ -433,7 +449,8 @@ async def ollama_async_streaming(
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
traceback.print_exc()
verbose_logger.error("LiteLLM.gemini(): Exception occured - {}".format(str(e)))
verbose_logger.debug(traceback.format_exc())
async def ollama_acompletion(
@ -483,7 +500,10 @@ async def ollama_acompletion(
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
@ -491,7 +511,9 @@ async def ollama_acompletion(
model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls"
else:
model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
model_response["choices"][0]["message"]["content"] = response_json[
"message"
]["content"]
model_response["created"] = int(time.time())
model_response["model"] = "ollama_chat/" + data["model"]
@ -509,5 +531,9 @@ async def ollama_acompletion(
)
return model_response
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"LiteLLM.ollama_acompletion(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
raise e

View file

@ -1,11 +1,12 @@
import os, types, traceback, copy
import json
from enum import Enum
import types
import traceback
import copy
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import sys, httpx
import httpx
from litellm import verbose_logger
class PalmError(Exception):
@ -165,7 +166,10 @@ def completion(
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.llms.palm.py::completion(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
raise PalmError(
message=traceback.format_exc(), status_code=response.status_code
)

View file

@ -3,14 +3,7 @@ import requests, traceback
import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, meta, BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment
from typing import (
Any,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
)
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
import litellm
import litellm.types
from litellm.types.completion import (
@ -24,7 +17,7 @@ from litellm.types.completion import (
import litellm.types.llms
from litellm.types.llms.anthropic import *
import uuid
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
import litellm.types.llms.vertex_ai
@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url):
try:
from PIL import Image
except:
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
raise Exception("image conversion failed please run `pip install Pillow`")
from io import BytesIO
try:
@ -1613,6 +1604,380 @@ def azure_text_pt(messages: list):
return prompt
###### AMAZON BEDROCK #######
from litellm.types.llms.bedrock import (
ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultBlock as BedrockToolResultBlock,
ToolConfigBlock as BedrockToolConfigBlock,
ToolUseBlock as BedrockToolUseBlock,
ImageSourceBlock as BedrockImageSourceBlock,
ImageBlock as BedrockImageBlock,
ContentBlock as BedrockContentBlock,
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
ToolSpecBlock as BedrockToolSpecBlock,
ToolBlock as BedrockToolBlock,
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
)
def get_image_details(image_url) -> Tuple[str, str]:
try:
import base64
# Send a GET request to the image URL
response = requests.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors
# Check the response's content type to ensure it is an image
content_type = response.headers.get("content-type")
if not content_type or "image" not in content_type:
raise ValueError(
f"URL does not point to a valid image (content-type: {content_type})"
)
# Convert the image content to base64 bytes
base64_bytes = base64.b64encode(response.content).decode("utf-8")
# Get mime-type
mime_type = content_type.split("/")[
1
] # Extract mime-type from content-type header
return base64_bytes, mime_type
except requests.RequestException as e:
raise Exception(f"Request failed: {e}")
except Exception as e:
raise e
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
if "base64" in image_url:
# Case 1: Images with base64 encoding
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = image_url.split(",")
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
mime_type = mime_type_match.group(1)
image_format = mime_type.split("/")[1]
else:
mime_type = "image/jpeg"
image_format = "jpeg"
_blob = BedrockImageSourceBlock(bytes=img_without_base_64)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
if image_format in supported_image_formats:
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
elif "https:/" in image_url:
# Case 2: Images with direct links
image_bytes, image_format = get_image_details(image_url)
_blob = BedrockImageSourceBlock(bytes=image_bytes)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
if image_format in supported_image_formats:
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
else:
raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string - \
e.g. 'data:image/jpeg;base64,<base64-encoded-string>'"
)
def _convert_to_bedrock_tool_call_invoke(
tool_calls: list,
) -> List[BedrockContentBlock]:
"""
OpenAI tool invokes:
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
}
}
]
},
"""
"""
Bedrock tool invokes:
[
{
"role": "assistant",
"toolUse": {
"input": {"location": "Boston, MA", ..},
"name": "get_current_weather",
"toolUseId": "call_abc123"
}
}
]
"""
"""
- json.loads argument
- extract name
- extract id
"""
try:
_parts_list: List[BedrockContentBlock] = []
for tool in tool_calls:
if "function" in tool:
id = tool["id"]
name = tool["function"].get("name", "")
arguments = tool["function"].get("arguments", "")
arguments_dict = json.loads(arguments)
bedrock_tool = BedrockToolUseBlock(
input=arguments_dict, name=name, toolUseId=id
)
bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool)
_parts_list.append(bedrock_content_block)
return _parts_list
except Exception as e:
raise Exception(
"Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format(
tool_calls, str(e)
)
)
def _convert_to_bedrock_tool_call_result(
message: dict,
) -> BedrockMessageBlock:
"""
OpenAI message with a tool result looks like:
{
"tool_call_id": "tool_1",
"role": "tool",
"name": "get_current_weather",
"content": "function result goes here",
},
OpenAI message with a function call result looks like:
{
"role": "function",
"name": "get_current_weather",
"content": "function result goes here",
}
"""
"""
Bedrock result looks like this:
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q",
"content": [
{
"json": {
"song": "Elemental Hotel",
"artist": "8 Storey Hike"
}
}
]
}
}
]
}
"""
"""
-
"""
content = message.get("content", "")
name = message.get("name", "")
id = message.get("tool_call_id", str(uuid.uuid4()))
tool_result_content_block = BedrockToolResultContentBlock(text=content)
tool_result = BedrockToolResultBlock(
content=[tool_result_content_block],
toolUseId=id,
)
content_block = BedrockContentBlock(toolResult=tool_result)
return BedrockMessageBlock(role="user", content=[content_block])
def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]:
"""
Converts given messages from OpenAI format to Bedrock format
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
- Please ensure that function response turn comes immediately after a function call turn
"""
contents: List[BedrockMessageBlock] = []
msg_i = 0
while msg_i < len(messages):
user_content: List[BedrockContentBlock] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "user":
if isinstance(messages[msg_i]["content"], list):
_parts: List[BedrockContentBlock] = []
for element in messages[msg_i]["content"]:
if isinstance(element, dict):
if element["type"] == "text":
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
_parts.append(BedrockContentBlock(image=_part)) # type: ignore
user_content.extend(_parts)
else:
_part = BedrockContentBlock(text=messages[msg_i]["content"])
user_content.append(_part)
msg_i += 1
if user_content:
contents.append(BedrockMessageBlock(role="user", content=user_content))
assistant_content: List[BedrockContentBlock] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
if isinstance(messages[msg_i]["content"], list):
assistants_parts: List[BedrockContentBlock] = []
for element in messages[msg_i]["content"]:
if isinstance(element, dict):
if element["type"] == "text":
assistants_part = BedrockContentBlock(text=element["text"])
assistants_parts.append(assistants_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
assistants_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
assistants_parts.append(
BedrockContentBlock(image=assistants_part) # type: ignore
)
assistant_content.extend(assistants_parts)
elif messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion
assistant_content.extend(
_convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"])
)
else:
assistant_text = (
messages[msg_i].get("content") or ""
) # either string or none
if assistant_text:
assistant_content.append(BedrockContentBlock(text=assistant_text))
msg_i += 1
if assistant_content:
contents.append(
BedrockMessageBlock(role="assistant", content=assistant_content)
)
## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
contents.append(tool_call_result)
msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
return contents
def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
"""
OpenAI tools looks like:
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
}
]
"""
"""
Bedrock toolConfig looks like:
"tools": [
{
"toolSpec": {
"name": "top_song",
"description": "Get the most popular song played on a radio station.",
"inputSchema": {
"json": {
"type": "object",
"properties": {
"sign": {
"type": "string",
"description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP."
}
},
"required": [
"sign"
]
}
}
}
}
]
"""
tool_block_list: List[BedrockToolBlock] = []
for tool in tools:
parameters = tool.get("function", {}).get("parameters", None)
name = tool.get("function", {}).get("name", "")
description = tool.get("function", {}).get("description", "")
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
tool_spec = BedrockToolSpecBlock(
inputSchema=tool_input_schema, name=name, description=description
)
tool_block = BedrockToolBlock(toolSpec=tool_spec)
tool_block_list.append(tool_block)
return tool_block_list
# Function call template
def function_call_prompt(messages: list, functions: list):
function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:"""

View file

@ -12,6 +12,7 @@ from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_result,
convert_to_gemini_tool_call_invoke,
)
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type
class VertexAIError(Exception):
@ -297,29 +298,31 @@ def _convert_gemini_role(role: str) -> Literal["user", "model"]:
def _process_gemini_image(image_url: str) -> PartType:
try:
if ".mp4" in image_url and "gs://" in image_url:
# Case 1: Videos with Cloud Storage URIs
part_mime = "video/mp4"
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
return PartType(file_data=_file_data)
elif ".pdf" in image_url and "gs://" in image_url:
# Case 2: PDF's with Cloud Storage URIs
part_mime = "application/pdf"
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
return PartType(file_data=_file_data)
elif "gs://" in image_url:
# Case 3: Images with Cloud Storage URIs
# The supported MIME types for images include image/png and image/jpeg.
part_mime = "image/png" if "png" in image_url else "image/jpeg"
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
return PartType(file_data=_file_data)
# GCS URIs
if "gs://" in image_url:
# Figure out file type
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
extension = extension_with_dot[1:] # Ex: "png"
file_type = get_file_type_from_extension(extension)
# Validate the file type is supported by Gemini
if not is_gemini_1_5_accepted_file_type(file_type):
raise Exception(f"File type not supported by gemini - {file_type}")
mime_type = get_file_mime_type_for_file_type(file_type)
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
return PartType(file_data=file_data)
# Direct links
elif "https:/" in image_url:
# Case 4: Images with direct links
image = _load_image_from_url(image_url)
_blob = BlobType(data=image.data, mime_type=image._mime_type)
return PartType(inline_data=_blob)
# Base64 encoding
elif "base64" in image_url:
# Case 5: Images with base64 encoding
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
@ -426,112 +429,6 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
return contents
def _gemini_vision_convert_messages(messages: list):
"""
Converts given messages for GPT-4 Vision to Gemini format.
Args:
messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type:
- If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt.
- If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images.
Returns:
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
Raises:
VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed.
Exception: If any other exception occurs during the execution of the function.
Note:
This function is based on the code from the 'gemini/getting-started/intro_gemini_python.ipynb' notebook in the 'generative-ai' repository on GitHub.
The supported MIME types for images include 'image/png' and 'image/jpeg'.
Examples:
>>> messages = [
... {"content": "Hello, world!"},
... {"content": [{"type": "text", "text": "This is a text message."}, {"type": "image_url", "image_url": "example.com/image.png"}]},
... ]
>>> _gemini_vision_convert_messages(messages)
('Hello, world!This is a text message.', [<Part object>, <Part object>])
"""
try:
import vertexai
except:
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
try:
from vertexai.preview.language_models import (
ChatModel,
CodeChatModel,
InputOutputTextPair,
)
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
Image,
)
# given messages for gpt-4 vision, convert them for gemini
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
prompt = ""
images = []
for message in messages:
if isinstance(message["content"], str):
prompt += message["content"]
elif isinstance(message["content"], list):
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
for element in message["content"]:
if isinstance(element, dict):
if element["type"] == "text":
prompt += element["text"]
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
images.append(image_url)
# processing images passed to gemini
processed_images = []
for img in images:
if "gs://" in img:
# Case 1: Images with Cloud Storage URIs
# The supported MIME types for images include image/png and image/jpeg.
part_mime = "image/png" if "png" in img else "image/jpeg"
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
processed_images.append(google_clooud_part)
elif "https:/" in img:
# Case 2: Images with direct links
image = _load_image_from_url(img)
processed_images.append(image)
elif ".mp4" in img and "gs://" in img:
# Case 3: Videos with Cloud Storage URIs
part_mime = "video/mp4"
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
processed_images.append(google_clooud_part)
elif "base64" in img:
# Case 4: Images with base64 encoding
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = img.split(",")
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
mime_type = mime_type_match.group(1)
else:
mime_type = "image/jpeg"
decoded_img = base64.b64decode(img_without_base_64)
processed_image = Part.from_data(data=decoded_img, mime_type=mime_type)
processed_images.append(processed_image)
return prompt, processed_images
except Exception as e:
raise e
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
return _cache_key
@ -647,9 +544,9 @@ def completion(
prompt = " ".join(
[
message["content"]
message.get("content")
for message in messages
if isinstance(message["content"], str)
if isinstance(message.get("content", None), str)
]
)

View file

@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
@ -122,6 +122,7 @@ huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
####### COMPLETION ENDPOINTS ################
@ -364,7 +365,10 @@ async def acompletion(
) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
return response
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.acompletion(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model,
@ -477,7 +481,10 @@ def mock_completion(
except Exception as e:
if isinstance(e, openai.APIError):
raise e
traceback.print_exc()
verbose_logger.error(
"litellm.mock_completion(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
raise Exception("Mock completion response failed")
@ -2100,22 +2107,40 @@ def completion(
logging_obj=logging,
)
else:
response = bedrock_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
client=client,
)
if model.startswith("anthropic"):
response = bedrock_converse_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
client=client,
)
else:
response = bedrock_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
client=client,
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
@ -4433,7 +4458,10 @@ async def ahealth_check(
response = {} # args like remaining ratelimit etc.
return response
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.ahealth_check(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
stack_trace = traceback.format_exc()
if isinstance(stack_trace, str):
stack_trace = stack_trace[:1000]

View file

@ -1,6 +1,7 @@
import json
import logging
from logging import Formatter
import sys
class JsonFormatter(Formatter):

View file

@ -56,8 +56,10 @@ router_settings:
litellm_settings:
success_callback: ["langfuse"]
json_logs: true
general_settings:
alerting: ["email"]
key_management_system: "aws_kms"
key_management_settings:
hosted_keys: ["LITELLM_MASTER_KEY"]

View file

@ -955,6 +955,7 @@ class KeyManagementSystem(enum.Enum):
AZURE_KEY_VAULT = "azure_key_vault"
AWS_SECRET_MANAGER = "aws_secret_manager"
LOCAL = "local"
AWS_KMS = "aws_kms"
class KeyManagementSettings(LiteLLMBase):

View file

@ -106,6 +106,40 @@ def common_checks(
raise Exception(
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
)
if general_settings.get("enforced_params", None) is not None:
# Enterprise ONLY Feature
# we already validate if user is premium_user when reading the config
# Add an extra premium_usercheck here too, just incase
from litellm.proxy.proxy_server import premium_user, CommonProxyErrors
if premium_user is not True:
raise ValueError(
"Trying to use `enforced_params`"
+ CommonProxyErrors.not_premium_user.value
)
if route in LiteLLMRoutes.openai_routes.value:
# loop through each enforced param
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
for enforced_param in general_settings["enforced_params"]:
_enforced_params = enforced_param.split(".")
if len(_enforced_params) == 1:
if _enforced_params[0] not in request_body:
raise ValueError(
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
)
elif len(_enforced_params) == 2:
# this is a scenario where user requires request['metadata']['generation_name'] to exist
if _enforced_params[0] not in request_body:
raise ValueError(
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
)
if _enforced_params[1] not in request_body[_enforced_params[0]]:
raise ValueError(
f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param"
)
pass
# 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
if (
litellm.max_budget > 0

View file

@ -88,7 +88,7 @@ class _PROXY_AzureContentSafety(
verbose_proxy_logger.debug(
"Error in Azure Content-Safety: %s", traceback.format_exc()
)
traceback.print_exc()
verbose_proxy_logger.debug(traceback.format_exc())
raise
result = self._compute_result(response)
@ -123,7 +123,12 @@ class _PROXY_AzureContentSafety(
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_post_call_success_hook(
self,

View file

@ -94,7 +94,12 @@ class _PROXY_BatchRedisRequests(CustomLogger):
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_get_cache(self, *args, **kwargs):
"""

View file

@ -1,13 +1,13 @@
# What this does?
## Checks if key is allowed to use the cache controls passed in to the completion() call
from typing import Optional
import litellm
from litellm import verbose_logger
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
import json, traceback
import traceback
class _PROXY_CacheControlCheck(CustomLogger):
@ -54,4 +54,9 @@ class _PROXY_CacheControlCheck(CustomLogger):
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.proxy.hooks.cache_control_check.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())

View file

@ -1,10 +1,10 @@
from typing import Optional
from litellm import verbose_logger
import litellm
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
import json, traceback
import traceback
class _PROXY_MaxBudgetLimiter(CustomLogger):
@ -44,4 +44,9 @@ class _PROXY_MaxBudgetLimiter(CustomLogger):
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())

View file

@ -8,8 +8,8 @@
# Tell us how we can improve! - Krrish & Ishaan
from typing import Optional, Literal, Union
import litellm, traceback, sys, uuid, json
from typing import Optional, Union
import litellm, traceback, uuid, json # noqa: E401
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
@ -21,8 +21,8 @@ from litellm.utils import (
ImageResponse,
StreamingChoices,
)
from datetime import datetime
import aiohttp, asyncio
import aiohttp
import asyncio
class _OPTIONAL_PresidioPIIMasking(CustomLogger):
@ -138,7 +138,12 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
else:
raise Exception(f"Invalid anonymizer response: {redacted_text}")
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e
async def async_pre_call_hook(

View file

@ -204,7 +204,12 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
return e.detail["error"]
raise e
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_moderation_hook(
self,

View file

@ -21,7 +21,14 @@ model_list:
general_settings:
master_key: sk-1234
alerting: ["slack"]
litellm_settings:
callbacks: ["otel"]
store_audit_logs: true
store_audit_logs: true
redact_messages_in_exceptions: True
enforced_params:
- user
- metadata
- metadata.generation_name

View file

@ -111,6 +111,7 @@ from litellm.proxy.utils import (
encrypt_value,
decrypt_value,
_to_ns,
get_error_message_str,
)
from litellm import (
CreateBatchRequest,
@ -120,7 +121,10 @@ from litellm import (
CreateFileRequest,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_secret_manager,
load_aws_kms,
)
import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache
@ -133,7 +137,10 @@ from litellm.router import (
AssistantsTypedDict,
)
from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm._logging import (
verbose_router_logger,
verbose_proxy_logger,
)
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck
from litellm.proxy.auth.model_checks import (
@ -1515,7 +1522,12 @@ async def user_api_key_auth(
else:
raise Exception()
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, litellm.BudgetExceededError):
raise ProxyException(
message=e.message, type="auth_error", param=None, code=400
@ -2782,10 +2794,12 @@ class ProxyConfig:
load_google_kms(use_google_kms=True)
elif (
key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
):
### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True)
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
else:
raise ValueError("Invalid Key Management System selected")
key_management_settings = general_settings.get(
@ -2819,6 +2833,7 @@ class ProxyConfig:
master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
if not isinstance(master_key, str):
@ -2909,6 +2924,16 @@ class ProxyConfig:
)
health_check_interval = general_settings.get("health_check_interval", 300)
## check if user has set a premium feature in general_settings
if (
general_settings.get("enforced_params") is not None
and premium_user is not True
):
raise ValueError(
"Trying to use `enforced_params`"
+ CommonProxyErrors.not_premium_user.value
)
router_params: dict = {
"cache_responses": litellm.cache
!= None, # cache if user passed in cache values
@ -3522,7 +3547,12 @@ async def generate_key_helper_fn(
)
key_data["token_id"] = getattr(create_key_response, "token", None)
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.generate_key_helper_fn(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise e
raise HTTPException(
@ -3561,7 +3591,12 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
else:
raise Exception("DB not connected. prisma_client is None")
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.delete_verification_token(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e
return deleted_tokens
@ -3722,7 +3757,12 @@ async def async_assistants_data_generator(
done_message = "[DONE]"
yield f"data: {done_message}\n\n"
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.async_assistants_data_generator(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
@ -3732,9 +3772,6 @@ async def async_assistants_data_generator(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
router_model_names = llm_router.model_names if llm_router is not None else []
if user_debug:
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
else:
@ -3774,7 +3811,12 @@ async def async_data_generator(
done_message = "[DONE]"
yield f"data: {done_message}\n\n"
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
@ -3784,8 +3826,6 @@ async def async_data_generator(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
router_model_names = llm_router.model_names if llm_router is not None else []
if user_debug:
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
@ -3846,6 +3886,18 @@ def on_backoff(details):
verbose_proxy_logger.debug("Backing off... this was attempt # %s", details["tries"])
def giveup(e):
result = not (
isinstance(e, ProxyException)
and getattr(e, "message", None) is not None
and isinstance(e.message, str)
and "Max parallel request limit reached" in e.message
)
if result:
verbose_proxy_logger.info(json.dumps({"event": "giveup", "exception": str(e)}))
return result
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db
@ -4130,12 +4182,8 @@ def model_list(
max_tries=litellm.num_retries or 3, # maximum number of retries
max_time=litellm.request_timeout or 60, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
giveup=lambda e: not (
isinstance(e, ProxyException)
and getattr(e, "message", None) is not None
and isinstance(e.message, str)
and "Max parallel request limit reached" in e.message
), # the result of the logical expression is on the second position
giveup=giveup,
logger=verbose_proxy_logger,
)
async def chat_completion(
request: Request,
@ -4144,6 +4192,7 @@ async def chat_completion(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global general_settings, user_debug, proxy_logging_obj, llm_model_list
data = {}
try:
body = await request.body()
@ -4434,7 +4483,12 @@ async def chat_completion(
return _chat_response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}".format(
get_error_message_str(e=e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@ -4445,8 +4499,6 @@ async def chat_completion(
litellm_debug_info,
)
router_model_names = llm_router.model_names if llm_router is not None else []
if user_debug:
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
@ -4678,15 +4730,12 @@ async def completion(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.debug("EXCEPTION RAISED IN PROXY MAIN.PY")
litellm_debug_info = getattr(e, "litellm_debug_info", "")
verbose_proxy_logger.debug(
"\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
e,
litellm_debug_info,
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.completion(): Exception occured - {}".format(
str(e)
)
)
traceback.print_exc()
error_traceback = traceback.format_exc()
verbose_proxy_logger.debug(traceback.format_exc())
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
@ -4896,7 +4945,12 @@ async def embeddings(
e,
litellm_debug_info,
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.embeddings(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e)),
@ -5075,7 +5129,12 @@ async def image_generation(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.image_generation(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e)),
@ -5253,7 +5312,12 @@ async def audio_speech(
)
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.audio_speech(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e
@ -5442,7 +5506,12 @@ async def audio_transcriptions(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.audio_transcription(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
@ -5451,7 +5520,6 @@ async def audio_transcriptions(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
@ -5579,7 +5647,12 @@ async def get_assistants(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.get_assistants(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
@ -5588,7 +5661,6 @@ async def get_assistants(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
@ -5708,7 +5780,12 @@ async def create_threads(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.create_threads(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
@ -5717,7 +5794,6 @@ async def create_threads(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
@ -5836,7 +5912,12 @@ async def get_thread(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.get_thread(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
@ -5845,7 +5926,6 @@ async def get_thread(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
@ -5967,7 +6047,12 @@ async def add_messages(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.add_messages(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
@ -5976,7 +6061,6 @@ async def add_messages(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
@ -6094,7 +6178,12 @@ async def get_messages(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.get_messages(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
@ -6103,7 +6192,6 @@ async def get_messages(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
@ -6235,7 +6323,12 @@ async def run_thread(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.run_thread(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
@ -6244,7 +6337,6 @@ async def run_thread(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
@ -6383,7 +6475,12 @@ async def create_batch(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
@ -6392,7 +6489,6 @@ async def create_batch(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
@ -6526,7 +6622,12 @@ async def retrieve_batch(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
@ -6679,7 +6780,12 @@ async def create_file(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.create_file(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
@ -6688,7 +6794,6 @@ async def create_file(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
@ -6864,7 +6969,12 @@ async def moderations(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.moderations(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e)),
@ -6873,7 +6983,6 @@ async def moderations(
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
@ -7184,7 +7293,12 @@ async def generate_key_fn(
return GenerateKeyResponse(**response)
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.generate_key_fn(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -9639,7 +9753,12 @@ async def user_info(
}
return response_data
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_info(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -9734,7 +9853,12 @@ async def user_update(data: UpdateUserRequest):
return response
# update based on remaining passed in values
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_update(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -9787,7 +9911,12 @@ async def user_request_model(request: Request):
return {"status": "success"}
# update based on remaining passed in values
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_request_model(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -9829,7 +9958,12 @@ async def user_get_requests():
return {"requests": response}
# update based on remaining passed in values
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_get_requests(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -10219,7 +10353,12 @@ async def update_end_user(
# update based on remaining passed in values
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.update_end_user(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
@ -10303,7 +10442,12 @@ async def delete_end_user(
# update based on remaining passed in values
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.delete_end_user(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
@ -11606,7 +11750,12 @@ async def add_new_model(
return model_response
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.add_new_model(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -11720,7 +11869,12 @@ async def update_model(
return model_response
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.update_model(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -13954,7 +14108,12 @@ async def update_config(config_info: ConfigYAML):
return {"message": "Config updated successfully"}
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.update_config(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -14427,7 +14586,12 @@ async def get_config():
"available_callbacks": all_available_callbacks,
}
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.get_config(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -14678,7 +14842,12 @@ async def health_services_endpoint(
}
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.health_services_endpoint(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -14757,7 +14926,12 @@ async def health_endpoint(
"unhealthy_count": len(unhealthy_endpoints),
}
except Exception as e:
traceback.print_exc()
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.py::health_endpoint(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e

View file

@ -8,7 +8,8 @@ Requires:
* `pip install boto3>=1.28.57`
"""
import litellm, os
import litellm
import os
from typing import Optional
from litellm.proxy._types import KeyManagementSystem
@ -38,3 +39,21 @@ def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
except Exception as e:
raise e
def load_aws_kms(use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False:
return
try:
import boto3
validate_environment()
# Create a Secrets Manager client
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
litellm.secret_manager_client = kms_client
litellm._key_management_system = KeyManagementSystem.AWS_KMS
except Exception as e:
raise e

View file

@ -2889,3 +2889,16 @@ missing_keys_html_form = """
def _to_ns(dt):
return int(dt.timestamp() * 1e9)
def get_error_message_str(e: Exception) -> str:
error_message = ""
if isinstance(e, HTTPException):
if isinstance(e.detail, str):
error_message = e.detail
elif isinstance(e.detail, dict):
error_message = json.dumps(e.detail)
else:
error_message = str(e)
else:
error_message = str(e)
return error_message

2
litellm/py.typed Normal file
View file

@ -0,0 +1,2 @@
# Marker file to instruct type checkers to look for inline type annotations in this package.
# See PEP 561 for more information.

View file

@ -220,8 +220,6 @@ class Router:
[]
) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
### SCHEDULER ###
self.scheduler = Scheduler(polling_interval=polling_interval)
### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
redis_cache = None
@ -259,6 +257,10 @@ class Router:
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
### SCHEDULER ###
self.scheduler = Scheduler(
polling_interval=polling_interval, redis_cache=redis_cache
)
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
self.default_max_parallel_requests = default_max_parallel_requests
@ -2096,8 +2098,8 @@ class Router:
except Exception as e:
raise e
except Exception as e:
verbose_router_logger.debug(f"An exception occurred - {str(e)}")
traceback.print_exc()
verbose_router_logger.error(f"An exception occurred - {str(e)}")
verbose_router_logger.debug(traceback.format_exc())
raise original_exception
async def async_function_with_retries(self, *args, **kwargs):
@ -4048,6 +4050,12 @@ class Router:
for idx in reversed(invalid_model_indices):
_returned_deployments.pop(idx)
## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2)
if len(_returned_deployments) > 0:
_returned_deployments = litellm.utils._get_order_filtered_deployments(
_returned_deployments
)
return _returned_deployments
def _common_checks_available_deployment(

View file

@ -1,11 +1,9 @@
#### What this does ####
# picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator
import os, requests, random # type: ignore
from pydantic import BaseModel
from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta
import random
from litellm import verbose_logger
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
@ -119,7 +117,12 @@ class LowestCostLoggingHandler(CustomLogger):
if self.test_flag:
self.logged_success += 1
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -201,7 +204,12 @@ class LowestCostLoggingHandler(CustomLogger):
if self.test_flag:
self.logged_success += 1
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass
async def async_get_available_deployments(

View file

@ -1,16 +1,16 @@
#### What this does ####
# picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator # type: ignore
import dotenv, os, requests, random # type: ignore
from pydantic import BaseModel
import random
from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta
import random
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm import ModelResponse
from litellm import token_counter
import litellm
from litellm import verbose_logger
class LiteLLMBase(BaseModel):
@ -165,7 +165,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
if self.test_flag:
self.logged_success += 1
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
@ -229,7 +234,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
# do nothing if it's not a timeout error
return
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -352,7 +362,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
if self.test_flag:
self.logged_success += 1
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.router_strategy.lowest_latency.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass
def get_available_deployments(

View file

@ -11,6 +11,7 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
@ -23,16 +24,20 @@ class LiteLLMBase(BaseModel):
# if using pydantic v1
return self.dict()
class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler(CustomLogger):
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}):
def __init__(
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
):
self.router_cache = router_cache
self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
@ -72,19 +77,28 @@ class LowestTPMLoggingHandler(CustomLogger):
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
self.router_cache.set_cache(
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
self.router_cache.set_cache(
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
traceback.print_exc()
verbose_router_logger.error(
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
verbose_router_logger.debug(traceback.format_exc())
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -123,19 +137,28 @@ class LowestTPMLoggingHandler(CustomLogger):
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
self.router_cache.set_cache(
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
self.router_cache.set_cache(
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
traceback.print_exc()
verbose_router_logger.error(
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
verbose_router_logger.debug(traceback.format_exc())
pass
def get_available_deployments(

View file

@ -1,19 +1,19 @@
#### What this does ####
# identifies lowest tpm deployment
from pydantic import BaseModel
import dotenv, os, requests, random
import random
from typing import Optional, Union, List, Dict
import datetime as datetime_og
from datetime import datetime
import traceback, asyncio, httpx
import traceback
import httpx
import litellm
from litellm import token_counter
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger
from litellm._logging import verbose_router_logger, verbose_logger
from litellm.utils import print_verbose, get_utc_datetime
from litellm.types.router import RouterErrors
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
@ -22,12 +22,14 @@ class LiteLLMBase(BaseModel):
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
except Exception as e:
# if using pydantic v1
return self.dict()
class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler_v2(CustomLogger):
"""
@ -47,7 +49,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}):
def __init__(
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
):
self.router_cache = router_cache
self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
@ -104,7 +108,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
)
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = self.router_cache.increment_cache(key=rpm_key, value=1, ttl=self.routing_args.ttl)
result = self.router_cache.increment_cache(
key=rpm_key, value=1, ttl=self.routing_args.ttl
)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
@ -244,12 +250,19 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# update cache
## TPM
self.router_cache.increment_cache(key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl)
self.router_cache.increment_cache(
key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -295,7 +308,12 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
if self.test_flag:
self.logged_success += 1
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
pass
def _common_checks_available_deployment(

View file

@ -1,13 +1,14 @@
import heapq, time
import heapq
from pydantic import BaseModel
from typing import Optional
import enum
from litellm.caching import DualCache
from litellm.caching import DualCache, RedisCache
from litellm import print_verbose
class SchedulerCacheKeys(enum.Enum):
queue = "scheduler:queue"
default_in_memory_ttl = 5 # cache queue in-memory for 5s when redis cache available
class DefaultPriorities(enum.Enum):
@ -25,18 +26,24 @@ class FlowItem(BaseModel):
class Scheduler:
cache: DualCache
def __init__(self, polling_interval: Optional[float] = None):
def __init__(
self,
polling_interval: Optional[float] = None,
redis_cache: Optional[RedisCache] = None,
):
"""
polling_interval: float or null - frequency of polling queue. Default is 3ms.
"""
self.queue: list = []
self.cache = DualCache()
default_in_memory_ttl: Optional[float] = None
if redis_cache is not None:
# if redis-cache available frequently poll that instead of using in-memory.
default_in_memory_ttl = SchedulerCacheKeys.default_in_memory_ttl.value
self.cache = DualCache(
redis_cache=redis_cache, default_in_memory_ttl=default_in_memory_ttl
)
self.polling_interval = polling_interval or 0.03 # default to 3ms
def update_variables(self, cache: Optional[DualCache] = None):
if cache is not None:
self.cache = cache
async def add_request(self, request: FlowItem):
# We use the priority directly, as lower values indicate higher priority
# get the queue

File diff suppressed because it is too large Load diff

View file

@ -198,7 +198,11 @@ async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")
pytest.fail(
"An unexpected error occurred when running the thread, {}".format(
run
)
)
else:
added_message = await litellm.a_add_message(**data)
@ -226,4 +230,8 @@ async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")
pytest.fail(
"An unexpected error occurred when running the thread, {}".format(
run
)
)

View file

@ -243,6 +243,7 @@ def test_completion_bedrock_claude_sts_oidc_auth():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
reason="Cannot run without being in CircleCI Runner",
@ -277,7 +278,15 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_bedrock_claude_3():
@pytest.mark.parametrize(
"image_url",
[
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAL0AAAC9CAMAAADRCYwCAAAAh1BMVEX///8AAAD8/Pz5+fkEBAT39/cJCQn09PRNTU3y8vIMDAwzMzPe3t7v7+8QEBCOjo7FxcXR0dHn5+elpaWGhoYYGBivr686OjocHBy0tLQtLS1TU1PY2Ni6urpaWlpERER3d3ecnJxoaGiUlJRiYmIlJSU4ODhBQUFycnKAgIDBwcFnZ2chISE7EjuwAAAI/UlEQVR4nO1caXfiOgz1bhJIyAJhX1JoSzv8/9/3LNlpYd4rhX6o4/N8Z2lKM2cURZau5JsQEhERERERERERERERERERERHx/wBjhDPC3OGN8+Cc5JeMuheaETSdO8vZFyCScHtmz2CsktoeMn7rLM1u3h0PMAEhyYX7v/Q9wQvoGdB0hlbzm45lEq/wd6y6G9aezvBk9AXwp1r3LHJIRsh6s2maxaJpmvqgvkC7WFS3loUnaFJtKRVUCEoV/RpCnHRvAsesVQ1hw+vd7Mpo+424tLs72NplkvQgcdrsvXkW/zJWqH/fA0FT84M/xnQJt4to3+ZLuanbM6X5lfXKHosO9COgREqpCR5i86pf2zPS7j9tTj+9nO7bQz3+xGEyGW9zqgQ1tyQ/VsxEDvce/4dcUPNb5OD9yXvR4Z2QisuP0xiGWPnemgugU5q/troHhGEjIF5sTOyW648aC0TssuaaCEsYEIkGzjWXOp3A0vVsf6kgRyqaDk+T7DIVWrb58b2tT5xpUucKwodOD/5LbrZC1ws6YSaBZJ/8xlh+XZSYXaMJ2ezNqjB3IPXuehPcx2U6b4t1dS/xNdFzguUt8ie7arnPeyCZroxLHzGgGdqVcspwafizPWEXBee+9G1OaufGdvNng/9C+gwgZ3PH3r87G6zXTZ5D5De2G2DeFoANXfbACkT+fxBQ22YFsTTJF9hjFVO6VbqxZXko4WJ8s52P4PnuxO5KRzu0/hlix1ySt8iXjgaQ+4IHPA9nVzNkdduM9LFT/Aacj4FtKrHA7iAw602Vnht6R8Vq1IOS+wNMKLYqayAYfRuufQPGeGb7sZogQQoLZrGPgZ6KoYn70Iw30O92BNEDpvwouCFn6wH2uS+EhRb3WF/HObZk3HuxfRQM3Y/Of/VH0n4MKNHZDiZvO9+m/ABALfkOcuar/7nOo7B95ACGVAFaz4jMiJwJhdaHBkySmzlGTu82gr6FSTik2kJvLnY9nOd/D90qcH268m3I/cgI1xg1maE5CuZYaWLH+UHANCIck0yt7Mx5zBm5vVHXHwChsZ35kKqUpmo5Svq5/fzfAI5g2vDtFPYo1HiEA85QrDeGm9g//LG7K0scO3sdpj2CBDgCa+0OFs0bkvVgnnM/QBDwllOMm+cN7vMSHlB7Uu4haHKaTwgGkv8tlK+hP8fzmFuK/RQTpaLPWvbd58yWIo66HHM0OsPoPhVqmtaEVL7N+wYcTLTbb0DLdgp23Eyy2VYJ2N7bkLFAAibtoLPe5sLt6Oa2bvU+zyeMa8wrixO0gRTn9tO9NCSThTLGqcqtsDvphlfmx/cPBZVvw24jg1LE2lPuEo35Mhi58U0I/Ga8n5w+NS8i34MAQLos5B1u0xL1ZvCVYVRw/Fs2q53KLaXJMWwOZZ/4MPYV19bAHmgGDKB6f01xoeJKFbl63q9J34KdaVNPJWztQyRkzA3KNs1AdAEDowMxh10emXTCx75CkurtbY/ZpdNDGdsn2UcHKHsQ8Ai3WZi48IfkvtjOhsLpuIRSKZTX9FA4o+0d6o/zOWqQzVJMynL9NsxhSJOaourq6nBVQBueMSyubsX2xHrmuABZN2Ns9jr5nwLFlLF/2R6atjW/67Yd11YQ1Z+kA9Zk9dPTM/o6dVo6HHVgC0JR8oUfmI93T9u3gvTG94bAH02Y5xeqRcjuwnKCK6Q2+ajl8KXJ3GSh22P3Zfx6S+n008ROhJn+JRIUVu6o7OXl8w1SeyhuqNDwNI7SjbK08QrqPxS95jy4G7nCXVq6G3HNu0LtK5J0e226CfC005WKK9sVvfxI0eUbcnzutfhWe3rpZHM0nZ/ny/N8tanKYlQ6VEW5Xuym8yV1zZX58vwGhZp/5tFfhybZabdbrQYOs8F+xEhmPsb0/nki6kIyVvzZzUASiOrTfF+Sj9bXC7DoJxeiV8tjQL6loSd0yCx7YyB6rPdLx31U2qCG3F/oXIuDuqd6LFO+4DNIJuxFZqSsU0ea88avovFnWKRYFYRQDfCfcGaBCLn4M4A1ntJ5E57vicwqq2enaZEF5nokCYu9TbKqCC5yCDfL+GhLxT4w4xEJs+anqgou8DOY2q8FMryjb2MehC1dRJ9s4g9NXeTwPkWON4RH+FhIe0AWR/S9ekvQ+t70XHeimGF78LzuU7d7PwrswdIG2VpgF8C53qVQsTDtBJc4CdnkQPbnZY9mbPdDFra3PCXBBQ5QBn2aQqtyhvlyYM4Hb2/mdhsxCUen04GZVvIJZw5PAamMOmjzq8Q+dzAKLXDQ3RUZItWsg4t7W2DP+JDrJDymoMH7E5zQtuEpG03GTIjGCW3LQqOYEsXgFc78x76NeRwY6SNM+IfQoh6myJKRBIcLYxZcwscJ/gI2isTBty2Po9IkYzP0/SS4hGlxRjFAG5z1Jt1LckiB57yWvo35EaolbvA+6fBa24xodL2YjsPpTnj3JgJOqhcgOeLVsYYwoK0wjY+m1D3rGc40CukkaHnkEjarlXrF1B9M6ECQ6Ow0V7R7N4G3LfOHAXtymoyXOb4QhaYHJ/gNBJUkxclpSs7DNcgWWDDmM7Ke5MJpGuioe7w5EOvfTunUKRzOh7G2ylL+6ynHrD54oQO3//cN3yVO+5qMVsPZq0CZIOx4TlcJ8+Vz7V5waL+7WekzUpRFMTnnTlSCq3X5usi8qmIleW/rit1+oQZn1WGSU/sKBYEqMNh1mBOc6PhK8yCfKHdUNQk8o/G19ZPTs5MYfai+DLs5vmee37zEyyH48WW3XA6Xw6+Az8lMhci7N/KleToo7PtTKm+RA887Kqc6E9dyqL/QPTugzMHLbLZtJKqKLFfzVWRNJ63c+95uWT/F7R0U5dDVvuS409AJXhJvD0EwWaWdW8UN11u/7+umaYjT8mJtzZwP/MD4r57fihiHlC5fylHfaqnJdro+Dr7DajvO+vi2EwyD70s8nCH71nzIO1l5Zl+v1DMCb5ebvCMkGHvobXy/hPumGLyX0218/3RyD1GRLOuf9u/OGQyDmto32yMiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIv7GP8YjWPR/czH2AAAAAElFTkSuQmCC",
"https://avatars.githubusercontent.com/u/29436595?v=",
],
)
def test_bedrock_claude_3(image_url):
try:
litellm.set_verbose = True
data = {
@ -294,7 +303,7 @@ def test_bedrock_claude_3():
{
"image_url": {
"detail": "high",
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAL0AAAC9CAMAAADRCYwCAAAAh1BMVEX///8AAAD8/Pz5+fkEBAT39/cJCQn09PRNTU3y8vIMDAwzMzPe3t7v7+8QEBCOjo7FxcXR0dHn5+elpaWGhoYYGBivr686OjocHBy0tLQtLS1TU1PY2Ni6urpaWlpERER3d3ecnJxoaGiUlJRiYmIlJSU4ODhBQUFycnKAgIDBwcFnZ2chISE7EjuwAAAI/UlEQVR4nO1caXfiOgz1bhJIyAJhX1JoSzv8/9/3LNlpYd4rhX6o4/N8Z2lKM2cURZau5JsQEhERERERERERERERERERERHx/wBjhDPC3OGN8+Cc5JeMuheaETSdO8vZFyCScHtmz2CsktoeMn7rLM1u3h0PMAEhyYX7v/Q9wQvoGdB0hlbzm45lEq/wd6y6G9aezvBk9AXwp1r3LHJIRsh6s2maxaJpmvqgvkC7WFS3loUnaFJtKRVUCEoV/RpCnHRvAsesVQ1hw+vd7Mpo+424tLs72NplkvQgcdrsvXkW/zJWqH/fA0FT84M/xnQJt4to3+ZLuanbM6X5lfXKHosO9COgREqpCR5i86pf2zPS7j9tTj+9nO7bQz3+xGEyGW9zqgQ1tyQ/VsxEDvce/4dcUPNb5OD9yXvR4Z2QisuP0xiGWPnemgugU5q/troHhGEjIF5sTOyW648aC0TssuaaCEsYEIkGzjWXOp3A0vVsf6kgRyqaDk+T7DIVWrb58b2tT5xpUucKwodOD/5LbrZC1ws6YSaBZJ/8xlh+XZSYXaMJ2ezNqjB3IPXuehPcx2U6b4t1dS/xNdFzguUt8ie7arnPeyCZroxLHzGgGdqVcspwafizPWEXBee+9G1OaufGdvNng/9C+gwgZ3PH3r87G6zXTZ5D5De2G2DeFoANXfbACkT+fxBQ22YFsTTJF9hjFVO6VbqxZXko4WJ8s52P4PnuxO5KRzu0/hlix1ySt8iXjgaQ+4IHPA9nVzNkdduM9LFT/Aacj4FtKrHA7iAw602Vnht6R8Vq1IOS+wNMKLYqayAYfRuufQPGeGb7sZogQQoLZrGPgZ6KoYn70Iw30O92BNEDpvwouCFn6wH2uS+EhRb3WF/HObZk3HuxfRQM3Y/Of/VH0n4MKNHZDiZvO9+m/ABALfkOcuar/7nOo7B95ACGVAFaz4jMiJwJhdaHBkySmzlGTu82gr6FSTik2kJvLnY9nOd/D90qcH268m3I/cgI1xg1maE5CuZYaWLH+UHANCIck0yt7Mx5zBm5vVHXHwChsZ35kKqUpmo5Svq5/fzfAI5g2vDtFPYo1HiEA85QrDeGm9g//LG7K0scO3sdpj2CBDgCa+0OFs0bkvVgnnM/QBDwllOMm+cN7vMSHlB7Uu4haHKaTwgGkv8tlK+hP8fzmFuK/RQTpaLPWvbd58yWIo66HHM0OsPoPhVqmtaEVL7N+wYcTLTbb0DLdgp23Eyy2VYJ2N7bkLFAAibtoLPe5sLt6Oa2bvU+zyeMa8wrixO0gRTn9tO9NCSThTLGqcqtsDvphlfmx/cPBZVvw24jg1LE2lPuEo35Mhi58U0I/Ga8n5w+NS8i34MAQLos5B1u0xL1ZvCVYVRw/Fs2q53KLaXJMWwOZZ/4MPYV19bAHmgGDKB6f01xoeJKFbl63q9J34KdaVNPJWztQyRkzA3KNs1AdAEDowMxh10emXTCx75CkurtbY/ZpdNDGdsn2UcHKHsQ8Ai3WZi48IfkvtjOhsLpuIRSKZTX9FA4o+0d6o/zOWqQzVJMynL9NsxhSJOaourq6nBVQBueMSyubsX2xHrmuABZN2Ns9jr5nwLFlLF/2R6atjW/67Yd11YQ1Z+kA9Zk9dPTM/o6dVo6HHVgC0JR8oUfmI93T9u3gvTG94bAH02Y5xeqRcjuwnKCK6Q2+ajl8KXJ3GSh22P3Zfx6S+n008ROhJn+JRIUVu6o7OXl8w1SeyhuqNDwNI7SjbK08QrqPxS95jy4G7nCXVq6G3HNu0LtK5J0e226CfC005WKK9sVvfxI0eUbcnzutfhWe3rpZHM0nZ/ny/N8tanKYlQ6VEW5Xuym8yV1zZX58vwGhZp/5tFfhybZabdbrQYOs8F+xEhmPsb0/nki6kIyVvzZzUASiOrTfF+Sj9bXC7DoJxeiV8tjQL6loSd0yCx7YyB6rPdLx31U2qCG3F/oXIuDuqd6LFO+4DNIJuxFZqSsU0ea88avovFnWKRYFYRQDfCfcGaBCLn4M4A1ntJ5E57vicwqq2enaZEF5nokCYu9TbKqCC5yCDfL+GhLxT4w4xEJs+anqgou8DOY2q8FMryjb2MehC1dRJ9s4g9NXeTwPkWON4RH+FhIe0AWR/S9ekvQ+t70XHeimGF78LzuU7d7PwrswdIG2VpgF8C53qVQsTDtBJc4CdnkQPbnZY9mbPdDFra3PCXBBQ5QBn2aQqtyhvlyYM4Hb2/mdhsxCUen04GZVvIJZw5PAamMOmjzq8Q+dzAKLXDQ3RUZItWsg4t7W2DP+JDrJDymoMH7E5zQtuEpG03GTIjGCW3LQqOYEsXgFc78x76NeRwY6SNM+IfQoh6myJKRBIcLYxZcwscJ/gI2isTBty2Po9IkYzP0/SS4hGlxRjFAG5z1Jt1LckiB57yWvo35EaolbvA+6fBa24xodL2YjsPpTnj3JgJOqhcgOeLVsYYwoK0wjY+m1D3rGc40CukkaHnkEjarlXrF1B9M6ECQ6Ow0V7R7N4G3LfOHAXtymoyXOb4QhaYHJ/gNBJUkxclpSs7DNcgWWDDmM7Ke5MJpGuioe7w5EOvfTunUKRzOh7G2ylL+6ynHrD54oQO3//cN3yVO+5qMVsPZq0CZIOx4TlcJ8+Vz7V5waL+7WekzUpRFMTnnTlSCq3X5usi8qmIleW/rit1+oQZn1WGSU/sKBYEqMNh1mBOc6PhK8yCfKHdUNQk8o/G19ZPTs5MYfai+DLs5vmee37zEyyH48WW3XA6Xw6+Az8lMhci7N/KleToo7PtTKm+RA887Kqc6E9dyqL/QPTugzMHLbLZtJKqKLFfzVWRNJ63c+95uWT/F7R0U5dDVvuS409AJXhJvD0EwWaWdW8UN11u/7+umaYjT8mJtzZwP/MD4r57fihiHlC5fylHfaqnJdro+Dr7DajvO+vi2EwyD70s8nCH71nzIO1l5Zl+v1DMCb5ebvCMkGHvobXy/hPumGLyX0218/3RyD1GRLOuf9u/OGQyDmto32yMiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIv7GP8YjWPR/czH2AAAAAElFTkSuQmCC",
"url": image_url,
},
"type": "image_url",
},
@ -313,7 +322,6 @@ def test_bedrock_claude_3():
# Add any assertions here to check the response
assert len(response.choices) > 0
assert len(response.choices[0].message.content) > 0
except RateLimitError:
pass
except Exception as e:
@ -552,7 +560,7 @@ def test_bedrock_ptu():
assert "url" in mock_client_post.call_args.kwargs
assert (
mock_client_post.call_args.kwargs["url"]
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke"
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/converse"
)
mock_client_post.assert_called_once()

View file

@ -300,7 +300,11 @@ def test_completion_claude_3():
pytest.fail(f"Error occurred: {e}")
def test_completion_claude_3_function_call():
@pytest.mark.parametrize(
"model",
["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"],
)
def test_completion_claude_3_function_call(model):
litellm.set_verbose = True
tools = [
{
@ -331,13 +335,14 @@ def test_completion_claude_3_function_call():
try:
# test without max tokens
response = completion(
model="anthropic/claude-3-opus-20240229",
model=model,
messages=messages,
tools=tools,
tool_choice={
"type": "function",
"function": {"name": "get_current_weather"},
},
drop_params=True,
)
# Add any assertions, here to check response args
@ -364,10 +369,11 @@ def test_completion_claude_3_function_call():
)
# In the second response, Claude should deduce answer from tool results
second_response = completion(
model="anthropic/claude-3-opus-20240229",
model=model,
messages=messages,
tools=tools,
tool_choice="auto",
drop_params=True,
)
print(second_response)
except Exception as e:
@ -2162,6 +2168,7 @@ def test_completion_azure_key_completion_arg():
logprobs=True,
max_tokens=10,
)
print(f"response: {response}")
print("Hidden Params", response._hidden_params)
@ -2534,6 +2541,7 @@ def test_replicate_custom_prompt_dict():
"content": "what is yc write 1 paragraph",
}
],
mock_response="Hello world",
repetition_penalty=0.1,
num_retries=3,
)

View file

@ -76,7 +76,7 @@ def test_image_generation_azure_dall_e_3():
)
print(f"response: {response}")
assert len(response.data) > 0
except litellm.RateLimitError as e:
except litellm.InternalServerError as e:
pass
except litellm.ContentPolicyViolationError:
pass # OpenAI randomly raises these errors - skip when they occur

View file

@ -2248,3 +2248,55 @@ async def test_create_update_team(prisma_client):
assert _team_info["budget_reset_at"] is not None and isinstance(
_team_info["budget_reset_at"], datetime.datetime
)
@pytest.mark.asyncio()
async def test_enforced_params(prisma_client):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
from litellm.proxy.proxy_server import general_settings
general_settings["enforced_params"] = [
"user",
"metadata",
"metadata.generation_name",
]
await litellm.proxy.proxy_server.prisma_client.connect()
request = NewUserRequest()
key = await new_user(request)
print(key)
generated_key = key.key
bearer_token = "Bearer " + generated_key
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# Case 1: Missing user
async def return_body():
return b'{"model": "gemini-pro-vision"}'
request.body = return_body
try:
await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail(f"This should have failed!. IT's an invalid request")
except Exception as e:
assert (
"BadRequest please pass param=user in request body. This is a required param"
in e.message
)
# Case 2: Missing metadata["generation_name"]
async def return_body_2():
return b'{"model": "gemini-pro-vision", "user": "1234", "metadata": {}}'
request.body = return_body_2
try:
await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail(f"This should have failed!. IT's an invalid request")
except Exception as e:
assert (
"Authentication Error, BadRequest please pass param=[metadata][generation_name] in request body"
in e.message
)

View file

@ -275,7 +275,7 @@ async def _deploy(lowest_latency_logger, deployment_id, tokens_used, duration):
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": tokens_used}}
time.sleep(duration)
await asyncio.sleep(duration)
end_time = time.time()
lowest_latency_logger.log_success_event(
response_obj=response_obj,
@ -325,6 +325,7 @@ def test_get_available_endpoints_tpm_rpm_check_async(ans_rpm):
d1 = [(lowest_latency_logger, "1234", 50, 0.01)] * non_ans_rpm
d2 = [(lowest_latency_logger, "5678", 50, 0.01)] * non_ans_rpm
asyncio.run(_gather_deploy([*d1, *d2]))
time.sleep(3)
## CHECK WHAT'S SELECTED ##
d_ans = lowest_latency_logger.get_available_deployments(
model_group=model_group, healthy_deployments=model_list

View file

@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import (
claude_2_1_pt,
llama_2_chat_pt,
prompt_factory,
_bedrock_tools_pt,
)
@ -128,3 +129,27 @@ def test_anthropic_messages_pt():
# codellama_prompt_format()
def test_bedrock_tool_calling_pt():
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
converted_tools = _bedrock_tools_pt(tools=tools)
print(converted_tools)

View file

@ -38,6 +38,48 @@ def test_router_sensitive_keys():
assert "special-key" not in str(e)
def test_router_order():
"""
Asserts for 2 models in a model group, model with order=1 always called first
"""
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
"mock_response": "Hello world",
"order": 1,
},
"model_info": {"id": "1"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-4o",
"api_key": "bad-key",
"mock_response": Exception("this is a bad key"),
"order": 2,
},
"model_info": {"id": "2"},
},
],
num_retries=0,
allowed_fails=0,
enable_pre_call_checks=True,
)
for _ in range(100):
response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
assert isinstance(response, litellm.ModelResponse)
assert response._hidden_params["model_id"] == "1"
@pytest.mark.parametrize("num_retries", [None, 2])
@pytest.mark.parametrize("max_retries", [None, 4])
def test_router_num_retries_init(num_retries, max_retries):

View file

@ -1284,18 +1284,18 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
# pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize("sync_mode", [True]) # False
@pytest.mark.parametrize(
"model",
[
# "bedrock/cohere.command-r-plus-v1:0",
# "anthropic.claude-3-sonnet-20240229-v1:0",
# "anthropic.claude-instant-v1",
# "bedrock/ai21.j2-mid",
# "mistral.mistral-7b-instruct-v0:2",
# "bedrock/amazon.titan-tg1-large",
# "meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14"
"bedrock/cohere.command-r-plus-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-instant-v1",
"bedrock/ai21.j2-mid",
"mistral.mistral-7b-instruct-v0:2",
"bedrock/amazon.titan-tg1-large",
"meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14",
],
)
@pytest.mark.asyncio

View file

@ -186,3 +186,13 @@ def test_load_test_token_counter(model):
total_time = end_time - start_time
print("model={}, total test time={}".format(model, total_time))
assert total_time < 10, f"Total encoding time > 10s, {total_time}"
def test_openai_token_with_image_and_text():
model = "gpt-4o"
full_request = {'model': 'gpt-4o', 'tools': [{'type': 'function', 'function': {'name': 'json', 'parameters': {'type': 'object', 'required': ['clause'], 'properties': {'clause': {'type': 'string'}}}, 'description': 'Respond with a JSON object.'}}], 'logprobs': False, 'messages': [{'role': 'user', 'content': [{'text': '\n Just some long text, long long text, and you know it will be longer than 7 tokens definetly.', 'type': 'text'}]}], 'tool_choice': {'type': 'function', 'function': {'name': 'json'}}, 'exclude_models': [], 'disable_fallback': False, 'exclude_providers': []}
messages = full_request.get("messages", [])
token_count = token_counter(model=model, messages=messages)
print(token_count)
test_openai_token_with_image_and_text()

267
litellm/types/files.py Normal file
View file

@ -0,0 +1,267 @@
from enum import Enum
from types import MappingProxyType
from typing import List, Set
"""
Base Enums/Consts
"""
class FileType(Enum):
AAC = "AAC"
CSV = "CSV"
DOC = "DOC"
DOCX = "DOCX"
FLAC = "FLAC"
FLV = "FLV"
GIF = "GIF"
GOOGLE_DOC = "GOOGLE_DOC"
GOOGLE_DRAWINGS = "GOOGLE_DRAWINGS"
GOOGLE_SHEETS = "GOOGLE_SHEETS"
GOOGLE_SLIDES = "GOOGLE_SLIDES"
HEIC = "HEIC"
HEIF = "HEIF"
HTML = "HTML"
JPEG = "JPEG"
JSON = "JSON"
M4A = "M4A"
M4V = "M4V"
MOV = "MOV"
MP3 = "MP3"
MP4 = "MP4"
MPEG = "MPEG"
MPEGPS = "MPEGPS"
MPG = "MPG"
MPA = "MPA"
MPGA = "MPGA"
OGG = "OGG"
OPUS = "OPUS"
PDF = "PDF"
PCM = "PCM"
PNG = "PNG"
PPT = "PPT"
PPTX = "PPTX"
RTF = "RTF"
THREE_GPP = "3GPP"
TXT = "TXT"
WAV = "WAV"
WEBM = "WEBM"
WEBP = "WEBP"
WMV = "WMV"
XLS = "XLS"
XLSX = "XLSX"
FILE_EXTENSIONS: MappingProxyType[FileType, List[str]] = MappingProxyType({
FileType.AAC: ["aac"],
FileType.CSV: ["csv"],
FileType.DOC: ["doc"],
FileType.DOCX: ["docx"],
FileType.FLAC: ["flac"],
FileType.FLV: ["flv"],
FileType.GIF: ["gif"],
FileType.GOOGLE_DOC: ["gdoc"],
FileType.GOOGLE_DRAWINGS: ["gdraw"],
FileType.GOOGLE_SHEETS: ["gsheet"],
FileType.GOOGLE_SLIDES: ["gslides"],
FileType.HEIC: ["heic"],
FileType.HEIF: ["heif"],
FileType.HTML: ["html", "htm"],
FileType.JPEG: ["jpeg", "jpg"],
FileType.JSON: ["json"],
FileType.M4A: ["m4a"],
FileType.M4V: ["m4v"],
FileType.MOV: ["mov"],
FileType.MP3: ["mp3"],
FileType.MP4: ["mp4"],
FileType.MPEG: ["mpeg"],
FileType.MPEGPS: ["mpegps"],
FileType.MPG: ["mpg"],
FileType.MPA: ["mpa"],
FileType.MPGA: ["mpga"],
FileType.OGG: ["ogg"],
FileType.OPUS: ["opus"],
FileType.PDF: ["pdf"],
FileType.PCM: ["pcm"],
FileType.PNG: ["png"],
FileType.PPT: ["ppt"],
FileType.PPTX: ["pptx"],
FileType.RTF: ["rtf"],
FileType.THREE_GPP: ["3gpp"],
FileType.TXT: ["txt"],
FileType.WAV: ["wav"],
FileType.WEBM: ["webm"],
FileType.WEBP: ["webp"],
FileType.WMV: ["wmv"],
FileType.XLS: ["xls"],
FileType.XLSX: ["xlsx"],
})
FILE_MIME_TYPES: MappingProxyType[FileType, str] = MappingProxyType({
FileType.AAC: "audio/aac",
FileType.CSV: "text/csv",
FileType.DOC: "application/msword",
FileType.DOCX: "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
FileType.FLAC: "audio/flac",
FileType.FLV: "video/x-flv",
FileType.GIF: "image/gif",
FileType.GOOGLE_DOC: "application/vnd.google-apps.document",
FileType.GOOGLE_DRAWINGS: "application/vnd.google-apps.drawing",
FileType.GOOGLE_SHEETS: "application/vnd.google-apps.spreadsheet",
FileType.GOOGLE_SLIDES: "application/vnd.google-apps.presentation",
FileType.HEIC: "image/heic",
FileType.HEIF: "image/heif",
FileType.HTML: "text/html",
FileType.JPEG: "image/jpeg",
FileType.JSON: "application/json",
FileType.M4A: "audio/x-m4a",
FileType.M4V: "video/x-m4v",
FileType.MOV: "video/quicktime",
FileType.MP3: "audio/mpeg",
FileType.MP4: "video/mp4",
FileType.MPEG: "video/mpeg",
FileType.MPEGPS: "video/mpegps",
FileType.MPG: "video/mpg",
FileType.MPA: "audio/m4a",
FileType.MPGA: "audio/mpga",
FileType.OGG: "audio/ogg",
FileType.OPUS: "audio/opus",
FileType.PDF: "application/pdf",
FileType.PCM: "audio/pcm",
FileType.PNG: "image/png",
FileType.PPT: "application/vnd.ms-powerpoint",
FileType.PPTX: "application/vnd.openxmlformats-officedocument.presentationml.presentation",
FileType.RTF: "application/rtf",
FileType.THREE_GPP: "video/3gpp",
FileType.TXT: "text/plain",
FileType.WAV: "audio/wav",
FileType.WEBM: "video/webm",
FileType.WEBP: "image/webp",
FileType.WMV: "video/wmv",
FileType.XLS: "application/vnd.ms-excel",
FileType.XLSX: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
})
"""
Util Functions
"""
def get_file_mime_type_from_extension(extension: str) -> str:
for file_type, extensions in FILE_EXTENSIONS.items():
if extension in extensions:
return FILE_MIME_TYPES[file_type]
raise ValueError(f"Unknown mime type for extension: {extension}")
def get_file_extension_from_mime_type(mime_type: str) -> str:
for file_type, mime in FILE_MIME_TYPES.items():
if mime == mime_type:
return FILE_EXTENSIONS[file_type][0]
raise ValueError(f"Unknown extension for mime type: {mime_type}")
def get_file_type_from_extension(extension: str) -> FileType:
for file_type, extensions in FILE_EXTENSIONS.items():
if extension in extensions:
return file_type
raise ValueError(f"Unknown file type for extension: {extension}")
def get_file_extension_for_file_type(file_type: FileType) -> str:
return FILE_EXTENSIONS[file_type][0]
def get_file_mime_type_for_file_type(file_type: FileType) -> str:
return FILE_MIME_TYPES[file_type]
"""
FileType Type Groupings (Videos, Images, etc)
"""
# Images
IMAGE_FILE_TYPES = {
FileType.PNG,
FileType.JPEG,
FileType.GIF,
FileType.WEBP,
FileType.HEIC,
FileType.HEIF
}
def is_image_file_type(file_type):
return file_type in IMAGE_FILE_TYPES
# Videos
VIDEO_FILE_TYPES = {
FileType.MOV,
FileType.MP4,
FileType.MPEG,
FileType.M4V,
FileType.FLV,
FileType.MPEGPS,
FileType.MPG,
FileType.WEBM,
FileType.WMV,
FileType.THREE_GPP
}
def is_video_file_type(file_type):
return file_type in VIDEO_FILE_TYPES
# Audio
AUDIO_FILE_TYPES = {
FileType.AAC,
FileType.FLAC,
FileType.MP3,
FileType.MPA,
FileType.MPGA,
FileType.OPUS,
FileType.PCM,
FileType.WAV,
}
def is_audio_file_type(file_type):
return file_type in AUDIO_FILE_TYPES
# Text
TEXT_FILE_TYPES = {
FileType.CSV,
FileType.HTML,
FileType.RTF,
FileType.TXT
}
def is_text_file_type(file_type):
return file_type in TEXT_FILE_TYPES
"""
Other FileType Groupings
"""
# Accepted file types for GEMINI 1.5 through Vertex AI
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#gemini-send-multimodal-samples-images-nodejs
GEMINI_1_5_ACCEPTED_FILE_TYPES: Set[FileType] = {
# Image
FileType.PNG,
FileType.JPEG,
# Audio
FileType.AAC,
FileType.FLAC,
FileType.MP3,
FileType.MPA,
FileType.MPGA,
FileType.OPUS,
FileType.PCM,
FileType.WAV,
# Video
FileType.FLV,
FileType.MOV,
FileType.MPEG,
FileType.MPEGPS,
FileType.MPG,
FileType.MP4,
FileType.WEBM,
FileType.WMV,
FileType.THREE_GPP,
# PDF
FileType.PDF,
}
def is_gemini_1_5_accepted_file_type(file_type: FileType) -> bool:
return file_type in GEMINI_1_5_ACCEPTED_FILE_TYPES

View file

@ -1,4 +1,4 @@
from typing import TypedDict, Any, Union, Optional
from typing import TypedDict, Any, Union, Optional, Literal, List
import json
from typing_extensions import (
Self,
@ -11,10 +11,137 @@ from typing_extensions import (
)
class SystemContentBlock(TypedDict):
text: str
class ImageSourceBlock(TypedDict):
bytes: Optional[str] # base 64 encoded string
class ImageBlock(TypedDict):
format: Literal["png", "jpeg", "gif", "webp"]
source: ImageSourceBlock
class ToolResultContentBlock(TypedDict, total=False):
image: ImageBlock
json: dict
text: str
class ToolResultBlock(TypedDict, total=False):
content: Required[List[ToolResultContentBlock]]
toolUseId: Required[str]
status: Literal["success", "error"]
class ToolUseBlock(TypedDict):
input: dict
name: str
toolUseId: str
class ContentBlock(TypedDict, total=False):
text: str
image: ImageBlock
toolResult: ToolResultBlock
toolUse: ToolUseBlock
class MessageBlock(TypedDict):
content: List[ContentBlock]
role: Literal["user", "assistant"]
class ConverseMetricsBlock(TypedDict):
latencyMs: float # time in ms
class ConverseResponseOutputBlock(TypedDict):
message: Optional[MessageBlock]
class ConverseTokenUsageBlock(TypedDict):
inputTokens: int
outputTokens: int
totalTokens: int
class ConverseResponseBlock(TypedDict):
additionalModelResponseFields: dict
metrics: ConverseMetricsBlock
output: ConverseResponseOutputBlock
stopReason: (
str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered
)
usage: ConverseTokenUsageBlock
class ToolInputSchemaBlock(TypedDict):
json: Optional[dict]
class ToolSpecBlock(TypedDict, total=False):
inputSchema: Required[ToolInputSchemaBlock]
name: Required[str]
description: str
class ToolBlock(TypedDict):
toolSpec: Optional[ToolSpecBlock]
class SpecificToolChoiceBlock(TypedDict):
name: str
class ToolChoiceValuesBlock(TypedDict, total=False):
any: dict
auto: dict
tool: SpecificToolChoiceBlock
class ToolConfigBlock(TypedDict, total=False):
tools: Required[List[ToolBlock]]
toolChoice: Union[str, ToolChoiceValuesBlock]
class InferenceConfig(TypedDict, total=False):
maxTokens: int
stopSequences: List[str]
temperature: float
topP: float
class ToolBlockDeltaEvent(TypedDict):
input: str
class ContentBlockDeltaEvent(TypedDict, total=False):
"""
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
"""
text: str
toolUse: ToolBlockDeltaEvent
class RequestObject(TypedDict, total=False):
additionalModelRequestFields: dict
additionalModelResponseFieldPaths: List[str]
inferenceConfig: InferenceConfig
messages: Required[List[MessageBlock]]
system: List[SystemContentBlock]
toolConfig: ToolConfigBlock
class GenericStreamingChunk(TypedDict):
text: Required[str]
tool_str: Required[str]
is_finished: Required[bool]
finish_reason: Required[str]
usage: Optional[ConverseTokenUsageBlock]
class Document(TypedDict):

View file

@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False):
extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]]
timeout: Optional[float]
class ChatCompletionToolCallFunctionChunk(TypedDict):
name: str
arguments: str
class ChatCompletionToolCallChunk(TypedDict):
id: str
type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk
class ChatCompletionResponseMessage(TypedDict, total=False):
content: Optional[str]
tool_calls: List[ChatCompletionToolCallChunk]
role: Literal["assistant"]

View file

@ -239,6 +239,8 @@ def map_finish_reason(
return "length"
elif finish_reason == "tool_use": # anthropic
return "tool_calls"
elif finish_reason == "content_filtered":
return "content_filter"
return finish_reason
@ -1372,8 +1374,12 @@ class Logging:
callback_func=callback,
)
except Exception as e:
traceback.print_exc()
print_verbose(
verbose_logger.error(
"litellm.Logging.pre_call(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while input logging with integrations {traceback.format_exc()}"
)
print_verbose(
@ -4060,6 +4066,9 @@ def openai_token_counter(
for c in value:
if c["type"] == "text":
text += c["text"]
num_tokens += len(
encoding.encode(c["text"], disallowed_special=())
)
elif c["type"] == "image_url":
if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"]
@ -5634,19 +5643,29 @@ def get_optional_params(
optional_params["stream"] = stream
elif "anthropic" in model:
_check_valid_arg(supported_params=supported_params)
# anthropic params on bedrock
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
else: # bedrock httpx route
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif "amazon" in model: # amazon titan llms
_check_valid_arg(supported_params=supported_params)
@ -6198,6 +6217,27 @@ def calculate_max_parallel_requests(
return None
def _get_order_filtered_deployments(healthy_deployments: List[Dict]) -> List:
min_order = min(
(
deployment["litellm_params"]["order"]
for deployment in healthy_deployments
if "order" in deployment["litellm_params"]
),
default=None,
)
if min_order is not None:
filtered_deployments = [
deployment
for deployment in healthy_deployments
if deployment["litellm_params"].get("order") == min_order
]
return filtered_deployments
return healthy_deployments
def _get_model_region(
custom_llm_provider: str, litellm_params: LiteLLM_Params
) -> Optional[str]:
@ -6403,20 +6443,7 @@ def get_supported_openai_params(
- None if unmapped
"""
if custom_llm_provider == "bedrock":
if model.startswith("anthropic.claude-3"):
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
elif model.startswith("anthropic"):
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
elif model.startswith("ai21"):
return ["max_tokens", "temperature", "top_p", "stream"]
elif model.startswith("amazon"):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
elif model.startswith("meta"):
return ["max_tokens", "temperature", "top_p", "stream"]
elif model.startswith("cohere"):
return ["stream", "temperature", "max_tokens"]
elif model.startswith("mistral"):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_supported_openai_params()
elif custom_llm_provider == "ollama_chat":
@ -8516,7 +8543,11 @@ def exception_type(
extra_information = f"\nModel: {model}"
if _api_base:
extra_information += f"\nAPI Base: `{_api_base}`"
if messages and len(messages) > 0:
if (
messages
and len(messages) > 0
and litellm.redact_messages_in_exceptions is False
):
extra_information += f"\nMessages: `{messages}`"
if _model_group is not None:
@ -9803,8 +9834,7 @@ def exception_type(
elif custom_llm_provider == "azure":
if "Internal server error" in error_str:
exception_mapping_worked = True
raise APIError(
status_code=500,
raise litellm.InternalServerError(
message=f"AzureException Internal server error - {original_exception.message}",
llm_provider="azure",
model=model,
@ -10054,6 +10084,8 @@ def get_secret(
):
key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings
args = locals()
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
@ -10141,13 +10173,13 @@ def get_secret(
key_manager = "local"
if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = client.get_secret(secret_name).value
elif (
key_manager == KeyManagementSystem.GOOGLE_KMS
key_manager == KeyManagementSystem.GOOGLE_KMS.value
or client.__class__.__name__ == "KeyManagementServiceClient"
):
encrypted_secret: Any = os.getenv(secret_name)
@ -10175,6 +10207,25 @@ def get_secret(
secret = response.plaintext.decode(
"utf-8"
) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value:
"""
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
"""
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception("encrypted value for AWS KMS cannot be None.")
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try:
get_secret_value_response = client.get_secret_value(
@ -10195,10 +10246,14 @@ def get_secret(
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
elif key_manager == "local":
secret = os.getenv(secret_name)
else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ
print_verbose(f"An exception occurred - {str(e)}")
verbose_logger.error(
f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}"
)
secret = os.getenv(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
@ -10532,7 +10587,12 @@ class CustomStreamWrapper:
"finish_reason": finish_reason,
}
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.CustomStreamWrapper.handle_predibase_chunk(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
raise e
def handle_huggingface_chunk(self, chunk):
@ -10576,7 +10636,12 @@ class CustomStreamWrapper:
"finish_reason": finish_reason,
}
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.CustomStreamWrapper.handle_huggingface_chunk(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
raise e
def handle_ai21_chunk(self, chunk): # fake streaming
@ -10811,7 +10876,12 @@ class CustomStreamWrapper:
"usage": usage,
}
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.CustomStreamWrapper.handle_openai_chat_completion_chunk(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
raise e
def handle_azure_text_completion_chunk(self, chunk):
@ -10892,7 +10962,12 @@ class CustomStreamWrapper:
else:
return ""
except:
traceback.print_exc()
verbose_logger.error(
"litellm.CustomStreamWrapper.handle_baseten_chunk(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
return ""
def handle_cloudlfare_stream(self, chunk):
@ -11091,7 +11166,12 @@ class CustomStreamWrapper:
"is_finished": True,
}
except:
traceback.print_exc()
verbose_logger.error(
"litellm.CustomStreamWrapper.handle_clarifai_chunk(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
return ""
def model_response_creator(self):
@ -11332,12 +11412,27 @@ class CustomStreamWrapper:
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock":
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None:
raise StopIteration
response_obj = self.handle_bedrock_stream(chunk)
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
@ -11563,7 +11658,12 @@ class CustomStreamWrapper:
tool["type"] = "function"
model_response.choices[0].delta = Delta(**_json_delta)
except Exception as e:
traceback.print_exc()
verbose_logger.error(
"litellm.CustomStreamWrapper.chunk_creator(): Exception occured - {}".format(
str(e)
)
)
verbose_logger.debug(traceback.format_exc())
model_response.choices[0].delta = Delta()
else:
try:
@ -11599,7 +11699,7 @@ class CustomStreamWrapper:
and hasattr(model_response, "usage")
and hasattr(model_response.usage, "prompt_tokens")
):
if self.sent_first_chunk == False:
if self.sent_first_chunk is False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
@ -11767,6 +11867,8 @@ class CustomStreamWrapper:
def __next__(self):
try:
if self.completion_stream is None:
self.fetch_sync_stream()
while True:
if (
isinstance(self.completion_stream, str)
@ -11841,6 +11943,14 @@ class CustomStreamWrapper:
custom_llm_provider=self.custom_llm_provider,
)
def fetch_sync_stream(self):
if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream
self.completion_stream = self.make_call(client=litellm.module_level_client)
self._stream_iter = self.completion_stream.__iter__()
return self.completion_stream
async def fetch_stream(self):
if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream

View file

@ -1,10 +1,14 @@
[tool.poetry]
name = "litellm"
version = "1.40.4"
version = "1.40.5"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT"
readme = "README.md"
packages = [
{ include = "litellm" },
{ include = "litellm/py.typed"},
]
[tool.poetry.urls]
homepage = "https://litellm.ai"
@ -80,7 +84,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
version = "1.40.4"
version = "1.40.5"
version_files = [
"pyproject.toml:^version"
]

3
ruff.toml Normal file
View file

@ -0,0 +1,3 @@
ignore = ["F405"]
extend-select = ["E501"]
line-length = 120