Merge branch 'main' into litellm_pass_through_endpoints_api

This commit is contained in:
Krish Dholakia 2024-08-15 22:39:19 -07:00 committed by GitHub
commit b3d15ace89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
55 changed files with 3658 additions and 1644 deletions

View file

@ -125,6 +125,7 @@ jobs:
pip install tiktoken pip install tiktoken
pip install aiohttp pip install aiohttp
pip install click pip install click
pip install "boto3==1.34.34"
pip install jinja2 pip install jinja2
pip install tokenizers pip install tokenizers
pip install openai pip install openai
@ -287,6 +288,7 @@ jobs:
pip install "pytest==7.3.1" pip install "pytest==7.3.1"
pip install "pytest-mock==3.12.0" pip install "pytest-mock==3.12.0"
pip install "pytest-asyncio==0.21.1" pip install "pytest-asyncio==0.21.1"
pip install "boto3==1.34.34"
pip install mypy pip install mypy
pip install pyarrow pip install pyarrow
pip install numpydoc pip install numpydoc

View file

@ -62,6 +62,11 @@ COPY --from=builder /wheels/ /wheels/
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels
# Generate prisma client # Generate prisma client
ENV PRISMA_BINARY_CACHE_DIR=/app/prisma
RUN mkdir -p /.cache
RUN chmod -R 777 /.cache
RUN pip install nodejs-bin
RUN pip install prisma
RUN prisma generate RUN prisma generate
RUN chmod +x entrypoint.sh RUN chmod +x entrypoint.sh

View file

@ -62,6 +62,11 @@ RUN pip install PyJWT --no-cache-dir
RUN chmod +x build_admin_ui.sh && ./build_admin_ui.sh RUN chmod +x build_admin_ui.sh && ./build_admin_ui.sh
# Generate prisma client # Generate prisma client
ENV PRISMA_BINARY_CACHE_DIR=/app/prisma
RUN mkdir -p /.cache
RUN chmod -R 777 /.cache
RUN pip install nodejs-bin
RUN pip install prisma
RUN prisma generate RUN prisma generate
RUN chmod +x entrypoint.sh RUN chmod +x entrypoint.sh

View file

@ -84,17 +84,20 @@ from litellm import completion
# add to env var # add to env var
os.environ["OPENAI_API_KEY"] = "" os.environ["OPENAI_API_KEY"] = ""
messages = [{"role": "user", "content": "List 5 cookie recipes"}] messages = [{"role": "user", "content": "List 5 important events in the XIX century"}]
class CalendarEvent(BaseModel): class CalendarEvent(BaseModel):
name: str name: str
date: str date: str
participants: list[str] participants: list[str]
class EventsList(BaseModel):
events: list[CalendarEvent]
resp = completion( resp = completion(
model="gpt-4o-2024-08-06", model="gpt-4o-2024-08-06",
messages=messages, messages=messages,
response_format=CalendarEvent response_format=EventsList
) )
print("Received={}".format(resp)) print("Received={}".format(resp))

View file

@ -225,22 +225,336 @@ print(response)
| claude-instant-1.2 | `completion('claude-instant-1.2', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-instant-1.2 | `completion('claude-instant-1.2', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-instant-1 | `completion('claude-instant-1', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-instant-1 | `completion('claude-instant-1', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
## Passing Extra Headers to Anthropic API ## **Prompt Caching**
Pass `extra_headers: dict` to `litellm.completion` Use Anthropic Prompt Caching
[Relevant Anthropic API Docs](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching)
### Caching - Large Context Caching
This example demonstrates basic Prompt Caching usage, caching the full text of the legal agreement as a prefix while keeping the user instruction uncached.
<Tabs>
<TabItem value="sdk" label="LiteLLM SDK">
```python ```python
from litellm import completion response = await litellm.acompletion(
messages = [{"role": "user", "content": "What is Anthropic?"}] model="anthropic/claude-3-5-sonnet-20240620",
response = completion( messages=[
model="claude-3-5-sonnet-20240620", {
messages=messages, "role": "system",
extra_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} "content": [
{
"type": "text",
"text": "You are an AI assistant tasked with analyzing legal documents.",
},
{
"type": "text",
"text": "Here is the full text of a complex legal agreement",
"cache_control": {"type": "ephemeral"},
},
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
],
extra_headers={
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
},
)
```
</TabItem>
<TabItem value="proxy" label="LiteLLM Proxy">
:::info
LiteLLM Proxy is OpenAI compatible
This is an example using the OpenAI Python SDK sending a request to LiteLLM Proxy
Assuming you have a model=`anthropic/claude-3-5-sonnet-20240620` on the [litellm proxy config.yaml](#usage-with-litellm-proxy)
:::
```python
import openai
client = openai.AsyncOpenAI(
api_key="anything", # litellm proxy api key
base_url="http://0.0.0.0:4000" # litellm proxy base url
)
response = await client.chat.completions.create(
model="anthropic/claude-3-5-sonnet-20240620",
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an AI assistant tasked with analyzing legal documents.",
},
{
"type": "text",
"text": "Here is the full text of a complex legal agreement",
"cache_control": {"type": "ephemeral"},
},
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
],
extra_headers={
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
},
)
```
</TabItem>
</Tabs>
### Caching - Tools definitions
In this example, we demonstrate caching tool definitions.
The cache_control parameter is placed on the final tool
<Tabs>
<TabItem value="sdk" label="LiteLLM SDK">
```python
import litellm
response = await litellm.acompletion(
model="anthropic/claude-3-5-sonnet-20240620",
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
"cache_control": {"type": "ephemeral"}
},
}
],
extra_headers={
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
},
) )
``` ```
## Advanced </TabItem>
<TabItem value="proxy" label="LiteLLM Proxy">
## Usage - Function Calling :::info
LiteLLM Proxy is OpenAI compatible
This is an example using the OpenAI Python SDK sending a request to LiteLLM Proxy
Assuming you have a model=`anthropic/claude-3-5-sonnet-20240620` on the [litellm proxy config.yaml](#usage-with-litellm-proxy)
:::
```python
import openai
client = openai.AsyncOpenAI(
api_key="anything", # litellm proxy api key
base_url="http://0.0.0.0:4000" # litellm proxy base url
)
response = await client.chat.completions.create(
model="anthropic/claude-3-5-sonnet-20240620",
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
"cache_control": {"type": "ephemeral"}
},
}
],
extra_headers={
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
},
)
```
</TabItem>
</Tabs>
### Caching - Continuing Multi-Turn Convo
In this example, we demonstrate how to use Prompt Caching in a multi-turn conversation.
The cache_control parameter is placed on the system message to designate it as part of the static prefix.
The conversation history (previous messages) is included in the messages array. The final turn is marked with cache-control, for continuing in followups. The second-to-last user message is marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
<Tabs>
<TabItem value="sdk" label="LiteLLM SDK">
```python
import litellm
response = await litellm.acompletion(
model="anthropic/claude-3-5-sonnet-20240620",
messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement"
* 400,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
],
extra_headers={
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
},
)
```
</TabItem>
<TabItem value="proxy" label="LiteLLM Proxy">
:::info
LiteLLM Proxy is OpenAI compatible
This is an example using the OpenAI Python SDK sending a request to LiteLLM Proxy
Assuming you have a model=`anthropic/claude-3-5-sonnet-20240620` on the [litellm proxy config.yaml](#usage-with-litellm-proxy)
:::
```python
import openai
client = openai.AsyncOpenAI(
api_key="anything", # litellm proxy api key
base_url="http://0.0.0.0:4000" # litellm proxy base url
)
response = await client.chat.completions.create(
model="anthropic/claude-3-5-sonnet-20240620",
messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement"
* 400,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
],
extra_headers={
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
},
)
```
</TabItem>
</Tabs>
## **Function/Tool Calling**
:::info :::info
@ -429,6 +743,20 @@ resp = litellm.completion(
print(f"\nResponse: {resp}") print(f"\nResponse: {resp}")
``` ```
## **Passing Extra Headers to Anthropic API**
Pass `extra_headers: dict` to `litellm.completion`
```python
from litellm import completion
messages = [{"role": "user", "content": "What is Anthropic?"}]
response = completion(
model="claude-3-5-sonnet-20240620",
messages=messages,
extra_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}
)
```
## Usage - "Assistant Pre-fill" ## Usage - "Assistant Pre-fill"
You can "put words in Claude's mouth" by including an `assistant` role message as the last item in the `messages` array. You can "put words in Claude's mouth" by including an `assistant` role message as the last item in the `messages` array.

View file

@ -393,7 +393,7 @@ response = completion(
) )
``` ```
</TabItem> </TabItem>
<TabItem value="proxy" label="LiteLLM Proxy Server"> <TabItem value="proxy" label="Proxy on request">
```python ```python
@ -420,6 +420,55 @@ extra_body={
} }
) )
print(response)
```
</TabItem>
<TabItem value="proxy-config" label="Proxy on config.yaml">
1. Update config.yaml
```yaml
model_list:
- model_name: bedrock-claude-v1
litellm_params:
model: bedrock/anthropic.claude-instant-v1
aws_access_key_id: os.environ/CUSTOM_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/CUSTOM_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/CUSTOM_AWS_REGION_NAME
guardrailConfig: {
"guardrailIdentifier": "ff6ujrregl1q", # The identifier (ID) for the guardrail.
"guardrailVersion": "DRAFT", # The version of the guardrail.
"trace": "disabled", # The trace behavior for the guardrail. Can either be "disabled" or "enabled"
}
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="bedrock-claude-v1", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
temperature=0.7
)
print(response) print(response)
``` ```
</TabItem> </TabItem>

View file

@ -705,6 +705,29 @@ docker run ghcr.io/berriai/litellm:main-latest \
Provide an ssl certificate when starting litellm proxy server Provide an ssl certificate when starting litellm proxy server
### 3. Providing LiteLLM config.yaml file as a s3 Object/url
Use this if you cannot mount a config file on your deployment service (example - AWS Fargate, Railway etc)
LiteLLM Proxy will read your config.yaml from an s3 Bucket
Set the following .env vars
```shell
LITELLM_CONFIG_BUCKET_NAME = "litellm-proxy" # your bucket name on s3
LITELLM_CONFIG_BUCKET_OBJECT_KEY = "litellm_proxy_config.yaml" # object key on s3
```
Start litellm proxy with these env vars - litellm will read your config from s3
```shell
docker run --name litellm-proxy \
-e DATABASE_URL=<database_url> \
-e LITELLM_CONFIG_BUCKET_NAME=<bucket_name> \
-e LITELLM_CONFIG_BUCKET_OBJECT_KEY="<object_key>> \
-p 4000:4000 \
ghcr.io/berriai/litellm-database:main-latest
```
## Platform-specific Guide ## Platform-specific Guide
<Tabs> <Tabs>

View file

