diff --git a/.circleci/config.yml b/.circleci/config.yml
index 26a2ae356b..b43a8aa64c 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -125,6 +125,7 @@ jobs:
pip install tiktoken
pip install aiohttp
pip install click
+ pip install "boto3==1.34.34"
pip install jinja2
pip install tokenizers
pip install openai
@@ -287,6 +288,7 @@ jobs:
pip install "pytest==7.3.1"
pip install "pytest-mock==3.12.0"
pip install "pytest-asyncio==0.21.1"
+ pip install "boto3==1.34.34"
pip install mypy
pip install pyarrow
pip install numpydoc
diff --git a/Dockerfile b/Dockerfile
index c8e9956b29..bd840eaf54 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -62,6 +62,11 @@ COPY --from=builder /wheels/ /wheels/
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels
# Generate prisma client
+ENV PRISMA_BINARY_CACHE_DIR=/app/prisma
+RUN mkdir -p /.cache
+RUN chmod -R 777 /.cache
+RUN pip install nodejs-bin
+RUN pip install prisma
RUN prisma generate
RUN chmod +x entrypoint.sh
diff --git a/Dockerfile.database b/Dockerfile.database
index 22084bab89..c995939e5b 100644
--- a/Dockerfile.database
+++ b/Dockerfile.database
@@ -62,6 +62,11 @@ RUN pip install PyJWT --no-cache-dir
RUN chmod +x build_admin_ui.sh && ./build_admin_ui.sh
# Generate prisma client
+ENV PRISMA_BINARY_CACHE_DIR=/app/prisma
+RUN mkdir -p /.cache
+RUN chmod -R 777 /.cache
+RUN pip install nodejs-bin
+RUN pip install prisma
RUN prisma generate
RUN chmod +x entrypoint.sh
diff --git a/docs/my-website/docs/completion/json_mode.md b/docs/my-website/docs/completion/json_mode.md
index bf159cd07e..1d12a22ba0 100644
--- a/docs/my-website/docs/completion/json_mode.md
+++ b/docs/my-website/docs/completion/json_mode.md
@@ -84,17 +84,20 @@ from litellm import completion
# add to env var
os.environ["OPENAI_API_KEY"] = ""
-messages = [{"role": "user", "content": "List 5 cookie recipes"}]
+messages = [{"role": "user", "content": "List 5 important events in the XIX century"}]
class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]
+class EventsList(BaseModel):
+ events: list[CalendarEvent]
+
resp = completion(
model="gpt-4o-2024-08-06",
messages=messages,
- response_format=CalendarEvent
+ response_format=EventsList
)
print("Received={}".format(resp))
diff --git a/docs/my-website/docs/providers/anthropic.md b/docs/my-website/docs/providers/anthropic.md
index 2227b7a6b5..2a7804bfda 100644
--- a/docs/my-website/docs/providers/anthropic.md
+++ b/docs/my-website/docs/providers/anthropic.md
@@ -225,22 +225,336 @@ print(response)
| claude-instant-1.2 | `completion('claude-instant-1.2', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-instant-1 | `completion('claude-instant-1', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
-## Passing Extra Headers to Anthropic API
+## **Prompt Caching**
-Pass `extra_headers: dict` to `litellm.completion`
+Use Anthropic Prompt Caching
-```python
-from litellm import completion
-messages = [{"role": "user", "content": "What is Anthropic?"}]
-response = completion(
- model="claude-3-5-sonnet-20240620",
- messages=messages,
- extra_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}
+
+[Relevant Anthropic API Docs](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching)
+
+### Caching - Large Context Caching
+
+This example demonstrates basic Prompt Caching usage, caching the full text of the legal agreement as a prefix while keeping the user instruction uncached.
+
+
+
+
+```python
+response = await litellm.acompletion(
+ model="anthropic/claude-3-5-sonnet-20240620",
+ messages=[
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "You are an AI assistant tasked with analyzing legal documents.",
+ },
+ {
+ "type": "text",
+ "text": "Here is the full text of a complex legal agreement",
+ "cache_control": {"type": "ephemeral"},
+ },
+ ],
+ },
+ {
+ "role": "user",
+ "content": "what are the key terms and conditions in this agreement?",
+ },
+ ],
+ extra_headers={
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "prompt-caching-2024-07-31",
+ },
+)
+
+```
+
+
+
+:::info
+
+LiteLLM Proxy is OpenAI compatible
+
+This is an example using the OpenAI Python SDK sending a request to LiteLLM Proxy
+
+Assuming you have a model=`anthropic/claude-3-5-sonnet-20240620` on the [litellm proxy config.yaml](#usage-with-litellm-proxy)
+
+:::
+
+```python
+import openai
+client = openai.AsyncOpenAI(
+ api_key="anything", # litellm proxy api key
+ base_url="http://0.0.0.0:4000" # litellm proxy base url
+)
+
+
+response = await client.chat.completions.create(
+ model="anthropic/claude-3-5-sonnet-20240620",
+ messages=[
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "You are an AI assistant tasked with analyzing legal documents.",
+ },
+ {
+ "type": "text",
+ "text": "Here is the full text of a complex legal agreement",
+ "cache_control": {"type": "ephemeral"},
+ },
+ ],
+ },
+ {
+ "role": "user",
+ "content": "what are the key terms and conditions in this agreement?",
+ },
+ ],
+ extra_headers={
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "prompt-caching-2024-07-31",
+ },
+)
+
+```
+
+
+
+
+### Caching - Tools definitions
+
+In this example, we demonstrate caching tool definitions.
+
+The cache_control parameter is placed on the final tool
+
+
+
+
+```python
+import litellm
+
+response = await litellm.acompletion(
+ model="anthropic/claude-3-5-sonnet-20240620",
+ messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather in a given location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
+ },
+ "required": ["location"],
+ },
+ "cache_control": {"type": "ephemeral"}
+ },
+ }
+ ],
+ extra_headers={
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "prompt-caching-2024-07-31",
+ },
)
```
-## Advanced
+
+
-## Usage - Function Calling
+:::info
+
+LiteLLM Proxy is OpenAI compatible
+
+This is an example using the OpenAI Python SDK sending a request to LiteLLM Proxy
+
+Assuming you have a model=`anthropic/claude-3-5-sonnet-20240620` on the [litellm proxy config.yaml](#usage-with-litellm-proxy)
+
+:::
+
+```python
+import openai
+client = openai.AsyncOpenAI(
+ api_key="anything", # litellm proxy api key
+ base_url="http://0.0.0.0:4000" # litellm proxy base url
+)
+
+response = await client.chat.completions.create(
+ model="anthropic/claude-3-5-sonnet-20240620",
+ messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather in a given location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
+ },
+ "required": ["location"],
+ },
+ "cache_control": {"type": "ephemeral"}
+ },
+ }
+ ],
+ extra_headers={
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "prompt-caching-2024-07-31",
+ },
+)
+```
+
+
+
+
+
+### Caching - Continuing Multi-Turn Convo
+
+In this example, we demonstrate how to use Prompt Caching in a multi-turn conversation.
+
+The cache_control parameter is placed on the system message to designate it as part of the static prefix.
+
+The conversation history (previous messages) is included in the messages array. The final turn is marked with cache-control, for continuing in followups. The second-to-last user message is marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
+
+
+
+
+```python
+import litellm
+
+response = await litellm.acompletion(
+ model="anthropic/claude-3-5-sonnet-20240620",
+ messages=[
+ # System Message
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "Here is the full text of a complex legal agreement"
+ * 400,
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What are the key terms and conditions in this agreement?",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ {
+ "role": "assistant",
+ "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
+ },
+ # The final turn is marked with cache-control, for continuing in followups.
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What are the key terms and conditions in this agreement?",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ ],
+ extra_headers={
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "prompt-caching-2024-07-31",
+ },
+)
+```
+
+
+
+:::info
+
+LiteLLM Proxy is OpenAI compatible
+
+This is an example using the OpenAI Python SDK sending a request to LiteLLM Proxy
+
+Assuming you have a model=`anthropic/claude-3-5-sonnet-20240620` on the [litellm proxy config.yaml](#usage-with-litellm-proxy)
+
+:::
+
+```python
+import openai
+client = openai.AsyncOpenAI(
+ api_key="anything", # litellm proxy api key
+ base_url="http://0.0.0.0:4000" # litellm proxy base url
+)
+
+response = await client.chat.completions.create(
+ model="anthropic/claude-3-5-sonnet-20240620",
+ messages=[
+ # System Message
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "Here is the full text of a complex legal agreement"
+ * 400,
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What are the key terms and conditions in this agreement?",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ {
+ "role": "assistant",
+ "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
+ },
+ # The final turn is marked with cache-control, for continuing in followups.
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What are the key terms and conditions in this agreement?",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ ],
+ extra_headers={
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "prompt-caching-2024-07-31",
+ },
+)
+```
+
+
+
+
+## **Function/Tool Calling**
:::info
@@ -429,6 +743,20 @@ resp = litellm.completion(
print(f"\nResponse: {resp}")
```
+## **Passing Extra Headers to Anthropic API**
+
+Pass `extra_headers: dict` to `litellm.completion`
+
+```python
+from litellm import completion
+messages = [{"role": "user", "content": "What is Anthropic?"}]
+response = completion(
+ model="claude-3-5-sonnet-20240620",
+ messages=messages,
+ extra_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}
+)
+```
+
## Usage - "Assistant Pre-fill"
You can "put words in Claude's mouth" by including an `assistant` role message as the last item in the `messages` array.
diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md
index 485dbf892b..907dfc2337 100644
--- a/docs/my-website/docs/providers/bedrock.md
+++ b/docs/my-website/docs/providers/bedrock.md
@@ -393,7 +393,7 @@ response = completion(
)
```
-
+
```python
@@ -420,6 +420,55 @@ extra_body={
}
)
+print(response)
+```
+
+
+
+1. Update config.yaml
+
+```yaml
+model_list:
+ - model_name: bedrock-claude-v1
+ litellm_params:
+ model: bedrock/anthropic.claude-instant-v1
+ aws_access_key_id: os.environ/CUSTOM_AWS_ACCESS_KEY_ID
+ aws_secret_access_key: os.environ/CUSTOM_AWS_SECRET_ACCESS_KEY
+ aws_region_name: os.environ/CUSTOM_AWS_REGION_NAME
+ guardrailConfig: {
+ "guardrailIdentifier": "ff6ujrregl1q", # The identifier (ID) for the guardrail.
+ "guardrailVersion": "DRAFT", # The version of the guardrail.
+ "trace": "disabled", # The trace behavior for the guardrail. Can either be "disabled" or "enabled"
+ }
+
+```
+
+2. Start proxy
+
+```bash
+litellm --config /path/to/config.yaml
+```
+
+3. Test it!
+
+```python
+
+import openai
+client = openai.OpenAI(
+ api_key="anything",
+ base_url="http://0.0.0.0:4000"
+)
+
+# request sent to model set on litellm proxy, `litellm --model`
+response = client.chat.completions.create(model="bedrock-claude-v1", messages = [
+ {
+ "role": "user",
+ "content": "this is a test request, write a short poem"
+ }
+],
+temperature=0.7
+)
+
print(response)
```
diff --git a/docs/my-website/docs/proxy/deploy.md b/docs/my-website/docs/proxy/deploy.md
index 7c254ed35d..9f21068e03 100644
--- a/docs/my-website/docs/proxy/deploy.md
+++ b/docs/my-website/docs/proxy/deploy.md
@@ -705,6 +705,29 @@ docker run ghcr.io/berriai/litellm:main-latest \
Provide an ssl certificate when starting litellm proxy server
+### 3. Providing LiteLLM config.yaml file as a s3 Object/url
+
+Use this if you cannot mount a config file on your deployment service (example - AWS Fargate, Railway etc)
+
+LiteLLM Proxy will read your config.yaml from an s3 Bucket
+
+Set the following .env vars
+```shell
+LITELLM_CONFIG_BUCKET_NAME = "litellm-proxy" # your bucket name on s3
+LITELLM_CONFIG_BUCKET_OBJECT_KEY = "litellm_proxy_config.yaml" # object key on s3
+```
+
+Start litellm proxy with these env vars - litellm will read your config from s3
+
+```shell
+docker run --name litellm-proxy \
+ -e DATABASE_URL= \
+ -e LITELLM_CONFIG_BUCKET_NAME= \
+ -e LITELLM_CONFIG_BUCKET_OBJECT_KEY="> \
+ -p 4000:4000 \
+ ghcr.io/berriai/litellm-database:main-latest
+```
+
## Platform-specific Guide
diff --git a/docs/my-website/docs/proxy/model_management.md b/docs/my-website/docs/proxy/model_management.md
index 02ce4ba23b..a8cc66ae76 100644
--- a/docs/my-website/docs/proxy/model_management.md
+++ b/docs/my-website/docs/proxy/model_management.md
@@ -17,7 +17,7 @@ model_list:
## Get Model Information - `/model/info`
-Retrieve detailed information about each model listed in the `/model/info` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled the model_info you set and the litellm model cost map. Sensitive details like API keys are excluded for security purposes.
+Retrieve detailed information about each model listed in the `/model/info` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled from the model_info you set and the [litellm model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). Sensitive details like API keys are excluded for security purposes.
-
+
+
```bash
curl -X POST "http://0.0.0.0:4000/model/new" \
- -H "accept: application/json" \
- -H "Content-Type: application/json" \
- -d '{ "model_name": "azure-gpt-turbo", "litellm_params": {"model": "azure/gpt-3.5-turbo", "api_key": "os.environ/AZURE_API_KEY", "api_base": "my-azure-api-base"} }'
+ -H "accept: application/json" \
+ -H "Content-Type: application/json" \
+ -d '{ "model_name": "azure-gpt-turbo", "litellm_params": {"model": "azure/gpt-3.5-turbo", "api_key": "os.environ/AZURE_API_KEY", "api_base": "my-azure-api-base"} }'
```
-
+
+
+
+```yaml
+model_list:
+ - model_name: gpt-3.5-turbo ### RECEIVED MODEL NAME ### `openai.chat.completions.create(model="gpt-3.5-turbo",...)`
+ litellm_params: # all params accepted by litellm.completion() - https://github.com/BerriAI/litellm/blob/9b46ec05b02d36d6e4fb5c32321e51e7f56e4a6e/litellm/types/router.py#L297
+ model: azure/gpt-turbo-small-eu ### MODEL NAME sent to `litellm.completion()` ###
+ api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
+ api_key: "os.environ/AZURE_API_KEY_EU" # does os.getenv("AZURE_API_KEY_EU")
+ rpm: 6 # [OPTIONAL] Rate limit for this deployment: in requests per minute (rpm)
+ model_info:
+ my_custom_key: my_custom_value # additional model metadata
+```
+
+
@@ -85,4 +96,83 @@ Keep in mind that as both endpoints are in [BETA], you may need to visit the ass
- Get Model Information: [Issue #933](https://github.com/BerriAI/litellm/issues/933)
- Add a New Model: [Issue #964](https://github.com/BerriAI/litellm/issues/964)
-Feedback on the beta endpoints is valuable and helps improve the API for all users.
\ No newline at end of file
+Feedback on the beta endpoints is valuable and helps improve the API for all users.
+
+
+## Add Additional Model Information
+
+If you want the ability to add a display name, description, and labels for models, just use `model_info:`
+
+```yaml
+model_list:
+ - model_name: "gpt-4"
+ litellm_params:
+ model: "gpt-4"
+ api_key: "os.environ/OPENAI_API_KEY"
+ model_info: # 👈 KEY CHANGE
+ my_custom_key: "my_custom_value"
+```
+
+### Usage
+
+1. Add additional information to model
+
+```yaml
+model_list:
+ - model_name: "gpt-4"
+ litellm_params:
+ model: "gpt-4"
+ api_key: "os.environ/OPENAI_API_KEY"
+ model_info: # 👈 KEY CHANGE
+ my_custom_key: "my_custom_value"
+```
+
+2. Call with `/model/info`
+
+Use a key with access to the model `gpt-4`.
+
+```bash
+curl -L -X GET 'http://0.0.0.0:4000/v1/model/info' \
+-H 'Authorization: Bearer LITELLM_KEY' \
+```
+
+3. **Expected Response**
+
+Returned `model_info = Your custom model_info + (if exists) LITELLM MODEL INFO`
+
+
+[**How LiteLLM Model Info is found**](https://github.com/BerriAI/litellm/blob/9b46ec05b02d36d6e4fb5c32321e51e7f56e4a6e/litellm/proxy/proxy_server.py#L7460)
+
+[Tell us how this can be improved!](https://github.com/BerriAI/litellm/issues)
+
+```bash
+{
+ "data": [
+ {
+ "model_name": "gpt-4",
+ "litellm_params": {
+ "model": "gpt-4"
+ },
+ "model_info": {
+ "id": "e889baacd17f591cce4c63639275ba5e8dc60765d6c553e6ee5a504b19e50ddc",
+ "db_model": false,
+ "my_custom_key": "my_custom_value", # 👈 CUSTOM INFO
+ "key": "gpt-4", # 👈 KEY in LiteLLM MODEL INFO/COST MAP - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
+ "max_tokens": 4096,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 3e-05,
+ "input_cost_per_character": null,
+ "input_cost_per_token_above_128k_tokens": null,
+ "output_cost_per_token": 6e-05,
+ "output_cost_per_character": null,
+ "output_cost_per_token_above_128k_tokens": null,
+ "output_cost_per_character_above_128k_tokens": null,
+ "output_vector_size": null,
+ "litellm_provider": "openai",
+ "mode": "chat"
+ }
+ },
+ ]
+}
+```
diff --git a/docs/my-website/docs/proxy/pass_through.md b/docs/my-website/docs/proxy/pass_through.md
index 4554f80135..bad23f0de0 100644
--- a/docs/my-website/docs/proxy/pass_through.md
+++ b/docs/my-website/docs/proxy/pass_through.md
@@ -193,6 +193,53 @@ curl --request POST \
}'
```
+### Use Langfuse client sdk w/ LiteLLM Key
+
+**Usage**
+
+1. Set-up yaml to pass-through langfuse /api/public/ingestion
+
+```yaml
+general_settings:
+ master_key: sk-1234
+ pass_through_endpoints:
+ - path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server
+ target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward
+ auth: true # 👈 KEY CHANGE
+ custom_auth_parser: "langfuse" # 👈 KEY CHANGE
+ headers:
+ LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_DEV_PUBLIC_KEY" # your langfuse account public key
+ LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY" # your langfuse account secret key
+```
+
+2. Start proxy
+
+```bash
+litellm --config /path/to/config.yaml
+```
+
+3. Test with langfuse sdk
+
+
+```python
+
+from langfuse import Langfuse
+
+langfuse = Langfuse(
+ host="http://localhost:4000", # your litellm proxy endpoint
+ public_key="sk-1234", # your litellm proxy api key
+ secret_key="anything", # no key required since this is a pass through
+)
+
+print("sending langfuse trace request")
+trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
+print("flushing langfuse request")
+langfuse.flush()
+
+print("flushed langfuse request")
+```
+
+
## `pass_through_endpoints` Spec on config.yaml
All possible values for `pass_through_endpoints` and what they mean
diff --git a/docs/my-website/docs/proxy/prometheus.md b/docs/my-website/docs/proxy/prometheus.md
index 6c856f58b3..4b913d2e82 100644
--- a/docs/my-website/docs/proxy/prometheus.md
+++ b/docs/my-website/docs/proxy/prometheus.md
@@ -72,15 +72,15 @@ http://localhost:4000/metrics
| Metric Name | Description |
|----------------------|--------------------------------------|
-| `deployment_state` | The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage. |
+| `litellm_deployment_state` | The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage. |
| `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment |
| `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment |
- `llm_deployment_success_responses` | Total number of successful LLM API calls for deployment |
-| `llm_deployment_failure_responses` | Total number of failed LLM API calls for deployment |
-| `llm_deployment_total_requests` | Total number of LLM API calls for deployment - success + failure |
-| `llm_deployment_latency_per_output_token` | Latency per output token for deployment |
-| `llm_deployment_successful_fallbacks` | Number of successful fallback requests from primary model -> fallback model |
-| `llm_deployment_failed_fallbacks` | Number of failed fallback requests from primary model -> fallback model |
+ `litellm_deployment_success_responses` | Total number of successful LLM API calls for deployment |
+| `litellm_deployment_failure_responses` | Total number of failed LLM API calls for deployment |
+| `litellm_deployment_total_requests` | Total number of LLM API calls for deployment - success + failure |
+| `litellm_deployment_latency_per_output_token` | Latency per output token for deployment |
+| `litellm_deployment_successful_fallbacks` | Number of successful fallback requests from primary model -> fallback model |
+| `litellm_deployment_failed_fallbacks` | Number of failed fallback requests from primary model -> fallback model |
diff --git a/docs/my-website/docs/proxy/team_logging.md b/docs/my-website/docs/proxy/team_logging.md
index 1cc91c2dfe..e36cb8f669 100644
--- a/docs/my-website/docs/proxy/team_logging.md
+++ b/docs/my-website/docs/proxy/team_logging.md
@@ -2,9 +2,9 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
-# 👥📊 Team Based Logging
+# 👥📊 Team/Key Based Logging
-Allow each team to use their own Langfuse Project / custom callbacks
+Allow each key/team to use their own Langfuse Project / custom callbacks
**This allows you to do the following**
```
@@ -189,3 +189,39 @@ curl -X GET 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/cal
+
+
+## [BETA] Key Based Logging
+
+Use the `/key/generate` or `/key/update` endpoints to add logging callbacks to a specific key.
+
+:::info
+
+✨ This is an Enterprise only feature [Get Started with Enterprise here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
+
+:::
+
+```bash
+curl -X POST 'http://0.0.0.0:4000/key/generate' \
+-H 'Authorization: Bearer sk-1234' \
+-H 'Content-Type: application/json' \
+-d '{
+ "metadata": {
+ "logging": {
+ "callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
+ "callback_type": "success" # set, if required by integration - future improvement, have logging tools work for success + failure by default
+ "callback_vars": {
+ "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY", # [RECOMMENDED] reference key in proxy environment
+ "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", # [RECOMMENDED] reference key in proxy environment
+ "langfuse_host": "https://cloud.langfuse.com"
+ }
+ }
+ }
+}'
+
+```
+
+---
+
+Help us improve this feature, by filing a [ticket here](https://github.com/BerriAI/litellm/issues)
+
diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js
index 7df5e61578..3c3e1cbf97 100644
--- a/docs/my-website/sidebars.js
+++ b/docs/my-website/sidebars.js
@@ -151,7 +151,7 @@ const sidebars = {
},
{
type: "category",
- label: "Chat Completions (litellm.completion)",
+ label: "Chat Completions (litellm.completion + PROXY)",
link: {
type: "generated-index",
title: "Chat Completions",
diff --git a/litellm/integrations/gcs_bucket.py b/litellm/integrations/gcs_bucket.py
index 46f55f8f01..be7f8e39c2 100644
--- a/litellm/integrations/gcs_bucket.py
+++ b/litellm/integrations/gcs_bucket.py
@@ -1,5 +1,6 @@
import json
import os
+import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, TypedDict, Union
@@ -29,6 +30,8 @@ class GCSBucketPayload(TypedDict):
end_time: str
response_cost: Optional[float]
spend_log_metadata: str
+ exception: Optional[str]
+ log_event_type: Optional[str]
class GCSBucketLogger(CustomLogger):
@@ -79,6 +82,7 @@ class GCSBucketLogger(CustomLogger):
logging_payload: GCSBucketPayload = await self.get_gcs_payload(
kwargs, response_obj, start_time_str, end_time_str
)
+ logging_payload["log_event_type"] = "successful_api_call"
json_logged_payload = json.dumps(logging_payload)
@@ -103,7 +107,56 @@ class GCSBucketLogger(CustomLogger):
verbose_logger.error("GCS Bucket logging error: %s", str(e))
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
- pass
+ from litellm.proxy.proxy_server import premium_user
+
+ if premium_user is not True:
+ raise ValueError(
+ f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
+ )
+ try:
+ verbose_logger.debug(
+ "GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+
+ start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S")
+ end_time_str = end_time.strftime("%Y-%m-%d %H:%M:%S")
+ headers = await self.construct_request_headers()
+
+ logging_payload: GCSBucketPayload = await self.get_gcs_payload(
+ kwargs, response_obj, start_time_str, end_time_str
+ )
+ logging_payload["log_event_type"] = "failed_api_call"
+
+ _litellm_params = kwargs.get("litellm_params") or {}
+ metadata = _litellm_params.get("metadata") or {}
+
+ json_logged_payload = json.dumps(logging_payload)
+
+ # Get the current date
+ current_date = datetime.now().strftime("%Y-%m-%d")
+
+ # Modify the object_name to include the date-based folder
+ object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
+
+ if "gcs_log_id" in metadata:
+ object_name = metadata["gcs_log_id"]
+
+ response = await self.async_httpx_client.post(
+ headers=headers,
+ url=f"https://storage.googleapis.com/upload/storage/v1/b/{self.BUCKET_NAME}/o?uploadType=media&name={object_name}",
+ data=json_logged_payload,
+ )
+
+ if response.status_code != 200:
+ verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
+
+ verbose_logger.debug("GCS Bucket response %s", response)
+ verbose_logger.debug("GCS Bucket status code %s", response.status_code)
+ verbose_logger.debug("GCS Bucket response.text %s", response.text)
+ except Exception as e:
+ verbose_logger.error("GCS Bucket logging error: %s", str(e))
async def construct_request_headers(self) -> Dict[str, str]:
from litellm import vertex_chat_completion
@@ -139,9 +192,18 @@ class GCSBucketLogger(CustomLogger):
optional_params=kwargs.get("optional_params", None),
)
response_dict = {}
- response_dict = convert_litellm_response_object_to_dict(
- response_obj=response_obj
- )
+ if response_obj:
+ response_dict = convert_litellm_response_object_to_dict(
+ response_obj=response_obj
+ )
+
+ exception_str = None
+
+ # Handle logging exception attributes
+ if "exception" in kwargs:
+ exception_str = kwargs.get("exception", "")
+ if not isinstance(exception_str, str):
+ exception_str = str(exception_str)
_spend_log_payload: SpendLogsPayload = get_logging_payload(
kwargs=kwargs,
@@ -156,8 +218,10 @@ class GCSBucketLogger(CustomLogger):
response_obj=response_dict,
start_time=start_time,
end_time=end_time,
- spend_log_metadata=_spend_log_payload["metadata"],
+ spend_log_metadata=_spend_log_payload.get("metadata", ""),
response_cost=kwargs.get("response_cost", None),
+ exception=exception_str,
+ log_event_type=None,
)
return gcs_payload
diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py
index 864fb34e20..d6c235d0cb 100644
--- a/litellm/integrations/langfuse.py
+++ b/litellm/integrations/langfuse.py
@@ -605,6 +605,12 @@ class LangFuseLogger:
if "cache_key" in litellm.langfuse_default_tags:
_hidden_params = metadata.get("hidden_params", {}) or {}
_cache_key = _hidden_params.get("cache_key", None)
+ if _cache_key is None:
+ # fallback to using "preset_cache_key"
+ _preset_cache_key = kwargs.get("litellm_params", {}).get(
+ "preset_cache_key", None
+ )
+ _cache_key = _preset_cache_key
tags.append(f"cache_key:{_cache_key}")
return tags
@@ -676,7 +682,6 @@ def log_provider_specific_information_as_span(
Returns:
None
"""
- from litellm.proxy.proxy_server import premium_user
_hidden_params = clean_metadata.get("hidden_params", None)
if _hidden_params is None:
diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py
index 8797807ac6..08431fd7af 100644
--- a/litellm/integrations/prometheus.py
+++ b/litellm/integrations/prometheus.py
@@ -141,42 +141,42 @@ class PrometheusLogger(CustomLogger):
]
# Metric for deployment state
- self.deployment_state = Gauge(
- "deployment_state",
+ self.litellm_deployment_state = Gauge(
+ "litellm_deployment_state",
"LLM Deployment Analytics - The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage",
labelnames=_logged_llm_labels,
)
- self.llm_deployment_success_responses = Counter(
- name="llm_deployment_success_responses",
+ self.litellm_deployment_success_responses = Counter(
+ name="litellm_deployment_success_responses",
documentation="LLM Deployment Analytics - Total number of successful LLM API calls via litellm",
labelnames=_logged_llm_labels,
)
- self.llm_deployment_failure_responses = Counter(
- name="llm_deployment_failure_responses",
+ self.litellm_deployment_failure_responses = Counter(
+ name="litellm_deployment_failure_responses",
documentation="LLM Deployment Analytics - Total number of failed LLM API calls via litellm",
labelnames=_logged_llm_labels,
)
- self.llm_deployment_total_requests = Counter(
- name="llm_deployment_total_requests",
+ self.litellm_deployment_total_requests = Counter(
+ name="litellm_deployment_total_requests",
documentation="LLM Deployment Analytics - Total number of LLM API calls via litellm - success + failure",
labelnames=_logged_llm_labels,
)
# Deployment Latency tracking
- self.llm_deployment_latency_per_output_token = Histogram(
- name="llm_deployment_latency_per_output_token",
+ self.litellm_deployment_latency_per_output_token = Histogram(
+ name="litellm_deployment_latency_per_output_token",
documentation="LLM Deployment Analytics - Latency per output token",
labelnames=_logged_llm_labels,
)
- self.llm_deployment_successful_fallbacks = Counter(
- "llm_deployment_successful_fallbacks",
+ self.litellm_deployment_successful_fallbacks = Counter(
+ "litellm_deployment_successful_fallbacks",
"LLM Deployment Analytics - Number of successful fallback requests from primary model -> fallback model",
["primary_model", "fallback_model"],
)
- self.llm_deployment_failed_fallbacks = Counter(
- "llm_deployment_failed_fallbacks",
+ self.litellm_deployment_failed_fallbacks = Counter(
+ "litellm_deployment_failed_fallbacks",
"LLM Deployment Analytics - Number of failed fallback requests from primary model -> fallback model",
["primary_model", "fallback_model"],
)
@@ -358,14 +358,14 @@ class PrometheusLogger(CustomLogger):
api_provider=llm_provider,
)
- self.llm_deployment_failure_responses.labels(
+ self.litellm_deployment_failure_responses.labels(
litellm_model_name=litellm_model_name,
model_id=model_id,
api_base=api_base,
api_provider=llm_provider,
).inc()
- self.llm_deployment_total_requests.labels(
+ self.litellm_deployment_total_requests.labels(
litellm_model_name=litellm_model_name,
model_id=model_id,
api_base=api_base,
@@ -438,14 +438,14 @@ class PrometheusLogger(CustomLogger):
api_provider=llm_provider,
)
- self.llm_deployment_success_responses.labels(
+ self.litellm_deployment_success_responses.labels(
litellm_model_name=litellm_model_name,
model_id=model_id,
api_base=api_base,
api_provider=llm_provider,
).inc()
- self.llm_deployment_total_requests.labels(
+ self.litellm_deployment_total_requests.labels(
litellm_model_name=litellm_model_name,
model_id=model_id,
api_base=api_base,
@@ -475,7 +475,7 @@ class PrometheusLogger(CustomLogger):
latency_per_token = None
if output_tokens is not None and output_tokens > 0:
latency_per_token = _latency_seconds / output_tokens
- self.llm_deployment_latency_per_output_token.labels(
+ self.litellm_deployment_latency_per_output_token.labels(
litellm_model_name=litellm_model_name,
model_id=model_id,
api_base=api_base,
@@ -497,7 +497,7 @@ class PrometheusLogger(CustomLogger):
kwargs,
)
_new_model = kwargs.get("model")
- self.llm_deployment_successful_fallbacks.labels(
+ self.litellm_deployment_successful_fallbacks.labels(
primary_model=original_model_group, fallback_model=_new_model
).inc()
@@ -508,11 +508,11 @@ class PrometheusLogger(CustomLogger):
kwargs,
)
_new_model = kwargs.get("model")
- self.llm_deployment_failed_fallbacks.labels(
+ self.litellm_deployment_failed_fallbacks.labels(
primary_model=original_model_group, fallback_model=_new_model
).inc()
- def set_deployment_state(
+ def set_litellm_deployment_state(
self,
state: int,
litellm_model_name: str,
@@ -520,7 +520,7 @@ class PrometheusLogger(CustomLogger):
api_base: str,
api_provider: str,
):
- self.deployment_state.labels(
+ self.litellm_deployment_state.labels(
litellm_model_name, model_id, api_base, api_provider
).set(state)
@@ -531,7 +531,7 @@ class PrometheusLogger(CustomLogger):
api_base: str,
api_provider: str,
):
- self.set_deployment_state(
+ self.set_litellm_deployment_state(
0, litellm_model_name, model_id, api_base, api_provider
)
@@ -542,7 +542,7 @@ class PrometheusLogger(CustomLogger):
api_base: str,
api_provider: str,
):
- self.set_deployment_state(
+ self.set_litellm_deployment_state(
1, litellm_model_name, model_id, api_base, api_provider
)
@@ -553,7 +553,7 @@ class PrometheusLogger(CustomLogger):
api_base: str,
api_provider: str,
):
- self.set_deployment_state(
+ self.set_litellm_deployment_state(
2, litellm_model_name, model_id, api_base, api_provider
)
diff --git a/litellm/integrations/prometheus_helpers/prometheus_api.py b/litellm/integrations/prometheus_helpers/prometheus_api.py
index 86764df7dd..13ccc15620 100644
--- a/litellm/integrations/prometheus_helpers/prometheus_api.py
+++ b/litellm/integrations/prometheus_helpers/prometheus_api.py
@@ -41,8 +41,8 @@ async def get_fallback_metric_from_prometheus():
"""
response_message = ""
relevant_metrics = [
- "llm_deployment_successful_fallbacks_total",
- "llm_deployment_failed_fallbacks_total",
+ "litellm_deployment_successful_fallbacks_total",
+ "litellm_deployment_failed_fallbacks_total",
]
for metric in relevant_metrics:
response_json = await get_metric_from_prometheus(
diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py
index 6f05aa226e..cf58163461 100644
--- a/litellm/llms/anthropic.py
+++ b/litellm/llms/anthropic.py
@@ -35,6 +35,7 @@ from litellm.types.llms.anthropic import (
AnthropicResponseContentBlockText,
AnthropicResponseContentBlockToolUse,
AnthropicResponseUsageBlock,
+ AnthropicSystemMessageContent,
ContentBlockDelta,
ContentBlockStart,
ContentBlockStop,
@@ -759,6 +760,7 @@ class AnthropicChatCompletion(BaseLLM):
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
+ _usage = completion_response["usage"]
total_tokens = prompt_tokens + completion_tokens
model_response.created = int(time.time())
@@ -768,6 +770,11 @@ class AnthropicChatCompletion(BaseLLM):
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
+
+ if "cache_creation_input_tokens" in _usage:
+ usage["cache_creation_input_tokens"] = _usage["cache_creation_input_tokens"]
+ if "cache_read_input_tokens" in _usage:
+ usage["cache_read_input_tokens"] = _usage["cache_read_input_tokens"]
setattr(model_response, "usage", usage) # type: ignore
return model_response
@@ -901,6 +908,7 @@ class AnthropicChatCompletion(BaseLLM):
# Separate system prompt from rest of message
system_prompt_indices = []
system_prompt = ""
+ anthropic_system_message_list = None
for idx, message in enumerate(messages):
if message["role"] == "system":
valid_content: bool = False
@@ -908,8 +916,23 @@ class AnthropicChatCompletion(BaseLLM):
system_prompt += message["content"]
valid_content = True
elif isinstance(message["content"], list):
- for content in message["content"]:
- system_prompt += content.get("text", "")
+ for _content in message["content"]:
+ anthropic_system_message_content = (
+ AnthropicSystemMessageContent(
+ type=_content.get("type"),
+ text=_content.get("text"),
+ )
+ )
+ if "cache_control" in _content:
+ anthropic_system_message_content["cache_control"] = (
+ _content["cache_control"]
+ )
+
+ if anthropic_system_message_list is None:
+ anthropic_system_message_list = []
+ anthropic_system_message_list.append(
+ anthropic_system_message_content
+ )
valid_content = True
if valid_content:
@@ -919,6 +942,10 @@ class AnthropicChatCompletion(BaseLLM):
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
+
+ # Handling anthropic API Prompt Caching
+ if anthropic_system_message_list is not None:
+ optional_params["system"] = anthropic_system_message_list
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
@@ -954,6 +981,8 @@ class AnthropicChatCompletion(BaseLLM):
else: # assume openai tool call
new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
+ if "cache_control" in tool:
+ new_tool["cache_control"] = tool["cache_control"]
anthropic_tools.append(new_tool)
optional_params["tools"] = anthropic_tools
diff --git a/litellm/llms/base_aws_llm.py b/litellm/llms/base_aws_llm.py
new file mode 100644
index 0000000000..8de42eda73
--- /dev/null
+++ b/litellm/llms/base_aws_llm.py
@@ -0,0 +1,218 @@
+import json
+from typing import List, Optional
+
+import httpx
+
+from litellm._logging import verbose_logger
+from litellm.caching import DualCache, InMemoryCache
+from litellm.utils import get_secret
+
+from .base import BaseLLM
+
+
+class AwsAuthError(Exception):
+ def __init__(self, status_code, message):
+ self.status_code = status_code
+ self.message = message
+ self.request = httpx.Request(
+ method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
+ )
+ self.response = httpx.Response(status_code=status_code, request=self.request)
+ super().__init__(
+ self.message
+ ) # Call the base class constructor with the parameters it needs
+
+
+class BaseAWSLLM(BaseLLM):
+ def __init__(self) -> None:
+ self.iam_cache = DualCache()
+ super().__init__()
+
+ def get_credentials(
+ self,
+ aws_access_key_id: Optional[str] = None,
+ aws_secret_access_key: Optional[str] = None,
+ aws_session_token: Optional[str] = None,
+ aws_region_name: Optional[str] = None,
+ aws_session_name: Optional[str] = None,
+ aws_profile_name: Optional[str] = None,
+ aws_role_name: Optional[str] = None,
+ aws_web_identity_token: Optional[str] = None,
+ aws_sts_endpoint: Optional[str] = None,
+ ):
+ """
+ Return a boto3.Credentials object
+ """
+ import boto3
+
+ ## CHECK IS 'os.environ/' passed in
+ params_to_check: List[Optional[str]] = [
+ aws_access_key_id,
+ aws_secret_access_key,
+ aws_session_token,
+ aws_region_name,
+ aws_session_name,
+ aws_profile_name,
+ aws_role_name,
+ aws_web_identity_token,
+ aws_sts_endpoint,
+ ]
+
+ # Iterate over parameters and update if needed
+ for i, param in enumerate(params_to_check):
+ if param and param.startswith("os.environ/"):
+ _v = get_secret(param)
+ if _v is not None and isinstance(_v, str):
+ params_to_check[i] = _v
+ # Assign updated values back to parameters
+ (
+ aws_access_key_id,
+ aws_secret_access_key,
+ aws_session_token,
+ aws_region_name,
+ aws_session_name,
+ aws_profile_name,
+ aws_role_name,
+ aws_web_identity_token,
+ aws_sts_endpoint,
+ ) = params_to_check
+
+ verbose_logger.debug(
+ "in get credentials\n"
+ "aws_access_key_id=%s\n"
+ "aws_secret_access_key=%s\n"
+ "aws_session_token=%s\n"
+ "aws_region_name=%s\n"
+ "aws_session_name=%s\n"
+ "aws_profile_name=%s\n"
+ "aws_role_name=%s\n"
+ "aws_web_identity_token=%s\n"
+ "aws_sts_endpoint=%s",
+ aws_access_key_id,
+ aws_secret_access_key,
+ aws_session_token,
+ aws_region_name,
+ aws_session_name,
+ aws_profile_name,
+ aws_role_name,
+ aws_web_identity_token,
+ aws_sts_endpoint,
+ )
+
+ ### CHECK STS ###
+ if (
+ aws_web_identity_token is not None
+ and aws_role_name is not None
+ and aws_session_name is not None
+ ):
+ verbose_logger.debug(
+ f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
+ )
+
+ if aws_sts_endpoint is None:
+ sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
+ else:
+ sts_endpoint = aws_sts_endpoint
+
+ iam_creds_cache_key = json.dumps(
+ {
+ "aws_web_identity_token": aws_web_identity_token,
+ "aws_role_name": aws_role_name,
+ "aws_session_name": aws_session_name,
+ "aws_region_name": aws_region_name,
+ "aws_sts_endpoint": sts_endpoint,
+ }
+ )
+
+ iam_creds_dict = self.iam_cache.get_cache(iam_creds_cache_key)
+ if iam_creds_dict is None:
+ oidc_token = get_secret(aws_web_identity_token)
+
+ if oidc_token is None:
+ raise AwsAuthError(
+ message="OIDC token could not be retrieved from secret manager.",
+ status_code=401,
+ )
+
+ sts_client = boto3.client(
+ "sts",
+ region_name=aws_region_name,
+ endpoint_url=sts_endpoint,
+ )
+
+ # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
+ sts_response = sts_client.assume_role_with_web_identity(
+ RoleArn=aws_role_name,
+ RoleSessionName=aws_session_name,
+ WebIdentityToken=oidc_token,
+ DurationSeconds=3600,
+ )
+
+ iam_creds_dict = {
+ "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
+ "aws_secret_access_key": sts_response["Credentials"][
+ "SecretAccessKey"
+ ],
+ "aws_session_token": sts_response["Credentials"]["SessionToken"],
+ "region_name": aws_region_name,
+ }
+
+ self.iam_cache.set_cache(
+ key=iam_creds_cache_key,
+ value=json.dumps(iam_creds_dict),
+ ttl=3600 - 60,
+ )
+
+ session = boto3.Session(**iam_creds_dict)
+
+ iam_creds = session.get_credentials()
+
+ return iam_creds
+ elif aws_role_name is not None and aws_session_name is not None:
+ sts_client = boto3.client(
+ "sts",
+ aws_access_key_id=aws_access_key_id, # [OPTIONAL]
+ aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
+ )
+
+ sts_response = sts_client.assume_role(
+ RoleArn=aws_role_name, RoleSessionName=aws_session_name
+ )
+
+ # Extract the credentials from the response and convert to Session Credentials
+ sts_credentials = sts_response["Credentials"]
+ from botocore.credentials import Credentials
+
+ credentials = Credentials(
+ access_key=sts_credentials["AccessKeyId"],
+ secret_key=sts_credentials["SecretAccessKey"],
+ token=sts_credentials["SessionToken"],
+ )
+ return credentials
+ elif aws_profile_name is not None: ### CHECK SESSION ###
+ # uses auth values from AWS profile usually stored in ~/.aws/credentials
+ client = boto3.Session(profile_name=aws_profile_name)
+
+ return client.get_credentials()
+ elif (
+ aws_access_key_id is not None
+ and aws_secret_access_key is not None
+ and aws_session_token is not None
+ ): ### CHECK FOR AWS SESSION TOKEN ###
+ from botocore.credentials import Credentials
+
+ credentials = Credentials(
+ access_key=aws_access_key_id,
+ secret_key=aws_secret_access_key,
+ token=aws_session_token,
+ )
+ return credentials
+ else:
+ session = boto3.Session(
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=aws_region_name,
+ )
+
+ return session.get_credentials()
diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py
index ffc096f762..73387212ff 100644
--- a/litellm/llms/bedrock_httpx.py
+++ b/litellm/llms/bedrock_httpx.py
@@ -57,6 +57,7 @@ from litellm.utils import (
)
from .base import BaseLLM
+from .base_aws_llm import BaseAWSLLM
from .bedrock import BedrockError, ModelResponseIterator, convert_messages_to_prompt
from .prompt_templates.factory import (
_bedrock_converse_messages_pt,
@@ -87,7 +88,6 @@ BEDROCK_CONVERSE_MODELS = [
]
-iam_cache = DualCache()
_response_stream_shape_cache = None
bedrock_tool_name_mappings: InMemoryCache = InMemoryCache(
max_size_in_memory=50, default_ttl=600
@@ -312,7 +312,7 @@ def make_sync_call(
return completion_stream
-class BedrockLLM(BaseLLM):
+class BedrockLLM(BaseAWSLLM):
"""
Example call
@@ -380,183 +380,6 @@ class BedrockLLM(BaseLLM):
prompt += f"{message['content']}"
return prompt, chat_history # type: ignore
- def get_credentials(
- self,
- aws_access_key_id: Optional[str] = None,
- aws_secret_access_key: Optional[str] = None,
- aws_session_token: Optional[str] = None,
- aws_region_name: Optional[str] = None,
- aws_session_name: Optional[str] = None,
- aws_profile_name: Optional[str] = None,
- aws_role_name: Optional[str] = None,
- aws_web_identity_token: Optional[str] = None,
- aws_sts_endpoint: Optional[str] = None,
- ):
- """
- Return a boto3.Credentials object
- """
- import boto3
-
- print_verbose(
- f"Boto3 get_credentials called variables passed to function {locals()}"
- )
-
- ## CHECK IS 'os.environ/' passed in
- params_to_check: List[Optional[str]] = [
- aws_access_key_id,
- aws_secret_access_key,
- aws_session_token,
- aws_region_name,
- aws_session_name,
- aws_profile_name,
- aws_role_name,
- aws_web_identity_token,
- aws_sts_endpoint,
- ]
-
- # Iterate over parameters and update if needed
- for i, param in enumerate(params_to_check):
- if param and param.startswith("os.environ/"):
- _v = get_secret(param)
- if _v is not None and isinstance(_v, str):
- params_to_check[i] = _v
- # Assign updated values back to parameters
- (
- aws_access_key_id,
- aws_secret_access_key,
- aws_session_token,
- aws_region_name,
- aws_session_name,
- aws_profile_name,
- aws_role_name,
- aws_web_identity_token,
- aws_sts_endpoint,
- ) = params_to_check
-
- ### CHECK STS ###
- if (
- aws_web_identity_token is not None
- and aws_role_name is not None
- and aws_session_name is not None
- ):
- print_verbose(
- f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
- )
-
- if aws_sts_endpoint is None:
- sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
- else:
- sts_endpoint = aws_sts_endpoint
-
- iam_creds_cache_key = json.dumps(
- {
- "aws_web_identity_token": aws_web_identity_token,
- "aws_role_name": aws_role_name,
- "aws_session_name": aws_session_name,
- "aws_region_name": aws_region_name,
- "aws_sts_endpoint": sts_endpoint,
- }
- )
-
- iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
- if iam_creds_dict is None:
- oidc_token = get_secret(aws_web_identity_token)
-
- if oidc_token is None:
- raise BedrockError(
- message="OIDC token could not be retrieved from secret manager.",
- status_code=401,
- )
-
- sts_client = boto3.client(
- "sts",
- region_name=aws_region_name,
- endpoint_url=sts_endpoint,
- )
-
- # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
- # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
- sts_response = sts_client.assume_role_with_web_identity(
- RoleArn=aws_role_name,
- RoleSessionName=aws_session_name,
- WebIdentityToken=oidc_token,
- DurationSeconds=3600,
- )
-
- iam_creds_dict = {
- "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
- "aws_secret_access_key": sts_response["Credentials"][
- "SecretAccessKey"
- ],
- "aws_session_token": sts_response["Credentials"]["SessionToken"],
- "region_name": aws_region_name,
- }
-
- iam_cache.set_cache(
- key=iam_creds_cache_key,
- value=json.dumps(iam_creds_dict),
- ttl=3600 - 60,
- )
-
- session = boto3.Session(**iam_creds_dict)
-
- iam_creds = session.get_credentials()
-
- return iam_creds
- elif aws_role_name is not None and aws_session_name is not None:
- print_verbose(
- f"Using STS Client AWS aws_role_name: {aws_role_name} aws_session_name: {aws_session_name}"
- )
- sts_client = boto3.client(
- "sts",
- aws_access_key_id=aws_access_key_id, # [OPTIONAL]
- aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
- )
-
- sts_response = sts_client.assume_role(
- RoleArn=aws_role_name, RoleSessionName=aws_session_name
- )
-
- # Extract the credentials from the response and convert to Session Credentials
- sts_credentials = sts_response["Credentials"]
- from botocore.credentials import Credentials
-
- credentials = Credentials(
- access_key=sts_credentials["AccessKeyId"],
- secret_key=sts_credentials["SecretAccessKey"],
- token=sts_credentials["SessionToken"],
- )
- return credentials
- elif aws_profile_name is not None: ### CHECK SESSION ###
- # uses auth values from AWS profile usually stored in ~/.aws/credentials
- print_verbose(f"Using AWS profile: {aws_profile_name}")
- client = boto3.Session(profile_name=aws_profile_name)
-
- return client.get_credentials()
- elif (
- aws_access_key_id is not None
- and aws_secret_access_key is not None
- and aws_session_token is not None
- ): ### CHECK FOR AWS SESSION TOKEN ###
- print_verbose(f"Using AWS Session Token: {aws_session_token}")
- from botocore.credentials import Credentials
-
- credentials = Credentials(
- access_key=aws_access_key_id,
- secret_key=aws_secret_access_key,
- token=aws_session_token,
- )
- return credentials
- else:
- print_verbose("Using Default AWS Session")
- session = boto3.Session(
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- region_name=aws_region_name,
- )
-
- return session.get_credentials()
-
def process_response(
self,
model: str,
@@ -1055,8 +878,8 @@ class BedrockLLM(BaseLLM):
},
)
raise BedrockError(
- status_code=400,
- message="Bedrock HTTPX: Unsupported provider={}, model={}".format(
+ status_code=404,
+ message="Bedrock HTTPX: Unknown provider={}, model={}".format(
provider, model
),
)
@@ -1414,7 +1237,7 @@ class AmazonConverseConfig:
return optional_params
-class BedrockConverseLLM(BaseLLM):
+class BedrockConverseLLM(BaseAWSLLM):
def __init__(self) -> None:
super().__init__()
@@ -1554,173 +1377,6 @@ class BedrockConverseLLM(BaseLLM):
"""
return urllib.parse.quote(model_id, safe="")
- def get_credentials(
- self,
- aws_access_key_id: Optional[str] = None,
- aws_secret_access_key: Optional[str] = None,
- aws_session_token: Optional[str] = None,
- aws_region_name: Optional[str] = None,
- aws_session_name: Optional[str] = None,
- aws_profile_name: Optional[str] = None,
- aws_role_name: Optional[str] = None,
- aws_web_identity_token: Optional[str] = None,
- aws_sts_endpoint: Optional[str] = None,
- ):
- """
- Return a boto3.Credentials object
- """
- import boto3
-
- ## CHECK IS 'os.environ/' passed in
- params_to_check: List[Optional[str]] = [
- aws_access_key_id,
- aws_secret_access_key,
- aws_session_token,
- aws_region_name,
- aws_session_name,
- aws_profile_name,
- aws_role_name,
- aws_web_identity_token,
- aws_sts_endpoint,
- ]
-
- # Iterate over parameters and update if needed
- for i, param in enumerate(params_to_check):
- if param and param.startswith("os.environ/"):
- _v = get_secret(param)
- if _v is not None and isinstance(_v, str):
- params_to_check[i] = _v
- # Assign updated values back to parameters
- (
- aws_access_key_id,
- aws_secret_access_key,
- aws_session_token,
- aws_region_name,
- aws_session_name,
- aws_profile_name,
- aws_role_name,
- aws_web_identity_token,
- aws_sts_endpoint,
- ) = params_to_check
-
- ### CHECK STS ###
- if (
- aws_web_identity_token is not None
- and aws_role_name is not None
- and aws_session_name is not None
- ):
- print_verbose(
- f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
- )
-
- if aws_sts_endpoint is None:
- sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
- else:
- sts_endpoint = aws_sts_endpoint
-
- iam_creds_cache_key = json.dumps(
- {
- "aws_web_identity_token": aws_web_identity_token,
- "aws_role_name": aws_role_name,
- "aws_session_name": aws_session_name,
- "aws_region_name": aws_region_name,
- "aws_sts_endpoint": sts_endpoint,
- }
- )
-
- iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
- if iam_creds_dict is None:
- oidc_token = get_secret(aws_web_identity_token)
-
- if oidc_token is None:
- raise BedrockError(
- message="OIDC token could not be retrieved from secret manager.",
- status_code=401,
- )
-
- sts_client = boto3.client(
- "sts",
- region_name=aws_region_name,
- endpoint_url=sts_endpoint,
- )
-
- # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
- # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
- sts_response = sts_client.assume_role_with_web_identity(
- RoleArn=aws_role_name,
- RoleSessionName=aws_session_name,
- WebIdentityToken=oidc_token,
- DurationSeconds=3600,
- )
-
- iam_creds_dict = {
- "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
- "aws_secret_access_key": sts_response["Credentials"][
- "SecretAccessKey"
- ],
- "aws_session_token": sts_response["Credentials"]["SessionToken"],
- "region_name": aws_region_name,
- }
-
- iam_cache.set_cache(
- key=iam_creds_cache_key,
- value=json.dumps(iam_creds_dict),
- ttl=3600 - 60,
- )
-
- session = boto3.Session(**iam_creds_dict)
-
- iam_creds = session.get_credentials()
-
- return iam_creds
- elif aws_role_name is not None and aws_session_name is not None:
- sts_client = boto3.client(
- "sts",
- aws_access_key_id=aws_access_key_id, # [OPTIONAL]
- aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
- )
-
- sts_response = sts_client.assume_role(
- RoleArn=aws_role_name, RoleSessionName=aws_session_name
- )
-
- # Extract the credentials from the response and convert to Session Credentials
- sts_credentials = sts_response["Credentials"]
- from botocore.credentials import Credentials
-
- credentials = Credentials(
- access_key=sts_credentials["AccessKeyId"],
- secret_key=sts_credentials["SecretAccessKey"],
- token=sts_credentials["SessionToken"],
- )
- return credentials
- elif aws_profile_name is not None: ### CHECK SESSION ###
- # uses auth values from AWS profile usually stored in ~/.aws/credentials
- client = boto3.Session(profile_name=aws_profile_name)
-
- return client.get_credentials()
- elif (
- aws_access_key_id is not None
- and aws_secret_access_key is not None
- and aws_session_token is not None
- ): ### CHECK FOR AWS SESSION TOKEN ###
- from botocore.credentials import Credentials
-
- credentials = Credentials(
- access_key=aws_access_key_id,
- secret_key=aws_secret_access_key,
- token=aws_session_token,
- )
- return credentials
- else:
- session = boto3.Session(
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- region_name=aws_region_name,
- )
-
- return session.get_credentials()
-
async def async_streaming(
self,
model: str,
diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py
index 6b984e1d82..f699cf0f5f 100644
--- a/litellm/llms/ollama.py
+++ b/litellm/llms/ollama.py
@@ -601,12 +601,13 @@ def ollama_embeddings(
):
return asyncio.run(
ollama_aembeddings(
- api_base,
- model,
- prompts,
- optional_params,
- logging_obj,
- model_response,
- encoding,
+ api_base=api_base,
+ model=model,
+ prompts=prompts,
+ model_response=model_response,
+ optional_params=optional_params,
+ logging_obj=logging_obj,
+ encoding=encoding,
)
+
)
diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py
index b0dd5d905a..ea84fa95cf 100644
--- a/litellm/llms/ollama_chat.py
+++ b/litellm/llms/ollama_chat.py
@@ -356,6 +356,7 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
"json": data,
"method": "POST",
"timeout": litellm.request_timeout,
+ "follow_redirects": True
}
if api_key is not None:
_request["headers"] = {"Authorization": "Bearer {}".format(api_key)}
diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py
index 7c3c7e80fb..f81515e98d 100644
--- a/litellm/llms/prompt_templates/factory.py
+++ b/litellm/llms/prompt_templates/factory.py
@@ -1224,6 +1224,19 @@ def convert_to_anthropic_tool_invoke(
return anthropic_tool_invoke
+def add_cache_control_to_content(
+ anthropic_content_element: Union[
+ dict, AnthropicMessagesImageParam, AnthropicMessagesTextParam
+ ],
+ orignal_content_element: dict,
+):
+ if "cache_control" in orignal_content_element:
+ anthropic_content_element["cache_control"] = orignal_content_element[
+ "cache_control"
+ ]
+ return anthropic_content_element
+
+
def anthropic_messages_pt(
messages: list,
model: str,
@@ -1264,18 +1277,31 @@ def anthropic_messages_pt(
image_chunk = convert_to_anthropic_image_obj(
m["image_url"]["url"]
)
- user_content.append(
- AnthropicMessagesImageParam(
- type="image",
- source=AnthropicImageParamSource(
- type="base64",
- media_type=image_chunk["media_type"],
- data=image_chunk["data"],
- ),
- )
+
+ _anthropic_content_element = AnthropicMessagesImageParam(
+ type="image",
+ source=AnthropicImageParamSource(
+ type="base64",
+ media_type=image_chunk["media_type"],
+ data=image_chunk["data"],
+ ),
)
+
+ anthropic_content_element = add_cache_control_to_content(
+ anthropic_content_element=_anthropic_content_element,
+ orignal_content_element=m,
+ )
+ user_content.append(anthropic_content_element)
elif m.get("type", "") == "text":
- user_content.append({"type": "text", "text": m["text"]})
+ _anthropic_text_content_element = {
+ "type": "text",
+ "text": m["text"],
+ }
+ anthropic_content_element = add_cache_control_to_content(
+ anthropic_content_element=_anthropic_text_content_element,
+ orignal_content_element=m,
+ )
+ user_content.append(anthropic_content_element)
elif (
messages[msg_i]["role"] == "tool"
or messages[msg_i]["role"] == "function"
@@ -1306,6 +1332,10 @@ def anthropic_messages_pt(
anthropic_message = AnthropicMessagesTextParam(
type="text", text=m.get("text")
)
+ anthropic_message = add_cache_control_to_content(
+ anthropic_content_element=anthropic_message,
+ orignal_content_element=m,
+ )
assistant_content.append(anthropic_message)
elif (
"content" in messages[msg_i]
@@ -1313,9 +1343,17 @@ def anthropic_messages_pt(
and len(messages[msg_i]["content"])
> 0 # don't pass empty text blocks. anthropic api raises errors.
):
- assistant_content.append(
- {"type": "text", "text": messages[msg_i]["content"]}
+
+ _anthropic_text_content_element = {
+ "type": "text",
+ "text": messages[msg_i]["content"],
+ }
+
+ anthropic_content_element = add_cache_control_to_content(
+ anthropic_content_element=_anthropic_text_content_element,
+ orignal_content_element=messages[msg_i],
)
+ assistant_content.append(anthropic_content_element)
if messages[msg_i].get(
"tool_calls", []
@@ -1701,12 +1739,14 @@ def cohere_messages_pt_v2(
assistant_tool_calls: List[ToolCallObject] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
- assistant_text = (
- messages[msg_i].get("content") or ""
- ) # either string or none
- if assistant_text:
- assistant_content += assistant_text
-
+ if isinstance(messages[msg_i]["content"], list):
+ for m in messages[msg_i]["content"]:
+ if m.get("type", "") == "text":
+ assistant_content += m["text"]
+ elif messages[msg_i].get("content") is not None and isinstance(
+ messages[msg_i]["content"], str
+ ):
+ assistant_content += messages[msg_i]["content"]
if messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke conversion
diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py
index d16d2bd11b..32146b9cae 100644
--- a/litellm/llms/sagemaker.py
+++ b/litellm/llms/sagemaker.py
@@ -7,16 +7,38 @@ import traceback
import types
from copy import deepcopy
from enum import Enum
-from typing import Any, Callable, Optional
+from functools import partial
+from typing import Any, AsyncIterator, Callable, Iterator, List, Optional, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
-from litellm.utils import EmbeddingResponse, ModelResponse, Usage, get_secret
+from litellm._logging import verbose_logger
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ _get_async_httpx_client,
+ _get_httpx_client,
+)
+from litellm.types.llms.openai import (
+ ChatCompletionToolCallChunk,
+ ChatCompletionUsageBlock,
+)
+from litellm.types.utils import GenericStreamingChunk as GChunk
+from litellm.utils import (
+ CustomStreamWrapper,
+ EmbeddingResponse,
+ ModelResponse,
+ Usage,
+ get_secret,
+)
+from .base_aws_llm import BaseAWSLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
+_response_stream_shape_cache = None
+
class SagemakerError(Exception):
def __init__(self, status_code, message):
@@ -31,73 +53,6 @@ class SagemakerError(Exception):
) # Call the base class constructor with the parameters it needs
-class TokenIterator:
- def __init__(self, stream, acompletion: bool = False):
- if acompletion == False:
- self.byte_iterator = iter(stream)
- elif acompletion == True:
- self.byte_iterator = stream
- self.buffer = io.BytesIO()
- self.read_pos = 0
- self.end_of_data = False
-
- def __iter__(self):
- return self
-
- def __next__(self):
- try:
- while True:
- self.buffer.seek(self.read_pos)
- line = self.buffer.readline()
- if line and line[-1] == ord("\n"):
- response_obj = {"text": "", "is_finished": False}
- self.read_pos += len(line) + 1
- full_line = line[:-1].decode("utf-8")
- line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
- if line_data.get("generated_text", None) is not None:
- self.end_of_data = True
- response_obj["is_finished"] = True
- response_obj["text"] = line_data["token"]["text"]
- return response_obj
- chunk = next(self.byte_iterator)
- self.buffer.seek(0, io.SEEK_END)
- self.buffer.write(chunk["PayloadPart"]["Bytes"])
- except StopIteration as e:
- if self.end_of_data == True:
- raise e # Re-raise StopIteration
- else:
- self.end_of_data = True
- return "data: [DONE]"
-
- def __aiter__(self):
- return self
-
- async def __anext__(self):
- try:
- while True:
- self.buffer.seek(self.read_pos)
- line = self.buffer.readline()
- if line and line[-1] == ord("\n"):
- response_obj = {"text": "", "is_finished": False}
- self.read_pos += len(line) + 1
- full_line = line[:-1].decode("utf-8")
- line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
- if line_data.get("generated_text", None) is not None:
- self.end_of_data = True
- response_obj["is_finished"] = True
- response_obj["text"] = line_data["token"]["text"]
- return response_obj
- chunk = await self.byte_iterator.__anext__()
- self.buffer.seek(0, io.SEEK_END)
- self.buffer.write(chunk["PayloadPart"]["Bytes"])
- except StopAsyncIteration as e:
- if self.end_of_data == True:
- raise e # Re-raise StopIteration
- else:
- self.end_of_data = True
- return "data: [DONE]"
-
-
class SagemakerConfig:
"""
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
@@ -145,439 +100,498 @@ os.environ['AWS_ACCESS_KEY_ID'] = ""
os.environ['AWS_SECRET_ACCESS_KEY'] = ""
"""
+
# set os.environ['AWS_REGION_NAME'] =
+class SagemakerLLM(BaseAWSLLM):
+ def _load_credentials(
+ self,
+ optional_params: dict,
+ ):
+ try:
+ from botocore.credentials import Credentials
+ except ImportError as e:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+ ## CREDENTIALS ##
+ # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
+ aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
+ aws_access_key_id = optional_params.pop("aws_access_key_id", None)
+ aws_session_token = optional_params.pop("aws_session_token", None)
+ aws_region_name = optional_params.pop("aws_region_name", None)
+ aws_role_name = optional_params.pop("aws_role_name", None)
+ aws_session_name = optional_params.pop("aws_session_name", None)
+ aws_profile_name = optional_params.pop("aws_profile_name", None)
+ aws_bedrock_runtime_endpoint = optional_params.pop(
+ "aws_bedrock_runtime_endpoint", None
+ ) # https://bedrock-runtime.{region_name}.amazonaws.com
+ aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
+ aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
-def completion(
- model: str,
- messages: list,
- model_response: ModelResponse,
- print_verbose: Callable,
- encoding,
- logging_obj,
- custom_prompt_dict={},
- hf_model_name=None,
- optional_params=None,
- litellm_params=None,
- logger_fn=None,
- acompletion: bool = False,
-):
- import boto3
+ ### SET REGION NAME ###
+ if aws_region_name is None:
+ # check env #
+ litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
- # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
- aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
- aws_access_key_id = optional_params.pop("aws_access_key_id", None)
- aws_region_name = optional_params.pop("aws_region_name", None)
- model_id = optional_params.pop("model_id", None)
+ if litellm_aws_region_name is not None and isinstance(
+ litellm_aws_region_name, str
+ ):
+ aws_region_name = litellm_aws_region_name
- if aws_access_key_id != None:
- # uses auth params passed to completion
- # aws_access_key_id is not None, assume user is trying to auth using litellm.completion
- client = boto3.client(
- service_name="sagemaker-runtime",
+ standard_aws_region_name = get_secret("AWS_REGION", None)
+ if standard_aws_region_name is not None and isinstance(
+ standard_aws_region_name, str
+ ):
+ aws_region_name = standard_aws_region_name
+
+ if aws_region_name is None:
+ aws_region_name = "us-west-2"
+
+ credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
- region_name=aws_region_name,
+ aws_session_token=aws_session_token,
+ aws_region_name=aws_region_name,
+ aws_session_name=aws_session_name,
+ aws_profile_name=aws_profile_name,
+ aws_role_name=aws_role_name,
+ aws_web_identity_token=aws_web_identity_token,
+ aws_sts_endpoint=aws_sts_endpoint,
)
- else:
- # aws_access_key_id is None, assume user is trying to auth using env variables
- # boto3 automaticaly reads env variables
+ return credentials, aws_region_name
- # we need to read region name from env
- # I assume majority of users use .env for auth
- region_name = (
- get_secret("AWS_REGION_NAME")
- or aws_region_name # get region from config file if specified
- or "us-west-2" # default to us-west-2 if region not specified
- )
- client = boto3.client(
- service_name="sagemaker-runtime",
- region_name=region_name,
- )
+ def _prepare_request(
+ self,
+ credentials,
+ model: str,
+ data: dict,
+ optional_params: dict,
+ aws_region_name: str,
+ extra_headers: Optional[dict] = None,
+ ):
+ try:
+ import boto3
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ from botocore.credentials import Credentials
+ except ImportError as e:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
- # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
- inference_params = deepcopy(optional_params)
+ sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
+ if optional_params.get("stream") is True:
+ api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream"
+ else:
+ api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations"
- ## Load Config
- config = litellm.SagemakerConfig.get_config()
- for k, v in config.items():
- if (
- k not in inference_params
- ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
- inference_params[k] = v
+ encoded_data = json.dumps(data).encode("utf-8")
+ headers = {"Content-Type": "application/json"}
+ if extra_headers is not None:
+ headers = {"Content-Type": "application/json", **extra_headers}
+ request = AWSRequest(
+ method="POST", url=api_base, data=encoded_data, headers=headers
+ )
+ sigv4.add_auth(request)
+ prepped_request = request.prepare()
- model = model
- if model in custom_prompt_dict:
- # check if the model has a registered custom prompt
- model_prompt_details = custom_prompt_dict[model]
- prompt = custom_prompt(
- role_dict=model_prompt_details.get("roles", None),
- initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
- final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
- messages=messages,
- )
- elif hf_model_name in custom_prompt_dict:
- # check if the base huggingface model has a registered custom prompt
- model_prompt_details = custom_prompt_dict[hf_model_name]
- prompt = custom_prompt(
- role_dict=model_prompt_details.get("roles", None),
- initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
- final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
- messages=messages,
- )
- else:
- if hf_model_name is None:
- if "llama-2" in model.lower(): # llama-2 model
- if "chat" in model.lower(): # apply llama2 chat template
- hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
- else: # apply regular llama2 template
- hf_model_name = "meta-llama/Llama-2-7b"
- hf_model_name = (
- hf_model_name or model
- ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
- prompt = prompt_factory(model=hf_model_name, messages=messages)
- stream = inference_params.pop("stream", None)
- if stream == True:
- data = json.dumps(
- {"inputs": prompt, "parameters": inference_params, "stream": True}
- ).encode("utf-8")
- if acompletion == True:
- response = async_streaming(
- optional_params=optional_params,
- encoding=encoding,
- model_response=model_response,
+ return prepped_request
+
+ def completion(
+ self,
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ logging_obj,
+ custom_prompt_dict={},
+ hf_model_name=None,
+ optional_params=None,
+ litellm_params=None,
+ logger_fn=None,
+ acompletion: bool = False,
+ ):
+
+ # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
+ credentials, aws_region_name = self._load_credentials(optional_params)
+ inference_params = deepcopy(optional_params)
+
+ ## Load Config
+ config = litellm.SagemakerConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+
+ if model in custom_prompt_dict:
+ # check if the model has a registered custom prompt
+ model_prompt_details = custom_prompt_dict[model]
+ prompt = custom_prompt(
+ role_dict=model_prompt_details.get("roles", None),
+ initial_prompt_value=model_prompt_details.get(
+ "initial_prompt_value", ""
+ ),
+ final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
+ messages=messages,
+ )
+ elif hf_model_name in custom_prompt_dict:
+ # check if the base huggingface model has a registered custom prompt
+ model_prompt_details = custom_prompt_dict[hf_model_name]
+ prompt = custom_prompt(
+ role_dict=model_prompt_details.get("roles", None),
+ initial_prompt_value=model_prompt_details.get(
+ "initial_prompt_value", ""
+ ),
+ final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
+ messages=messages,
+ )
+ else:
+ if hf_model_name is None:
+ if "llama-2" in model.lower(): # llama-2 model
+ if "chat" in model.lower(): # apply llama2 chat template
+ hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
+ else: # apply regular llama2 template
+ hf_model_name = "meta-llama/Llama-2-7b"
+ hf_model_name = (
+ hf_model_name or model
+ ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
+ prompt = prompt_factory(model=hf_model_name, messages=messages)
+ stream = inference_params.pop("stream", None)
+ model_id = optional_params.get("model_id", None)
+
+ if stream is True:
+ data = {"inputs": prompt, "parameters": inference_params, "stream": True}
+ prepared_request = self._prepare_request(
model=model,
- logging_obj=logging_obj,
data=data,
- model_id=model_id,
- aws_secret_access_key=aws_secret_access_key,
- aws_access_key_id=aws_access_key_id,
+ optional_params=optional_params,
+ credentials=credentials,
aws_region_name=aws_region_name,
)
- return response
+ if model_id is not None:
+ # Add model_id as InferenceComponentName header
+ # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
+ prepared_request.headers.update(
+ {"X-Amzn-SageMaker-Inference-Componen": model_id}
+ )
- if model_id is not None:
- response = client.invoke_endpoint_with_response_stream(
- EndpointName=model,
- InferenceComponentName=model_id,
- ContentType="application/json",
- Body=data,
- CustomAttributes="accept_eula=true",
+ if acompletion is True:
+ response = self.async_streaming(
+ prepared_request=prepared_request,
+ optional_params=optional_params,
+ encoding=encoding,
+ model_response=model_response,
+ model=model,
+ logging_obj=logging_obj,
+ data=data,
+ model_id=model_id,
+ )
+ return response
+ else:
+ if stream is not None and stream == True:
+ sync_handler = _get_httpx_client()
+ sync_response = sync_handler.post(
+ url=prepared_request.url,
+ headers=prepared_request.headers, # type: ignore
+ json=data,
+ stream=stream,
+ )
+
+ if sync_response.status_code != 200:
+ raise SagemakerError(
+ status_code=sync_response.status_code,
+ message=sync_response.read(),
+ )
+
+ decoder = AWSEventStreamDecoder(model="")
+
+ completion_stream = decoder.iter_bytes(
+ sync_response.iter_bytes(chunk_size=1024)
+ )
+ streaming_response = CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider="sagemaker",
+ logging_obj=logging_obj,
+ )
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=streaming_response,
+ additional_args={"complete_input_dict": data},
)
- else:
- response = client.invoke_endpoint_with_response_stream(
- EndpointName=model,
- ContentType="application/json",
- Body=data,
- CustomAttributes="accept_eula=true",
- )
- return response["Body"]
- elif acompletion == True:
+ return streaming_response
+
+ # Non-Streaming Requests
_data = {"inputs": prompt, "parameters": inference_params}
- return async_completion(
- optional_params=optional_params,
- encoding=encoding,
- model_response=model_response,
+ prepared_request = self._prepare_request(
model=model,
- logging_obj=logging_obj,
data=_data,
- model_id=model_id,
- aws_secret_access_key=aws_secret_access_key,
- aws_access_key_id=aws_access_key_id,
+ optional_params=optional_params,
+ credentials=credentials,
aws_region_name=aws_region_name,
)
- data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
- "utf-8"
- )
- ## COMPLETION CALL
- try:
- if model_id is not None:
- ## LOGGING
- request_str = f"""
- response = client.invoke_endpoint(
- EndpointName={model},
- InferenceComponentName={model_id},
- ContentType="application/json",
- Body={data}, # type: ignore
- CustomAttributes="accept_eula=true",
+
+ # Async completion
+ if acompletion == True:
+ return self.async_completion(
+ prepared_request=prepared_request,
+ model_response=model_response,
+ encoding=encoding,
+ model=model,
+ logging_obj=logging_obj,
+ data=_data,
+ model_id=model_id,
)
- """ # type: ignore
- logging_obj.pre_call(
- input=prompt,
- api_key="",
- additional_args={
- "complete_input_dict": data,
- "request_str": request_str,
- "hf_model_name": hf_model_name,
- },
- )
- response = client.invoke_endpoint(
- EndpointName=model,
- InferenceComponentName=model_id,
- ContentType="application/json",
- Body=data,
- CustomAttributes="accept_eula=true",
- )
- else:
- ## LOGGING
- request_str = f"""
- response = client.invoke_endpoint(
- EndpointName={model},
- ContentType="application/json",
- Body={data}, # type: ignore
- CustomAttributes="accept_eula=true",
- )
- """ # type: ignore
- logging_obj.pre_call(
- input=prompt,
- api_key="",
- additional_args={
- "complete_input_dict": data,
- "request_str": request_str,
- "hf_model_name": hf_model_name,
- },
- )
- response = client.invoke_endpoint(
- EndpointName=model,
- ContentType="application/json",
- Body=data,
- CustomAttributes="accept_eula=true",
- )
- except Exception as e:
- status_code = (
- getattr(e, "response", {})
- .get("ResponseMetadata", {})
- .get("HTTPStatusCode", 500)
- )
- error_message = (
- getattr(e, "response", {}).get("Error", {}).get("Message", str(e))
- )
- if "Inference Component Name header is required" in error_message:
- error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`"
- raise SagemakerError(status_code=status_code, message=error_message)
-
- response = response["Body"].read().decode("utf8")
- ## LOGGING
- logging_obj.post_call(
- input=prompt,
- api_key="",
- original_response=response,
- additional_args={"complete_input_dict": data},
- )
- print_verbose(f"raw model_response: {response}")
- ## RESPONSE OBJECT
- completion_response = json.loads(response)
- try:
- if isinstance(completion_response, list):
- completion_response_choices = completion_response[0]
- else:
- completion_response_choices = completion_response
- completion_output = ""
- if "generation" in completion_response_choices:
- completion_output += completion_response_choices["generation"]
- elif "generated_text" in completion_response_choices:
- completion_output += completion_response_choices["generated_text"]
-
- # check if the prompt template is part of output, if so - filter it out
- if completion_output.startswith(prompt) and "" in prompt:
- completion_output = completion_output.replace(prompt, "", 1)
-
- model_response.choices[0].message.content = completion_output # type: ignore
- except:
- raise SagemakerError(
- message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
- status_code=500,
- )
-
- ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
- prompt_tokens = len(encoding.encode(prompt))
- completion_tokens = len(
- encoding.encode(model_response["choices"][0]["message"].get("content", ""))
- )
-
- model_response.created = int(time.time())
- model_response.model = model
- usage = Usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=prompt_tokens + completion_tokens,
- )
- setattr(model_response, "usage", usage)
- return model_response
-
-
-async def async_streaming(
- optional_params,
- encoding,
- model_response: ModelResponse,
- model: str,
- model_id: Optional[str],
- logging_obj: Any,
- data,
- aws_secret_access_key: Optional[str],
- aws_access_key_id: Optional[str],
- aws_region_name: Optional[str],
-):
- """
- Use aioboto3
- """
- import aioboto3
-
- session = aioboto3.Session()
-
- if aws_access_key_id != None:
- # uses auth params passed to completion
- # aws_access_key_id is not None, assume user is trying to auth using litellm.completion
- _client = session.client(
- service_name="sagemaker-runtime",
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- region_name=aws_region_name,
- )
- else:
- # aws_access_key_id is None, assume user is trying to auth using env variables
- # boto3 automaticaly reads env variables
-
- # we need to read region name from env
- # I assume majority of users use .env for auth
- region_name = (
- get_secret("AWS_REGION_NAME")
- or aws_region_name # get region from config file if specified
- or "us-west-2" # default to us-west-2 if region not specified
- )
- _client = session.client(
- service_name="sagemaker-runtime",
- region_name=region_name,
- )
-
- async with _client as client:
+ ## Non-Streaming completion CALL
try:
if model_id is not None:
- response = await client.invoke_endpoint_with_response_stream(
- EndpointName=model,
- InferenceComponentName=model_id,
- ContentType="application/json",
- Body=data,
- CustomAttributes="accept_eula=true",
+ # Add model_id as InferenceComponentName header
+ # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
+ prepared_request.headers.update(
+ {"X-Amzn-SageMaker-Inference-Componen": model_id}
)
- else:
- response = await client.invoke_endpoint_with_response_stream(
- EndpointName=model,
- ContentType="application/json",
- Body=data,
- CustomAttributes="accept_eula=true",
+
+ ## LOGGING
+ timeout = 300.0
+ sync_handler = _get_httpx_client()
+ ## LOGGING
+ logging_obj.pre_call(
+ input=[],
+ api_key="",
+ additional_args={
+ "complete_input_dict": _data,
+ "api_base": prepared_request.url,
+ "headers": prepared_request.headers,
+ },
+ )
+
+ # make sync httpx post request here
+ try:
+ sync_response = sync_handler.post(
+ url=prepared_request.url,
+ headers=prepared_request.headers,
+ json=_data,
+ timeout=timeout,
)
+
+ if sync_response.status_code != 200:
+ raise SagemakerError(
+ status_code=sync_response.status_code,
+ message=sync_response.text,
+ )
+ except Exception as e:
+ ## LOGGING
+ logging_obj.post_call(
+ input=[],
+ api_key="",
+ original_response=str(e),
+ additional_args={"complete_input_dict": _data},
+ )
+ raise e
except Exception as e:
- raise SagemakerError(status_code=500, message=f"{str(e)}")
- response = response["Body"]
- async for chunk in response:
- yield chunk
+ verbose_logger.error("Sagemaker error %s", str(e))
+ status_code = (
+ getattr(e, "response", {})
+ .get("ResponseMetadata", {})
+ .get("HTTPStatusCode", 500)
+ )
+ error_message = (
+ getattr(e, "response", {}).get("Error", {}).get("Message", str(e))
+ )
+ if "Inference Component Name header is required" in error_message:
+ error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`"
+ raise SagemakerError(status_code=status_code, message=error_message)
-
-async def async_completion(
- optional_params,
- encoding,
- model_response: ModelResponse,
- model: str,
- logging_obj: Any,
- data: dict,
- model_id: Optional[str],
- aws_secret_access_key: Optional[str],
- aws_access_key_id: Optional[str],
- aws_region_name: Optional[str],
-):
- """
- Use aioboto3
- """
- import aioboto3
-
- session = aioboto3.Session()
-
- if aws_access_key_id != None:
- # uses auth params passed to completion
- # aws_access_key_id is not None, assume user is trying to auth using litellm.completion
- _client = session.client(
- service_name="sagemaker-runtime",
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- region_name=aws_region_name,
+ completion_response = sync_response.json()
+ ## LOGGING
+ logging_obj.post_call(
+ input=prompt,
+ api_key="",
+ original_response=completion_response,
+ additional_args={"complete_input_dict": _data},
)
- else:
- # aws_access_key_id is None, assume user is trying to auth using env variables
- # boto3 automaticaly reads env variables
+ print_verbose(f"raw model_response: {completion_response}")
+ ## RESPONSE OBJECT
+ try:
+ if isinstance(completion_response, list):
+ completion_response_choices = completion_response[0]
+ else:
+ completion_response_choices = completion_response
+ completion_output = ""
+ if "generation" in completion_response_choices:
+ completion_output += completion_response_choices["generation"]
+ elif "generated_text" in completion_response_choices:
+ completion_output += completion_response_choices["generated_text"]
- # we need to read region name from env
- # I assume majority of users use .env for auth
- region_name = (
- get_secret("AWS_REGION_NAME")
- or aws_region_name # get region from config file if specified
- or "us-west-2" # default to us-west-2 if region not specified
- )
- _client = session.client(
- service_name="sagemaker-runtime",
- region_name=region_name,
+ # check if the prompt template is part of output, if so - filter it out
+ if completion_output.startswith(prompt) and "" in prompt:
+ completion_output = completion_output.replace(prompt, "", 1)
+
+ model_response.choices[0].message.content = completion_output # type: ignore
+ except:
+ raise SagemakerError(
+ message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
+ status_code=500,
+ )
+
+ ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
+ prompt_tokens = len(encoding.encode(prompt))
+ completion_tokens = len(
+ encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
- async with _client as client:
- encoded_data = json.dumps(data).encode("utf-8")
+ model_response.created = int(time.time())
+ model_response.model = model
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+ setattr(model_response, "usage", usage)
+ return model_response
+
+ async def make_async_call(
+ self,
+ api_base: str,
+ headers: dict,
+ data: str,
+ logging_obj,
+ client=None,
+ ):
+ try:
+ if client is None:
+ client = (
+ _get_async_httpx_client()
+ ) # Create a new client if none provided
+ response = await client.post(
+ api_base,
+ headers=headers,
+ json=data,
+ stream=True,
+ )
+
+ if response.status_code != 200:
+ raise SagemakerError(
+ status_code=response.status_code, message=response.text
+ )
+
+ decoder = AWSEventStreamDecoder(model="")
+ completion_stream = decoder.aiter_bytes(
+ response.aiter_bytes(chunk_size=1024)
+ )
+
+ return completion_stream
+
+ # LOGGING
+ logging_obj.post_call(
+ input=[],
+ api_key="",
+ original_response="first stream response received",
+ additional_args={"complete_input_dict": data},
+ )
+
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise SagemakerError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException as e:
+ raise SagemakerError(status_code=408, message="Timeout error occurred.")
+ except Exception as e:
+ raise SagemakerError(status_code=500, message=str(e))
+
+ async def async_streaming(
+ self,
+ prepared_request,
+ optional_params,
+ encoding,
+ model_response: ModelResponse,
+ model: str,
+ model_id: Optional[str],
+ logging_obj: Any,
+ data,
+ ):
+ streaming_response = CustomStreamWrapper(
+ completion_stream=None,
+ make_call=partial(
+ self.make_async_call,
+ api_base=prepared_request.url,
+ headers=prepared_request.headers,
+ data=data,
+ logging_obj=logging_obj,
+ ),
+ model=model,
+ custom_llm_provider="sagemaker",
+ logging_obj=logging_obj,
+ )
+
+ # LOGGING
+ logging_obj.post_call(
+ input=[],
+ api_key="",
+ original_response="first stream response received",
+ additional_args={"complete_input_dict": data},
+ )
+
+ return streaming_response
+
+ async def async_completion(
+ self,
+ prepared_request,
+ encoding,
+ model_response: ModelResponse,
+ model: str,
+ logging_obj: Any,
+ data: dict,
+ model_id: Optional[str],
+ ):
+ timeout = 300.0
+ async_handler = _get_async_httpx_client()
+ ## LOGGING
+ logging_obj.pre_call(
+ input=[],
+ api_key="",
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": prepared_request.url,
+ "headers": prepared_request.headers,
+ },
+ )
try:
if model_id is not None:
- ## LOGGING
- request_str = f"""
- response = client.invoke_endpoint(
- EndpointName={model},
- InferenceComponentName={model_id},
- ContentType="application/json",
- Body={data},
- CustomAttributes="accept_eula=true",
+ # Add model_id as InferenceComponentName header
+ # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
+ prepared_request.headers.update(
+ {"X-Amzn-SageMaker-Inference-Componen": model_id}
)
- """ # type: ignore
- logging_obj.pre_call(
+ # make async httpx post request here
+ try:
+ response = await async_handler.post(
+ url=prepared_request.url,
+ headers=prepared_request.headers,
+ json=data,
+ timeout=timeout,
+ )
+
+ if response.status_code != 200:
+ raise SagemakerError(
+ status_code=response.status_code, message=response.text
+ )
+ except Exception as e:
+ ## LOGGING
+ logging_obj.post_call(
input=data["inputs"],
api_key="",
- additional_args={
- "complete_input_dict": data,
- "request_str": request_str,
- },
- )
- response = await client.invoke_endpoint(
- EndpointName=model,
- InferenceComponentName=model_id,
- ContentType="application/json",
- Body=encoded_data,
- CustomAttributes="accept_eula=true",
- )
- else:
- ## LOGGING
- request_str = f"""
- response = client.invoke_endpoint(
- EndpointName={model},
- ContentType="application/json",
- Body={data},
- CustomAttributes="accept_eula=true",
- )
- """ # type: ignore
- logging_obj.pre_call(
- input=data["inputs"],
- api_key="",
- additional_args={
- "complete_input_dict": data,
- "request_str": request_str,
- },
- )
- response = await client.invoke_endpoint(
- EndpointName=model,
- ContentType="application/json",
- Body=encoded_data,
- CustomAttributes="accept_eula=true",
+ original_response=str(e),
+ additional_args={"complete_input_dict": data},
)
+ raise e
except Exception as e:
error_message = f"{str(e)}"
if "Inference Component Name header is required" in error_message:
error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`"
raise SagemakerError(status_code=500, message=error_message)
- response = await response["Body"].read()
- response = response.decode("utf8")
+ completion_response = response.json()
## LOGGING
logging_obj.post_call(
input=data["inputs"],
@@ -586,7 +600,6 @@ async def async_completion(
additional_args={"complete_input_dict": data},
)
## RESPONSE OBJECT
- completion_response = json.loads(response)
try:
if isinstance(completion_response, list):
completion_response_choices = completion_response[0]
@@ -625,141 +638,296 @@ async def async_completion(
setattr(model_response, "usage", usage)
return model_response
+ def embedding(
+ self,
+ model: str,
+ input: list,
+ model_response: EmbeddingResponse,
+ print_verbose: Callable,
+ encoding,
+ logging_obj,
+ custom_prompt_dict={},
+ optional_params=None,
+ litellm_params=None,
+ logger_fn=None,
+ ):
+ """
+ Supports Huggingface Jumpstart embeddings like GPT-6B
+ """
+ ### BOTO3 INIT
+ import boto3
-def embedding(
- model: str,
- input: list,
- model_response: EmbeddingResponse,
- print_verbose: Callable,
- encoding,
- logging_obj,
- custom_prompt_dict={},
- optional_params=None,
- litellm_params=None,
- logger_fn=None,
-):
- """
- Supports Huggingface Jumpstart embeddings like GPT-6B
- """
- ### BOTO3 INIT
- import boto3
+ # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
+ aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
+ aws_access_key_id = optional_params.pop("aws_access_key_id", None)
+ aws_region_name = optional_params.pop("aws_region_name", None)
- # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
- aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
- aws_access_key_id = optional_params.pop("aws_access_key_id", None)
- aws_region_name = optional_params.pop("aws_region_name", None)
+ if aws_access_key_id is not None:
+ # uses auth params passed to completion
+ # aws_access_key_id is not None, assume user is trying to auth using litellm.completion
+ client = boto3.client(
+ service_name="sagemaker-runtime",
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=aws_region_name,
+ )
+ else:
+ # aws_access_key_id is None, assume user is trying to auth using env variables
+ # boto3 automaticaly reads env variables
- if aws_access_key_id is not None:
- # uses auth params passed to completion
- # aws_access_key_id is not None, assume user is trying to auth using litellm.completion
- client = boto3.client(
- service_name="sagemaker-runtime",
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- region_name=aws_region_name,
- )
- else:
- # aws_access_key_id is None, assume user is trying to auth using env variables
- # boto3 automaticaly reads env variables
+ # we need to read region name from env
+ # I assume majority of users use .env for auth
+ region_name = (
+ get_secret("AWS_REGION_NAME")
+ or aws_region_name # get region from config file if specified
+ or "us-west-2" # default to us-west-2 if region not specified
+ )
+ client = boto3.client(
+ service_name="sagemaker-runtime",
+ region_name=region_name,
+ )
- # we need to read region name from env
- # I assume majority of users use .env for auth
- region_name = (
- get_secret("AWS_REGION_NAME")
- or aws_region_name # get region from config file if specified
- or "us-west-2" # default to us-west-2 if region not specified
- )
- client = boto3.client(
- service_name="sagemaker-runtime",
- region_name=region_name,
- )
+ # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
+ inference_params = deepcopy(optional_params)
+ inference_params.pop("stream", None)
- # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
- inference_params = deepcopy(optional_params)
- inference_params.pop("stream", None)
+ ## Load Config
+ config = litellm.SagemakerConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
- ## Load Config
- config = litellm.SagemakerConfig.get_config()
- for k, v in config.items():
- if (
- k not in inference_params
- ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
- inference_params[k] = v
+ #### HF EMBEDDING LOGIC
+ data = json.dumps({"text_inputs": input}).encode("utf-8")
- #### HF EMBEDDING LOGIC
- data = json.dumps({"text_inputs": input}).encode("utf-8")
-
- ## LOGGING
- request_str = f"""
- response = client.invoke_endpoint(
- EndpointName={model},
- ContentType="application/json",
- Body={data}, # type: ignore
- CustomAttributes="accept_eula=true",
- )""" # type: ignore
- logging_obj.pre_call(
- input=input,
- api_key="",
- additional_args={"complete_input_dict": data, "request_str": request_str},
- )
- ## EMBEDDING CALL
- try:
+ ## LOGGING
+ request_str = f"""
response = client.invoke_endpoint(
- EndpointName=model,
+ EndpointName={model},
ContentType="application/json",
- Body=data,
+ Body={data}, # type: ignore
CustomAttributes="accept_eula=true",
+ )""" # type: ignore
+ logging_obj.pre_call(
+ input=input,
+ api_key="",
+ additional_args={"complete_input_dict": data, "request_str": request_str},
)
- except Exception as e:
- status_code = (
- getattr(e, "response", {})
- .get("ResponseMetadata", {})
- .get("HTTPStatusCode", 500)
- )
- error_message = (
- getattr(e, "response", {}).get("Error", {}).get("Message", str(e))
- )
- raise SagemakerError(status_code=status_code, message=error_message)
+ ## EMBEDDING CALL
+ try:
+ response = client.invoke_endpoint(
+ EndpointName=model,
+ ContentType="application/json",
+ Body=data,
+ CustomAttributes="accept_eula=true",
+ )
+ except Exception as e:
+ status_code = (
+ getattr(e, "response", {})
+ .get("ResponseMetadata", {})
+ .get("HTTPStatusCode", 500)
+ )
+ error_message = (
+ getattr(e, "response", {}).get("Error", {}).get("Message", str(e))
+ )
+ raise SagemakerError(status_code=status_code, message=error_message)
- response = json.loads(response["Body"].read().decode("utf8"))
- ## LOGGING
- logging_obj.post_call(
- input=input,
- api_key="",
- original_response=response,
- additional_args={"complete_input_dict": data},
- )
-
- print_verbose(f"raw model_response: {response}")
- if "embedding" not in response:
- raise SagemakerError(status_code=500, message="embedding not found in response")
- embeddings = response["embedding"]
-
- if not isinstance(embeddings, list):
- raise SagemakerError(
- status_code=422, message=f"Response not in expected format - {embeddings}"
+ response = json.loads(response["Body"].read().decode("utf8"))
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key="",
+ original_response=response,
+ additional_args={"complete_input_dict": data},
)
- output_data = []
- for idx, embedding in enumerate(embeddings):
- output_data.append(
- {"object": "embedding", "index": idx, "embedding": embedding}
+ print_verbose(f"raw model_response: {response}")
+ if "embedding" not in response:
+ raise SagemakerError(
+ status_code=500, message="embedding not found in response"
+ )
+ embeddings = response["embedding"]
+
+ if not isinstance(embeddings, list):
+ raise SagemakerError(
+ status_code=422,
+ message=f"Response not in expected format - {embeddings}",
+ )
+
+ output_data = []
+ for idx, embedding in enumerate(embeddings):
+ output_data.append(
+ {"object": "embedding", "index": idx, "embedding": embedding}
+ )
+
+ model_response.object = "list"
+ model_response.data = output_data
+ model_response.model = model
+
+ input_tokens = 0
+ for text in input:
+ input_tokens += len(encoding.encode(text))
+
+ setattr(
+ model_response,
+ "usage",
+ Usage(
+ prompt_tokens=input_tokens,
+ completion_tokens=0,
+ total_tokens=input_tokens,
+ ),
)
- model_response.object = "list"
- model_response.data = output_data
- model_response.model = model
+ return model_response
- input_tokens = 0
- for text in input:
- input_tokens += len(encoding.encode(text))
- setattr(
- model_response,
- "usage",
- Usage(
- prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
- ),
- )
+def get_response_stream_shape():
+ global _response_stream_shape_cache
+ if _response_stream_shape_cache is None:
- return model_response
+ from botocore.loaders import Loader
+ from botocore.model import ServiceModel
+
+ loader = Loader()
+ sagemaker_service_dict = loader.load_service_model(
+ "sagemaker-runtime", "service-2"
+ )
+ sagemaker_service_model = ServiceModel(sagemaker_service_dict)
+ _response_stream_shape_cache = sagemaker_service_model.shape_for(
+ "InvokeEndpointWithResponseStreamOutput"
+ )
+ return _response_stream_shape_cache
+
+
+class AWSEventStreamDecoder:
+ def __init__(self, model: str) -> None:
+ from botocore.parsers import EventStreamJSONParser
+
+ self.model = model
+ self.parser = EventStreamJSONParser()
+ self.content_blocks: List = []
+
+ def _chunk_parser(self, chunk_data: dict) -> GChunk:
+ verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
+ _token = chunk_data.get("token", {}) or {}
+ _index = chunk_data.get("index", None) or 0
+ is_finished = False
+ finish_reason = ""
+
+ _text = _token.get("text", "")
+ if _text == "<|endoftext|>":
+ return GChunk(
+ text="",
+ index=_index,
+ is_finished=True,
+ finish_reason="stop",
+ )
+
+ return GChunk(
+ text=_text,
+ index=_index,
+ is_finished=is_finished,
+ finish_reason=finish_reason,
+ )
+
+ def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
+ """Given an iterator that yields lines, iterate over it & yield every event encountered"""
+ from botocore.eventstream import EventStreamBuffer
+
+ event_stream_buffer = EventStreamBuffer()
+ accumulated_json = ""
+
+ for chunk in iterator:
+ event_stream_buffer.add_data(chunk)
+ for event in event_stream_buffer:
+ message = self._parse_message_from_event(event)
+ if message:
+ # remove data: prefix and "\n\n" at the end
+ message = message.replace("data:", "").replace("\n\n", "")
+
+ # Accumulate JSON data
+ accumulated_json += message
+
+ # Try to parse the accumulated JSON
+ try:
+ _data = json.loads(accumulated_json)
+ yield self._chunk_parser(chunk_data=_data)
+ # Reset accumulated_json after successful parsing
+ accumulated_json = ""
+ except json.JSONDecodeError:
+ # If it's not valid JSON yet, continue to the next event
+ continue
+
+ # Handle any remaining data after the iterator is exhausted
+ if accumulated_json:
+ try:
+ _data = json.loads(accumulated_json)
+ yield self._chunk_parser(chunk_data=_data)
+ except json.JSONDecodeError:
+ # Handle or log any unparseable data at the end
+ verbose_logger.error(
+ f"Warning: Unparseable JSON data remained: {accumulated_json}"
+ )
+
+ async def aiter_bytes(
+ self, iterator: AsyncIterator[bytes]
+ ) -> AsyncIterator[GChunk]:
+ """Given an async iterator that yields lines, iterate over it & yield every event encountered"""
+ from botocore.eventstream import EventStreamBuffer
+
+ event_stream_buffer = EventStreamBuffer()
+ accumulated_json = ""
+
+ async for chunk in iterator:
+ event_stream_buffer.add_data(chunk)
+ for event in event_stream_buffer:
+ message = self._parse_message_from_event(event)
+ if message:
+ verbose_logger.debug("sagemaker parsed chunk bytes %s", message)
+ # remove data: prefix and "\n\n" at the end
+ message = message.replace("data:", "").replace("\n\n", "")
+
+ # Accumulate JSON data
+ accumulated_json += message
+
+ # Try to parse the accumulated JSON
+ try:
+ _data = json.loads(accumulated_json)
+ yield self._chunk_parser(chunk_data=_data)
+ # Reset accumulated_json after successful parsing
+ accumulated_json = ""
+ except json.JSONDecodeError:
+ # If it's not valid JSON yet, continue to the next event
+ continue
+
+ # Handle any remaining data after the iterator is exhausted
+ if accumulated_json:
+ try:
+ _data = json.loads(accumulated_json)
+ yield self._chunk_parser(chunk_data=_data)
+ except json.JSONDecodeError:
+ # Handle or log any unparseable data at the end
+ verbose_logger.error(
+ f"Warning: Unparseable JSON data remained: {accumulated_json}"
+ )
+
+ def _parse_message_from_event(self, event) -> Optional[str]:
+ response_dict = event.to_response_dict()
+ parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
+
+ if response_dict["status_code"] != 200:
+ raise ValueError(f"Bad response code, expected 200: {response_dict}")
+
+ if "chunk" in parsed_response:
+ chunk = parsed_response.get("chunk")
+ if not chunk:
+ return None
+ return chunk.get("bytes").decode() # type: ignore[no-any-return]
+ else:
+ chunk = response_dict.get("body")
+ if not chunk:
+ return None
+
+ return chunk.decode() # type: ignore[no-any-return]
diff --git a/litellm/llms/triton.py b/litellm/llms/triton.py
index 7d0338d069..14a2e828b4 100644
--- a/litellm/llms/triton.py
+++ b/litellm/llms/triton.py
@@ -240,10 +240,10 @@ class TritonChatCompletion(BaseLLM):
handler = HTTPHandler()
if stream:
return self._handle_stream(
- handler, api_base, data_for_triton, model, logging_obj
+ handler, api_base, json_data_for_triton, model, logging_obj
)
else:
- response = handler.post(url=api_base, data=data_for_triton, headers=headers)
+ response = handler.post(url=api_base, data=json_data_for_triton, headers=headers)
return self._handle_response(
response, model_response, logging_obj, type_of_model=type_of_model
)
diff --git a/litellm/main.py b/litellm/main.py
index d7a3ca996d..cf7a4a5e7e 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -95,7 +95,6 @@ from .llms import (
palm,
petals,
replicate,
- sagemaker,
together_ai,
triton,
vertex_ai,
@@ -120,6 +119,7 @@ from .llms.prompt_templates.factory import (
prompt_factory,
stringify_json_tool_call_content,
)
+from .llms.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_partner import VertexAIPartnerModels
@@ -166,6 +166,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
vertex_partner_models_chat_completion = VertexAIPartnerModels()
watsonxai = IBMWatsonXAI()
+sagemaker_llm = SagemakerLLM()
####### COMPLETION ENDPOINTS ################
@@ -2216,7 +2217,7 @@ def completion(
response = model_response
elif custom_llm_provider == "sagemaker":
# boto3 reads keys from .env
- model_response = sagemaker.completion(
+ model_response = sagemaker_llm.completion(
model=model,
messages=messages,
model_response=model_response,
@@ -2230,26 +2231,13 @@ def completion(
logging_obj=logging,
acompletion=acompletion,
)
- if (
- "stream" in optional_params and optional_params["stream"] == True
- ): ## [BETA]
- print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
- from .llms.sagemaker import TokenIterator
-
- tokenIterator = TokenIterator(model_response, acompletion=acompletion)
- response = CustomStreamWrapper(
- completion_stream=tokenIterator,
- model=model,
- custom_llm_provider="sagemaker",
- logging_obj=logging,
- )
+ if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
- original_response=response,
+ original_response=model_response,
)
- return response
## RESPONSE OBJECT
response = model_response
@@ -3529,7 +3517,7 @@ def embedding(
model_response=EmbeddingResponse(),
)
elif custom_llm_provider == "sagemaker":
- response = sagemaker.embedding(
+ response = sagemaker_llm.embedding(
model=model,
input=input,
encoding=encoding,
@@ -4898,7 +4886,6 @@ async def ahealth_check(
verbose_logger.error(
"litellm.ahealth_check(): Exception occured - {}".format(str(e))
)
- verbose_logger.debug(traceback.format_exc())
stack_trace = traceback.format_exc()
if isinstance(stack_trace, str):
stack_trace = stack_trace[:1000]
@@ -4907,7 +4894,12 @@ async def ahealth_check(
"error": "Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models"
}
- error_to_return = str(e) + " stack trace: " + stack_trace
+ error_to_return = (
+ str(e)
+ + "\nHave you set 'mode' - https://docs.litellm.ai/docs/proxy/health#embedding-models"
+ + "\nstack trace: "
+ + stack_trace
+ )
return {"error": error_to_return}
diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json
index 455fe1e3c5..d30270c5c8 100644
--- a/litellm/model_prices_and_context_window_backup.json
+++ b/litellm/model_prices_and_context_window_backup.json
@@ -57,6 +57,18 @@
"supports_parallel_function_calling": true,
"supports_vision": true
},
+ "chatgpt-4o-latest": {
+ "max_tokens": 4096,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000005,
+ "output_cost_per_token": 0.000015,
+ "litellm_provider": "openai",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_parallel_function_calling": true,
+ "supports_vision": true
+ },
"gpt-4o-2024-05-13": {
"max_tokens": 4096,
"max_input_tokens": 128000,
@@ -2062,7 +2074,8 @@
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat",
"supports_function_calling": true,
- "supports_vision": true
+ "supports_vision": true,
+ "supports_assistant_prefill": true
},
"vertex_ai/claude-3-5-sonnet@20240620": {
"max_tokens": 4096,
@@ -2073,7 +2086,8 @@
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat",
"supports_function_calling": true,
- "supports_vision": true
+ "supports_vision": true,
+ "supports_assistant_prefill": true
},
"vertex_ai/claude-3-haiku@20240307": {
"max_tokens": 4096,
@@ -2084,7 +2098,8 @@
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat",
"supports_function_calling": true,
- "supports_vision": true
+ "supports_vision": true,
+ "supports_assistant_prefill": true
},
"vertex_ai/claude-3-opus@20240229": {
"max_tokens": 4096,
@@ -2095,7 +2110,8 @@
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat",
"supports_function_calling": true,
- "supports_vision": true
+ "supports_vision": true,
+ "supports_assistant_prefill": true
},
"vertex_ai/meta/llama3-405b-instruct-maas": {
"max_tokens": 32000,
@@ -4519,6 +4535,69 @@
"litellm_provider": "perplexity",
"mode": "chat"
},
+ "perplexity/llama-3.1-70b-instruct": {
+ "max_tokens": 131072,
+ "max_input_tokens": 131072,
+ "max_output_tokens": 131072,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000001,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
+ "perplexity/llama-3.1-8b-instruct": {
+ "max_tokens": 131072,
+ "max_input_tokens": 131072,
+ "max_output_tokens": 131072,
+ "input_cost_per_token": 0.0000002,
+ "output_cost_per_token": 0.0000002,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
+ "perplexity/llama-3.1-sonar-huge-128k-online": {
+ "max_tokens": 127072,
+ "max_input_tokens": 127072,
+ "max_output_tokens": 127072,
+ "input_cost_per_token": 0.000005,
+ "output_cost_per_token": 0.000005,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
+ "perplexity/llama-3.1-sonar-large-128k-online": {
+ "max_tokens": 127072,
+ "max_input_tokens": 127072,
+ "max_output_tokens": 127072,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000001,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
+ "perplexity/llama-3.1-sonar-large-128k-chat": {
+ "max_tokens": 131072,
+ "max_input_tokens": 131072,
+ "max_output_tokens": 131072,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000001,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
+ "perplexity/llama-3.1-sonar-small-128k-chat": {
+ "max_tokens": 131072,
+ "max_input_tokens": 131072,
+ "max_output_tokens": 131072,
+ "input_cost_per_token": 0.0000002,
+ "output_cost_per_token": 0.0000002,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
+ "perplexity/llama-3.1-sonar-small-128k-online": {
+ "max_tokens": 127072,
+ "max_input_tokens": 127072,
+ "max_output_tokens": 127072,
+ "input_cost_per_token": 0.0000002,
+ "output_cost_per_token": 0.0000002,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
"perplexity/pplx-7b-chat": {
"max_tokens": 8192,
"max_input_tokens": 8192,
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index e4e180727d..dfa5c16520 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -1,13 +1,6 @@
model_list:
- - model_name: "*"
+ - model_name: "gpt-4"
litellm_params:
- model: "*"
-
-# general_settings:
-# master_key: sk-1234
-# pass_through_endpoints:
-# - path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server
-# target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward
-# headers:
-# LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_PUBLIC_KEY" # your langfuse account public key
-# LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_SECRET_KEY" # your langfuse account secret key
\ No newline at end of file
+ model: "gpt-4"
+ model_info:
+ my_custom_key: "my_custom_value"
\ No newline at end of file
diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py
index 5ae149f1bd..00e78f64e6 100644
--- a/litellm/proxy/auth/user_api_key_auth.py
+++ b/litellm/proxy/auth/user_api_key_auth.py
@@ -12,7 +12,7 @@ import json
import secrets
import traceback
from datetime import datetime, timedelta, timezone
-from typing import Optional
+from typing import Optional, Tuple
from uuid import uuid4
import fastapi
@@ -125,7 +125,7 @@ async def user_api_key_auth(
# Check 2. FILTER IP ADDRESS
await check_if_request_size_is_safe(request=request)
- is_valid_ip = _check_valid_ip(
+ is_valid_ip, passed_in_ip = _check_valid_ip(
allowed_ips=general_settings.get("allowed_ips", None),
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
request=request,
@@ -134,7 +134,7 @@ async def user_api_key_auth(
if not is_valid_ip:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
- detail="Access forbidden: IP address not allowed.",
+ detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
)
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
@@ -1251,12 +1251,12 @@ def _check_valid_ip(
allowed_ips: Optional[List[str]],
request: Request,
use_x_forwarded_for: Optional[bool] = False,
-) -> bool:
+) -> Tuple[bool, Optional[str]]:
"""
Returns if ip is allowed or not
"""
if allowed_ips is None: # if not set, assume true
- return True
+ return True, None
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
client_ip = None
@@ -1267,9 +1267,9 @@ def _check_valid_ip(
# Check if IP address is allowed
if client_ip not in allowed_ips:
- return False
+ return False, client_ip
- return True
+ return True, client_ip
def get_api_key_from_custom_header(
diff --git a/litellm/proxy/common_utils/load_config_utils.py b/litellm/proxy/common_utils/load_config_utils.py
new file mode 100644
index 0000000000..b009695b8c
--- /dev/null
+++ b/litellm/proxy/common_utils/load_config_utils.py
@@ -0,0 +1,56 @@
+import yaml
+
+from litellm._logging import verbose_proxy_logger
+
+
+def get_file_contents_from_s3(bucket_name, object_key):
+ try:
+ # v0 rely on boto3 for authentication - allowing boto3 to handle IAM credentials etc
+ import tempfile
+
+ import boto3
+ from botocore.config import Config
+ from botocore.credentials import Credentials
+
+ from litellm.main import bedrock_converse_chat_completion
+
+ credentials: Credentials = bedrock_converse_chat_completion.get_credentials()
+ s3_client = boto3.client(
+ "s3",
+ aws_access_key_id=credentials.access_key,
+ aws_secret_access_key=credentials.secret_key,
+ aws_session_token=credentials.token, # Optional, if using temporary credentials
+ )
+ verbose_proxy_logger.debug(
+ f"Retrieving {object_key} from S3 bucket: {bucket_name}"
+ )
+ response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
+ verbose_proxy_logger.debug(f"Response: {response}")
+
+ # Read the file contents
+ file_contents = response["Body"].read().decode("utf-8")
+ verbose_proxy_logger.debug(f"File contents retrieved from S3")
+
+ # Create a temporary file with YAML extension
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") as temp_file:
+ temp_file.write(file_contents.encode("utf-8"))
+ temp_file_path = temp_file.name
+ verbose_proxy_logger.debug(f"File stored temporarily at: {temp_file_path}")
+
+ # Load the YAML file content
+ with open(temp_file_path, "r") as yaml_file:
+ config = yaml.safe_load(yaml_file)
+
+ return config
+ except ImportError:
+ # this is most likely if a user is not using the litellm docker container
+ verbose_proxy_logger.error(f"ImportError: {str(e)}")
+ pass
+ except Exception as e:
+ verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
+ return None
+
+
+# # Example usage
+# bucket_name = 'litellm-proxy'
+# object_key = 'litellm_proxy_config.yaml'
diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py
index 990cb52337..dd39efd6b7 100644
--- a/litellm/proxy/litellm_pre_call_utils.py
+++ b/litellm/proxy/litellm_pre_call_utils.py
@@ -5,7 +5,12 @@ from fastapi import Request
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
-from litellm.proxy._types import CommonProxyErrors, TeamCallbackMetadata, UserAPIKeyAuth
+from litellm.proxy._types import (
+ AddTeamCallback,
+ CommonProxyErrors,
+ TeamCallbackMetadata,
+ UserAPIKeyAuth,
+)
from litellm.types.utils import SupportedCacheControls
if TYPE_CHECKING:
@@ -59,6 +64,42 @@ def safe_add_api_version_from_query_params(data: dict, request: Request):
verbose_logger.error("error checking api version in query params: %s", str(e))
+def convert_key_logging_metadata_to_callback(
+ data: AddTeamCallback, team_callback_settings_obj: Optional[TeamCallbackMetadata]
+) -> TeamCallbackMetadata:
+ if team_callback_settings_obj is None:
+ team_callback_settings_obj = TeamCallbackMetadata()
+ if data.callback_type == "success":
+ if team_callback_settings_obj.success_callback is None:
+ team_callback_settings_obj.success_callback = []
+
+ if data.callback_name not in team_callback_settings_obj.success_callback:
+ team_callback_settings_obj.success_callback.append(data.callback_name)
+ elif data.callback_type == "failure":
+ if team_callback_settings_obj.failure_callback is None:
+ team_callback_settings_obj.failure_callback = []
+
+ if data.callback_name not in team_callback_settings_obj.failure_callback:
+ team_callback_settings_obj.failure_callback.append(data.callback_name)
+ elif data.callback_type == "success_and_failure":
+ if team_callback_settings_obj.success_callback is None:
+ team_callback_settings_obj.success_callback = []
+ if team_callback_settings_obj.failure_callback is None:
+ team_callback_settings_obj.failure_callback = []
+ if data.callback_name not in team_callback_settings_obj.success_callback:
+ team_callback_settings_obj.success_callback.append(data.callback_name)
+
+ if data.callback_name in team_callback_settings_obj.failure_callback:
+ team_callback_settings_obj.failure_callback.append(data.callback_name)
+
+ for var, value in data.callback_vars.items():
+ if team_callback_settings_obj.callback_vars is None:
+ team_callback_settings_obj.callback_vars = {}
+ team_callback_settings_obj.callback_vars[var] = litellm.get_secret(value)
+
+ return team_callback_settings_obj
+
+
async def add_litellm_data_to_request(
data: dict,
request: Request,
@@ -85,14 +126,19 @@ async def add_litellm_data_to_request(
safe_add_api_version_from_query_params(data, request)
+ _headers = dict(request.headers)
+
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
- "headers": dict(request.headers),
+ "headers": _headers,
"body": copy.copy(data), # use copy instead of deepcopy
}
+ ## Forward any LLM API Provider specific headers in extra_headers
+ add_provider_specific_headers_to_request(data=data, headers=_headers)
+
## Cache Controls
headers = request.headers
verbose_proxy_logger.debug("Request Headers: %s", headers)
@@ -224,6 +270,7 @@ async def add_litellm_data_to_request(
} # add the team-specific configs to the completion call
# Team Callbacks controls
+ callback_settings_obj: Optional[TeamCallbackMetadata] = None
if user_api_key_dict.team_metadata is not None:
team_metadata = user_api_key_dict.team_metadata
if "callback_settings" in team_metadata:
@@ -241,17 +288,54 @@ async def add_litellm_data_to_request(
}
}
"""
- data["success_callback"] = callback_settings_obj.success_callback
- data["failure_callback"] = callback_settings_obj.failure_callback
+ elif (
+ user_api_key_dict.metadata is not None
+ and "logging" in user_api_key_dict.metadata
+ ):
+ for item in user_api_key_dict.metadata["logging"]:
- if callback_settings_obj.callback_vars is not None:
- # unpack callback_vars in data
- for k, v in callback_settings_obj.callback_vars.items():
- data[k] = v
+ callback_settings_obj = convert_key_logging_metadata_to_callback(
+ data=AddTeamCallback(**item),
+ team_callback_settings_obj=callback_settings_obj,
+ )
+
+ if callback_settings_obj is not None:
+ data["success_callback"] = callback_settings_obj.success_callback
+ data["failure_callback"] = callback_settings_obj.failure_callback
+
+ if callback_settings_obj.callback_vars is not None:
+ # unpack callback_vars in data
+ for k, v in callback_settings_obj.callback_vars.items():
+ data[k] = v
return data
+def add_provider_specific_headers_to_request(
+ data: dict,
+ headers: dict,
+):
+ ANTHROPIC_API_HEADERS = [
+ "anthropic-version",
+ "anthropic-beta",
+ ]
+
+ extra_headers = data.get("extra_headers", {}) or {}
+
+ # boolean to indicate if a header was added
+ added_header = False
+ for header in ANTHROPIC_API_HEADERS:
+ if header in headers:
+ header_value = headers[header]
+ extra_headers.update({header: header_value})
+ added_header = True
+
+ if added_header is True:
+ data["extra_headers"] = extra_headers
+
+ return
+
+
def _add_otel_traceparent_to_data(data: dict, request: Request):
from litellm.proxy.proxy_server import open_telemetry_logger
diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml
index 660c27f249..d25f1b9468 100644
--- a/litellm/proxy/proxy_config.yaml
+++ b/litellm/proxy/proxy_config.yaml
@@ -19,6 +19,9 @@ model_list:
litellm_params:
model: mistral/mistral-small-latest
api_key: "os.environ/MISTRAL_API_KEY"
+ - model_name: bedrock-anthropic
+ litellm_params:
+ model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
- model_name: gemini-1.5-pro-001
litellm_params:
model: vertex_ai_beta/gemini-1.5-pro-001
@@ -39,7 +42,7 @@ general_settings:
litellm_settings:
fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}]
- success_callback: ["langfuse", "prometheus"]
- langfuse_default_tags: ["cache_hit", "cache_key", "proxy_base_url", "user_api_key_alias", "user_api_key_user_id", "user_api_key_user_email", "user_api_key_team_alias", "semantic-similarity", "proxy_base_url"]
- failure_callback: ["prometheus"]
+ callbacks: ["gcs_bucket"]
+ success_callback: ["langfuse"]
+ langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"]
cache: True
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index 4d141955b2..10c06b2ece 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -159,6 +159,7 @@ from litellm.proxy.common_utils.http_parsing_utils import (
check_file_size_under_limit,
)
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
+from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment,
)
@@ -197,6 +198,8 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
router as pass_through_router,
)
+from litellm.proxy.route_llm_request import route_request
+
from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_kms,
load_aws_secret_manager,
@@ -1444,7 +1447,18 @@ class ProxyConfig:
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details
# Load existing config
- config = await self.get_config(config_file_path=config_file_path)
+ if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
+ bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME")
+ object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY")
+ verbose_proxy_logger.debug(
+ "bucket_name: %s, object_key: %s", bucket_name, object_key
+ )
+ config = get_file_contents_from_s3(
+ bucket_name=bucket_name, object_key=object_key
+ )
+ else:
+ # default to file
+ config = await self.get_config(config_file_path=config_file_path)
## PRINT YAML FOR CONFIRMING IT WORKS
printed_yaml = copy.deepcopy(config)
printed_yaml.pop("environment_variables", None)
@@ -2652,6 +2666,15 @@ async def startup_event():
)
else:
await initialize(**worker_config)
+ elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
+ (
+ llm_router,
+ llm_model_list,
+ general_settings,
+ ) = await proxy_config.load_config(
+ router=llm_router, config_file_path=worker_config
+ )
+
else:
# if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
@@ -3036,68 +3059,13 @@ async def chat_completion(
### ROUTE THE REQUEST ###
# Do not change this - it should be a constant time fetch - ALWAYS
- router_model_names = llm_router.model_names if llm_router is not None else []
- # skip router if user passed their key
- if "api_key" in data:
- tasks.append(litellm.acompletion(**data))
- elif "," in data["model"] and llm_router is not None:
- if (
- data.get("fastest_response", None) is not None
- and data["fastest_response"] == True
- ):
- tasks.append(llm_router.abatch_completion_fastest_response(**data))
- else:
- _models_csv_string = data.pop("model")
- _models = [model.strip() for model in _models_csv_string.split(",")]
- tasks.append(llm_router.abatch_completion(models=_models, **data))
- elif "user_config" in data:
- # initialize a new router instance. make request using this Router
- router_config = data.pop("user_config")
- user_router = litellm.Router(**router_config)
- tasks.append(user_router.acompletion(**data))
- elif (
- llm_router is not None and data["model"] in router_model_names
- ): # model in router model list
- tasks.append(llm_router.acompletion(**data))
- elif (
- llm_router is not None and data["model"] in llm_router.get_model_ids()
- ): # model in router model list
- tasks.append(llm_router.acompletion(**data))
- elif (
- llm_router is not None
- and llm_router.model_group_alias is not None
- and data["model"] in llm_router.model_group_alias
- ): # model set in model_group_alias
- tasks.append(llm_router.acompletion(**data))
- elif (
- llm_router is not None and data["model"] in llm_router.deployment_names
- ): # model in router deployments, calling a specific deployment on the router
- tasks.append(llm_router.acompletion(**data, specific_deployment=True))
- elif (
- llm_router is not None
- and data["model"] not in router_model_names
- and llm_router.router_general_settings.pass_through_all_models is True
- ):
- tasks.append(litellm.acompletion(**data))
- elif (
- llm_router is not None
- and data["model"] not in router_model_names
- and (
- llm_router.default_deployment is not None
- or len(llm_router.provider_default_deployments) > 0
- )
- ): # model in router deployments, calling a specific deployment on the router
- tasks.append(llm_router.acompletion(**data))
- elif user_model is not None: # `litellm --model `
- tasks.append(litellm.acompletion(**data))
- else:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail={
- "error": "chat_completion: Invalid model name passed in model="
- + data.get("model", "")
- },
- )
+ llm_call = await route_request(
+ data=data,
+ route_type="acompletion",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+ tasks.append(llm_call)
# wait for call to end
llm_responses = asyncio.gather(
@@ -3320,58 +3288,15 @@ async def completion(
)
### ROUTE THE REQUESTs ###
- router_model_names = llm_router.model_names if llm_router is not None else []
- # skip router if user passed their key
- if "api_key" in data:
- llm_response = asyncio.create_task(litellm.atext_completion(**data))
- elif (
- llm_router is not None and data["model"] in router_model_names
- ): # model in router model list
- llm_response = asyncio.create_task(llm_router.atext_completion(**data))
- elif (
- llm_router is not None
- and llm_router.model_group_alias is not None
- and data["model"] in llm_router.model_group_alias
- ): # model set in model_group_alias
- llm_response = asyncio.create_task(llm_router.atext_completion(**data))
- elif (
- llm_router is not None and data["model"] in llm_router.deployment_names
- ): # model in router deployments, calling a specific deployment on the router
- llm_response = asyncio.create_task(
- llm_router.atext_completion(**data, specific_deployment=True)
- )
- elif (
- llm_router is not None and data["model"] in llm_router.get_model_ids()
- ): # model in router model list
- llm_response = asyncio.create_task(llm_router.atext_completion(**data))
- elif (
- llm_router is not None
- and data["model"] not in router_model_names
- and llm_router.router_general_settings.pass_through_all_models is True
- ):
- llm_response = asyncio.create_task(litellm.atext_completion(**data))
- elif (
- llm_router is not None
- and data["model"] not in router_model_names
- and (
- llm_router.default_deployment is not None
- or len(llm_router.provider_default_deployments) > 0
- )
- ): # model in router deployments, calling a specific deployment on the router
- llm_response = asyncio.create_task(llm_router.atext_completion(**data))
- elif user_model is not None: # `litellm --model `
- llm_response = asyncio.create_task(litellm.atext_completion(**data))
- else:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail={
- "error": "completion: Invalid model name passed in model="
- + data.get("model", "")
- },
- )
+ llm_call = await route_request(
+ data=data,
+ route_type="atext_completion",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
# Await the llm_response task
- response = await llm_response
+ response = await llm_call
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
@@ -3585,59 +3510,13 @@ async def embeddings(
)
## ROUTE TO CORRECT ENDPOINT ##
- # skip router if user passed their key
- if "api_key" in data:
- tasks.append(litellm.aembedding(**data))
- elif "user_config" in data:
- # initialize a new router instance. make request using this Router
- router_config = data.pop("user_config")
- user_router = litellm.Router(**router_config)
- tasks.append(user_router.aembedding(**data))
- elif (
- llm_router is not None and data["model"] in router_model_names
- ): # model in router model list
- tasks.append(llm_router.aembedding(**data))
- elif (
- llm_router is not None
- and llm_router.model_group_alias is not None
- and data["model"] in llm_router.model_group_alias
- ): # model set in model_group_alias
- tasks.append(
- llm_router.aembedding(**data)
- ) # ensure this goes the llm_router, router will do the correct alias mapping
- elif (
- llm_router is not None and data["model"] in llm_router.deployment_names
- ): # model in router deployments, calling a specific deployment on the router
- tasks.append(llm_router.aembedding(**data, specific_deployment=True))
- elif (
- llm_router is not None and data["model"] in llm_router.get_model_ids()
- ): # model in router deployments, calling a specific deployment on the router
- tasks.append(llm_router.aembedding(**data))
- elif (
- llm_router is not None
- and data["model"] not in router_model_names
- and llm_router.router_general_settings.pass_through_all_models is True
- ):
- tasks.append(litellm.aembedding(**data))
- elif (
- llm_router is not None
- and data["model"] not in router_model_names
- and (
- llm_router.default_deployment is not None
- or len(llm_router.provider_default_deployments) > 0
- )
- ): # model in router deployments, calling a specific deployment on the router
- tasks.append(llm_router.aembedding(**data))
- elif user_model is not None: # `litellm --model `
- tasks.append(litellm.aembedding(**data))
- else:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail={
- "error": "embeddings: Invalid model name passed in model="
- + data.get("model", "")
- },
- )
+ llm_call = await route_request(
+ data=data,
+ route_type="aembedding",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+ tasks.append(llm_call)
# wait for call to end
llm_responses = asyncio.gather(
@@ -3768,46 +3647,13 @@ async def image_generation(
)
## ROUTE TO CORRECT ENDPOINT ##
- # skip router if user passed their key
- if "api_key" in data:
- response = await litellm.aimage_generation(**data)
- elif (
- llm_router is not None and data["model"] in router_model_names
- ): # model in router model list
- response = await llm_router.aimage_generation(**data)
- elif (
- llm_router is not None and data["model"] in llm_router.deployment_names
- ): # model in router deployments, calling a specific deployment on the router
- response = await llm_router.aimage_generation(
- **data, specific_deployment=True
- )
- elif (
- llm_router is not None
- and llm_router.model_group_alias is not None
- and data["model"] in llm_router.model_group_alias
- ): # model set in model_group_alias
- response = await llm_router.aimage_generation(
- **data
- ) # ensure this goes the llm_router, router will do the correct alias mapping
- elif (
- llm_router is not None
- and data["model"] not in router_model_names
- and (
- llm_router.default_deployment is not None
- or len(llm_router.provider_default_deployments) > 0
- )
- ): # model in router deployments, calling a specific deployment on the router
- response = await llm_router.aimage_generation(**data)
- elif user_model is not None: # `litellm --model `
- response = await litellm.aimage_generation(**data)
- else:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail={
- "error": "image_generation: Invalid model name passed in model="
- + data.get("model", "")
- },
- )
+ llm_call = await route_request(
+ data=data,
+ route_type="aimage_generation",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+ response = await llm_call
### ALERTING ###
asyncio.create_task(
@@ -3915,44 +3761,13 @@ async def audio_speech(
)
## ROUTE TO CORRECT ENDPOINT ##
- # skip router if user passed their key
- if "api_key" in data:
- response = await litellm.aspeech(**data)
- elif (
- llm_router is not None and data["model"] in router_model_names
- ): # model in router model list
- response = await llm_router.aspeech(**data)
- elif (
- llm_router is not None and data["model"] in llm_router.deployment_names
- ): # model in router deployments, calling a specific deployment on the router
- response = await llm_router.aspeech(**data, specific_deployment=True)
- elif (
- llm_router is not None
- and llm_router.model_group_alias is not None
- and data["model"] in llm_router.model_group_alias
- ): # model set in model_group_alias
- response = await llm_router.aspeech(
- **data
- ) # ensure this goes the llm_router, router will do the correct alias mapping
- elif (
- llm_router is not None
- and data["model"] not in router_model_names
- and (
- llm_router.default_deployment is not None
- or len(llm_router.provider_default_deployments) > 0
- )
- ): # model in router deployments, calling a specific deployment on the router
- response = await llm_router.aspeech(**data)
- elif user_model is not None: # `litellm --model `
- response = await litellm.aspeech(**data)
- else:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail={
- "error": "audio_speech: Invalid model name passed in model="
- + data.get("model", "")
- },
- )
+ llm_call = await route_request(
+ data=data,
+ route_type="aspeech",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+ response = await llm_call
### ALERTING ###
asyncio.create_task(
@@ -4085,47 +3900,13 @@ async def audio_transcriptions(
)
## ROUTE TO CORRECT ENDPOINT ##
- # skip router if user passed their key
- if "api_key" in data:
- response = await litellm.atranscription(**data)
- elif (
- llm_router is not None and data["model"] in router_model_names
- ): # model in router model list
- response = await llm_router.atranscription(**data)
-
- elif (
- llm_router is not None and data["model"] in llm_router.deployment_names
- ): # model in router deployments, calling a specific deployment on the router
- response = await llm_router.atranscription(
- **data, specific_deployment=True
- )
- elif (
- llm_router is not None
- and llm_router.model_group_alias is not None
- and data["model"] in llm_router.model_group_alias
- ): # model set in model_group_alias
- response = await llm_router.atranscription(
- **data
- ) # ensure this goes the llm_router, router will do the correct alias mapping
- elif (
- llm_router is not None
- and data["model"] not in router_model_names
- and (
- llm_router.default_deployment is not None
- or len(llm_router.provider_default_deployments) > 0
- )
- ): # model in router deployments, calling a specific deployment on the router
- response = await llm_router.atranscription(**data)
- elif user_model is not None: # `litellm --model `
- response = await litellm.atranscription(**data)
- else:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail={
- "error": "audio_transcriptions: Invalid model name passed in model="
- + data.get("model", "")
- },
- )
+ llm_call = await route_request(
+ data=data,
+ route_type="atranscription",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+ response = await llm_call
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
@@ -5341,40 +5122,13 @@ async def moderations(
start_time = time.time()
## ROUTE TO CORRECT ENDPOINT ##
- # skip router if user passed their key
- if "api_key" in data:
- response = await litellm.amoderation(**data)
- elif (
- llm_router is not None and data.get("model") in router_model_names
- ): # model in router model list
- response = await llm_router.amoderation(**data)
- elif (
- llm_router is not None and data.get("model") in llm_router.deployment_names
- ): # model in router deployments, calling a specific deployment on the router
- response = await llm_router.amoderation(**data, specific_deployment=True)
- elif (
- llm_router is not None
- and llm_router.model_group_alias is not None
- and data.get("model") in llm_router.model_group_alias
- ): # model set in model_group_alias
- response = await llm_router.amoderation(
- **data
- ) # ensure this goes the llm_router, router will do the correct alias mapping
- elif (
- llm_router is not None
- and data.get("model") not in router_model_names
- and (
- llm_router.default_deployment is not None
- or len(llm_router.provider_default_deployments) > 0
- )
- ): # model in router deployments, calling a specific deployment on the router
- response = await llm_router.amoderation(**data)
- elif user_model is not None: # `litellm --model `
- response = await litellm.amoderation(**data)
- else:
- # /moderations does not need a "model" passed
- # see https://platform.openai.com/docs/api-reference/moderations
- response = await litellm.amoderation(**data)
+ llm_call = await route_request(
+ data=data,
+ route_type="amoderation",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+ response = await llm_call
### ALERTING ###
asyncio.create_task(
diff --git a/litellm/proxy/route_llm_request.py b/litellm/proxy/route_llm_request.py
new file mode 100644
index 0000000000..7a7be55b22
--- /dev/null
+++ b/litellm/proxy/route_llm_request.py
@@ -0,0 +1,117 @@
+from typing import TYPE_CHECKING, Any, Literal, Optional, Union
+
+from fastapi import (
+ Depends,
+ FastAPI,
+ File,
+ Form,
+ Header,
+ HTTPException,
+ Path,
+ Request,
+ Response,
+ UploadFile,
+ status,
+)
+
+import litellm
+from litellm._logging import verbose_logger
+
+if TYPE_CHECKING:
+ from litellm.router import Router as _Router
+
+ LitellmRouter = _Router
+else:
+ LitellmRouter = Any
+
+
+ROUTE_ENDPOINT_MAPPING = {
+ "acompletion": "/chat/completions",
+ "atext_completion": "/completions",
+ "aembedding": "/embeddings",
+ "aimage_generation": "/image/generations",
+ "aspeech": "/audio/speech",
+ "atranscription": "/audio/transcriptions",
+ "amoderation": "/moderations",
+}
+
+
+async def route_request(
+ data: dict,
+ llm_router: Optional[LitellmRouter],
+ user_model: Optional[str],
+ route_type: Literal[
+ "acompletion",
+ "atext_completion",
+ "aembedding",
+ "aimage_generation",
+ "aspeech",
+ "atranscription",
+ "amoderation",
+ ],
+):
+ """
+ Common helper to route the request
+
+ """
+ router_model_names = llm_router.model_names if llm_router is not None else []
+
+ if "api_key" in data:
+ return getattr(litellm, f"{route_type}")(**data)
+
+ elif "user_config" in data:
+ router_config = data.pop("user_config")
+ user_router = litellm.Router(**router_config)
+ return getattr(user_router, f"{route_type}")(**data)
+
+ elif (
+ route_type == "acompletion"
+ and data.get("model", "") is not None
+ and "," in data.get("model", "")
+ and llm_router is not None
+ ):
+ if data.get("fastest_response", False):
+ return llm_router.abatch_completion_fastest_response(**data)
+ else:
+ models = [model.strip() for model in data.pop("model").split(",")]
+ return llm_router.abatch_completion(models=models, **data)
+
+ elif llm_router is not None:
+ if (
+ data["model"] in router_model_names
+ or data["model"] in llm_router.get_model_ids()
+ ):
+ return getattr(llm_router, f"{route_type}")(**data)
+
+ elif (
+ llm_router.model_group_alias is not None
+ and data["model"] in llm_router.model_group_alias
+ ):
+ return getattr(llm_router, f"{route_type}")(**data)
+
+ elif data["model"] in llm_router.deployment_names:
+ return getattr(llm_router, f"{route_type}")(
+ **data, specific_deployment=True
+ )
+
+ elif data["model"] not in router_model_names:
+ if llm_router.router_general_settings.pass_through_all_models:
+ return getattr(litellm, f"{route_type}")(**data)
+ elif (
+ llm_router.default_deployment is not None
+ or len(llm_router.provider_default_deployments) > 0
+ ):
+ return getattr(llm_router, f"{route_type}")(**data)
+
+ elif user_model is not None:
+ return getattr(litellm, f"{route_type}")(**data)
+
+ # if no route found then it's a bad request
+ route_name = ROUTE_ENDPOINT_MAPPING.get(route_type, route_type)
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail={
+ "error": f"{route_name}: Invalid model name passed in model="
+ + data.get("model", "")
+ },
+ )
diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py
index cd7004e41d..6a28d70b17 100644
--- a/litellm/proxy/spend_tracking/spend_tracking_utils.py
+++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py
@@ -21,6 +21,8 @@ def get_logging_payload(
if kwargs is None:
kwargs = {}
+ if response_obj is None:
+ response_obj = {}
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
metadata = (
diff --git a/litellm/proxy/tests/test_anthropic_context_caching.py b/litellm/proxy/tests/test_anthropic_context_caching.py
new file mode 100644
index 0000000000..6156e4a048
--- /dev/null
+++ b/litellm/proxy/tests/test_anthropic_context_caching.py
@@ -0,0 +1,37 @@
+import openai
+
+client = openai.OpenAI(
+ api_key="sk-1234", # litellm proxy api key
+ base_url="http://0.0.0.0:4000", # litellm proxy base url
+)
+
+
+response = client.chat.completions.create(
+ model="anthropic/claude-3-5-sonnet-20240620",
+ messages=[
+ { # type: ignore
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "You are an AI assistant tasked with analyzing legal documents.",
+ },
+ {
+ "type": "text",
+ "text": "Here is the full text of a complex legal agreement" * 100,
+ "cache_control": {"type": "ephemeral"},
+ },
+ ],
+ },
+ {
+ "role": "user",
+ "content": "what are the key terms and conditions in this agreement?",
+ },
+ ],
+ extra_headers={
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "prompt-caching-2024-07-31",
+ },
+)
+
+print(response)
diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py
index 073a87901a..f396defb51 100644
--- a/litellm/router_utils/client_initalization_utils.py
+++ b/litellm/router_utils/client_initalization_utils.py
@@ -190,7 +190,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
if api_version is None:
- api_version = litellm.AZURE_DEFAULT_API_VERSION
+ api_version = os.getenv("AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION)
if "gateway.ai.cloudflare.com" in api_base:
if not api_base.endswith("/"):
diff --git a/litellm/tests/test_anthropic_prompt_caching.py b/litellm/tests/test_anthropic_prompt_caching.py
new file mode 100644
index 0000000000..87bfc23f84
--- /dev/null
+++ b/litellm/tests/test_anthropic_prompt_caching.py
@@ -0,0 +1,321 @@
+import json
+import os
+import sys
+import traceback
+
+from dotenv import load_dotenv
+
+load_dotenv()
+import io
+import os
+
+sys.path.insert(
+ 0, os.path.abspath("../..")
+) # Adds the parent directory to the system path
+
+import os
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+import litellm
+from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.llms.prompt_templates.factory import anthropic_messages_pt
+
+# litellm.num_retries =3
+litellm.cache = None
+litellm.success_callback = []
+user_message = "Write a short poem about the sky"
+messages = [{"content": user_message, "role": "user"}]
+
+
+def logger_fn(user_model_dict):
+ print(f"user_model_dict: {user_model_dict}")
+
+
+@pytest.fixture(autouse=True)
+def reset_callbacks():
+ print("\npytest fixture - resetting callbacks")
+ litellm.success_callback = []
+ litellm._async_success_callback = []
+ litellm.failure_callback = []
+ litellm.callbacks = []
+
+
+@pytest.mark.asyncio
+async def test_litellm_anthropic_prompt_caching_tools():
+ # Arrange: Set up the MagicMock for the httpx.AsyncClient
+ mock_response = AsyncMock()
+
+ def return_val():
+ return {
+ "id": "msg_01XFDUDYJgAACzvnptvVoYEL",
+ "type": "message",
+ "role": "assistant",
+ "content": [{"type": "text", "text": "Hello!"}],
+ "model": "claude-3-5-sonnet-20240620",
+ "stop_reason": "end_turn",
+ "stop_sequence": None,
+ "usage": {"input_tokens": 12, "output_tokens": 6},
+ }
+
+ mock_response.json = return_val
+
+ litellm.set_verbose = True
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
+ return_value=mock_response,
+ ) as mock_post:
+ # Act: Call the litellm.acompletion function
+ response = await litellm.acompletion(
+ api_key="mock_api_key",
+ model="anthropic/claude-3-5-sonnet-20240620",
+ messages=[
+ {"role": "user", "content": "What's the weather like in Boston today?"}
+ ],
+ tools=[
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather in a given location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ },
+ },
+ "required": ["location"],
+ },
+ "cache_control": {"type": "ephemeral"},
+ },
+ }
+ ],
+ extra_headers={
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "prompt-caching-2024-07-31",
+ },
+ )
+
+ # Print what was called on the mock
+ print("call args=", mock_post.call_args)
+
+ expected_url = "https://api.anthropic.com/v1/messages"
+ expected_headers = {
+ "accept": "application/json",
+ "content-type": "application/json",
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "prompt-caching-2024-07-31",
+ "x-api-key": "mock_api_key",
+ }
+
+ expected_json = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What's the weather like in Boston today?",
+ }
+ ],
+ }
+ ],
+ "tools": [
+ {
+ "name": "get_current_weather",
+ "description": "Get the current weather in a given location",
+ "cache_control": {"type": "ephemeral"},
+ "input_schema": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ },
+ },
+ "required": ["location"],
+ },
+ }
+ ],
+ "max_tokens": 4096,
+ "model": "claude-3-5-sonnet-20240620",
+ }
+
+ mock_post.assert_called_once_with(
+ expected_url, json=expected_json, headers=expected_headers, timeout=600.0
+ )
+
+
+@pytest.mark.asyncio()
+async def test_anthropic_api_prompt_caching_basic():
+ litellm.set_verbose = True
+ response = await litellm.acompletion(
+ model="anthropic/claude-3-5-sonnet-20240620",
+ messages=[
+ # System Message
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "Here is the full text of a complex legal agreement"
+ * 400,
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What are the key terms and conditions in this agreement?",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ {
+ "role": "assistant",
+ "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
+ },
+ # The final turn is marked with cache-control, for continuing in followups.
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What are the key terms and conditions in this agreement?",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ ],
+ temperature=0.2,
+ max_tokens=10,
+ extra_headers={
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "prompt-caching-2024-07-31",
+ },
+ )
+
+ print("response=", response)
+
+ assert "cache_read_input_tokens" in response.usage
+ assert "cache_creation_input_tokens" in response.usage
+
+ # Assert either a cache entry was created or cache was read - changes depending on the anthropic api ttl
+ assert (response.usage.cache_read_input_tokens > 0) or (
+ response.usage.cache_creation_input_tokens > 0
+ )
+
+
+@pytest.mark.asyncio
+async def test_litellm_anthropic_prompt_caching_system():
+ # https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#prompt-caching-examples
+ # LArge Context Caching Example
+ mock_response = AsyncMock()
+
+ def return_val():
+ return {
+ "id": "msg_01XFDUDYJgAACzvnptvVoYEL",
+ "type": "message",
+ "role": "assistant",
+ "content": [{"type": "text", "text": "Hello!"}],
+ "model": "claude-3-5-sonnet-20240620",
+ "stop_reason": "end_turn",
+ "stop_sequence": None,
+ "usage": {"input_tokens": 12, "output_tokens": 6},
+ }
+
+ mock_response.json = return_val
+
+ litellm.set_verbose = True
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
+ return_value=mock_response,
+ ) as mock_post:
+ # Act: Call the litellm.acompletion function
+ response = await litellm.acompletion(
+ api_key="mock_api_key",
+ model="anthropic/claude-3-5-sonnet-20240620",
+ messages=[
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "You are an AI assistant tasked with analyzing legal documents.",
+ },
+ {
+ "type": "text",
+ "text": "Here is the full text of a complex legal agreement",
+ "cache_control": {"type": "ephemeral"},
+ },
+ ],
+ },
+ {
+ "role": "user",
+ "content": "what are the key terms and conditions in this agreement?",
+ },
+ ],
+ extra_headers={
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "prompt-caching-2024-07-31",
+ },
+ )
+
+ # Print what was called on the mock
+ print("call args=", mock_post.call_args)
+
+ expected_url = "https://api.anthropic.com/v1/messages"
+ expected_headers = {
+ "accept": "application/json",
+ "content-type": "application/json",
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "prompt-caching-2024-07-31",
+ "x-api-key": "mock_api_key",
+ }
+
+ expected_json = {
+ "system": [
+ {
+ "type": "text",
+ "text": "You are an AI assistant tasked with analyzing legal documents.",
+ },
+ {
+ "type": "text",
+ "text": "Here is the full text of a complex legal agreement",
+ "cache_control": {"type": "ephemeral"},
+ },
+ ],
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "what are the key terms and conditions in this agreement?",
+ }
+ ],
+ }
+ ],
+ "max_tokens": 4096,
+ "model": "claude-3-5-sonnet-20240620",
+ }
+
+ mock_post.assert_called_once_with(
+ expected_url, json=expected_json, headers=expected_headers, timeout=600.0
+ )
diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py
index 4da18144d0..c331021213 100644
--- a/litellm/tests/test_bedrock_completion.py
+++ b/litellm/tests/test_bedrock_completion.py
@@ -1159,8 +1159,8 @@ def test_bedrock_tools_pt_invalid_names():
assert result[1]["toolSpec"]["name"] == "another_invalid_name"
-def test_bad_request_error():
- with pytest.raises(litellm.BadRequestError):
+def test_not_found_error():
+ with pytest.raises(litellm.NotFoundError):
completion(
model="bedrock/bad_model",
messages=[
diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py
index db0239ca33..654b210ff7 100644
--- a/litellm/tests/test_completion.py
+++ b/litellm/tests/test_completion.py
@@ -14,7 +14,7 @@ sys.path.insert(
) # Adds the parent directory to the system path
import os
-from unittest.mock import MagicMock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
-# litellm.num_retries = 3
+# litellm.num_retries =3
litellm.cache = None
litellm.success_callback = []
user_message = "Write a short poem about the sky"
@@ -190,6 +190,31 @@ def test_completion_azure_command_r():
pytest.fail(f"Error occurred: {e}")
+@pytest.mark.parametrize(
+ "api_base",
+ [
+ "https://litellm8397336933.openai.azure.com",
+ "https://litellm8397336933.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2023-03-15-preview",
+ ],
+)
+def test_completion_azure_ai_gpt_4o(api_base):
+ try:
+ litellm.set_verbose = True
+
+ response = completion(
+ model="azure_ai/gpt-4o",
+ api_base=api_base,
+ api_key=os.getenv("AZURE_AI_OPENAI_KEY"),
+ messages=[{"role": "user", "content": "What is the meaning of life?"}],
+ )
+
+ print(response)
+ except litellm.Timeout as e:
+ pass
+ except Exception as e:
+ pytest.fail(f"Error occurred: {e}")
+
+
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_databricks(sync_mode):
@@ -3312,108 +3337,6 @@ def test_customprompt_together_ai():
# test_customprompt_together_ai()
-@pytest.mark.skip(reason="AWS Suspended Account")
-def test_completion_sagemaker():
- try:
- litellm.set_verbose = True
- print("testing sagemaker")
- response = completion(
- model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-ins-20240329-150233",
- model_id="huggingface-llm-mistral-7b-instruct-20240329-150233",
- messages=messages,
- temperature=0.2,
- max_tokens=80,
- aws_region_name=os.getenv("AWS_REGION_NAME_2"),
- aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID_2"),
- aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY_2"),
- input_cost_per_second=0.000420,
- )
- # Add any assertions here to check the response
- print(response)
- cost = completion_cost(completion_response=response)
- print("calculated cost", cost)
- assert (
- cost > 0.0 and cost < 1.0
- ) # should never be > $1 for a single completion call
- except Exception as e:
- pytest.fail(f"Error occurred: {e}")
-
-
-# test_completion_sagemaker()
-
-
-@pytest.mark.skip(reason="AWS Suspended Account")
-@pytest.mark.asyncio
-async def test_acompletion_sagemaker():
- try:
- litellm.set_verbose = True
- print("testing sagemaker")
- response = await litellm.acompletion(
- model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-ins-20240329-150233",
- model_id="huggingface-llm-mistral-7b-instruct-20240329-150233",
- messages=messages,
- temperature=0.2,
- max_tokens=80,
- aws_region_name=os.getenv("AWS_REGION_NAME_2"),
- aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID_2"),
- aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY_2"),
- input_cost_per_second=0.000420,
- )
- # Add any assertions here to check the response
- print(response)
- cost = completion_cost(completion_response=response)
- print("calculated cost", cost)
- assert (
- cost > 0.0 and cost < 1.0
- ) # should never be > $1 for a single completion call
- except Exception as e:
- pytest.fail(f"Error occurred: {e}")
-
-
-@pytest.mark.skip(reason="AWS Suspended Account")
-def test_completion_chat_sagemaker():
- try:
- messages = [{"role": "user", "content": "Hey, how's it going?"}]
- litellm.set_verbose = True
- response = completion(
- model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
- messages=messages,
- max_tokens=100,
- temperature=0.7,
- stream=True,
- )
- # Add any assertions here to check the response
- complete_response = ""
- for chunk in response:
- complete_response += chunk.choices[0].delta.content or ""
- print(f"complete_response: {complete_response}")
- assert len(complete_response) > 0
- except Exception as e:
- pytest.fail(f"Error occurred: {e}")
-
-
-# test_completion_chat_sagemaker()
-
-
-@pytest.mark.skip(reason="AWS Suspended Account")
-def test_completion_chat_sagemaker_mistral():
- try:
- messages = [{"role": "user", "content": "Hey, how's it going?"}]
-
- response = completion(
- model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-instruct",
- messages=messages,
- max_tokens=100,
- )
- # Add any assertions here to check the response
- print(response)
- except Exception as e:
- pytest.fail(f"An error occurred: {str(e)}")
-
-
-# test_completion_chat_sagemaker_mistral()
-
-
def response_format_tests(response: litellm.ModelResponse):
assert isinstance(response.id, str)
assert response.id != ""
@@ -3449,7 +3372,6 @@ def response_format_tests(response: litellm.ModelResponse):
assert isinstance(response.usage.total_tokens, int) # type: ignore
-@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"model",
[
@@ -3463,6 +3385,7 @@ def response_format_tests(response: litellm.ModelResponse):
"cohere.command-text-v14",
],
)
+@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_bedrock_httpx_models(sync_mode, model):
litellm.set_verbose = True
@@ -3705,19 +3628,21 @@ def test_completion_anyscale_api():
# test_completion_anyscale_api()
-@pytest.mark.skip(reason="flaky test, times out frequently")
+# @pytest.mark.skip(reason="flaky test, times out frequently")
def test_completion_cohere():
try:
# litellm.set_verbose=True
messages = [
{"role": "system", "content": "You're a good bot"},
+ {"role": "assistant", "content": [{"text": "2", "type": "text"}]},
+ {"role": "assistant", "content": [{"text": "3", "type": "text"}]},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
- model="command-nightly",
+ model="command-r",
messages=messages,
)
print(response)
diff --git a/litellm/tests/test_function_call_parsing.py b/litellm/tests/test_function_call_parsing.py
index d223a7c8f6..fab9cf110c 100644
--- a/litellm/tests/test_function_call_parsing.py
+++ b/litellm/tests/test_function_call_parsing.py
@@ -1,23 +1,27 @@
# What is this?
## Test to make sure function call response always works with json.loads() -> no extra parsing required. Relevant issue - https://github.com/BerriAI/litellm/issues/2654
-import sys, os
+import os
+import sys
import traceback
+
from dotenv import load_dotenv
load_dotenv()
-import os, io
+import io
+import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
-import pytest
-import litellm
import json
import warnings
-
-from litellm import completion
from typing import List
+import pytest
+
+import litellm
+from litellm import completion
+
# Just a stub to keep the sample code simple
class Trade:
@@ -78,58 +82,60 @@ def trade(model_name: str) -> List[Trade]:
},
}
- response = completion(
- model_name,
- [
- {
- "role": "system",
- "content": """You are an expert asset manager, managing a portfolio.
+ try:
+ response = completion(
+ model_name,
+ [
+ {
+ "role": "system",
+ "content": """You are an expert asset manager, managing a portfolio.
- Always use the `trade` function. Make sure that you call it correctly. For example, the following is a valid call:
+ Always use the `trade` function. Make sure that you call it correctly. For example, the following is a valid call:
+ ```
+ trade({
+ "orders": [
+ {"action": "buy", "asset": "BTC", "amount": 0.1},
+ {"action": "sell", "asset": "ETH", "amount": 0.2}
+ ]
+ })
+ ```
+
+ If there are no trades to make, call `trade` with an empty array:
+ ```
+ trade({ "orders": [] })
+ ```
+ """,
+ },
+ {
+ "role": "user",
+ "content": """Manage the portfolio.
+
+ Don't jabber.
+
+ This is the current market data:
```
- trade({
- "orders": [
- {"action": "buy", "asset": "BTC", "amount": 0.1},
- {"action": "sell", "asset": "ETH", "amount": 0.2}
- ]
- })
+ {market_data}
```
- If there are no trades to make, call `trade` with an empty array:
+ Your portfolio is as follows:
```
- trade({ "orders": [] })
+ {portfolio}
```
- """,
+ """.replace(
+ "{market_data}", "BTC: 64,000 USD\nETH: 3,500 USD"
+ ).replace(
+ "{portfolio}", "USD: 1000, BTC: 0.1, ETH: 0.2"
+ ),
+ },
+ ],
+ tools=[tool_spec],
+ tool_choice={
+ "type": "function",
+ "function": {"name": tool_spec["function"]["name"]}, # type: ignore
},
- {
- "role": "user",
- "content": """Manage the portfolio.
-
- Don't jabber.
-
- This is the current market data:
- ```
- {market_data}
- ```
-
- Your portfolio is as follows:
- ```
- {portfolio}
- ```
- """.replace(
- "{market_data}", "BTC: 64,000 USD\nETH: 3,500 USD"
- ).replace(
- "{portfolio}", "USD: 1000, BTC: 0.1, ETH: 0.2"
- ),
- },
- ],
- tools=[tool_spec],
- tool_choice={
- "type": "function",
- "function": {"name": tool_spec["function"]["name"]}, # type: ignore
- },
- )
-
+ )
+ except litellm.InternalServerError:
+ pass
calls = response.choices[0].message.tool_calls
trades = [trade for call in calls for trade in parse_call(call)]
return trades
diff --git a/litellm/tests/test_gcs_bucket.py b/litellm/tests/test_gcs_bucket.py
index c21988c73d..f0aaf8d8dd 100644
--- a/litellm/tests/test_gcs_bucket.py
+++ b/litellm/tests/test_gcs_bucket.py
@@ -147,6 +147,117 @@ async def test_basic_gcs_logger():
assert gcs_payload["response_cost"] > 0.0
+ assert gcs_payload["log_event_type"] == "successful_api_call"
+ gcs_payload["spend_log_metadata"] = json.loads(gcs_payload["spend_log_metadata"])
+
+ assert (
+ gcs_payload["spend_log_metadata"]["user_api_key"]
+ == "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b"
+ )
+ assert (
+ gcs_payload["spend_log_metadata"]["user_api_key_user_id"]
+ == "116544810872468347480"
+ )
+
+ # Delete Object from GCS
+ print("deleting object from GCS")
+ await gcs_logger.delete_gcs_object(object_name=object_name)
+
+
+@pytest.mark.asyncio
+async def test_basic_gcs_logger_failure():
+ load_vertex_ai_credentials()
+ gcs_logger = GCSBucketLogger()
+ print("GCSBucketLogger", gcs_logger)
+
+ gcs_log_id = f"failure-test-{uuid.uuid4().hex}"
+
+ litellm.callbacks = [gcs_logger]
+
+ try:
+ response = await litellm.acompletion(
+ model="gpt-3.5-turbo",
+ temperature=0.7,
+ messages=[{"role": "user", "content": "This is a test"}],
+ max_tokens=10,
+ user="ishaan-2",
+ mock_response=litellm.BadRequestError(
+ model="gpt-3.5-turbo",
+ message="Error: 400: Bad Request: Invalid API key, please check your API key and try again.",
+ llm_provider="openai",
+ ),
+ metadata={
+ "gcs_log_id": gcs_log_id,
+ "tags": ["model-anthropic-claude-v2.1", "app-ishaan-prod"],
+ "user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
+ "user_api_key_alias": None,
+ "user_api_end_user_max_budget": None,
+ "litellm_api_version": "0.0.0",
+ "global_max_parallel_requests": None,
+ "user_api_key_user_id": "116544810872468347480",
+ "user_api_key_org_id": None,
+ "user_api_key_team_id": None,
+ "user_api_key_team_alias": None,
+ "user_api_key_metadata": {},
+ "requester_ip_address": "127.0.0.1",
+ "spend_logs_metadata": {"hello": "world"},
+ "headers": {
+ "content-type": "application/json",
+ "user-agent": "PostmanRuntime/7.32.3",
+ "accept": "*/*",
+ "postman-token": "92300061-eeaa-423b-a420-0b44896ecdc4",
+ "host": "localhost:4000",
+ "accept-encoding": "gzip, deflate, br",
+ "connection": "keep-alive",
+ "content-length": "163",
+ },
+ "endpoint": "http://localhost:4000/chat/completions",
+ "model_group": "gpt-3.5-turbo",
+ "deployment": "azure/chatgpt-v-2",
+ "model_info": {
+ "id": "4bad40a1eb6bebd1682800f16f44b9f06c52a6703444c99c7f9f32e9de3693b4",
+ "db_model": False,
+ },
+ "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
+ "caching_groups": None,
+ "raw_request": "\n\nPOST Request Sent from LiteLLM:\ncurl -X POST \\\nhttps://openai-gpt-4-test-v-1.openai.azure.com//openai/ \\\n-H 'Authorization: *****' \\\n-d '{'model': 'chatgpt-v-2', 'messages': [{'role': 'system', 'content': 'you are a helpful assistant.\\n'}, {'role': 'user', 'content': 'bom dia'}], 'stream': False, 'max_tokens': 10, 'user': '116544810872468347480', 'extra_body': {}}'\n",
+ },
+ )
+ except:
+ pass
+
+ await asyncio.sleep(5)
+
+ # Get the current date
+ # Get the current date
+ current_date = datetime.now().strftime("%Y-%m-%d")
+
+ # Modify the object_name to include the date-based folder
+ object_name = gcs_log_id
+
+ print("object_name", object_name)
+
+ # Check if object landed on GCS
+ object_from_gcs = await gcs_logger.download_gcs_object(object_name=object_name)
+ print("object from gcs=", object_from_gcs)
+ # convert object_from_gcs from bytes to DICT
+ parsed_data = json.loads(object_from_gcs)
+ print("object_from_gcs as dict", parsed_data)
+
+ print("type of object_from_gcs", type(parsed_data))
+
+ gcs_payload = GCSBucketPayload(**parsed_data)
+
+ print("gcs_payload", gcs_payload)
+
+ assert gcs_payload["request_kwargs"]["model"] == "gpt-3.5-turbo"
+ assert gcs_payload["request_kwargs"]["messages"] == [
+ {"role": "user", "content": "This is a test"}
+ ]
+
+ assert gcs_payload["response_cost"] == 0
+ assert gcs_payload["log_event_type"] == "failed_api_call"
+
gcs_payload["spend_log_metadata"] = json.loads(gcs_payload["spend_log_metadata"])
assert (
diff --git a/litellm/tests/test_prometheus.py b/litellm/tests/test_prometheus.py
index 64e824e6db..7574beb9d9 100644
--- a/litellm/tests/test_prometheus.py
+++ b/litellm/tests/test_prometheus.py
@@ -76,6 +76,6 @@ async def test_async_prometheus_success_logging():
print("metrics from prometheus", metrics)
assert metrics["litellm_requests_metric_total"] == 1.0
assert metrics["litellm_total_tokens_total"] == 30.0
- assert metrics["llm_deployment_success_responses_total"] == 1.0
- assert metrics["llm_deployment_total_requests_total"] == 1.0
- assert metrics["llm_deployment_latency_per_output_token_bucket"] == 1.0
+ assert metrics["litellm_deployment_success_responses_total"] == 1.0
+ assert metrics["litellm_deployment_total_requests_total"] == 1.0
+ assert metrics["litellm_deployment_latency_per_output_token_bucket"] == 1.0
diff --git a/litellm/tests/test_prompt_factory.py b/litellm/tests/test_prompt_factory.py
index f7a715a220..93e92a7926 100644
--- a/litellm/tests/test_prompt_factory.py
+++ b/litellm/tests/test_prompt_factory.py
@@ -260,3 +260,56 @@ def test_anthropic_messages_tool_call():
translated_messages[-1]["content"][0]["tool_use_id"]
== "bc8cb4b6-88c4-4138-8993-3a9d9cd51656"
)
+
+
+def test_anthropic_cache_controls_pt():
+ "see anthropic docs for this: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation"
+ messages = [
+ # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What are the key terms and conditions in this agreement?",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ {
+ "role": "assistant",
+ "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
+ },
+ # The final turn is marked with cache-control, for continuing in followups.
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What are the key terms and conditions in this agreement?",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ {
+ "role": "assistant",
+ "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
+ "cache_control": {"type": "ephemeral"},
+ },
+ ]
+
+ translated_messages = anthropic_messages_pt(
+ messages, model="claude-3-5-sonnet-20240620", llm_provider="anthropic"
+ )
+
+ for i, msg in enumerate(translated_messages):
+ if i == 0:
+ assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
+ elif i == 1:
+ assert "cache_controls" not in msg["content"][0]
+ elif i == 2:
+ assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
+ elif i == 3:
+ assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
+
+ print("translated_messages: ", translated_messages)
diff --git a/litellm/tests/test_provider_specific_config.py b/litellm/tests/test_provider_specific_config.py
index c20c44fb13..a7765f658c 100644
--- a/litellm/tests/test_provider_specific_config.py
+++ b/litellm/tests/test_provider_specific_config.py
@@ -2,16 +2,19 @@
# This tests setting provider specific configs across providers
# There are 2 types of tests - changing config dynamically or by setting class variables
-import sys, os
+import os
+import sys
import traceback
+
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
+from unittest.mock import AsyncMock, MagicMock, patch
+
import litellm
-from litellm import completion
-from litellm import RateLimitError
+from litellm import RateLimitError, completion
# Huggingface - Expensive to deploy models and keep them running. Maybe we can try doing this via baseten??
# def hf_test_completion_tgi():
@@ -513,102 +516,165 @@ def sagemaker_test_completion():
# sagemaker_test_completion()
-def test_sagemaker_default_region(mocker):
+def test_sagemaker_default_region():
"""
If no regions are specified in config or in environment, the default region is us-west-2
"""
- mock_client = mocker.patch("boto3.client")
- try:
+ mock_response = MagicMock()
+
+ def return_val():
+ return {
+ "generated_text": "This is a mock response from SageMaker.",
+ "id": "cmpl-mockid",
+ "object": "text_completion",
+ "created": 1629800000,
+ "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ "choices": [
+ {
+ "text": "This is a mock response from SageMaker.",
+ "index": 0,
+ "logprobs": None,
+ "finish_reason": "length",
+ }
+ ],
+ "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
+ }
+
+ mock_response.json = return_val
+ mock_response.status_code = 200
+
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
+ return_value=mock_response,
+ ) as mock_post:
response = litellm.completion(
model="sagemaker/mock-endpoint",
- messages=[
- {
- "content": "Hello, world!",
- "role": "user"
- }
- ]
+ messages=[{"content": "Hello, world!", "role": "user"}],
)
- except Exception:
- pass # expected serialization exception because AWS client was replaced with a Mock
- assert mock_client.call_args.kwargs["region_name"] == "us-west-2"
+ mock_post.assert_called_once()
+ _, kwargs = mock_post.call_args
+ args_to_sagemaker = kwargs["json"]
+ print("Arguments passed to sagemaker=", args_to_sagemaker)
+ print("url=", kwargs["url"])
+
+ assert (
+ kwargs["url"]
+ == "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/mock-endpoint/invocations"
+ )
+
# test_sagemaker_default_region()
-def test_sagemaker_environment_region(mocker):
+def test_sagemaker_environment_region():
"""
If a region is specified in the environment, use that region instead of us-west-2
"""
expected_region = "us-east-1"
os.environ["AWS_REGION_NAME"] = expected_region
- mock_client = mocker.patch("boto3.client")
- try:
+ mock_response = MagicMock()
+
+ def return_val():
+ return {
+ "generated_text": "This is a mock response from SageMaker.",
+ "id": "cmpl-mockid",
+ "object": "text_completion",
+ "created": 1629800000,
+ "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ "choices": [
+ {
+ "text": "This is a mock response from SageMaker.",
+ "index": 0,
+ "logprobs": None,
+ "finish_reason": "length",
+ }
+ ],
+ "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
+ }
+
+ mock_response.json = return_val
+ mock_response.status_code = 200
+
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
+ return_value=mock_response,
+ ) as mock_post:
response = litellm.completion(
model="sagemaker/mock-endpoint",
- messages=[
- {
- "content": "Hello, world!",
- "role": "user"
- }
- ]
+ messages=[{"content": "Hello, world!", "role": "user"}],
)
- except Exception:
- pass # expected serialization exception because AWS client was replaced with a Mock
- del os.environ["AWS_REGION_NAME"] # cleanup
- assert mock_client.call_args.kwargs["region_name"] == expected_region
+ mock_post.assert_called_once()
+ _, kwargs = mock_post.call_args
+ args_to_sagemaker = kwargs["json"]
+ print("Arguments passed to sagemaker=", args_to_sagemaker)
+ print("url=", kwargs["url"])
+
+ assert (
+ kwargs["url"]
+ == f"https://runtime.sagemaker.{expected_region}.amazonaws.com/endpoints/mock-endpoint/invocations"
+ )
+
+ del os.environ["AWS_REGION_NAME"] # cleanup
+
# test_sagemaker_environment_region()
-def test_sagemaker_config_region(mocker):
+def test_sagemaker_config_region():
"""
If a region is specified as part of the optional parameters of the completion, including as
part of the config file, then use that region instead of us-west-2
"""
expected_region = "us-east-1"
- mock_client = mocker.patch("boto3.client")
- try:
- response = litellm.completion(
- model="sagemaker/mock-endpoint",
- messages=[
+ mock_response = MagicMock()
+
+ def return_val():
+ return {
+ "generated_text": "This is a mock response from SageMaker.",
+ "id": "cmpl-mockid",
+ "object": "text_completion",
+ "created": 1629800000,
+ "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ "choices": [
{
- "content": "Hello, world!",
- "role": "user"
+ "text": "This is a mock response from SageMaker.",
+ "index": 0,
+ "logprobs": None,
+ "finish_reason": "length",
}
],
+ "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
+ }
+
+ mock_response.json = return_val
+ mock_response.status_code = 200
+
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
+ return_value=mock_response,
+ ) as mock_post:
+
+ response = litellm.completion(
+ model="sagemaker/mock-endpoint",
+ messages=[{"content": "Hello, world!", "role": "user"}],
aws_region_name=expected_region,
)
- except Exception:
- pass # expected serialization exception because AWS client was replaced with a Mock
- assert mock_client.call_args.kwargs["region_name"] == expected_region
+
+ mock_post.assert_called_once()
+ _, kwargs = mock_post.call_args
+ args_to_sagemaker = kwargs["json"]
+ print("Arguments passed to sagemaker=", args_to_sagemaker)
+ print("url=", kwargs["url"])
+
+ assert (
+ kwargs["url"]
+ == f"https://runtime.sagemaker.{expected_region}.amazonaws.com/endpoints/mock-endpoint/invocations"
+ )
+
# test_sagemaker_config_region()
-def test_sagemaker_config_and_environment_region(mocker):
- """
- If both the environment and config file specify a region, the environment region is expected
- """
- expected_region = "us-east-1"
- unexpected_region = "us-east-2"
- os.environ["AWS_REGION_NAME"] = expected_region
- mock_client = mocker.patch("boto3.client")
- try:
- response = litellm.completion(
- model="sagemaker/mock-endpoint",
- messages=[
- {
- "content": "Hello, world!",
- "role": "user"
- }
- ],
- aws_region_name=unexpected_region,
- )
- except Exception:
- pass # expected serialization exception because AWS client was replaced with a Mock
- del os.environ["AWS_REGION_NAME"] # cleanup
- assert mock_client.call_args.kwargs["region_name"] == expected_region
-
# test_sagemaker_config_and_environment_region()
diff --git a/litellm/tests/test_proxy_exception_mapping.py b/litellm/tests/test_proxy_exception_mapping.py
index a774d1b0ef..89b4cd926e 100644
--- a/litellm/tests/test_proxy_exception_mapping.py
+++ b/litellm/tests/test_proxy_exception_mapping.py
@@ -229,8 +229,9 @@ def test_chat_completion_exception_any_model(client):
)
assert isinstance(openai_exception, openai.BadRequestError)
_error_message = openai_exception.message
- assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str(
- _error_message
+ assert (
+ "/chat/completions: Invalid model name passed in model=Lite-GPT-12"
+ in str(_error_message)
)
except Exception as e:
@@ -259,7 +260,7 @@ def test_embedding_exception_any_model(client):
print("Exception raised=", openai_exception)
assert isinstance(openai_exception, openai.BadRequestError)
_error_message = openai_exception.message
- assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str(
+ assert "/embeddings: Invalid model name passed in model=Lite-GPT-12" in str(
_error_message
)
diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py
index 757eef6d62..9a1c091267 100644
--- a/litellm/tests/test_proxy_server.py
+++ b/litellm/tests/test_proxy_server.py
@@ -966,3 +966,203 @@ async def test_user_info_team_list(prisma_client):
pass
mock_client.assert_called()
+
+
+@pytest.mark.skip(reason="Local test")
+@pytest.mark.asyncio
+async def test_add_callback_via_key(prisma_client):
+ """
+ Test if callback specified in key, is used.
+ """
+ global headers
+ import json
+
+ from fastapi import HTTPException, Request, Response
+ from starlette.datastructures import URL
+
+ from litellm.proxy.proxy_server import chat_completion
+
+ setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
+ setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
+ await litellm.proxy.proxy_server.prisma_client.connect()
+
+ litellm.set_verbose = True
+
+ try:
+ # Your test data
+ test_data = {
+ "model": "azure/chatgpt-v-2",
+ "messages": [
+ {"role": "user", "content": "write 1 sentence poem"},
+ ],
+ "max_tokens": 10,
+ "mock_response": "Hello world",
+ "api_key": "my-fake-key",
+ }
+
+ request = Request(scope={"type": "http", "method": "POST", "headers": {}})
+ request._url = URL(url="/chat/completions")
+
+ json_bytes = json.dumps(test_data).encode("utf-8")
+
+ request._body = json_bytes
+
+ with patch.object(
+ litellm.litellm_core_utils.litellm_logging,
+ "LangFuseLogger",
+ new=MagicMock(),
+ ) as mock_client:
+ resp = await chat_completion(
+ request=request,
+ fastapi_response=Response(),
+ user_api_key_dict=UserAPIKeyAuth(
+ metadata={
+ "logging": [
+ {
+ "callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
+ "callback_type": "success", # set, if required by integration - future improvement, have logging tools work for success + failure by default
+ "callback_vars": {
+ "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
+ "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
+ "langfuse_host": "https://us.cloud.langfuse.com",
+ },
+ }
+ ]
+ }
+ ),
+ )
+ print(resp)
+ mock_client.assert_called()
+ mock_client.return_value.log_event.assert_called()
+ args, kwargs = mock_client.return_value.log_event.call_args
+ kwargs = kwargs["kwargs"]
+ assert "user_api_key_metadata" in kwargs["litellm_params"]["metadata"]
+ assert (
+ "logging"
+ in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"]
+ )
+ checked_keys = False
+ for item in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"][
+ "logging"
+ ]:
+ for k, v in item["callback_vars"].items():
+ print("k={}, v={}".format(k, v))
+ if "key" in k:
+ assert "os.environ" in v
+ checked_keys = True
+
+ assert checked_keys
+ except Exception as e:
+ pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
+
+
+@pytest.mark.asyncio
+async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
+ import json
+
+ from fastapi import HTTPException, Request, Response
+ from starlette.datastructures import URL
+
+ from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
+
+ setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
+ setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
+ await litellm.proxy.proxy_server.prisma_client.connect()
+
+ proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
+
+ request = Request(scope={"type": "http", "method": "POST", "headers": {}})
+ request._url = URL(url="/chat/completions")
+
+ test_data = {
+ "model": "azure/chatgpt-v-2",
+ "messages": [
+ {"role": "user", "content": "write 1 sentence poem"},
+ ],
+ "max_tokens": 10,
+ "mock_response": "Hello world",
+ "api_key": "my-fake-key",
+ }
+
+ json_bytes = json.dumps(test_data).encode("utf-8")
+
+ request._body = json_bytes
+
+ data = {
+ "data": {
+ "model": "azure/chatgpt-v-2",
+ "messages": [{"role": "user", "content": "write 1 sentence poem"}],
+ "max_tokens": 10,
+ "mock_response": "Hello world",
+ "api_key": "my-fake-key",
+ },
+ "request": request,
+ "user_api_key_dict": UserAPIKeyAuth(
+ token=None,
+ key_name=None,
+ key_alias=None,
+ spend=0.0,
+ max_budget=None,
+ expires=None,
+ models=[],
+ aliases={},
+ config={},
+ user_id=None,
+ team_id=None,
+ max_parallel_requests=None,
+ metadata={
+ "logging": [
+ {
+ "callback_name": "langfuse",
+ "callback_type": "success",
+ "callback_vars": {
+ "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
+ "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
+ "langfuse_host": "https://us.cloud.langfuse.com",
+ },
+ }
+ ]
+ },
+ tpm_limit=None,
+ rpm_limit=None,
+ budget_duration=None,
+ budget_reset_at=None,
+ allowed_cache_controls=[],
+ permissions={},
+ model_spend={},
+ model_max_budget={},
+ soft_budget_cooldown=False,
+ litellm_budget_table=None,
+ org_id=None,
+ team_spend=None,
+ team_alias=None,
+ team_tpm_limit=None,
+ team_rpm_limit=None,
+ team_max_budget=None,
+ team_models=[],
+ team_blocked=False,
+ soft_budget=None,
+ team_model_aliases=None,
+ team_member_spend=None,
+ team_metadata=None,
+ end_user_id=None,
+ end_user_tpm_limit=None,
+ end_user_rpm_limit=None,
+ end_user_max_budget=None,
+ last_refreshed_at=None,
+ api_key=None,
+ user_role=None,
+ allowed_model_region=None,
+ parent_otel_span=None,
+ ),
+ "proxy_config": proxy_config,
+ "general_settings": {},
+ "version": "0.0.0",
+ }
+
+ new_data = await add_litellm_data_to_request(**data)
+
+ assert "success_callback" in new_data
+ assert new_data["success_callback"] == ["langfuse"]
+ assert "langfuse_public_key" in new_data
+ assert "langfuse_secret_key" in new_data
diff --git a/litellm/tests/test_sagemaker.py b/litellm/tests/test_sagemaker.py
new file mode 100644
index 0000000000..b6b4251c6b
--- /dev/null
+++ b/litellm/tests/test_sagemaker.py
@@ -0,0 +1,316 @@
+import json
+import os
+import sys
+import traceback
+
+from dotenv import load_dotenv
+
+load_dotenv()
+import io
+import os
+
+sys.path.insert(
+ 0, os.path.abspath("../..")
+) # Adds the parent directory to the system path
+
+import os
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+import litellm
+from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.llms.prompt_templates.factory import anthropic_messages_pt
+
+# litellm.num_retries =3
+litellm.cache = None
+litellm.success_callback = []
+user_message = "Write a short poem about the sky"
+messages = [{"content": user_message, "role": "user"}]
+import logging
+
+from litellm._logging import verbose_logger
+
+
+def logger_fn(user_model_dict):
+ print(f"user_model_dict: {user_model_dict}")
+
+
+@pytest.fixture(autouse=True)
+def reset_callbacks():
+ print("\npytest fixture - resetting callbacks")
+ litellm.success_callback = []
+ litellm._async_success_callback = []
+ litellm.failure_callback = []
+ litellm.callbacks = []
+
+
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("sync_mode", [True, False])
+async def test_completion_sagemaker(sync_mode):
+ try:
+ litellm.set_verbose = True
+ verbose_logger.setLevel(logging.DEBUG)
+ print("testing sagemaker")
+ if sync_mode is True:
+ response = litellm.completion(
+ model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ messages=[
+ {"role": "user", "content": "hi"},
+ ],
+ temperature=0.2,
+ max_tokens=80,
+ input_cost_per_second=0.000420,
+ )
+ else:
+ response = await litellm.acompletion(
+ model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ messages=[
+ {"role": "user", "content": "hi"},
+ ],
+ temperature=0.2,
+ max_tokens=80,
+ input_cost_per_second=0.000420,
+ )
+ # Add any assertions here to check the response
+ print(response)
+ cost = completion_cost(completion_response=response)
+ print("calculated cost", cost)
+ assert (
+ cost > 0.0 and cost < 1.0
+ ) # should never be > $1 for a single completion call
+ except Exception as e:
+ pytest.fail(f"Error occurred: {e}")
+
+
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("sync_mode", [False, True])
+async def test_completion_sagemaker_stream(sync_mode):
+ try:
+ litellm.set_verbose = False
+ print("testing sagemaker")
+ verbose_logger.setLevel(logging.DEBUG)
+ full_text = ""
+ if sync_mode is True:
+ response = litellm.completion(
+ model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ messages=[
+ {"role": "user", "content": "hi - what is ur name"},
+ ],
+ temperature=0.2,
+ stream=True,
+ max_tokens=80,
+ input_cost_per_second=0.000420,
+ )
+
+ for chunk in response:
+ print(chunk)
+ full_text += chunk.choices[0].delta.content or ""
+
+ print("SYNC RESPONSE full text", full_text)
+ else:
+ response = await litellm.acompletion(
+ model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ messages=[
+ {"role": "user", "content": "hi - what is ur name"},
+ ],
+ stream=True,
+ temperature=0.2,
+ max_tokens=80,
+ input_cost_per_second=0.000420,
+ )
+
+ print("streaming response")
+
+ async for chunk in response:
+ print(chunk)
+ full_text += chunk.choices[0].delta.content or ""
+
+ print("ASYNC RESPONSE full text", full_text)
+
+ except Exception as e:
+ pytest.fail(f"Error occurred: {e}")
+
+
+@pytest.mark.asyncio
+async def test_acompletion_sagemaker_non_stream():
+ mock_response = AsyncMock()
+
+ def return_val():
+ return {
+ "generated_text": "This is a mock response from SageMaker.",
+ "id": "cmpl-mockid",
+ "object": "text_completion",
+ "created": 1629800000,
+ "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ "choices": [
+ {
+ "text": "This is a mock response from SageMaker.",
+ "index": 0,
+ "logprobs": None,
+ "finish_reason": "length",
+ }
+ ],
+ "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
+ }
+
+ mock_response.json = return_val
+ mock_response.status_code = 200
+
+ expected_payload = {
+ "inputs": "hi",
+ "parameters": {"temperature": 0.2, "max_new_tokens": 80},
+ }
+
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
+ return_value=mock_response,
+ ) as mock_post:
+ # Act: Call the litellm.acompletion function
+ response = await litellm.acompletion(
+ model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ messages=[
+ {"role": "user", "content": "hi"},
+ ],
+ temperature=0.2,
+ max_tokens=80,
+ input_cost_per_second=0.000420,
+ )
+
+ # Print what was called on the mock
+ print("call args=", mock_post.call_args)
+
+ # Assert
+ mock_post.assert_called_once()
+ _, kwargs = mock_post.call_args
+ args_to_sagemaker = kwargs["json"]
+ print("Arguments passed to sagemaker=", args_to_sagemaker)
+ assert args_to_sagemaker == expected_payload
+ assert (
+ kwargs["url"]
+ == "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
+ )
+
+
+@pytest.mark.asyncio
+async def test_completion_sagemaker_non_stream():
+ mock_response = MagicMock()
+
+ def return_val():
+ return {
+ "generated_text": "This is a mock response from SageMaker.",
+ "id": "cmpl-mockid",
+ "object": "text_completion",
+ "created": 1629800000,
+ "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ "choices": [
+ {
+ "text": "This is a mock response from SageMaker.",
+ "index": 0,
+ "logprobs": None,
+ "finish_reason": "length",
+ }
+ ],
+ "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
+ }
+
+ mock_response.json = return_val
+ mock_response.status_code = 200
+
+ expected_payload = {
+ "inputs": "hi",
+ "parameters": {"temperature": 0.2, "max_new_tokens": 80},
+ }
+
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
+ return_value=mock_response,
+ ) as mock_post:
+ # Act: Call the litellm.acompletion function
+ response = litellm.completion(
+ model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ messages=[
+ {"role": "user", "content": "hi"},
+ ],
+ temperature=0.2,
+ max_tokens=80,
+ input_cost_per_second=0.000420,
+ )
+
+ # Print what was called on the mock
+ print("call args=", mock_post.call_args)
+
+ # Assert
+ mock_post.assert_called_once()
+ _, kwargs = mock_post.call_args
+ args_to_sagemaker = kwargs["json"]
+ print("Arguments passed to sagemaker=", args_to_sagemaker)
+ assert args_to_sagemaker == expected_payload
+ assert (
+ kwargs["url"]
+ == "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
+ )
+
+
+@pytest.mark.asyncio
+async def test_completion_sagemaker_non_stream_with_aws_params():
+ mock_response = MagicMock()
+
+ def return_val():
+ return {
+ "generated_text": "This is a mock response from SageMaker.",
+ "id": "cmpl-mockid",
+ "object": "text_completion",
+ "created": 1629800000,
+ "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ "choices": [
+ {
+ "text": "This is a mock response from SageMaker.",
+ "index": 0,
+ "logprobs": None,
+ "finish_reason": "length",
+ }
+ ],
+ "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
+ }
+
+ mock_response.json = return_val
+ mock_response.status_code = 200
+
+ expected_payload = {
+ "inputs": "hi",
+ "parameters": {"temperature": 0.2, "max_new_tokens": 80},
+ }
+
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
+ return_value=mock_response,
+ ) as mock_post:
+ # Act: Call the litellm.acompletion function
+ response = litellm.completion(
+ model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
+ messages=[
+ {"role": "user", "content": "hi"},
+ ],
+ temperature=0.2,
+ max_tokens=80,
+ input_cost_per_second=0.000420,
+ aws_access_key_id="gm",
+ aws_secret_access_key="s",
+ aws_region_name="us-west-5",
+ )
+
+ # Print what was called on the mock
+ print("call args=", mock_post.call_args)
+
+ # Assert
+ mock_post.assert_called_once()
+ _, kwargs = mock_post.call_args
+ args_to_sagemaker = kwargs["json"]
+ print("Arguments passed to sagemaker=", args_to_sagemaker)
+ assert args_to_sagemaker == expected_payload
+ assert (
+ kwargs["url"]
+ == "https://runtime.sagemaker.us-west-5.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
+ )
diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py
index 025ea81200..3f42331879 100644
--- a/litellm/tests/test_streaming.py
+++ b/litellm/tests/test_streaming.py
@@ -1683,6 +1683,7 @@ def test_completion_bedrock_mistral_stream():
pytest.fail(f"Error occurred: {e}")
+@pytest.mark.skip(reason="stopped using TokenIterator")
def test_sagemaker_weird_response():
"""
When the stream ends, flush any remaining holding chunks.
diff --git a/litellm/tests/test_user_api_key_auth.py b/litellm/tests/test_user_api_key_auth.py
index ad057ee572..e0595ac13c 100644
--- a/litellm/tests/test_user_api_key_auth.py
+++ b/litellm/tests/test_user_api_key_auth.py
@@ -44,7 +44,7 @@ def test_check_valid_ip(
request = Request(client_ip)
- assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore
+ assert _check_valid_ip(allowed_ips, request)[0] == expected_result # type: ignore
# test x-forwarder for is used when user has opted in
@@ -72,7 +72,7 @@ def test_check_valid_ip_sent_with_x_forwarded_for(
request = Request(client_ip, headers={"X-Forwarded-For": client_ip})
- assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True) == expected_result # type: ignore
+ assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True)[0] == expected_result # type: ignore
@pytest.mark.asyncio
diff --git a/litellm/types/llms/anthropic.py b/litellm/types/llms/anthropic.py
index 36bcb6cc73..f14aa20c73 100644
--- a/litellm/types/llms/anthropic.py
+++ b/litellm/types/llms/anthropic.py
@@ -15,9 +15,10 @@ class AnthropicMessagesTool(TypedDict, total=False):
input_schema: Required[dict]
-class AnthropicMessagesTextParam(TypedDict):
+class AnthropicMessagesTextParam(TypedDict, total=False):
type: Literal["text"]
text: str
+ cache_control: Optional[dict]
class AnthropicMessagesToolUseParam(TypedDict):
@@ -54,9 +55,10 @@ class AnthropicImageParamSource(TypedDict):
data: str
-class AnthropicMessagesImageParam(TypedDict):
+class AnthropicMessagesImageParam(TypedDict, total=False):
type: Literal["image"]
source: AnthropicImageParamSource
+ cache_control: Optional[dict]
class AnthropicMessagesToolResultContent(TypedDict):
@@ -92,6 +94,12 @@ class AnthropicMetadata(TypedDict, total=False):
user_id: str
+class AnthropicSystemMessageContent(TypedDict, total=False):
+ type: str
+ text: str
+ cache_control: Optional[dict]
+
+
class AnthropicMessagesRequest(TypedDict, total=False):
model: Required[str]
messages: Required[
@@ -106,7 +114,7 @@ class AnthropicMessagesRequest(TypedDict, total=False):
metadata: AnthropicMetadata
stop_sequences: List[str]
stream: bool
- system: str
+ system: Union[str, List]
temperature: float
tool_choice: AnthropicMessagesToolChoice
tools: List[AnthropicMessagesTool]
diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py
index 0d67d5d602..5d2c416f9c 100644
--- a/litellm/types/llms/openai.py
+++ b/litellm/types/llms/openai.py
@@ -361,7 +361,7 @@ class ChatCompletionToolMessage(TypedDict):
class ChatCompletionSystemMessage(TypedDict, total=False):
role: Required[Literal["system"]]
- content: Required[str]
+ content: Required[Union[str, List]]
name: str
diff --git a/litellm/types/utils.py b/litellm/types/utils.py
index 5cf6270868..519b301039 100644
--- a/litellm/types/utils.py
+++ b/litellm/types/utils.py
@@ -80,7 +80,7 @@ class ModelInfo(TypedDict, total=False):
supports_assistant_prefill: Optional[bool]
-class GenericStreamingChunk(TypedDict):
+class GenericStreamingChunk(TypedDict, total=False):
text: Required[str]
tool_use: Optional[ChatCompletionToolCallChunk]
is_finished: Required[bool]
diff --git a/litellm/utils.py b/litellm/utils.py
index 49528d0f77..0875b0e0e5 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -4479,7 +4479,22 @@ def _is_non_openai_azure_model(model: str) -> bool:
or f"mistral/{model_name}" in litellm.mistral_chat_models
):
return True
- except:
+ except Exception:
+ return False
+ return False
+
+
+def _is_azure_openai_model(model: str) -> bool:
+ try:
+ if "/" in model:
+ model = model.split("/", 1)[1]
+ if (
+ model in litellm.open_ai_chat_completion_models
+ or model in litellm.open_ai_text_completion_models
+ or model in litellm.open_ai_embedding_models
+ ):
+ return True
+ except Exception:
return False
return False
@@ -4613,6 +4628,14 @@ def get_llm_provider(
elif custom_llm_provider == "azure_ai":
api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY")
+
+ if _is_azure_openai_model(model=model):
+ verbose_logger.debug(
+ "Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
+ model
+ )
+ )
+ custom_llm_provider = "azure"
elif custom_llm_provider == "github":
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
dynamic_api_key = api_key or get_secret("GITHUB_API_KEY")
@@ -9825,11 +9848,28 @@ class CustomStreamWrapper:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker":
- print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
- response_obj = self.handle_sagemaker_stream(chunk)
+ from litellm.types.llms.bedrock import GenericStreamingChunk
+
+ if self.received_finish_reason is not None:
+ raise StopIteration
+ response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
+
+ if (
+ self.stream_options
+ and self.stream_options.get("include_usage", False) is True
+ and response_obj["usage"] is not None
+ ):
+ model_response.usage = litellm.Usage(
+ prompt_tokens=response_obj["usage"]["inputTokens"],
+ completion_tokens=response_obj["usage"]["outputTokens"],
+ total_tokens=response_obj["usage"]["totalTokens"],
+ )
+
+ if "tool_use" in response_obj and response_obj["tool_use"] is not None:
+ completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0:
if self.received_finish_reason is not None:
diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json
index 455fe1e3c5..d30270c5c8 100644
--- a/model_prices_and_context_window.json
+++ b/model_prices_and_context_window.json
@@ -57,6 +57,18 @@
"supports_parallel_function_calling": true,
"supports_vision": true
},
+ "chatgpt-4o-latest": {
+ "max_tokens": 4096,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000005,
+ "output_cost_per_token": 0.000015,
+ "litellm_provider": "openai",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_parallel_function_calling": true,
+ "supports_vision": true
+ },
"gpt-4o-2024-05-13": {
"max_tokens": 4096,
"max_input_tokens": 128000,
@@ -2062,7 +2074,8 @@
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat",
"supports_function_calling": true,
- "supports_vision": true
+ "supports_vision": true,
+ "supports_assistant_prefill": true
},
"vertex_ai/claude-3-5-sonnet@20240620": {
"max_tokens": 4096,
@@ -2073,7 +2086,8 @@
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat",
"supports_function_calling": true,
- "supports_vision": true
+ "supports_vision": true,
+ "supports_assistant_prefill": true
},
"vertex_ai/claude-3-haiku@20240307": {
"max_tokens": 4096,
@@ -2084,7 +2098,8 @@
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat",
"supports_function_calling": true,
- "supports_vision": true
+ "supports_vision": true,
+ "supports_assistant_prefill": true
},
"vertex_ai/claude-3-opus@20240229": {
"max_tokens": 4096,
@@ -2095,7 +2110,8 @@
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat",
"supports_function_calling": true,
- "supports_vision": true
+ "supports_vision": true,
+ "supports_assistant_prefill": true
},
"vertex_ai/meta/llama3-405b-instruct-maas": {
"max_tokens": 32000,
@@ -4519,6 +4535,69 @@
"litellm_provider": "perplexity",
"mode": "chat"
},
+ "perplexity/llama-3.1-70b-instruct": {
+ "max_tokens": 131072,
+ "max_input_tokens": 131072,
+ "max_output_tokens": 131072,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000001,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
+ "perplexity/llama-3.1-8b-instruct": {
+ "max_tokens": 131072,
+ "max_input_tokens": 131072,
+ "max_output_tokens": 131072,
+ "input_cost_per_token": 0.0000002,
+ "output_cost_per_token": 0.0000002,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
+ "perplexity/llama-3.1-sonar-huge-128k-online": {
+ "max_tokens": 127072,
+ "max_input_tokens": 127072,
+ "max_output_tokens": 127072,
+ "input_cost_per_token": 0.000005,
+ "output_cost_per_token": 0.000005,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
+ "perplexity/llama-3.1-sonar-large-128k-online": {
+ "max_tokens": 127072,
+ "max_input_tokens": 127072,
+ "max_output_tokens": 127072,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000001,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
+ "perplexity/llama-3.1-sonar-large-128k-chat": {
+ "max_tokens": 131072,
+ "max_input_tokens": 131072,
+ "max_output_tokens": 131072,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000001,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
+ "perplexity/llama-3.1-sonar-small-128k-chat": {
+ "max_tokens": 131072,
+ "max_input_tokens": 131072,
+ "max_output_tokens": 131072,
+ "input_cost_per_token": 0.0000002,
+ "output_cost_per_token": 0.0000002,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
+ "perplexity/llama-3.1-sonar-small-128k-online": {
+ "max_tokens": 127072,
+ "max_input_tokens": 127072,
+ "max_output_tokens": 127072,
+ "input_cost_per_token": 0.0000002,
+ "output_cost_per_token": 0.0000002,
+ "litellm_provider": "perplexity",
+ "mode": "chat"
+ },
"perplexity/pplx-7b-chat": {
"max_tokens": 8192,
"max_input_tokens": 8192,
diff --git a/pyproject.toml b/pyproject.toml
index ae9ba13da2..a7d069789d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
-version = "1.43.9"
+version = "1.43.15"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT"
@@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
-version = "1.43.9"
+version = "1.43.15"
version_files = [
"pyproject.toml:^version"
]