@ -17,7 +17,7 @@ model_list:
## Get Model Information - `/model/info` ## Get Model Information - `/model/info`
Retrieve detailed information about each model listed in the `/model/info` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled the model_info you set and the litellm model cost map. Sensitive details like API keys are excluded for security purposes. Retrieve detailed information about each model listed in the `/model/info` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled from the model_info you set and the [litellm model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). Sensitive details like API keys are excluded for security purposes.
<Tabs <Tabs
defaultValue="curl" defaultValue="curl"
@ -35,22 +35,33 @@ curl -X GET "http://0.0.0.0:4000/model/info" \
## Add a New Model ## Add a New Model
Add a new model to the list in the `config.yaml` by providing the model parameters. This allows you to update the model list without restarting the proxy. Add a new model to the proxy via the `/model/new` API, to add models without restarting the proxy.
<Tabs <Tabs>
defaultValue="curl" <TabItem value="API">
values={[
{ label: 'cURL', value: 'curl', },
]}>
<TabItem value="curl">
```bash ```bash
curl -X POST "http://0.0.0.0:4000/model/new" \ curl -X POST "http://0.0.0.0:4000/model/new" \
-H "accept: application/json" \ -H "accept: application/json" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ "model_name": "azure-gpt-turbo", "litellm_params": {"model": "azure/gpt-3.5-turbo", "api_key": "os.environ/AZURE_API_KEY", "api_base": "my-azure-api-base"} }' -d '{ "model_name": "azure-gpt-turbo", "litellm_params": {"model": "azure/gpt-3.5-turbo", "api_key": "os.environ/AZURE_API_KEY", "api_base": "my-azure-api-base"} }'
``` ```
</TabItem> </TabItem>
<TabItem value="Yaml">
```yaml
model_list:
- model_name: gpt-3.5-turbo ### RECEIVED MODEL NAME ### `openai.chat.completions.create(model="gpt-3.5-turbo",...)`
litellm_params: # all params accepted by litellm.completion() - https://github.com/BerriAI/litellm/blob/9b46ec05b02d36d6e4fb5c32321e51e7f56e4a6e/litellm/types/router.py#L297
model: azure/gpt-turbo-small-eu ### MODEL NAME sent to `litellm.completion()` ###
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: "os.environ/AZURE_API_KEY_EU" # does os.getenv("AZURE_API_KEY_EU")
rpm: 6 # [OPTIONAL] Rate limit for this deployment: in requests per minute (rpm)
model_info:
my_custom_key: my_custom_value # additional model metadata
```
</TabItem>
</Tabs> </Tabs>
@ -86,3 +97,82 @@ Keep in mind that as both endpoints are in [BETA], you may need to visit the ass
- Add a New Model: [Issue #964](https://github.com/BerriAI/litellm/issues/964) - Add a New Model: [Issue #964](https://github.com/BerriAI/litellm/issues/964)
Feedback on the beta endpoints is valuable and helps improve the API for all users. Feedback on the beta endpoints is valuable and helps improve the API for all users.
## Add Additional Model Information
If you want the ability to add a display name, description, and labels for models, just use `model_info:`
```yaml
model_list:
- model_name: "gpt-4"
litellm_params:
model: "gpt-4"
api_key: "os.environ/OPENAI_API_KEY"
model_info: # 👈 KEY CHANGE
my_custom_key: "my_custom_value"
```
### Usage
1. Add additional information to model
```yaml
model_list:
- model_name: "gpt-4"
litellm_params:
model: "gpt-4"
api_key: "os.environ/OPENAI_API_KEY"
model_info: # 👈 KEY CHANGE
my_custom_key: "my_custom_value"
```
2. Call with `/model/info`
Use a key with access to the model `gpt-4`.
```bash
curl -L -X GET 'http://0.0.0.0:4000/v1/model/info' \
-H 'Authorization: Bearer LITELLM_KEY' \
```
3. **Expected Response**
Returned `model_info = Your custom model_info + (if exists) LITELLM MODEL INFO`
[**How LiteLLM Model Info is found**](https://github.com/BerriAI/litellm/blob/9b46ec05b02d36d6e4fb5c32321e51e7f56e4a6e/litellm/proxy/proxy_server.py#L7460)
[Tell us how this can be improved!](https://github.com/BerriAI/litellm/issues)
```bash
{
"data": [
{
"model_name": "gpt-4",
"litellm_params": {
"model": "gpt-4"
},
"model_info": {
"id": "e889baacd17f591cce4c63639275ba5e8dc60765d6c553e6ee5a504b19e50ddc",
"db_model": false,
"my_custom_key": "my_custom_value", # 👈 CUSTOM INFO
"key": "gpt-4", # 👈 KEY in LiteLLM MODEL INFO/COST MAP - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
"max_tokens": 4096,
"max_input_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 3e-05,
"input_cost_per_character": null,
"input_cost_per_token_above_128k_tokens": null,
"output_cost_per_token": 6e-05,
"output_cost_per_character": null,
"output_cost_per_token_above_128k_tokens": null,
"output_cost_per_character_above_128k_tokens": null,
"output_vector_size": null,
"litellm_provider": "openai",
"mode": "chat"
}
},
]
}
```

View file

@ -193,6 +193,53 @@ curl --request POST \
}' }'
``` ```
### Use Langfuse client sdk w/ LiteLLM Key
**Usage**
1. Set-up yaml to pass-through langfuse /api/public/ingestion
```yaml
general_settings:
master_key: sk-1234
pass_through_endpoints:
- path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server
target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward
auth: true # 👈 KEY CHANGE
custom_auth_parser: "langfuse" # 👈 KEY CHANGE
headers:
LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_DEV_PUBLIC_KEY" # your langfuse account public key
LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY" # your langfuse account secret key
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test with langfuse sdk
```python
from langfuse import Langfuse
langfuse = Langfuse(
host="http://localhost:4000", # your litellm proxy endpoint
public_key="sk-1234", # your litellm proxy api key
secret_key="anything", # no key required since this is a pass through
)
print("sending langfuse trace request")
trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
print("flushing langfuse request")
langfuse.flush()
print("flushed langfuse request")
```
## `pass_through_endpoints` Spec on config.yaml ## `pass_through_endpoints` Spec on config.yaml
All possible values for `pass_through_endpoints` and what they mean All possible values for `pass_through_endpoints` and what they mean

View file

@ -72,15 +72,15 @@ http://localhost:4000/metrics
| Metric Name | Description | | Metric Name | Description |
|----------------------|--------------------------------------| |----------------------|--------------------------------------|
| `deployment_state` | The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage. | | `litellm_deployment_state` | The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage. |
| `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment | | `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment |
| `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment | | `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment |
`llm_deployment_success_responses` | Total number of successful LLM API calls for deployment | `litellm_deployment_success_responses` | Total number of successful LLM API calls for deployment |
| `llm_deployment_failure_responses` | Total number of failed LLM API calls for deployment | | `litellm_deployment_failure_responses` | Total number of failed LLM API calls for deployment |
| `llm_deployment_total_requests` | Total number of LLM API calls for deployment - success + failure | | `litellm_deployment_total_requests` | Total number of LLM API calls for deployment - success + failure |
| `llm_deployment_latency_per_output_token` | Latency per output token for deployment | | `litellm_deployment_latency_per_output_token` | Latency per output token for deployment |
| `llm_deployment_successful_fallbacks` | Number of successful fallback requests from primary model -> fallback model | | `litellm_deployment_successful_fallbacks` | Number of successful fallback requests from primary model -> fallback model |
| `llm_deployment_failed_fallbacks` | Number of failed fallback requests from primary model -> fallback model | | `litellm_deployment_failed_fallbacks` | Number of failed fallback requests from primary model -> fallback model |

View file

@ -2,9 +2,9 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# 👥📊 Team Based Logging # 👥📊 Team/Key Based Logging
Allow each team to use their own Langfuse Project / custom callbacks Allow each key/team to use their own Langfuse Project / custom callbacks
**This allows you to do the following** **This allows you to do the following**
``` ```
@ -189,3 +189,39 @@ curl -X GET 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/cal
## [BETA] Key Based Logging
Use the `/key/generate` or `/key/update` endpoints to add logging callbacks to a specific key.
:::info
✨ This is an Enterprise only feature [Get Started with Enterprise here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
```bash
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"metadata": {
"logging": {
"callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
"callback_type": "success" # set, if required by integration - future improvement, have logging tools work for success + failure by default
"callback_vars": {
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY", # [RECOMMENDED] reference key in proxy environment
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", # [RECOMMENDED] reference key in proxy environment
"langfuse_host": "https://cloud.langfuse.com"
}
}
}
}'
```
---
Help us improve this feature, by filing a [ticket here](https://github.com/BerriAI/litellm/issues)

View file

@ -151,7 +151,7 @@ const sidebars = {
}, },
{ {
type: "category", type: "category",
label: "Chat Completions (litellm.completion)", label: "Chat Completions (litellm.completion + PROXY)",
link: { link: {
type: "generated-index", type: "generated-index",
title: "Chat Completions", title: "Chat Completions",

View file

@ -1,5 +1,6 @@
import json import json
import os import os
import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, TypedDict, Union from typing import Any, Dict, List, Optional, TypedDict, Union
@ -29,6 +30,8 @@ class GCSBucketPayload(TypedDict):
end_time: str end_time: str
response_cost: Optional[float] response_cost: Optional[float]
spend_log_metadata: str spend_log_metadata: str
exception: Optional[str]
log_event_type: Optional[str]
class GCSBucketLogger(CustomLogger): class GCSBucketLogger(CustomLogger):
@ -79,6 +82,7 @@ class GCSBucketLogger(CustomLogger):
logging_payload: GCSBucketPayload = await self.get_gcs_payload( logging_payload: GCSBucketPayload = await self.get_gcs_payload(
kwargs, response_obj, start_time_str, end_time_str kwargs, response_obj, start_time_str, end_time_str
) )
logging_payload["log_event_type"] = "successful_api_call"
json_logged_payload = json.dumps(logging_payload) json_logged_payload = json.dumps(logging_payload)
@ -103,7 +107,56 @@ class GCSBucketLogger(CustomLogger):
verbose_logger.error("GCS Bucket logging error: %s", str(e)) verbose_logger.error("GCS Bucket logging error: %s", str(e))
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
try:
verbose_logger.debug(
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S")
end_time_str = end_time.strftime("%Y-%m-%d %H:%M:%S")
headers = await self.construct_request_headers()
logging_payload: GCSBucketPayload = await self.get_gcs_payload(
kwargs, response_obj, start_time_str, end_time_str
)
logging_payload["log_event_type"] = "failed_api_call"
_litellm_params = kwargs.get("litellm_params") or {}
metadata = _litellm_params.get("metadata") or {}
json_logged_payload = json.dumps(logging_payload)
# Get the current date
current_date = datetime.now().strftime("%Y-%m-%d")
# Modify the object_name to include the date-based folder
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
if "gcs_log_id" in metadata:
object_name = metadata["gcs_log_id"]
response = await self.async_httpx_client.post(
headers=headers,
url=f"https://storage.googleapis.com/upload/storage/v1/b/{self.BUCKET_NAME}/o?uploadType=media&name={object_name}",
data=json_logged_payload,
)
if response.status_code != 200:
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
verbose_logger.debug("GCS Bucket response %s", response)
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
verbose_logger.debug("GCS Bucket response.text %s", response.text)
except Exception as e:
verbose_logger.error("GCS Bucket logging error: %s", str(e))
async def construct_request_headers(self) -> Dict[str, str]: async def construct_request_headers(self) -> Dict[str, str]:
from litellm import vertex_chat_completion from litellm import vertex_chat_completion
@ -139,9 +192,18 @@ class GCSBucketLogger(CustomLogger):
optional_params=kwargs.get("optional_params", None), optional_params=kwargs.get("optional_params", None),
) )
response_dict = {} response_dict = {}
response_dict = convert_litellm_response_object_to_dict( if response_obj:
response_obj=response_obj response_dict = convert_litellm_response_object_to_dict(
) response_obj=response_obj
)
exception_str = None
# Handle logging exception attributes
if "exception" in kwargs:
exception_str = kwargs.get("exception", "")
if not isinstance(exception_str, str):
exception_str = str(exception_str)
_spend_log_payload: SpendLogsPayload = get_logging_payload( _spend_log_payload: SpendLogsPayload = get_logging_payload(
kwargs=kwargs, kwargs=kwargs,
@ -156,8 +218,10 @@ class GCSBucketLogger(CustomLogger):
response_obj=response_dict, response_obj=response_dict,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
spend_log_metadata=_spend_log_payload["metadata"], spend_log_metadata=_spend_log_payload.get("metadata", ""),
response_cost=kwargs.get("response_cost", None), response_cost=kwargs.get("response_cost", None),
exception=exception_str,
log_event_type=None,
) )
return gcs_payload return gcs_payload

View file

@ -605,6 +605,12 @@ class LangFuseLogger:
if "cache_key" in litellm.langfuse_default_tags: if "cache_key" in litellm.langfuse_default_tags:
_hidden_params = metadata.get("hidden_params", {}) or {} _hidden_params = metadata.get("hidden_params", {}) or {}
_cache_key = _hidden_params.get("cache_key", None) _cache_key = _hidden_params.get("cache_key", None)
if _cache_key is None:
# fallback to using "preset_cache_key"
_preset_cache_key = kwargs.get("litellm_params", {}).get(
"preset_cache_key", None
)
_cache_key = _preset_cache_key
tags.append(f"cache_key:{_cache_key}") tags.append(f"cache_key:{_cache_key}")
return tags return tags
@ -676,7 +682,6 @@ def log_provider_specific_information_as_span(
Returns: Returns:
None None
""" """
from litellm.proxy.proxy_server import premium_user
_hidden_params = clean_metadata.get("hidden_params", None) _hidden_params = clean_metadata.get("hidden_params", None)
if _hidden_params is None: if _hidden_params is None:

View file

@ -141,42 +141,42 @@ class PrometheusLogger(CustomLogger):
] ]
# Metric for deployment state # Metric for deployment state
self.deployment_state = Gauge( self.litellm_deployment_state = Gauge(
"deployment_state", "litellm_deployment_state",
"LLM Deployment Analytics - The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage", "LLM Deployment Analytics - The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage",
labelnames=_logged_llm_labels, labelnames=_logged_llm_labels,
) )
self.llm_deployment_success_responses = Counter( self.litellm_deployment_success_responses = Counter(
name="llm_deployment_success_responses", name="litellm_deployment_success_responses",
documentation="LLM Deployment Analytics - Total number of successful LLM API calls via litellm", documentation="LLM Deployment Analytics - Total number of successful LLM API calls via litellm",
labelnames=_logged_llm_labels, labelnames=_logged_llm_labels,
) )
self.llm_deployment_failure_responses = Counter( self.litellm_deployment_failure_responses = Counter(
name="llm_deployment_failure_responses", name="litellm_deployment_failure_responses",
documentation="LLM Deployment Analytics - Total number of failed LLM API calls via litellm", documentation="LLM Deployment Analytics - Total number of failed LLM API calls via litellm",
labelnames=_logged_llm_labels, labelnames=_logged_llm_labels,
) )
self.llm_deployment_total_requests = Counter( self.litellm_deployment_total_requests = Counter(
name="llm_deployment_total_requests", name="litellm_deployment_total_requests",
documentation="LLM Deployment Analytics - Total number of LLM API calls via litellm - success + failure", documentation="LLM Deployment Analytics - Total number of LLM API calls via litellm - success + failure",
labelnames=_logged_llm_labels, labelnames=_logged_llm_labels,
) )
# Deployment Latency tracking # Deployment Latency tracking
self.llm_deployment_latency_per_output_token = Histogram( self.litellm_deployment_latency_per_output_token = Histogram(
name="llm_deployment_latency_per_output_token", name="litellm_deployment_latency_per_output_token",
documentation="LLM Deployment Analytics - Latency per output token", documentation="LLM Deployment Analytics - Latency per output token",
labelnames=_logged_llm_labels, labelnames=_logged_llm_labels,
) )
self.llm_deployment_successful_fallbacks = Counter( self.litellm_deployment_successful_fallbacks = Counter(
"llm_deployment_successful_fallbacks", "litellm_deployment_successful_fallbacks",
"LLM Deployment Analytics - Number of successful fallback requests from primary model -> fallback model", "LLM Deployment Analytics - Number of successful fallback requests from primary model -> fallback model",
["primary_model", "fallback_model"], ["primary_model", "fallback_model"],
) )
self.llm_deployment_failed_fallbacks = Counter( self.litellm_deployment_failed_fallbacks = Counter(
"llm_deployment_failed_fallbacks", "litellm_deployment_failed_fallbacks",
"LLM Deployment Analytics - Number of failed fallback requests from primary model -> fallback model", "LLM Deployment Analytics - Number of failed fallback requests from primary model -> fallback model",
["primary_model", "fallback_model"], ["primary_model", "fallback_model"],
) )
@ -358,14 +358,14 @@ class PrometheusLogger(CustomLogger):
api_provider=llm_provider, api_provider=llm_provider,
) )
self.llm_deployment_failure_responses.labels( self.litellm_deployment_failure_responses.labels(
litellm_model_name=litellm_model_name, litellm_model_name=litellm_model_name,
model_id=model_id, model_id=model_id,
api_base=api_base, api_base=api_base,
api_provider=llm_provider, api_provider=llm_provider,
).inc() ).inc()
self.llm_deployment_total_requests.labels( self.litellm_deployment_total_requests.labels(
litellm_model_name=litellm_model_name, litellm_model_name=litellm_model_name,
model_id=model_id, model_id=model_id,
api_base=api_base, api_base=api_base,
@ -438,14 +438,14 @@ class PrometheusLogger(CustomLogger):
api_provider=llm_provider, api_provider=llm_provider,
) )
self.llm_deployment_success_responses.labels( self.litellm_deployment_success_responses.labels(
litellm_model_name=litellm_model_name, litellm_model_name=litellm_model_name,
model_id=model_id, model_id=model_id,
api_base=api_base, api_base=api_base,
api_provider=llm_provider, api_provider=llm_provider,
).inc() ).inc()
self.llm_deployment_total_requests.labels( self.litellm_deployment_total_requests.labels(
litellm_model_name=litellm_model_name, litellm_model_name=litellm_model_name,
model_id=model_id, model_id=model_id,
api_base=api_base, api_base=api_base,
@ -475,7 +475,7 @@ class PrometheusLogger(CustomLogger):
latency_per_token = None latency_per_token = None
if output_tokens is not None and output_tokens > 0: if output_tokens is not None and output_tokens > 0:
latency_per_token = _latency_seconds / output_tokens latency_per_token = _latency_seconds / output_tokens
self.llm_deployment_latency_per_output_token.labels( self.litellm_deployment_latency_per_output_token.labels(
litellm_model_name=litellm_model_name, litellm_model_name=litellm_model_name,
model_id=model_id, model_id=model_id,
api_base=api_base, api_base=api_base,
@ -497,7 +497,7 @@ class PrometheusLogger(CustomLogger):
kwargs, kwargs,
) )
_new_model = kwargs.get("model") _new_model = kwargs.get("model")
self.llm_deployment_successful_fallbacks.labels( self.litellm_deployment_successful_fallbacks.labels(
primary_model=original_model_group, fallback_model=_new_model primary_model=original_model_group, fallback_model=_new_model
).inc() ).inc()
@ -508,11 +508,11 @@ class PrometheusLogger(CustomLogger):
kwargs, kwargs,
) )
_new_model = kwargs.get("model") _new_model = kwargs.get("model")
self.llm_deployment_failed_fallbacks.labels( self.litellm_deployment_failed_fallbacks.labels(
primary_model=original_model_group, fallback_model=_new_model primary_model=original_model_group, fallback_model=_new_model
).inc() ).inc()
def set_deployment_state( def set_litellm_deployment_state(
self, self,
state: int, state: int,
litellm_model_name: str, litellm_model_name: str,
@ -520,7 +520,7 @@ class PrometheusLogger(CustomLogger):
api_base: str, api_base: str,
api_provider: str, api_provider: str,
): ):
self.deployment_state.labels( self.litellm_deployment_state.labels(
litellm_model_name, model_id, api_base, api_provider litellm_model_name, model_id, api_base, api_provider
).set(state) ).set(state)
@ -531,7 +531,7 @@ class PrometheusLogger(CustomLogger):
api_base: str, api_base: str,
api_provider: str, api_provider: str,
): ):
self.set_deployment_state( self.set_litellm_deployment_state(
0, litellm_model_name, model_id, api_base, api_provider 0, litellm_model_name, model_id, api_base, api_provider
) )
@ -542,7 +542,7 @@ class PrometheusLogger(CustomLogger):
api_base: str, api_base: str,
api_provider: str, api_provider: str,
): ):
self.set_deployment_state( self.set_litellm_deployment_state(
1, litellm_model_name, model_id, api_base, api_provider 1, litellm_model_name, model_id, api_base, api_provider
) )
@ -553,7 +553,7 @@ class PrometheusLogger(CustomLogger):
api_base: str, api_base: str,
api_provider: str, api_provider: str,
): ):
self.set_deployment_state( self.set_litellm_deployment_state(
2, litellm_model_name, model_id, api_base, api_provider 2, litellm_model_name, model_id, api_base, api_provider
) )

View file

@ -41,8 +41,8 @@ async def get_fallback_metric_from_prometheus():
""" """
response_message = "" response_message = ""
relevant_metrics = [ relevant_metrics = [
"llm_deployment_successful_fallbacks_total", "litellm_deployment_successful_fallbacks_total",
"llm_deployment_failed_fallbacks_total", "litellm_deployment_failed_fallbacks_total",
] ]
for metric in relevant_metrics: for metric in relevant_metrics:
response_json = await get_metric_from_prometheus( response_json = await get_metric_from_prometheus(

View file

@ -35,6 +35,7 @@ from litellm.types.llms.anthropic import (
AnthropicResponseContentBlockText, AnthropicResponseContentBlockText,
AnthropicResponseContentBlockToolUse, AnthropicResponseContentBlockToolUse,
AnthropicResponseUsageBlock, AnthropicResponseUsageBlock,
AnthropicSystemMessageContent,
ContentBlockDelta, ContentBlockDelta,
ContentBlockStart, ContentBlockStart,
ContentBlockStop, ContentBlockStop,
@ -759,6 +760,7 @@ class AnthropicChatCompletion(BaseLLM):
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"] prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"] completion_tokens = completion_response["usage"]["output_tokens"]
_usage = completion_response["usage"]
total_tokens = prompt_tokens + completion_tokens total_tokens = prompt_tokens + completion_tokens
model_response.created = int(time.time()) model_response.created = int(time.time())
@ -768,6 +770,11 @@ class AnthropicChatCompletion(BaseLLM):
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=total_tokens, total_tokens=total_tokens,
) )
if "cache_creation_input_tokens" in _usage:
usage["cache_creation_input_tokens"] = _usage["cache_creation_input_tokens"]
if "cache_read_input_tokens" in _usage:
usage["cache_read_input_tokens"] = _usage["cache_read_input_tokens"]
setattr(model_response, "usage", usage) # type: ignore setattr(model_response, "usage", usage) # type: ignore
return model_response return model_response
@ -901,6 +908,7 @@ class AnthropicChatCompletion(BaseLLM):
# Separate system prompt from rest of message # Separate system prompt from rest of message
system_prompt_indices = [] system_prompt_indices = []
system_prompt = "" system_prompt = ""
anthropic_system_message_list = None
for idx, message in enumerate(messages): for idx, message in enumerate(messages):
if message["role"] == "system": if message["role"] == "system":
valid_content: bool = False valid_content: bool = False
@ -908,8 +916,23 @@ class AnthropicChatCompletion(BaseLLM):
system_prompt += message["content"] system_prompt += message["content"]
valid_content = True valid_content = True
elif isinstance(message["content"], list): elif isinstance(message["content"], list):
for content in message["content"]: for _content in message["content"]:
system_prompt += content.get("text", "") anthropic_system_message_content = (
AnthropicSystemMessageContent(
type=_content.get("type"),
text=_content.get("text"),
)
)
if "cache_control" in _content:
anthropic_system_message_content["cache_control"] = (
_content["cache_control"]
)
if anthropic_system_message_list is None:
anthropic_system_message_list = []
anthropic_system_message_list.append(
anthropic_system_message_content
)
valid_content = True valid_content = True
if valid_content: if valid_content:
@ -919,6 +942,10 @@ class AnthropicChatCompletion(BaseLLM):
messages.pop(idx) messages.pop(idx)
if len(system_prompt) > 0: if len(system_prompt) > 0:
optional_params["system"] = system_prompt optional_params["system"] = system_prompt
# Handling anthropic API Prompt Caching
if anthropic_system_message_list is not None:
optional_params["system"] = anthropic_system_message_list
# Format rest of message according to anthropic guidelines # Format rest of message according to anthropic guidelines
try: try:
messages = prompt_factory( messages = prompt_factory(
@ -954,6 +981,8 @@ class AnthropicChatCompletion(BaseLLM):
else: # assume openai tool call else: # assume openai tool call
new_tool = tool["function"] new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key new_tool["input_schema"] = new_tool.pop("parameters") # rename key
if "cache_control" in tool:
new_tool["cache_control"] = tool["cache_control"]
anthropic_tools.append(new_tool) anthropic_tools.append(new_tool)
optional_params["tools"] = anthropic_tools optional_params["tools"] = anthropic_tools

View file

@ -0,0 +1,218 @@
import json
from typing import List, Optional
import httpx
from litellm._logging import verbose_logger
from litellm.caching import DualCache, InMemoryCache
from litellm.utils import get_secret
from .base import BaseLLM
class AwsAuthError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class BaseAWSLLM(BaseLLM):
def __init__(self) -> None:
self.iam_cache = DualCache()
super().__init__()
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
aws_sts_endpoint: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
) = params_to_check
verbose_logger.debug(
"in get credentials\n"
"aws_access_key_id=%s\n"
"aws_secret_access_key=%s\n"
"aws_session_token=%s\n"
"aws_region_name=%s\n"
"aws_session_name=%s\n"
"aws_profile_name=%s\n"
"aws_role_name=%s\n"
"aws_web_identity_token=%s\n"
"aws_sts_endpoint=%s",
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
)
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
verbose_logger.debug(
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
)
if aws_sts_endpoint is None:
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
else:
sts_endpoint = aws_sts_endpoint
iam_creds_cache_key = json.dumps(
{
"aws_web_identity_token": aws_web_identity_token,
"aws_role_name": aws_role_name,
"aws_session_name": aws_session_name,
"aws_region_name": aws_region_name,
"aws_sts_endpoint": sts_endpoint,
}
)
iam_creds_dict = self.iam_cache.get_cache(iam_creds_cache_key)
if iam_creds_dict is None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise AwsAuthError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=sts_endpoint,
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"][
"SecretAccessKey"
],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
self.iam_cache.set_cache(
key=iam_creds_cache_key,
value=json.dumps(iam_creds_dict),
ttl=3600 - 60,
)
session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
from botocore.credentials import Credentials
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
return credentials
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
elif (
aws_access_key_id is not None
and aws_secret_access_key is not None
and aws_session_token is not None
): ### CHECK FOR AWS SESSION TOKEN ###
from botocore.credentials import Credentials
credentials = Credentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
)
return credentials
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()

View file

@ -57,6 +57,7 @@ from litellm.utils import (
) )
from .base import BaseLLM from .base import BaseLLM
from .base_aws_llm import BaseAWSLLM
from .bedrock import BedrockError, ModelResponseIterator, convert_messages_to_prompt from .bedrock import BedrockError, ModelResponseIterator, convert_messages_to_prompt
from .prompt_templates.factory import ( from .prompt_templates.factory import (
_bedrock_converse_messages_pt, _bedrock_converse_messages_pt,
@ -87,7 +88,6 @@ BEDROCK_CONVERSE_MODELS = [
] ]
iam_cache = DualCache()
_response_stream_shape_cache = None _response_stream_shape_cache = None
bedrock_tool_name_mappings: InMemoryCache = InMemoryCache( bedrock_tool_name_mappings: InMemoryCache = InMemoryCache(
max_size_in_memory=50, default_ttl=600 max_size_in_memory=50, default_ttl=600
@ -312,7 +312,7 @@ def make_sync_call(
return completion_stream return completion_stream
class BedrockLLM(BaseLLM): class BedrockLLM(BaseAWSLLM):
""" """
Example call Example call
@ -380,183 +380,6 @@ class BedrockLLM(BaseLLM):
prompt += f"{message['content']}" prompt += f"{message['content']}"
return prompt, chat_history # type: ignore return prompt, chat_history # type: ignore
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
aws_sts_endpoint: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
print_verbose(
f"Boto3 get_credentials called variables passed to function {locals()}"
)
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
) = params_to_check
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
print_verbose(
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
)
if aws_sts_endpoint is None:
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
else:
sts_endpoint = aws_sts_endpoint
iam_creds_cache_key = json.dumps(
{
"aws_web_identity_token": aws_web_identity_token,
"aws_role_name": aws_role_name,
"aws_session_name": aws_session_name,
"aws_region_name": aws_region_name,
"aws_sts_endpoint": sts_endpoint,
}
)
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
if iam_creds_dict is None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=sts_endpoint,
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"][
"SecretAccessKey"
],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
iam_cache.set_cache(
key=iam_creds_cache_key,
value=json.dumps(iam_creds_dict),
ttl=3600 - 60,
)
session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds
elif aws_role_name is not None and aws_session_name is not None:
print_verbose(
f"Using STS Client AWS aws_role_name: {aws_role_name} aws_session_name: {aws_session_name}"
)
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
from botocore.credentials import Credentials
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
return credentials
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
print_verbose(f"Using AWS profile: {aws_profile_name}")
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
elif (
aws_access_key_id is not None
and aws_secret_access_key is not None
and aws_session_token is not None
): ### CHECK FOR AWS SESSION TOKEN ###
print_verbose(f"Using AWS Session Token: {aws_session_token}")
from botocore.credentials import Credentials
credentials = Credentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
)
return credentials
else:
print_verbose("Using Default AWS Session")
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()
def process_response( def process_response(
self, self,
model: str, model: str,
@ -1055,8 +878,8 @@ class BedrockLLM(BaseLLM):
}, },
) )
raise BedrockError( raise BedrockError(
status_code=400, status_code=404,
message="Bedrock HTTPX: Unsupported provider={}, model={}".format( message="Bedrock HTTPX: Unknown provider={}, model={}".format(
provider, model provider, model
), ),
) )
@ -1414,7 +1237,7 @@ class AmazonConverseConfig:
return optional_params return optional_params
class BedrockConverseLLM(BaseLLM): class BedrockConverseLLM(BaseAWSLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -1554,173 +1377,6 @@ class BedrockConverseLLM(BaseLLM):
""" """
return urllib.parse.quote(model_id, safe="") return urllib.parse.quote(model_id, safe="")
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
aws_sts_endpoint: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
) = params_to_check
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
print_verbose(
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
)
if aws_sts_endpoint is None:
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
else:
sts_endpoint = aws_sts_endpoint
iam_creds_cache_key = json.dumps(
{
"aws_web_identity_token": aws_web_identity_token,
"aws_role_name": aws_role_name,
"aws_session_name": aws_session_name,
"aws_region_name": aws_region_name,
"aws_sts_endpoint": sts_endpoint,
}
)
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
if iam_creds_dict is None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=sts_endpoint,
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"][
"SecretAccessKey"
],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
iam_cache.set_cache(
key=iam_creds_cache_key,
value=json.dumps(iam_creds_dict),
ttl=3600 - 60,
)
session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
from botocore.credentials import Credentials
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
return credentials
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
elif (
aws_access_key_id is not None
and aws_secret_access_key is not None
and aws_session_token is not None
): ### CHECK FOR AWS SESSION TOKEN ###
from botocore.credentials import Credentials
credentials = Credentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
)
return credentials
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()
async def async_streaming( async def async_streaming(
self, self,
model: str, model: str,

View file

@ -601,12 +601,13 @@ def ollama_embeddings(
): ):
return asyncio.run( return asyncio.run(
ollama_aembeddings( ollama_aembeddings(
api_base, api_base=api_base,
model, model=model,
prompts, prompts=prompts,
optional_params, model_response=model_response,
logging_obj, optional_params=optional_params,
model_response, logging_obj=logging_obj,
encoding, encoding=encoding,
) )
) )

View file

@ -356,6 +356,7 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
"json": data, "json": data,
"method": "POST", "method": "POST",
"timeout": litellm.request_timeout, "timeout": litellm.request_timeout,
"follow_redirects": True
} }
if api_key is not None: if api_key is not None:
_request["headers"] = {"Authorization": "Bearer {}".format(api_key)} _request["headers"] = {"Authorization": "Bearer {}".format(api_key)}

View file

@ -1224,6 +1224,19 @@ def convert_to_anthropic_tool_invoke(
return anthropic_tool_invoke return anthropic_tool_invoke
def add_cache_control_to_content(
anthropic_content_element: Union[
dict, AnthropicMessagesImageParam, AnthropicMessagesTextParam
],
orignal_content_element: dict,
):
if "cache_control" in orignal_content_element:
anthropic_content_element["cache_control"] = orignal_content_element[
"cache_control"
]
return anthropic_content_element
def anthropic_messages_pt( def anthropic_messages_pt(
messages: list, messages: list,
model: str, model: str,
@ -1264,18 +1277,31 @@ def anthropic_messages_pt(
image_chunk = convert_to_anthropic_image_obj( image_chunk = convert_to_anthropic_image_obj(
m["image_url"]["url"] m["image_url"]["url"]
) )
user_content.append(
AnthropicMessagesImageParam( _anthropic_content_element = AnthropicMessagesImageParam(
type="image", type="image",
source=AnthropicImageParamSource( source=AnthropicImageParamSource(
type="base64", type="base64",
media_type=image_chunk["media_type"], media_type=image_chunk["media_type"],
data=image_chunk["data"], data=image_chunk["data"],
), ),
)
) )
anthropic_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_content_element,
orignal_content_element=m,
)
user_content.append(anthropic_content_element)
elif m.get("type", "") == "text": elif m.get("type", "") == "text":
user_content.append({"type": "text", "text": m["text"]}) _anthropic_text_content_element = {
"type": "text",
"text": m["text"],
}
anthropic_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
orignal_content_element=m,
)
user_content.append(anthropic_content_element)
elif ( elif (
messages[msg_i]["role"] == "tool" messages[msg_i]["role"] == "tool"
or messages[msg_i]["role"] == "function" or messages[msg_i]["role"] == "function"
@ -1306,6 +1332,10 @@ def anthropic_messages_pt(
anthropic_message = AnthropicMessagesTextParam( anthropic_message = AnthropicMessagesTextParam(
type="text", text=m.get("text") type="text", text=m.get("text")
) )
anthropic_message = add_cache_control_to_content(
anthropic_content_element=anthropic_message,
orignal_content_element=m,
)
assistant_content.append(anthropic_message) assistant_content.append(anthropic_message)
elif ( elif (
"content" in messages[msg_i] "content" in messages[msg_i]
@ -1313,9 +1343,17 @@ def anthropic_messages_pt(
and len(messages[msg_i]["content"]) and len(messages[msg_i]["content"])
> 0 # don't pass empty text blocks. anthropic api raises errors. > 0 # don't pass empty text blocks. anthropic api raises errors.
): ):
assistant_content.append(
{"type": "text", "text": messages[msg_i]["content"]} _anthropic_text_content_element = {
"type": "text",
"text": messages[msg_i]["content"],
}
anthropic_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
orignal_content_element=messages[msg_i],
) )
assistant_content.append(anthropic_content_element)
if messages[msg_i].get( if messages[msg_i].get(
"tool_calls", [] "tool_calls", []
@ -1701,12 +1739,14 @@ def cohere_messages_pt_v2(
assistant_tool_calls: List[ToolCallObject] = [] assistant_tool_calls: List[ToolCallObject] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ## ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_text = ( if isinstance(messages[msg_i]["content"], list):
messages[msg_i].get("content") or "" for m in messages[msg_i]["content"]:
) # either string or none if m.get("type", "") == "text":
if assistant_text: assistant_content += m["text"]
assistant_content += assistant_text elif messages[msg_i].get("content") is not None and isinstance(
messages[msg_i]["content"], str
):
assistant_content += messages[msg_i]["content"]
if messages[msg_i].get( if messages[msg_i].get(
"tool_calls", [] "tool_calls", []
): # support assistant tool invoke conversion ): # support assistant tool invoke conversion

File diff suppressed because it is too large Load diff

View file

@ -240,10 +240,10 @@ class TritonChatCompletion(BaseLLM):
handler = HTTPHandler() handler = HTTPHandler()
if stream: if stream:
return self._handle_stream( return self._handle_stream(
handler, api_base, data_for_triton, model, logging_obj handler, api_base, json_data_for_triton, model, logging_obj
) )
else: else:
response = handler.post(url=api_base, data=data_for_triton, headers=headers) response = handler.post(url=api_base, data=json_data_for_triton, headers=headers)
return self._handle_response( return self._handle_response(
response, model_response, logging_obj, type_of_model=type_of_model response, model_response, logging_obj, type_of_model=type_of_model
) )

View file

@ -95,7 +95,6 @@ from .llms import (
palm, palm,
petals, petals,
replicate, replicate,
sagemaker,
together_ai, together_ai,
triton, triton,
vertex_ai, vertex_ai,
@ -120,6 +119,7 @@ from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
stringify_json_tool_call_content, stringify_json_tool_call_content,
) )
from .llms.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_partner import VertexAIPartnerModels from .llms.vertex_ai_partner import VertexAIPartnerModels
@ -166,6 +166,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_partner_models_chat_completion = VertexAIPartnerModels()
watsonxai = IBMWatsonXAI() watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -2216,7 +2217,7 @@ def completion(
response = model_response response = model_response
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
# boto3 reads keys from .env # boto3 reads keys from .env
model_response = sagemaker.completion( model_response = sagemaker_llm.completion(
model=model, model=model,
messages=messages, messages=messages,
model_response=model_response, model_response=model_response,
@ -2230,26 +2231,13 @@ def completion(
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
) )
if ( if optional_params.get("stream", False):
"stream" in optional_params and optional_params["stream"] == True
): ## [BETA]
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
from .llms.sagemaker import TokenIterator
tokenIterator = TokenIterator(model_response, acompletion=acompletion)
response = CustomStreamWrapper(
completion_stream=tokenIterator,
model=model,
custom_llm_provider="sagemaker",
logging_obj=logging,
)
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,
api_key=None, api_key=None,
original_response=response, original_response=model_response,
) )
return response
## RESPONSE OBJECT ## RESPONSE OBJECT
response = model_response response = model_response
@ -3529,7 +3517,7 @@ def embedding(
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
) )
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
response = sagemaker.embedding( response = sagemaker_llm.embedding(
model=model, model=model,
input=input, input=input,
encoding=encoding, encoding=encoding,
@ -4898,7 +4886,6 @@ async def ahealth_check(
verbose_logger.error( verbose_logger.error(
"litellm.ahealth_check(): Exception occured - {}".format(str(e)) "litellm.ahealth_check(): Exception occured - {}".format(str(e))
) )
verbose_logger.debug(traceback.format_exc())
stack_trace = traceback.format_exc() stack_trace = traceback.format_exc()
if isinstance(stack_trace, str): if isinstance(stack_trace, str):
stack_trace = stack_trace[:1000] stack_trace = stack_trace[:1000]
@ -4907,7 +4894,12 @@ async def ahealth_check(
"error": "Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models" "error": "Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models"
} }
error_to_return = str(e) + " stack trace: " + stack_trace error_to_return = (
str(e)
+ "\nHave you set 'mode' - https://docs.litellm.ai/docs/proxy/health#embedding-models"
+ "\nstack trace: "
+ stack_trace
)
return {"error": error_to_return} return {"error": error_to_return}

View file

@ -57,6 +57,18 @@
"supports_parallel_function_calling": true, "supports_parallel_function_calling": true,
"supports_vision": true "supports_vision": true
}, },
"chatgpt-4o-latest": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"gpt-4o-2024-05-13": { "gpt-4o-2024-05-13": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 128000, "max_input_tokens": 128000,
@ -2062,7 +2074,8 @@
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true "supports_vision": true,
"supports_assistant_prefill": true
}, },
"vertex_ai/claude-3-5-sonnet@20240620": { "vertex_ai/claude-3-5-sonnet@20240620": {
"max_tokens": 4096, "max_tokens": 4096,
@ -2073,7 +2086,8 @@
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true "supports_vision": true,
"supports_assistant_prefill": true
}, },
"vertex_ai/claude-3-haiku@20240307": { "vertex_ai/claude-3-haiku@20240307": {
"max_tokens": 4096, "max_tokens": 4096,
@ -2084,7 +2098,8 @@
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true "supports_vision": true,
"supports_assistant_prefill": true
}, },
"vertex_ai/claude-3-opus@20240229": { "vertex_ai/claude-3-opus@20240229": {
"max_tokens": 4096, "max_tokens": 4096,
@ -2095,7 +2110,8 @@
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true "supports_vision": true,
"supports_assistant_prefill": true
}, },
"vertex_ai/meta/llama3-405b-instruct-maas": { "vertex_ai/meta/llama3-405b-instruct-maas": {
"max_tokens": 32000, "max_tokens": 32000,
@ -4519,6 +4535,69 @@
"litellm_provider": "perplexity", "litellm_provider": "perplexity",
"mode": "chat" "mode": "chat"
}, },
"perplexity/llama-3.1-70b-instruct": {
"max_tokens": 131072,
"max_input_tokens": 131072,
"max_output_tokens": 131072,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-3.1-8b-instruct": {
"max_tokens": 131072,
"max_input_tokens": 131072,
"max_output_tokens": 131072,
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-3.1-sonar-huge-128k-online": {
"max_tokens": 127072,
"max_input_tokens": 127072,
"max_output_tokens": 127072,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000005,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-3.1-sonar-large-128k-online": {
"max_tokens": 127072,
"max_input_tokens": 127072,
"max_output_tokens": 127072,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-3.1-sonar-large-128k-chat": {
"max_tokens": 131072,
"max_input_tokens": 131072,
"max_output_tokens": 131072,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-3.1-sonar-small-128k-chat": {
"max_tokens": 131072,
"max_input_tokens": 131072,
"max_output_tokens": 131072,
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-3.1-sonar-small-128k-online": {
"max_tokens": 127072,
"max_input_tokens": 127072,
"max_output_tokens": 127072,
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/pplx-7b-chat": { "perplexity/pplx-7b-chat": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 8192, "max_input_tokens": 8192,

View file

@ -1,13 +1,6 @@
model_list: model_list:
- model_name: "*" - model_name: "gpt-4"
litellm_params: litellm_params:
model: "*" model: "gpt-4"
model_info:
# general_settings: my_custom_key: "my_custom_value"
# master_key: sk-1234
# pass_through_endpoints:
# - path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server
# target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward
# headers:
# LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_PUBLIC_KEY" # your langfuse account public key
# LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_SECRET_KEY" # your langfuse account secret key

View file

@ -12,7 +12,7 @@ import json
import secrets import secrets
import traceback import traceback
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Optional from typing import Optional, Tuple
from uuid import uuid4 from uuid import uuid4
import fastapi import fastapi
@ -125,7 +125,7 @@ async def user_api_key_auth(
# Check 2. FILTER IP ADDRESS # Check 2. FILTER IP ADDRESS
await check_if_request_size_is_safe(request=request) await check_if_request_size_is_safe(request=request)
is_valid_ip = _check_valid_ip( is_valid_ip, passed_in_ip = _check_valid_ip(
allowed_ips=general_settings.get("allowed_ips", None), allowed_ips=general_settings.get("allowed_ips", None),
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
request=request, request=request,
@ -134,7 +134,7 @@ async def user_api_key_auth(
if not is_valid_ip: if not is_valid_ip:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Access forbidden: IP address not allowed.", detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
) )
pass_through_endpoints: Optional[List[dict]] = general_settings.get( pass_through_endpoints: Optional[List[dict]] = general_settings.get(
@ -1251,12 +1251,12 @@ def _check_valid_ip(
allowed_ips: Optional[List[str]], allowed_ips: Optional[List[str]],
request: Request, request: Request,
use_x_forwarded_for: Optional[bool] = False, use_x_forwarded_for: Optional[bool] = False,
) -> bool: ) -> Tuple[bool, Optional[str]]:
""" """
Returns if ip is allowed or not Returns if ip is allowed or not
""" """
if allowed_ips is None: # if not set, assume true if allowed_ips is None: # if not set, assume true
return True return True, None
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for # if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
client_ip = None client_ip = None
@ -1267,9 +1267,9 @@ def _check_valid_ip(
# Check if IP address is allowed # Check if IP address is allowed
if client_ip not in allowed_ips: if client_ip not in allowed_ips:
return False return False, client_ip
return True return True, client_ip
def get_api_key_from_custom_header( def get_api_key_from_custom_header(

View file

@ -0,0 +1,56 @@
import yaml
from litellm._logging import verbose_proxy_logger
def get_file_contents_from_s3(bucket_name, object_key):
try:
# v0 rely on boto3 for authentication - allowing boto3 to handle IAM credentials etc
import tempfile
import boto3
from botocore.config import Config
from botocore.credentials import Credentials
from litellm.main import bedrock_converse_chat_completion
credentials: Credentials = bedrock_converse_chat_completion.get_credentials()
s3_client = boto3.client(
"s3",
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token, # Optional, if using temporary credentials
)
verbose_proxy_logger.debug(
f"Retrieving {object_key} from S3 bucket: {bucket_name}"
)
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
verbose_proxy_logger.debug(f"Response: {response}")
# Read the file contents
file_contents = response["Body"].read().decode("utf-8")
verbose_proxy_logger.debug(f"File contents retrieved from S3")
# Create a temporary file with YAML extension
with tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") as temp_file:
temp_file.write(file_contents.encode("utf-8"))
temp_file_path = temp_file.name
verbose_proxy_logger.debug(f"File stored temporarily at: {temp_file_path}")
# Load the YAML file content
with open(temp_file_path, "r") as yaml_file:
config = yaml.safe_load(yaml_file)
return config
except ImportError:
# this is most likely if a user is not using the litellm docker container
verbose_proxy_logger.error(f"ImportError: {str(e)}")
pass
except Exception as e:
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
return None
# # Example usage
# bucket_name = 'litellm-proxy'
# object_key = 'litellm_proxy_config.yaml'

View file

@ -5,7 +5,12 @@ from fastapi import Request
import litellm import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, TeamCallbackMetadata, UserAPIKeyAuth from litellm.proxy._types import (
AddTeamCallback,
CommonProxyErrors,
TeamCallbackMetadata,
UserAPIKeyAuth,
)
from litellm.types.utils import SupportedCacheControls from litellm.types.utils import SupportedCacheControls
if TYPE_CHECKING: if TYPE_CHECKING:
@ -59,6 +64,42 @@ def safe_add_api_version_from_query_params(data: dict, request: Request):
verbose_logger.error("error checking api version in query params: %s", str(e)) verbose_logger.error("error checking api version in query params: %s", str(e))
def convert_key_logging_metadata_to_callback(
data: AddTeamCallback, team_callback_settings_obj: Optional[TeamCallbackMetadata]
) -> TeamCallbackMetadata:
if team_callback_settings_obj is None:
team_callback_settings_obj = TeamCallbackMetadata()
if data.callback_type == "success":
if team_callback_settings_obj.success_callback is None:
team_callback_settings_obj.success_callback = []
if data.callback_name not in team_callback_settings_obj.success_callback:
team_callback_settings_obj.success_callback.append(data.callback_name)
elif data.callback_type == "failure":
if team_callback_settings_obj.failure_callback is None:
team_callback_settings_obj.failure_callback = []
if data.callback_name not in team_callback_settings_obj.failure_callback:
team_callback_settings_obj.failure_callback.append(data.callback_name)
elif data.callback_type == "success_and_failure":
if team_callback_settings_obj.success_callback is None:
team_callback_settings_obj.success_callback = []
if team_callback_settings_obj.failure_callback is None:
team_callback_settings_obj.failure_callback = []
if data.callback_name not in team_callback_settings_obj.success_callback:
team_callback_settings_obj.success_callback.append(data.callback_name)
if data.callback_name in team_callback_settings_obj.failure_callback:
team_callback_settings_obj.failure_callback.append(data.callback_name)
for var, value in data.callback_vars.items():
if team_callback_settings_obj.callback_vars is None:
team_callback_settings_obj.callback_vars = {}
team_callback_settings_obj.callback_vars[var] = litellm.get_secret(value)
return team_callback_settings_obj
async def add_litellm_data_to_request( async def add_litellm_data_to_request(
data: dict, data: dict,
request: Request, request: Request,
@ -85,14 +126,19 @@ async def add_litellm_data_to_request(
safe_add_api_version_from_query_params(data, request) safe_add_api_version_from_query_params(data, request)
_headers = dict(request.headers)
# Include original request and headers in the data # Include original request and headers in the data
data["proxy_server_request"] = { data["proxy_server_request"] = {
"url": str(request.url), "url": str(request.url),
"method": request.method, "method": request.method,
"headers": dict(request.headers), "headers": _headers,
"body": copy.copy(data), # use copy instead of deepcopy "body": copy.copy(data), # use copy instead of deepcopy
} }
## Forward any LLM API Provider specific headers in extra_headers
add_provider_specific_headers_to_request(data=data, headers=_headers)
## Cache Controls ## Cache Controls
headers = request.headers headers = request.headers
verbose_proxy_logger.debug("Request Headers: %s", headers) verbose_proxy_logger.debug("Request Headers: %s", headers)
@ -224,6 +270,7 @@ async def add_litellm_data_to_request(
} # add the team-specific configs to the completion call } # add the team-specific configs to the completion call
# Team Callbacks controls # Team Callbacks controls
callback_settings_obj: Optional[TeamCallbackMetadata] = None
if user_api_key_dict.team_metadata is not None: if user_api_key_dict.team_metadata is not None:
team_metadata = user_api_key_dict.team_metadata team_metadata = user_api_key_dict.team_metadata
if "callback_settings" in team_metadata: if "callback_settings" in team_metadata:
@ -241,17 +288,54 @@ async def add_litellm_data_to_request(
} }
} }
""" """
data["success_callback"] = callback_settings_obj.success_callback elif (
data["failure_callback"] = callback_settings_obj.failure_callback user_api_key_dict.metadata is not None
and "logging" in user_api_key_dict.metadata
):
for item in user_api_key_dict.metadata["logging"]:
if callback_settings_obj.callback_vars is not None: callback_settings_obj = convert_key_logging_metadata_to_callback(
# unpack callback_vars in data data=AddTeamCallback(**item),
for k, v in callback_settings_obj.callback_vars.items(): team_callback_settings_obj=callback_settings_obj,
data[k] = v )
if callback_settings_obj is not None:
data["success_callback"] = callback_settings_obj.success_callback
data["failure_callback"] = callback_settings_obj.failure_callback
if callback_settings_obj.callback_vars is not None:
# unpack callback_vars in data
for k, v in callback_settings_obj.callback_vars.items():
data[k] = v
return data return data
def add_provider_specific_headers_to_request(
data: dict,
headers: dict,
):
ANTHROPIC_API_HEADERS = [
"anthropic-version",
"anthropic-beta",
]
extra_headers = data.get("extra_headers", {}) or {}
# boolean to indicate if a header was added
added_header = False
for header in ANTHROPIC_API_HEADERS:
if header in headers:
header_value = headers[header]
extra_headers.update({header: header_value})
added_header = True
if added_header is True:
data["extra_headers"] = extra_headers
return
def _add_otel_traceparent_to_data(data: dict, request: Request): def _add_otel_traceparent_to_data(data: dict, request: Request):
from litellm.proxy.proxy_server import open_telemetry_logger from litellm.proxy.proxy_server import open_telemetry_logger

View file

@ -19,6 +19,9 @@ model_list:
litellm_params: litellm_params:
model: mistral/mistral-small-latest model: mistral/mistral-small-latest
api_key: "os.environ/MISTRAL_API_KEY" api_key: "os.environ/MISTRAL_API_KEY"
- model_name: bedrock-anthropic
litellm_params:
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
- model_name: gemini-1.5-pro-001 - model_name: gemini-1.5-pro-001
litellm_params: litellm_params:
model: vertex_ai_beta/gemini-1.5-pro-001 model: vertex_ai_beta/gemini-1.5-pro-001
@ -39,7 +42,7 @@ general_settings:
litellm_settings: litellm_settings:
fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}] fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}]
success_callback: ["langfuse", "prometheus"] callbacks: ["gcs_bucket"]
langfuse_default_tags: ["cache_hit", "cache_key", "proxy_base_url", "user_api_key_alias", "user_api_key_user_id", "user_api_key_user_email", "user_api_key_team_alias", "semantic-similarity", "proxy_base_url"] success_callback: ["langfuse"]
failure_callback: ["prometheus"] langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"]
cache: True cache: True

View file

@ -159,6 +159,7 @@ from litellm.proxy.common_utils.http_parsing_utils import (
check_file_size_under_limit, check_file_size_under_limit,
) )
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
from litellm.proxy.common_utils.openai_endpoint_utils import ( from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment, remove_sensitive_info_from_deployment,
) )
@ -197,6 +198,8 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
router as pass_through_router, router as pass_through_router,
) )
from litellm.proxy.route_llm_request import route_request
from litellm.proxy.secret_managers.aws_secret_manager import ( from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_kms, load_aws_kms,
load_aws_secret_manager, load_aws_secret_manager,
@ -1444,7 +1447,18 @@ class ProxyConfig:
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details
# Load existing config # Load existing config
config = await self.get_config(config_file_path=config_file_path) if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME")
object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY")
verbose_proxy_logger.debug(
"bucket_name: %s, object_key: %s", bucket_name, object_key
)
config = get_file_contents_from_s3(
bucket_name=bucket_name, object_key=object_key
)
else:
# default to file
config = await self.get_config(config_file_path=config_file_path)
## PRINT YAML FOR CONFIRMING IT WORKS ## PRINT YAML FOR CONFIRMING IT WORKS
printed_yaml = copy.deepcopy(config) printed_yaml = copy.deepcopy(config)
printed_yaml.pop("environment_variables", None) printed_yaml.pop("environment_variables", None)
@ -2652,6 +2666,15 @@ async def startup_event():
) )
else: else:
await initialize(**worker_config) await initialize(**worker_config)
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
else: else:
# if not, assume it's a json string # if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG")) worker_config = json.loads(os.getenv("WORKER_CONFIG"))
@ -3036,68 +3059,13 @@ async def chat_completion(
### ROUTE THE REQUEST ### ### ROUTE THE REQUEST ###
# Do not change this - it should be a constant time fetch - ALWAYS # Do not change this - it should be a constant time fetch - ALWAYS
router_model_names = llm_router.model_names if llm_router is not None else [] llm_call = await route_request(
# skip router if user passed their key data=data,
if "api_key" in data: route_type="acompletion",
tasks.append(litellm.acompletion(**data)) llm_router=llm_router,
elif "," in data["model"] and llm_router is not None: user_model=user_model,
if ( )
data.get("fastest_response", None) is not None tasks.append(llm_call)
and data["fastest_response"] == True
):
tasks.append(llm_router.abatch_completion_fastest_response(**data))
else:
_models_csv_string = data.pop("model")
_models = [model.strip() for model in _models_csv_string.split(",")]
tasks.append(llm_router.abatch_completion(models=_models, **data))
elif "user_config" in data:
# initialize a new router instance. make request using this Router
router_config = data.pop("user_config")
user_router = litellm.Router(**router_config)
tasks.append(user_router.acompletion(**data))
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
tasks.append(llm_router.acompletion(**data))
elif (
llm_router is not None and data["model"] in llm_router.get_model_ids()
): # model in router model list
tasks.append(llm_router.acompletion(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
tasks.append(llm_router.acompletion(**data))
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.acompletion(**data, specific_deployment=True))
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.router_general_settings.pass_through_all_models is True
):
tasks.append(litellm.acompletion(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.acompletion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
tasks.append(litellm.acompletion(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "chat_completion: Invalid model name passed in model="
+ data.get("model", "")
},
)
# wait for call to end # wait for call to end
llm_responses = asyncio.gather( llm_responses = asyncio.gather(
@ -3320,58 +3288,15 @@ async def completion(
) )
### ROUTE THE REQUESTs ### ### ROUTE THE REQUESTs ###
router_model_names = llm_router.model_names if llm_router is not None else [] llm_call = await route_request(
# skip router if user passed their key data=data,
if "api_key" in data: route_type="atext_completion",
llm_response = asyncio.create_task(litellm.atext_completion(**data)) llm_router=llm_router,
elif ( user_model=user_model,
llm_router is not None and data["model"] in router_model_names )
): # model in router model list
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(
llm_router.atext_completion(**data, specific_deployment=True)
)
elif (
llm_router is not None and data["model"] in llm_router.get_model_ids()
): # model in router model list
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.router_general_settings.pass_through_all_models is True
):
llm_response = asyncio.create_task(litellm.atext_completion(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
llm_response = asyncio.create_task(litellm.atext_completion(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "completion: Invalid model name passed in model="
+ data.get("model", "")
},
)
# Await the llm_response task # Await the llm_response task
response = await llm_response response = await llm_call
hidden_params = getattr(response, "_hidden_params", {}) or {} hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or "" model_id = hidden_params.get("model_id", None) or ""
@ -3585,59 +3510,13 @@ async def embeddings(
) )
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key llm_call = await route_request(
if "api_key" in data: data=data,
tasks.append(litellm.aembedding(**data)) route_type="aembedding",
elif "user_config" in data: llm_router=llm_router,
# initialize a new router instance. make request using this Router user_model=user_model,
router_config = data.pop("user_config") )
user_router = litellm.Router(**router_config) tasks.append(llm_call)
tasks.append(user_router.aembedding(**data))
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
tasks.append(llm_router.aembedding(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
tasks.append(
llm_router.aembedding(**data)
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.aembedding(**data, specific_deployment=True))
elif (
llm_router is not None and data["model"] in llm_router.get_model_ids()
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.aembedding(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.router_general_settings.pass_through_all_models is True
):
tasks.append(litellm.aembedding(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.aembedding(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
tasks.append(litellm.aembedding(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "embeddings: Invalid model name passed in model="
+ data.get("model", "")
},
)
# wait for call to end # wait for call to end
llm_responses = asyncio.gather( llm_responses = asyncio.gather(
@ -3768,46 +3647,13 @@ async def image_generation(
) )
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key llm_call = await route_request(
if "api_key" in data: data=data,
response = await litellm.aimage_generation(**data) route_type="aimage_generation",
elif ( llm_router=llm_router,
llm_router is not None and data["model"] in router_model_names user_model=user_model,
): # model in router model list )
response = await llm_router.aimage_generation(**data) response = await llm_call
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aimage_generation(
**data, specific_deployment=True
)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.aimage_generation(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aimage_generation(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.aimage_generation(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "image_generation: Invalid model name passed in model="
+ data.get("model", "")
},
)
### ALERTING ### ### ALERTING ###
asyncio.create_task( asyncio.create_task(
@ -3915,44 +3761,13 @@ async def audio_speech(
) )
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key llm_call = await route_request(
if "api_key" in data: data=data,
response = await litellm.aspeech(**data) route_type="aspeech",
elif ( llm_router=llm_router,
llm_router is not None and data["model"] in router_model_names user_model=user_model,
): # model in router model list )
response = await llm_router.aspeech(**data) response = await llm_call
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aspeech(**data, specific_deployment=True)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.aspeech(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aspeech(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.aspeech(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "audio_speech: Invalid model name passed in model="
+ data.get("model", "")
},
)
### ALERTING ### ### ALERTING ###
asyncio.create_task( asyncio.create_task(
@ -4085,47 +3900,13 @@ async def audio_transcriptions(
) )
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key llm_call = await route_request(
if "api_key" in data: data=data,
response = await litellm.atranscription(**data) route_type="atranscription",
elif ( llm_router=llm_router,
llm_router is not None and data["model"] in router_model_names user_model=user_model,
): # model in router model list )
response = await llm_router.atranscription(**data) response = await llm_call
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atranscription(
**data, specific_deployment=True
)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.atranscription(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atranscription(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.atranscription(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "audio_transcriptions: Invalid model name passed in model="
+ data.get("model", "")
},
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally: finally:
@ -5341,40 +5122,13 @@ async def moderations(
start_time = time.time() start_time = time.time()
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key llm_call = await route_request(
if "api_key" in data: data=data,
response = await litellm.amoderation(**data) route_type="amoderation",
elif ( llm_router=llm_router,
llm_router is not None and data.get("model") in router_model_names user_model=user_model,
): # model in router model list )
response = await llm_router.amoderation(**data) response = await llm_call
elif (
llm_router is not None and data.get("model") in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.amoderation(**data, specific_deployment=True)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data.get("model") in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.amoderation(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data.get("model") not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.amoderation(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.amoderation(**data)
else:
# /moderations does not need a "model" passed
# see https://platform.openai.com/docs/api-reference/moderations
response = await litellm.amoderation(**data)
### ALERTING ### ### ALERTING ###
asyncio.create_task( asyncio.create_task(

View file

@ -0,0 +1,117 @@
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from fastapi import (
Depends,
FastAPI,
File,
Form,
Header,
HTTPException,
Path,
Request,
Response,
UploadFile,
status,
)
import litellm
from litellm._logging import verbose_logger
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
ROUTE_ENDPOINT_MAPPING = {
"acompletion": "/chat/completions",
"atext_completion": "/completions",
"aembedding": "/embeddings",
"aimage_generation": "/image/generations",
"aspeech": "/audio/speech",
"atranscription": "/audio/transcriptions",
"amoderation": "/moderations",
}
async def route_request(
data: dict,
llm_router: Optional[LitellmRouter],
user_model: Optional[str],
route_type: Literal[
"acompletion",
"atext_completion",
"aembedding",
"aimage_generation",
"aspeech",
"atranscription",
"amoderation",
],
):
"""
Common helper to route the request
"""
router_model_names = llm_router.model_names if llm_router is not None else []
if "api_key" in data:
return getattr(litellm, f"{route_type}")(**data)
elif "user_config" in data:
router_config = data.pop("user_config")
user_router = litellm.Router(**router_config)
return getattr(user_router, f"{route_type}")(**data)
elif (
route_type == "acompletion"
and data.get("model", "") is not None
and "," in data.get("model", "")
and llm_router is not None
):
if data.get("fastest_response", False):
return llm_router.abatch_completion_fastest_response(**data)
else:
models = [model.strip() for model in data.pop("model").split(",")]
return llm_router.abatch_completion(models=models, **data)
elif llm_router is not None:
if (
data["model"] in router_model_names
or data["model"] in llm_router.get_model_ids()
):
return getattr(llm_router, f"{route_type}")(**data)
elif (
llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
):
return getattr(llm_router, f"{route_type}")(**data)
elif data["model"] in llm_router.deployment_names:
return getattr(llm_router, f"{route_type}")(
**data, specific_deployment=True
)
elif data["model"] not in router_model_names:
if llm_router.router_general_settings.pass_through_all_models:
return getattr(litellm, f"{route_type}")(**data)
elif (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
):
return getattr(llm_router, f"{route_type}")(**data)
elif user_model is not None:
return getattr(litellm, f"{route_type}")(**data)
# if no route found then it's a bad request
route_name = ROUTE_ENDPOINT_MAPPING.get(route_type, route_type)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": f"{route_name}: Invalid model name passed in model="
+ data.get("model", "")
},
)

View file

@ -21,6 +21,8 @@ def get_logging_payload(
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
if response_obj is None:
response_obj = {}
# standardize this function to be used across, s3, dynamoDB, langfuse logging # standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
metadata = ( metadata = (

View file

@ -0,0 +1,37 @@
import openai
client = openai.OpenAI(
api_key="sk-1234", # litellm proxy api key
base_url="http://0.0.0.0:4000", # litellm proxy base url
)
response = client.chat.completions.create(
model="anthropic/claude-3-5-sonnet-20240620",
messages=[
{ # type: ignore
"role": "system",
"content": [
{
"type": "text",
"text": "You are an AI assistant tasked with analyzing legal documents.",
},
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 100,
"cache_control": {"type": "ephemeral"},
},
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
],
extra_headers={
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
},
)
print(response)

View file

@ -190,7 +190,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
if api_version is None: if api_version is None:
api_version = litellm.AZURE_DEFAULT_API_VERSION api_version = os.getenv("AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION)
if "gateway.ai.cloudflare.com" in api_base: if "gateway.ai.cloudflare.com" in api_base:
if not api_base.endswith("/"): if not api_base.endswith("/"):

View file

@ -0,0 +1,321 @@
import json
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import litellm
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
# litellm.num_retries =3
litellm.cache = None
litellm.success_callback = []
user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
def logger_fn(user_model_dict):
print(f"user_model_dict: {user_model_dict}")
@pytest.fixture(autouse=True)
def reset_callbacks():
print("\npytest fixture - resetting callbacks")
litellm.success_callback = []
litellm._async_success_callback = []
litellm.failure_callback = []
litellm.callbacks = []
@pytest.mark.asyncio
async def test_litellm_anthropic_prompt_caching_tools():
# Arrange: Set up the MagicMock for the httpx.AsyncClient
mock_response = AsyncMock()
def return_val():
return {
"id": "msg_01XFDUDYJgAACzvnptvVoYEL",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": "Hello!"}],
"model": "claude-3-5-sonnet-20240620",
"stop_reason": "end_turn",
"stop_sequence": None,
"usage": {"input_tokens": 12, "output_tokens": 6},
}
mock_response.json = return_val
litellm.set_verbose = True
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.acompletion function
response = await litellm.acompletion(
api_key="mock_api_key",
model="anthropic/claude-3-5-sonnet-20240620",
messages=[
{"role": "user", "content": "What's the weather like in Boston today?"}
],
tools=[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
"cache_control": {"type": "ephemeral"},
},
}
],
extra_headers={
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
},
)
# Print what was called on the mock
print("call args=", mock_post.call_args)
expected_url = "https://api.anthropic.com/v1/messages"
expected_headers = {
"accept": "application/json",
"content-type": "application/json",
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
"x-api-key": "mock_api_key",
}
expected_json = {
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's the weather like in Boston today?",
}
],
}
],
"tools": [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"cache_control": {"type": "ephemeral"},
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
}
],
"max_tokens": 4096,
"model": "claude-3-5-sonnet-20240620",
}
mock_post.assert_called_once_with(
expected_url, json=expected_json, headers=expected_headers, timeout=600.0
)
@pytest.mark.asyncio()
async def test_anthropic_api_prompt_caching_basic():
litellm.set_verbose = True
response = await litellm.acompletion(
model="anthropic/claude-3-5-sonnet-20240620",
messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement"
* 400,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
],
temperature=0.2,
max_tokens=10,
extra_headers={
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
},
)
print("response=", response)
assert "cache_read_input_tokens" in response.usage
assert "cache_creation_input_tokens" in response.usage
# Assert either a cache entry was created or cache was read - changes depending on the anthropic api ttl
assert (response.usage.cache_read_input_tokens > 0) or (
response.usage.cache_creation_input_tokens > 0
)
@pytest.mark.asyncio
async def test_litellm_anthropic_prompt_caching_system():
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#prompt-caching-examples
# LArge Context Caching Example
mock_response = AsyncMock()
def return_val():
return {
"id": "msg_01XFDUDYJgAACzvnptvVoYEL",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": "Hello!"}],
"model": "claude-3-5-sonnet-20240620",
"stop_reason": "end_turn",
"stop_sequence": None,
"usage": {"input_tokens": 12, "output_tokens": 6},
}
mock_response.json = return_val
litellm.set_verbose = True
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.acompletion function
response = await litellm.acompletion(
api_key="mock_api_key",
model="anthropic/claude-3-5-sonnet-20240620",
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an AI assistant tasked with analyzing legal documents.",
},
{
"type": "text",
"text": "Here is the full text of a complex legal agreement",
"cache_control": {"type": "ephemeral"},
},
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
],
extra_headers={
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
},
)
# Print what was called on the mock
print("call args=", mock_post.call_args)
expected_url = "https://api.anthropic.com/v1/messages"
expected_headers = {
"accept": "application/json",
"content-type": "application/json",
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
"x-api-key": "mock_api_key",
}
expected_json = {
"system": [
{
"type": "text",
"text": "You are an AI assistant tasked with analyzing legal documents.",
},
{
"type": "text",
"text": "Here is the full text of a complex legal agreement",
"cache_control": {"type": "ephemeral"},
},
],
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "what are the key terms and conditions in this agreement?",
}
],
}
],
"max_tokens": 4096,
"model": "claude-3-5-sonnet-20240620",
}
mock_post.assert_called_once_with(
expected_url, json=expected_json, headers=expected_headers, timeout=600.0
)

View file

@ -1159,8 +1159,8 @@ def test_bedrock_tools_pt_invalid_names():
assert result[1]["toolSpec"]["name"] == "another_invalid_name" assert result[1]["toolSpec"]["name"] == "another_invalid_name"
def test_bad_request_error(): def test_not_found_error():
with pytest.raises(litellm.BadRequestError): with pytest.raises(litellm.NotFoundError):
completion( completion(
model="bedrock/bad_model", model="bedrock/bad_model",
messages=[ messages=[

View file

@ -14,7 +14,7 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import os import os
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import anthropic_messages_pt from litellm.llms.prompt_templates.factory import anthropic_messages_pt
# litellm.num_retries = 3 # litellm.num_retries =3
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
user_message = "Write a short poem about the sky" user_message = "Write a short poem about the sky"
@ -190,6 +190,31 @@ def test_completion_azure_command_r():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize(
"api_base",
[
"https://litellm8397336933.openai.azure.com",
"https://litellm8397336933.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2023-03-15-preview",
],
)
def test_completion_azure_ai_gpt_4o(api_base):
try:
litellm.set_verbose = True
response = completion(
model="azure_ai/gpt-4o",
api_base=api_base,
api_key=os.getenv("AZURE_AI_OPENAI_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}],
)
print(response)
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_databricks(sync_mode): async def test_completion_databricks(sync_mode):
@ -3312,108 +3337,6 @@ def test_customprompt_together_ai():
# test_customprompt_together_ai() # test_customprompt_together_ai()
@pytest.mark.skip(reason="AWS Suspended Account")
def test_completion_sagemaker():
try:
litellm.set_verbose = True
print("testing sagemaker")
response = completion(
model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-ins-20240329-150233",
model_id="huggingface-llm-mistral-7b-instruct-20240329-150233",
messages=messages,
temperature=0.2,
max_tokens=80,
aws_region_name=os.getenv("AWS_REGION_NAME_2"),
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID_2"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY_2"),
input_cost_per_second=0.000420,
)
# Add any assertions here to check the response
print(response)
cost = completion_cost(completion_response=response)
print("calculated cost", cost)
assert (
cost > 0.0 and cost < 1.0
) # should never be > $1 for a single completion call
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_sagemaker()
@pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.asyncio
async def test_acompletion_sagemaker():
try:
litellm.set_verbose = True
print("testing sagemaker")
response = await litellm.acompletion(
model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-ins-20240329-150233",
model_id="huggingface-llm-mistral-7b-instruct-20240329-150233",
messages=messages,
temperature=0.2,
max_tokens=80,
aws_region_name=os.getenv("AWS_REGION_NAME_2"),
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID_2"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY_2"),
input_cost_per_second=0.000420,
)
# Add any assertions here to check the response
print(response)
cost = completion_cost(completion_response=response)
print("calculated cost", cost)
assert (
cost > 0.0 and cost < 1.0
) # should never be > $1 for a single completion call
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="AWS Suspended Account")
def test_completion_chat_sagemaker():
try:
messages = [{"role": "user", "content": "Hey, how's it going?"}]
litellm.set_verbose = True
response = completion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=messages,
max_tokens=100,
temperature=0.7,
stream=True,
)
# Add any assertions here to check the response
complete_response = ""
for chunk in response:
complete_response += chunk.choices[0].delta.content or ""
print(f"complete_response: {complete_response}")
assert len(complete_response) > 0
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_chat_sagemaker()
@pytest.mark.skip(reason="AWS Suspended Account")
def test_completion_chat_sagemaker_mistral():
try:
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = completion(
model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-instruct",
messages=messages,
max_tokens=100,
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"An error occurred: {str(e)}")
# test_completion_chat_sagemaker_mistral()
def response_format_tests(response: litellm.ModelResponse): def response_format_tests(response: litellm.ModelResponse):
assert isinstance(response.id, str) assert isinstance(response.id, str)
assert response.id != "" assert response.id != ""
@ -3449,7 +3372,6 @@ def response_format_tests(response: litellm.ModelResponse):
assert isinstance(response.usage.total_tokens, int) # type: ignore assert isinstance(response.usage.total_tokens, int) # type: ignore
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
@ -3463,6 +3385,7 @@ def response_format_tests(response: litellm.ModelResponse):
"cohere.command-text-v14", "cohere.command-text-v14",
], ],
) )
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_bedrock_httpx_models(sync_mode, model): async def test_completion_bedrock_httpx_models(sync_mode, model):
litellm.set_verbose = True litellm.set_verbose = True
@ -3705,19 +3628,21 @@ def test_completion_anyscale_api():
# test_completion_anyscale_api() # test_completion_anyscale_api()
@pytest.mark.skip(reason="flaky test, times out frequently") # @pytest.mark.skip(reason="flaky test, times out frequently")
def test_completion_cohere(): def test_completion_cohere():
try: try:
# litellm.set_verbose=True # litellm.set_verbose=True
messages = [ messages = [
{"role": "system", "content": "You're a good bot"}, {"role": "system", "content": "You're a good bot"},
{"role": "assistant", "content": [{"text": "2", "type": "text"}]},
{"role": "assistant", "content": [{"text": "3", "type": "text"}]},
{ {
"role": "user", "role": "user",
"content": "Hey", "content": "Hey",
}, },
] ]
response = completion( response = completion(
model="command-nightly", model="command-r",
messages=messages, messages=messages,
) )
print(response) print(response)

View file

@ -1,23 +1,27 @@
# What is this? # What is this?
## Test to make sure function call response always works with json.loads() -> no extra parsing required. Relevant issue - https://github.com/BerriAI/litellm/issues/2654 ## Test to make sure function call response always works with json.loads() -> no extra parsing required. Relevant issue - https://github.com/BerriAI/litellm/issues/2654
import sys, os import os
import sys
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import os, io import io
import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest
import litellm
import json import json
import warnings import warnings
from litellm import completion
from typing import List from typing import List
import pytest
import litellm
from litellm import completion
# Just a stub to keep the sample code simple # Just a stub to keep the sample code simple
class Trade: class Trade:
@ -78,58 +82,60 @@ def trade(model_name: str) -> List[Trade]:
}, },
} }
response = completion( try:
model_name, response = completion(
[ model_name,
{ [
"role": "system", {
"content": """You are an expert asset manager, managing a portfolio. "role": "system",
"content": """You are an expert asset manager, managing a portfolio.
Always use the `trade` function. Make sure that you call it correctly. For example, the following is a valid call: Always use the `trade` function. Make sure that you call it correctly. For example, the following is a valid call:
```
trade({
"orders": [
{"action": "buy", "asset": "BTC", "amount": 0.1},
{"action": "sell", "asset": "ETH", "amount": 0.2}
]
})
```
If there are no trades to make, call `trade` with an empty array:
```
trade({ "orders": [] })
```
""",
},
{
"role": "user",
"content": """Manage the portfolio.
Don't jabber.
This is the current market data:
``` ```
trade({ {market_data}
"orders": [
{"action": "buy", "asset": "BTC", "amount": 0.1},
{"action": "sell", "asset": "ETH", "amount": 0.2}
]
})
``` ```
If there are no trades to make, call `trade` with an empty array: Your portfolio is as follows:
``` ```
trade({ "orders": [] }) {portfolio}
``` ```
""", """.replace(
"{market_data}", "BTC: 64,000 USD\nETH: 3,500 USD"
).replace(
"{portfolio}", "USD: 1000, BTC: 0.1, ETH: 0.2"
),
},
],
tools=[tool_spec],
tool_choice={
"type": "function",
"function": {"name": tool_spec["function"]["name"]}, # type: ignore
}, },
{ )
"role": "user", except litellm.InternalServerError:
"content": """Manage the portfolio. pass
Don't jabber.
This is the current market data:
```
{market_data}
```
Your portfolio is as follows:
```
{portfolio}
```
""".replace(
"{market_data}", "BTC: 64,000 USD\nETH: 3,500 USD"
).replace(
"{portfolio}", "USD: 1000, BTC: 0.1, ETH: 0.2"
),
},
],
tools=[tool_spec],
tool_choice={
"type": "function",
"function": {"name": tool_spec["function"]["name"]}, # type: ignore
},
)
calls = response.choices[0].message.tool_calls calls = response.choices[0].message.tool_calls
trades = [trade for call in calls for trade in parse_call(call)] trades = [trade for call in calls for trade in parse_call(call)]
return trades return trades

View file

@ -147,6 +147,117 @@ async def test_basic_gcs_logger():
assert gcs_payload["response_cost"] > 0.0 assert gcs_payload["response_cost"] > 0.0
assert gcs_payload["log_event_type"] == "successful_api_call"
gcs_payload["spend_log_metadata"] = json.loads(gcs_payload["spend_log_metadata"])
assert (
gcs_payload["spend_log_metadata"]["user_api_key"]
== "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b"
)
assert (
gcs_payload["spend_log_metadata"]["user_api_key_user_id"]
== "116544810872468347480"
)
# Delete Object from GCS
print("deleting object from GCS")
await gcs_logger.delete_gcs_object(object_name=object_name)
@pytest.mark.asyncio
async def test_basic_gcs_logger_failure():
load_vertex_ai_credentials()
gcs_logger = GCSBucketLogger()
print("GCSBucketLogger", gcs_logger)
gcs_log_id = f"failure-test-{uuid.uuid4().hex}"
litellm.callbacks = [gcs_logger]
try:
response = await litellm.acompletion(
model="gpt-3.5-turbo",
temperature=0.7,
messages=[{"role": "user", "content": "This is a test"}],
max_tokens=10,
user="ishaan-2",
mock_response=litellm.BadRequestError(
model="gpt-3.5-turbo",
message="Error: 400: Bad Request: Invalid API key, please check your API key and try again.",
llm_provider="openai",
),
metadata={
"gcs_log_id": gcs_log_id,
"tags": ["model-anthropic-claude-v2.1", "app-ishaan-prod"],
"user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
"user_api_key_alias": None,
"user_api_end_user_max_budget": None,
"litellm_api_version": "0.0.0",
"global_max_parallel_requests": None,
"user_api_key_user_id": "116544810872468347480",
"user_api_key_org_id": None,
"user_api_key_team_id": None,
"user_api_key_team_alias": None,
"user_api_key_metadata": {},
"requester_ip_address": "127.0.0.1",
"spend_logs_metadata": {"hello": "world"},
"headers": {
"content-type": "application/json",
"user-agent": "PostmanRuntime/7.32.3",
"accept": "*/*",
"postman-token": "92300061-eeaa-423b-a420-0b44896ecdc4",
"host": "localhost:4000",
"accept-encoding": "gzip, deflate, br",
"connection": "keep-alive",
"content-length": "163",
},
"endpoint": "http://localhost:4000/chat/completions",
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
"model_info": {
"id": "4bad40a1eb6bebd1682800f16f44b9f06c52a6703444c99c7f9f32e9de3693b4",
"db_model": False,
},
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"caching_groups": None,
"raw_request": "\n\nPOST Request Sent from LiteLLM:\ncurl -X POST \\\nhttps://openai-gpt-4-test-v-1.openai.azure.com//openai/ \\\n-H 'Authorization: *****' \\\n-d '{'model': 'chatgpt-v-2', 'messages': [{'role': 'system', 'content': 'you are a helpful assistant.\\n'}, {'role': 'user', 'content': 'bom dia'}], 'stream': False, 'max_tokens': 10, 'user': '116544810872468347480', 'extra_body': {}}'\n",
},
)
except:
pass
await asyncio.sleep(5)
# Get the current date
# Get the current date
current_date = datetime.now().strftime("%Y-%m-%d")
# Modify the object_name to include the date-based folder
object_name = gcs_log_id
print("object_name", object_name)
# Check if object landed on GCS
object_from_gcs = await gcs_logger.download_gcs_object(object_name=object_name)
print("object from gcs=", object_from_gcs)
# convert object_from_gcs from bytes to DICT
parsed_data = json.loads(object_from_gcs)
print("object_from_gcs as dict", parsed_data)
print("type of object_from_gcs", type(parsed_data))
gcs_payload = GCSBucketPayload(**parsed_data)
print("gcs_payload", gcs_payload)
assert gcs_payload["request_kwargs"]["model"] == "gpt-3.5-turbo"
assert gcs_payload["request_kwargs"]["messages"] == [
{"role": "user", "content": "This is a test"}
]
assert gcs_payload["response_cost"] == 0
assert gcs_payload["log_event_type"] == "failed_api_call"
gcs_payload["spend_log_metadata"] = json.loads(gcs_payload["spend_log_metadata"]) gcs_payload["spend_log_metadata"] = json.loads(gcs_payload["spend_log_metadata"])
assert ( assert (

View file

@ -76,6 +76,6 @@ async def test_async_prometheus_success_logging():
print("metrics from prometheus", metrics) print("metrics from prometheus", metrics)
assert metrics["litellm_requests_metric_total"] == 1.0 assert metrics["litellm_requests_metric_total"] == 1.0
assert metrics["litellm_total_tokens_total"] == 30.0 assert metrics["litellm_total_tokens_total"] == 30.0
assert metrics["llm_deployment_success_responses_total"] == 1.0 assert metrics["litellm_deployment_success_responses_total"] == 1.0
assert metrics["llm_deployment_total_requests_total"] == 1.0 assert metrics["litellm_deployment_total_requests_total"] == 1.0
assert metrics["llm_deployment_latency_per_output_token_bucket"] == 1.0 assert metrics["litellm_deployment_latency_per_output_token_bucket"] == 1.0

View file

@ -260,3 +260,56 @@ def test_anthropic_messages_tool_call():
translated_messages[-1]["content"][0]["tool_use_id"] translated_messages[-1]["content"][0]["tool_use_id"]
== "bc8cb4b6-88c4-4138-8993-3a9d9cd51656" == "bc8cb4b6-88c4-4138-8993-3a9d9cd51656"
) )
def test_anthropic_cache_controls_pt():
"see anthropic docs for this: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation"
messages = [
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
"cache_control": {"type": "ephemeral"},
},
]
translated_messages = anthropic_messages_pt(
messages, model="claude-3-5-sonnet-20240620", llm_provider="anthropic"
)
for i, msg in enumerate(translated_messages):
if i == 0:
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
elif i == 1:
assert "cache_controls" not in msg["content"][0]
elif i == 2:
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
elif i == 3:
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
print("translated_messages: ", translated_messages)

View file

@ -2,16 +2,19 @@
# This tests setting provider specific configs across providers # This tests setting provider specific configs across providers
# There are 2 types of tests - changing config dynamically or by setting class variables # There are 2 types of tests - changing config dynamically or by setting class variables
import sys, os import os
import sys
import traceback import traceback
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock, patch
import litellm import litellm
from litellm import completion from litellm import RateLimitError, completion
from litellm import RateLimitError
# Huggingface - Expensive to deploy models and keep them running. Maybe we can try doing this via baseten?? # Huggingface - Expensive to deploy models and keep them running. Maybe we can try doing this via baseten??
# def hf_test_completion_tgi(): # def hf_test_completion_tgi():
@ -513,102 +516,165 @@ def sagemaker_test_completion():
# sagemaker_test_completion() # sagemaker_test_completion()
def test_sagemaker_default_region(mocker): def test_sagemaker_default_region():
""" """
If no regions are specified in config or in environment, the default region is us-west-2 If no regions are specified in config or in environment, the default region is us-west-2
""" """
mock_client = mocker.patch("boto3.client") mock_response = MagicMock()
try:
def return_val():
return {
"generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{
"text": "This is a mock response from SageMaker.",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = litellm.completion( response = litellm.completion(
model="sagemaker/mock-endpoint", model="sagemaker/mock-endpoint",
messages=[ messages=[{"content": "Hello, world!", "role": "user"}],
{
"content": "Hello, world!",
"role": "user"
}
]
) )
except Exception: mock_post.assert_called_once()
pass # expected serialization exception because AWS client was replaced with a Mock _, kwargs = mock_post.call_args
assert mock_client.call_args.kwargs["region_name"] == "us-west-2" args_to_sagemaker = kwargs["json"]
print("Arguments passed to sagemaker=", args_to_sagemaker)
print("url=", kwargs["url"])
assert (
kwargs["url"]
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/mock-endpoint/invocations"
)
# test_sagemaker_default_region() # test_sagemaker_default_region()
def test_sagemaker_environment_region(mocker): def test_sagemaker_environment_region():
""" """
If a region is specified in the environment, use that region instead of us-west-2 If a region is specified in the environment, use that region instead of us-west-2
""" """
expected_region = "us-east-1" expected_region = "us-east-1"
os.environ["AWS_REGION_NAME"] = expected_region os.environ["AWS_REGION_NAME"] = expected_region
mock_client = mocker.patch("boto3.client") mock_response = MagicMock()
try:
def return_val():
return {
"generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{
"text": "This is a mock response from SageMaker.",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = litellm.completion( response = litellm.completion(
model="sagemaker/mock-endpoint", model="sagemaker/mock-endpoint",
messages=[ messages=[{"content": "Hello, world!", "role": "user"}],
{
"content": "Hello, world!",
"role": "user"
}
]
) )
except Exception: mock_post.assert_called_once()
pass # expected serialization exception because AWS client was replaced with a Mock _, kwargs = mock_post.call_args
del os.environ["AWS_REGION_NAME"] # cleanup args_to_sagemaker = kwargs["json"]
assert mock_client.call_args.kwargs["region_name"] == expected_region print("Arguments passed to sagemaker=", args_to_sagemaker)
print("url=", kwargs["url"])
assert (
kwargs["url"]
== f"https://runtime.sagemaker.{expected_region}.amazonaws.com/endpoints/mock-endpoint/invocations"
)
del os.environ["AWS_REGION_NAME"] # cleanup
# test_sagemaker_environment_region() # test_sagemaker_environment_region()
def test_sagemaker_config_region(mocker): def test_sagemaker_config_region():
""" """
If a region is specified as part of the optional parameters of the completion, including as If a region is specified as part of the optional parameters of the completion, including as
part of the config file, then use that region instead of us-west-2 part of the config file, then use that region instead of us-west-2
""" """
expected_region = "us-east-1" expected_region = "us-east-1"
mock_client = mocker.patch("boto3.client") mock_response = MagicMock()
try:
response = litellm.completion( def return_val():
model="sagemaker/mock-endpoint", return {
messages=[ "generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{ {
"content": "Hello, world!", "text": "This is a mock response from SageMaker.",
"role": "user" "index": 0,
"logprobs": None,
"finish_reason": "length",
} }
], ],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = litellm.completion(
model="sagemaker/mock-endpoint",
messages=[{"content": "Hello, world!", "role": "user"}],
aws_region_name=expected_region, aws_region_name=expected_region,
) )
except Exception:
pass # expected serialization exception because AWS client was replaced with a Mock mock_post.assert_called_once()
assert mock_client.call_args.kwargs["region_name"] == expected_region _, kwargs = mock_post.call_args
args_to_sagemaker = kwargs["json"]
print("Arguments passed to sagemaker=", args_to_sagemaker)
print("url=", kwargs["url"])
assert (
kwargs["url"]
== f"https://runtime.sagemaker.{expected_region}.amazonaws.com/endpoints/mock-endpoint/invocations"
)
# test_sagemaker_config_region() # test_sagemaker_config_region()
def test_sagemaker_config_and_environment_region(mocker):
"""
If both the environment and config file specify a region, the environment region is expected
"""
expected_region = "us-east-1"
unexpected_region = "us-east-2"
os.environ["AWS_REGION_NAME"] = expected_region
mock_client = mocker.patch("boto3.client")
try:
response = litellm.completion(
model="sagemaker/mock-endpoint",
messages=[
{
"content": "Hello, world!",
"role": "user"
}
],
aws_region_name=unexpected_region,
)
except Exception:
pass # expected serialization exception because AWS client was replaced with a Mock
del os.environ["AWS_REGION_NAME"] # cleanup
assert mock_client.call_args.kwargs["region_name"] == expected_region
# test_sagemaker_config_and_environment_region() # test_sagemaker_config_and_environment_region()

View file

@ -229,8 +229,9 @@ def test_chat_completion_exception_any_model(client):
) )
assert isinstance(openai_exception, openai.BadRequestError) assert isinstance(openai_exception, openai.BadRequestError)
_error_message = openai_exception.message _error_message = openai_exception.message
assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str( assert (
_error_message "/chat/completions: Invalid model name passed in model=Lite-GPT-12"
in str(_error_message)
) )
except Exception as e: except Exception as e:
@ -259,7 +260,7 @@ def test_embedding_exception_any_model(client):
print("Exception raised=", openai_exception) print("Exception raised=", openai_exception)
assert isinstance(openai_exception, openai.BadRequestError) assert isinstance(openai_exception, openai.BadRequestError)
_error_message = openai_exception.message _error_message = openai_exception.message
assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str( assert "/embeddings: Invalid model name passed in model=Lite-GPT-12" in str(
_error_message _error_message
) )

View file

@ -966,3 +966,203 @@ async def test_user_info_team_list(prisma_client):
pass pass
mock_client.assert_called() mock_client.assert_called()
@pytest.mark.skip(reason="Local test")
@pytest.mark.asyncio
async def test_add_callback_via_key(prisma_client):
"""
Test if callback specified in key, is used.
"""
global headers
import json
from fastapi import HTTPException, Request, Response
from starlette.datastructures import URL
from litellm.proxy.proxy_server import chat_completion
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
litellm.set_verbose = True
try:
# Your test data
test_data = {
"model": "azure/chatgpt-v-2",
"messages": [
{"role": "user", "content": "write 1 sentence poem"},
],
"max_tokens": 10,
"mock_response": "Hello world",
"api_key": "my-fake-key",
}
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
request._url = URL(url="/chat/completions")
json_bytes = json.dumps(test_data).encode("utf-8")
request._body = json_bytes
with patch.object(
litellm.litellm_core_utils.litellm_logging,
"LangFuseLogger",
new=MagicMock(),
) as mock_client:
resp = await chat_completion(
request=request,
fastapi_response=Response(),
user_api_key_dict=UserAPIKeyAuth(
metadata={
"logging": [
{
"callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
"callback_type": "success", # set, if required by integration - future improvement, have logging tools work for success + failure by default
"callback_vars": {
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
"langfuse_host": "https://us.cloud.langfuse.com",
},
}
]
}
),
)
print(resp)
mock_client.assert_called()
mock_client.return_value.log_event.assert_called()
args, kwargs = mock_client.return_value.log_event.call_args
kwargs = kwargs["kwargs"]
assert "user_api_key_metadata" in kwargs["litellm_params"]["metadata"]
assert (
"logging"
in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"]
)
checked_keys = False
for item in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"][
"logging"
]:
for k, v in item["callback_vars"].items():
print("k={}, v={}".format(k, v))
if "key" in k:
assert "os.environ" in v
checked_keys = True
assert checked_keys
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
@pytest.mark.asyncio
async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
import json
from fastapi import HTTPException, Request, Response
from starlette.datastructures import URL
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
request._url = URL(url="/chat/completions")
test_data = {
"model": "azure/chatgpt-v-2",
"messages": [
{"role": "user", "content": "write 1 sentence poem"},
],
"max_tokens": 10,
"mock_response": "Hello world",
"api_key": "my-fake-key",
}
json_bytes = json.dumps(test_data).encode("utf-8")
request._body = json_bytes
data = {
"data": {
"model": "azure/chatgpt-v-2",
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
"max_tokens": 10,
"mock_response": "Hello world",
"api_key": "my-fake-key",
},
"request": request,
"user_api_key_dict": UserAPIKeyAuth(
token=None,
key_name=None,
key_alias=None,
spend=0.0,
max_budget=None,
expires=None,
models=[],
aliases={},
config={},
user_id=None,
team_id=None,
max_parallel_requests=None,
metadata={
"logging": [
{
"callback_name": "langfuse",
"callback_type": "success",
"callback_vars": {
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
"langfuse_host": "https://us.cloud.langfuse.com",
},
}
]
},
tpm_limit=None,
rpm_limit=None,
budget_duration=None,
budget_reset_at=None,
allowed_cache_controls=[],
permissions={},
model_spend={},
model_max_budget={},
soft_budget_cooldown=False,
litellm_budget_table=None,
org_id=None,
team_spend=None,
team_alias=None,
team_tpm_limit=None,
team_rpm_limit=None,
team_max_budget=None,
team_models=[],
team_blocked=False,
soft_budget=None,
team_model_aliases=None,
team_member_spend=None,
team_metadata=None,
end_user_id=None,
end_user_tpm_limit=None,
end_user_rpm_limit=None,
end_user_max_budget=None,
last_refreshed_at=None,
api_key=None,
user_role=None,
allowed_model_region=None,
parent_otel_span=None,
),
"proxy_config": proxy_config,
"general_settings": {},
"version": "0.0.0",
}
new_data = await add_litellm_data_to_request(**data)
assert "success_callback" in new_data
assert new_data["success_callback"] == ["langfuse"]
assert "langfuse_public_key" in new_data
assert "langfuse_secret_key" in new_data

View file

@ -0,0 +1,316 @@
import json
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import litellm
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
# litellm.num_retries =3
litellm.cache = None
litellm.success_callback = []
user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
import logging
from litellm._logging import verbose_logger
def logger_fn(user_model_dict):
print(f"user_model_dict: {user_model_dict}")
@pytest.fixture(autouse=True)
def reset_callbacks():
print("\npytest fixture - resetting callbacks")
litellm.success_callback = []
litellm._async_success_callback = []
litellm.failure_callback = []
litellm.callbacks = []
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_completion_sagemaker(sync_mode):
try:
litellm.set_verbose = True
verbose_logger.setLevel(logging.DEBUG)
print("testing sagemaker")
if sync_mode is True:
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
else:
response = await litellm.acompletion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
# Add any assertions here to check the response
print(response)
cost = completion_cost(completion_response=response)
print("calculated cost", cost)
assert (
cost > 0.0 and cost < 1.0
) # should never be > $1 for a single completion call
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [False, True])
async def test_completion_sagemaker_stream(sync_mode):
try:
litellm.set_verbose = False
print("testing sagemaker")
verbose_logger.setLevel(logging.DEBUG)
full_text = ""
if sync_mode is True:
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
temperature=0.2,
stream=True,
max_tokens=80,
input_cost_per_second=0.000420,
)
for chunk in response:
print(chunk)
full_text += chunk.choices[0].delta.content or ""
print("SYNC RESPONSE full text", full_text)
else:
response = await litellm.acompletion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
stream=True,
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
print("streaming response")
async for chunk in response:
print(chunk)
full_text += chunk.choices[0].delta.content or ""
print("ASYNC RESPONSE full text", full_text)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_sagemaker_non_stream():
mock_response = AsyncMock()
def return_val():
return {
"generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{
"text": "This is a mock response from SageMaker.",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"inputs": "hi",
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.acompletion function
response = await litellm.acompletion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
# Print what was called on the mock
print("call args=", mock_post.call_args)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_sagemaker = kwargs["json"]
print("Arguments passed to sagemaker=", args_to_sagemaker)
assert args_to_sagemaker == expected_payload
assert (
kwargs["url"]
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
)
@pytest.mark.asyncio
async def test_completion_sagemaker_non_stream():
mock_response = MagicMock()
def return_val():
return {
"generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{
"text": "This is a mock response from SageMaker.",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"inputs": "hi",
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
}
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.acompletion function
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
# Print what was called on the mock
print("call args=", mock_post.call_args)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_sagemaker = kwargs["json"]
print("Arguments passed to sagemaker=", args_to_sagemaker)
assert args_to_sagemaker == expected_payload
assert (
kwargs["url"]
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
)
@pytest.mark.asyncio
async def test_completion_sagemaker_non_stream_with_aws_params():
mock_response = MagicMock()
def return_val():
return {
"generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{
"text": "This is a mock response from SageMaker.",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"inputs": "hi",
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
}
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.acompletion function
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
aws_access_key_id="gm",
aws_secret_access_key="s",
aws_region_name="us-west-5",
)
# Print what was called on the mock
print("call args=", mock_post.call_args)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_sagemaker = kwargs["json"]
print("Arguments passed to sagemaker=", args_to_sagemaker)
assert args_to_sagemaker == expected_payload
assert (
kwargs["url"]
== "https://runtime.sagemaker.us-west-5.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
)

View file

@ -1683,6 +1683,7 @@ def test_completion_bedrock_mistral_stream():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="stopped using TokenIterator")
def test_sagemaker_weird_response(): def test_sagemaker_weird_response():
""" """
When the stream ends, flush any remaining holding chunks. When the stream ends, flush any remaining holding chunks.

View file

@ -44,7 +44,7 @@ def test_check_valid_ip(
request = Request(client_ip) request = Request(client_ip)
assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore assert _check_valid_ip(allowed_ips, request)[0] == expected_result # type: ignore
# test x-forwarder for is used when user has opted in # test x-forwarder for is used when user has opted in
@ -72,7 +72,7 @@ def test_check_valid_ip_sent_with_x_forwarded_for(
request = Request(client_ip, headers={"X-Forwarded-For": client_ip}) request = Request(client_ip, headers={"X-Forwarded-For": client_ip})
assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True) == expected_result # type: ignore assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True)[0] == expected_result # type: ignore
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -15,9 +15,10 @@ class AnthropicMessagesTool(TypedDict, total=False):
input_schema: Required[dict] input_schema: Required[dict]
class AnthropicMessagesTextParam(TypedDict): class AnthropicMessagesTextParam(TypedDict, total=False):
type: Literal["text"] type: Literal["text"]
text: str text: str
cache_control: Optional[dict]
class AnthropicMessagesToolUseParam(TypedDict): class AnthropicMessagesToolUseParam(TypedDict):
@ -54,9 +55,10 @@ class AnthropicImageParamSource(TypedDict):
data: str data: str
class AnthropicMessagesImageParam(TypedDict): class AnthropicMessagesImageParam(TypedDict, total=False):
type: Literal["image"] type: Literal["image"]
source: AnthropicImageParamSource source: AnthropicImageParamSource
cache_control: Optional[dict]
class AnthropicMessagesToolResultContent(TypedDict): class AnthropicMessagesToolResultContent(TypedDict):
@ -92,6 +94,12 @@ class AnthropicMetadata(TypedDict, total=False):
user_id: str user_id: str
class AnthropicSystemMessageContent(TypedDict, total=False):
type: str
text: str
cache_control: Optional[dict]
class AnthropicMessagesRequest(TypedDict, total=False): class AnthropicMessagesRequest(TypedDict, total=False):
model: Required[str] model: Required[str]
messages: Required[ messages: Required[
@ -106,7 +114,7 @@ class AnthropicMessagesRequest(TypedDict, total=False):
metadata: AnthropicMetadata metadata: AnthropicMetadata
stop_sequences: List[str] stop_sequences: List[str]
stream: bool stream: bool
system: str system: Union[str, List]
temperature: float temperature: float
tool_choice: AnthropicMessagesToolChoice tool_choice: AnthropicMessagesToolChoice
tools: List[AnthropicMessagesTool] tools: List[AnthropicMessagesTool]

View file

@ -361,7 +361,7 @@ class ChatCompletionToolMessage(TypedDict):
class ChatCompletionSystemMessage(TypedDict, total=False): class ChatCompletionSystemMessage(TypedDict, total=False):
role: Required[Literal["system"]] role: Required[Literal["system"]]
content: Required[str] content: Required[Union[str, List]]
name: str name: str

View file

@ -80,7 +80,7 @@ class ModelInfo(TypedDict, total=False):
supports_assistant_prefill: Optional[bool] supports_assistant_prefill: Optional[bool]
class GenericStreamingChunk(TypedDict): class GenericStreamingChunk(TypedDict, total=False):
text: Required[str] text: Required[str]
tool_use: Optional[ChatCompletionToolCallChunk] tool_use: Optional[ChatCompletionToolCallChunk]
is_finished: Required[bool] is_finished: Required[bool]

View file

@ -4479,7 +4479,22 @@ def _is_non_openai_azure_model(model: str) -> bool:
or f"mistral/{model_name}" in litellm.mistral_chat_models or f"mistral/{model_name}" in litellm.mistral_chat_models
): ):
return True return True
except: except Exception:
return False
return False
def _is_azure_openai_model(model: str) -> bool:
try:
if "/" in model:
model = model.split("/", 1)[1]
if (
model in litellm.open_ai_chat_completion_models
or model in litellm.open_ai_text_completion_models
or model in litellm.open_ai_embedding_models
):
return True
except Exception:
return False return False
return False return False
@ -4613,6 +4628,14 @@ def get_llm_provider(
elif custom_llm_provider == "azure_ai": elif custom_llm_provider == "azure_ai":
api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY") dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY")
if _is_azure_openai_model(model=model):
verbose_logger.debug(
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
model
)
)
custom_llm_provider = "azure"
elif custom_llm_provider == "github": elif custom_llm_provider == "github":
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
dynamic_api_key = api_key or get_secret("GITHUB_API_KEY") dynamic_api_key = api_key or get_secret("GITHUB_API_KEY")
@ -9825,11 +9848,28 @@ class CustomStreamWrapper:
completion_obj["tool_calls"] = [response_obj["tool_use"]] completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker": elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") from litellm.types.llms.bedrock import GenericStreamingChunk
response_obj = self.handle_sagemaker_stream(chunk)
if self.received_finish_reason is not None:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "petals": elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0: if len(self.completion_stream) == 0:
if self.received_finish_reason is not None: if self.received_finish_reason is not None:

View file

@ -57,6 +57,18 @@
"supports_parallel_function_calling": true, "supports_parallel_function_calling": true,
"supports_vision": true "supports_vision": true
}, },
"chatgpt-4o-latest": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"gpt-4o-2024-05-13": { "gpt-4o-2024-05-13": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 128000, "max_input_tokens": 128000,
@ -2062,7 +2074,8 @@
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true "supports_vision": true,
"supports_assistant_prefill": true
}, },
"vertex_ai/claude-3-5-sonnet@20240620": { "vertex_ai/claude-3-5-sonnet@20240620": {
"max_tokens": 4096, "max_tokens": 4096,
@ -2073,7 +2086,8 @@
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true "supports_vision": true,
"supports_assistant_prefill": true
}, },
"vertex_ai/claude-3-haiku@20240307": { "vertex_ai/claude-3-haiku@20240307": {
"max_tokens": 4096, "max_tokens": 4096,
@ -2084,7 +2098,8 @@
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true "supports_vision": true,
"supports_assistant_prefill": true
}, },
"vertex_ai/claude-3-opus@20240229": { "vertex_ai/claude-3-opus@20240229": {
"max_tokens": 4096, "max_tokens": 4096,
@ -2095,7 +2110,8 @@
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true "supports_vision": true,
"supports_assistant_prefill": true
}, },
"vertex_ai/meta/llama3-405b-instruct-maas": { "vertex_ai/meta/llama3-405b-instruct-maas": {
"max_tokens": 32000, "max_tokens": 32000,
@ -4519,6 +4535,69 @@
"litellm_provider": "perplexity", "litellm_provider": "perplexity",
"mode": "chat" "mode": "chat"
}, },
"perplexity/llama-3.1-70b-instruct": {
"max_tokens": 131072,
"max_input_tokens": 131072,
"max_output_tokens": 131072,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-3.1-8b-instruct": {
"max_tokens": 131072,
"max_input_tokens": 131072,
"max_output_tokens": 131072,
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-3.1-sonar-huge-128k-online": {
"max_tokens": 127072,
"max_input_tokens": 127072,
"max_output_tokens": 127072,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000005,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-3.1-sonar-large-128k-online": {
"max_tokens": 127072,
"max_input_tokens": 127072,
"max_output_tokens": 127072,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-3.1-sonar-large-128k-chat": {
"max_tokens": 131072,
"max_input_tokens": 131072,
"max_output_tokens": 131072,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-3.1-sonar-small-128k-chat": {
"max_tokens": 131072,
"max_input_tokens": 131072,
"max_output_tokens": 131072,
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-3.1-sonar-small-128k-online": {
"max_tokens": 127072,
"max_input_tokens": 127072,
"max_output_tokens": 127072,
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/pplx-7b-chat": { "perplexity/pplx-7b-chat": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 8192, "max_input_tokens": 8192,

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.43.9" version = "1.43.15"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.43.9" version = "1.43.15"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]