mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge branch 'main' into litellm_pass_through_endpoints_api
This commit is contained in:
commit
b3d15ace89
55 changed files with 3658 additions and 1644 deletions
|
@ -125,6 +125,7 @@ jobs:
|
||||||
pip install tiktoken
|
pip install tiktoken
|
||||||
pip install aiohttp
|
pip install aiohttp
|
||||||
pip install click
|
pip install click
|
||||||
|
pip install "boto3==1.34.34"
|
||||||
pip install jinja2
|
pip install jinja2
|
||||||
pip install tokenizers
|
pip install tokenizers
|
||||||
pip install openai
|
pip install openai
|
||||||
|
@ -287,6 +288,7 @@ jobs:
|
||||||
pip install "pytest==7.3.1"
|
pip install "pytest==7.3.1"
|
||||||
pip install "pytest-mock==3.12.0"
|
pip install "pytest-mock==3.12.0"
|
||||||
pip install "pytest-asyncio==0.21.1"
|
pip install "pytest-asyncio==0.21.1"
|
||||||
|
pip install "boto3==1.34.34"
|
||||||
pip install mypy
|
pip install mypy
|
||||||
pip install pyarrow
|
pip install pyarrow
|
||||||
pip install numpydoc
|
pip install numpydoc
|
||||||
|
|
|
@ -62,6 +62,11 @@ COPY --from=builder /wheels/ /wheels/
|
||||||
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels
|
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels
|
||||||
|
|
||||||
# Generate prisma client
|
# Generate prisma client
|
||||||
|
ENV PRISMA_BINARY_CACHE_DIR=/app/prisma
|
||||||
|
RUN mkdir -p /.cache
|
||||||
|
RUN chmod -R 777 /.cache
|
||||||
|
RUN pip install nodejs-bin
|
||||||
|
RUN pip install prisma
|
||||||
RUN prisma generate
|
RUN prisma generate
|
||||||
RUN chmod +x entrypoint.sh
|
RUN chmod +x entrypoint.sh
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,11 @@ RUN pip install PyJWT --no-cache-dir
|
||||||
RUN chmod +x build_admin_ui.sh && ./build_admin_ui.sh
|
RUN chmod +x build_admin_ui.sh && ./build_admin_ui.sh
|
||||||
|
|
||||||
# Generate prisma client
|
# Generate prisma client
|
||||||
|
ENV PRISMA_BINARY_CACHE_DIR=/app/prisma
|
||||||
|
RUN mkdir -p /.cache
|
||||||
|
RUN chmod -R 777 /.cache
|
||||||
|
RUN pip install nodejs-bin
|
||||||
|
RUN pip install prisma
|
||||||
RUN prisma generate
|
RUN prisma generate
|
||||||
RUN chmod +x entrypoint.sh
|
RUN chmod +x entrypoint.sh
|
||||||
|
|
||||||
|
|
|
@ -84,17 +84,20 @@ from litellm import completion
|
||||||
# add to env var
|
# add to env var
|
||||||
os.environ["OPENAI_API_KEY"] = ""
|
os.environ["OPENAI_API_KEY"] = ""
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "List 5 cookie recipes"}]
|
messages = [{"role": "user", "content": "List 5 important events in the XIX century"}]
|
||||||
|
|
||||||
class CalendarEvent(BaseModel):
|
class CalendarEvent(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
date: str
|
date: str
|
||||||
participants: list[str]
|
participants: list[str]
|
||||||
|
|
||||||
|
class EventsList(BaseModel):
|
||||||
|
events: list[CalendarEvent]
|
||||||
|
|
||||||
resp = completion(
|
resp = completion(
|
||||||
model="gpt-4o-2024-08-06",
|
model="gpt-4o-2024-08-06",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_format=CalendarEvent
|
response_format=EventsList
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Received={}".format(resp))
|
print("Received={}".format(resp))
|
||||||
|
|
|
@ -225,22 +225,336 @@ print(response)
|
||||||
| claude-instant-1.2 | `completion('claude-instant-1.2', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
|
| claude-instant-1.2 | `completion('claude-instant-1.2', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
|
||||||
| claude-instant-1 | `completion('claude-instant-1', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
|
| claude-instant-1 | `completion('claude-instant-1', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
|
||||||
|
|
||||||
## Passing Extra Headers to Anthropic API
|
## **Prompt Caching**
|
||||||
|
|
||||||
Pass `extra_headers: dict` to `litellm.completion`
|
Use Anthropic Prompt Caching
|
||||||
|
|
||||||
|
|
||||||
|
[Relevant Anthropic API Docs](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching)
|
||||||
|
|
||||||
|
### Caching - Large Context Caching
|
||||||
|
|
||||||
|
This example demonstrates basic Prompt Caching usage, caching the full text of the legal agreement as a prefix while keeping the user instruction uncached.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="LiteLLM SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm import completion
|
response = await litellm.acompletion(
|
||||||
messages = [{"role": "user", "content": "What is Anthropic?"}]
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
response = completion(
|
messages=[
|
||||||
model="claude-3-5-sonnet-20240620",
|
{
|
||||||
messages=messages,
|
"role": "system",
|
||||||
extra_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are an AI assistant tasked with analyzing legal documents.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what are the key terms and conditions in this agreement?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="LiteLLM Proxy">
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
LiteLLM Proxy is OpenAI compatible
|
||||||
|
|
||||||
|
This is an example using the OpenAI Python SDK sending a request to LiteLLM Proxy
|
||||||
|
|
||||||
|
Assuming you have a model=`anthropic/claude-3-5-sonnet-20240620` on the [litellm proxy config.yaml](#usage-with-litellm-proxy)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key="anything", # litellm proxy api key
|
||||||
|
base_url="http://0.0.0.0:4000" # litellm proxy base url
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are an AI assistant tasked with analyzing legal documents.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what are the key terms and conditions in this agreement?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
### Caching - Tools definitions
|
||||||
|
|
||||||
|
In this example, we demonstrate caching tool definitions.
|
||||||
|
|
||||||
|
The cache_control parameter is placed on the final tool
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="LiteLLM SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
"cache_control": {"type": "ephemeral"}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
## Advanced
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="LiteLLM Proxy">
|
||||||
|
|
||||||
## Usage - Function Calling
|
:::info
|
||||||
|
|
||||||
|
LiteLLM Proxy is OpenAI compatible
|
||||||
|
|
||||||
|
This is an example using the OpenAI Python SDK sending a request to LiteLLM Proxy
|
||||||
|
|
||||||
|
Assuming you have a model=`anthropic/claude-3-5-sonnet-20240620` on the [litellm proxy config.yaml](#usage-with-litellm-proxy)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key="anything", # litellm proxy api key
|
||||||
|
base_url="http://0.0.0.0:4000" # litellm proxy base url
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
"cache_control": {"type": "ephemeral"}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
### Caching - Continuing Multi-Turn Convo
|
||||||
|
|
||||||
|
In this example, we demonstrate how to use Prompt Caching in a multi-turn conversation.
|
||||||
|
|
||||||
|
The cache_control parameter is placed on the system message to designate it as part of the static prefix.
|
||||||
|
|
||||||
|
The conversation history (previous messages) is included in the messages array. The final turn is marked with cache-control, for continuing in followups. The second-to-last user message is marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="LiteLLM SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
messages=[
|
||||||
|
# System Message
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement"
|
||||||
|
* 400,
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
|
||||||
|
},
|
||||||
|
# The final turn is marked with cache-control, for continuing in followups.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="LiteLLM Proxy">
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
LiteLLM Proxy is OpenAI compatible
|
||||||
|
|
||||||
|
This is an example using the OpenAI Python SDK sending a request to LiteLLM Proxy
|
||||||
|
|
||||||
|
Assuming you have a model=`anthropic/claude-3-5-sonnet-20240620` on the [litellm proxy config.yaml](#usage-with-litellm-proxy)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key="anything", # litellm proxy api key
|
||||||
|
base_url="http://0.0.0.0:4000" # litellm proxy base url
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
messages=[
|
||||||
|
# System Message
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement"
|
||||||
|
* 400,
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
|
||||||
|
},
|
||||||
|
# The final turn is marked with cache-control, for continuing in followups.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## **Function/Tool Calling**
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
|
||||||
|
@ -429,6 +743,20 @@ resp = litellm.completion(
|
||||||
print(f"\nResponse: {resp}")
|
print(f"\nResponse: {resp}")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## **Passing Extra Headers to Anthropic API**
|
||||||
|
|
||||||
|
Pass `extra_headers: dict` to `litellm.completion`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
messages = [{"role": "user", "content": "What is Anthropic?"}]
|
||||||
|
response = completion(
|
||||||
|
model="claude-3-5-sonnet-20240620",
|
||||||
|
messages=messages,
|
||||||
|
extra_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Usage - "Assistant Pre-fill"
|
## Usage - "Assistant Pre-fill"
|
||||||
|
|
||||||
You can "put words in Claude's mouth" by including an `assistant` role message as the last item in the `messages` array.
|
You can "put words in Claude's mouth" by including an `assistant` role message as the last item in the `messages` array.
|
||||||
|
|
|
@ -393,7 +393,7 @@ response = completion(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="proxy" label="LiteLLM Proxy Server">
|
<TabItem value="proxy" label="Proxy on request">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
||||||
|
@ -420,6 +420,55 @@ extra_body={
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy-config" label="Proxy on config.yaml">
|
||||||
|
|
||||||
|
1. Update config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: bedrock-claude-v1
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/anthropic.claude-instant-v1
|
||||||
|
aws_access_key_id: os.environ/CUSTOM_AWS_ACCESS_KEY_ID
|
||||||
|
aws_secret_access_key: os.environ/CUSTOM_AWS_SECRET_ACCESS_KEY
|
||||||
|
aws_region_name: os.environ/CUSTOM_AWS_REGION_NAME
|
||||||
|
guardrailConfig: {
|
||||||
|
"guardrailIdentifier": "ff6ujrregl1q", # The identifier (ID) for the guardrail.
|
||||||
|
"guardrailVersion": "DRAFT", # The version of the guardrail.
|
||||||
|
"trace": "disabled", # The trace behavior for the guardrail. Can either be "disabled" or "enabled"
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="anything",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(model="bedrock-claude-v1", messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
|
@ -705,6 +705,29 @@ docker run ghcr.io/berriai/litellm:main-latest \
|
||||||
|
|
||||||
Provide an ssl certificate when starting litellm proxy server
|
Provide an ssl certificate when starting litellm proxy server
|
||||||
|
|
||||||
|
### 3. Providing LiteLLM config.yaml file as a s3 Object/url
|
||||||
|
|
||||||
|
Use this if you cannot mount a config file on your deployment service (example - AWS Fargate, Railway etc)
|
||||||
|
|
||||||
|
LiteLLM Proxy will read your config.yaml from an s3 Bucket
|
||||||
|
|
||||||
|
Set the following .env vars
|
||||||
|
```shell
|
||||||
|
LITELLM_CONFIG_BUCKET_NAME = "litellm-proxy" # your bucket name on s3
|
||||||
|
LITELLM_CONFIG_BUCKET_OBJECT_KEY = "litellm_proxy_config.yaml" # object key on s3
|
||||||
|
```
|
||||||
|
|
||||||
|
Start litellm proxy with these env vars - litellm will read your config from s3
|
||||||
|
|
||||||
|
```shell
|
||||||
|
docker run --name litellm-proxy \
|
||||||
|
-e DATABASE_URL=<database_url> \
|
||||||
|
-e LITELLM_CONFIG_BUCKET_NAME=<bucket_name> \
|
||||||
|
-e LITELLM_CONFIG_BUCKET_OBJECT_KEY="<object_key>> \
|
||||||
|
-p 4000:4000 \
|
||||||
|
ghcr.io/berriai/litellm-database:main-latest
|
||||||
|
```
|
||||||
|
|
||||||
## Platform-specific Guide
|
## Platform-specific Guide
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
|
|
@ -17,7 +17,7 @@ model_list:
|
||||||
|
|
||||||
## Get Model Information - `/model/info`
|
## Get Model Information - `/model/info`
|
||||||
|
|
||||||
Retrieve detailed information about each model listed in the `/model/info` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled the model_info you set and the litellm model cost map. Sensitive details like API keys are excluded for security purposes.
|
Retrieve detailed information about each model listed in the `/model/info` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled from the model_info you set and the [litellm model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). Sensitive details like API keys are excluded for security purposes.
|
||||||
|
|
||||||
<Tabs
|
<Tabs
|
||||||
defaultValue="curl"
|
defaultValue="curl"
|
||||||
|
@ -35,22 +35,33 @@ curl -X GET "http://0.0.0.0:4000/model/info" \
|
||||||
|
|
||||||
## Add a New Model
|
## Add a New Model
|
||||||
|
|
||||||
Add a new model to the list in the `config.yaml` by providing the model parameters. This allows you to update the model list without restarting the proxy.
|
Add a new model to the proxy via the `/model/new` API, to add models without restarting the proxy.
|
||||||
|
|
||||||
<Tabs
|
<Tabs>
|
||||||
defaultValue="curl"
|
<TabItem value="API">
|
||||||
values={[
|
|
||||||
{ label: 'cURL', value: 'curl', },
|
|
||||||
]}>
|
|
||||||
<TabItem value="curl">
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -X POST "http://0.0.0.0:4000/model/new" \
|
curl -X POST "http://0.0.0.0:4000/model/new" \
|
||||||
-H "accept: application/json" \
|
-H "accept: application/json" \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{ "model_name": "azure-gpt-turbo", "litellm_params": {"model": "azure/gpt-3.5-turbo", "api_key": "os.environ/AZURE_API_KEY", "api_base": "my-azure-api-base"} }'
|
-d '{ "model_name": "azure-gpt-turbo", "litellm_params": {"model": "azure/gpt-3.5-turbo", "api_key": "os.environ/AZURE_API_KEY", "api_base": "my-azure-api-base"} }'
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
<TabItem value="Yaml">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo ### RECEIVED MODEL NAME ### `openai.chat.completions.create(model="gpt-3.5-turbo",...)`
|
||||||
|
litellm_params: # all params accepted by litellm.completion() - https://github.com/BerriAI/litellm/blob/9b46ec05b02d36d6e4fb5c32321e51e7f56e4a6e/litellm/types/router.py#L297
|
||||||
|
model: azure/gpt-turbo-small-eu ### MODEL NAME sent to `litellm.completion()` ###
|
||||||
|
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||||
|
api_key: "os.environ/AZURE_API_KEY_EU" # does os.getenv("AZURE_API_KEY_EU")
|
||||||
|
rpm: 6 # [OPTIONAL] Rate limit for this deployment: in requests per minute (rpm)
|
||||||
|
model_info:
|
||||||
|
my_custom_key: my_custom_value # additional model metadata
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,3 +97,82 @@ Keep in mind that as both endpoints are in [BETA], you may need to visit the ass
|
||||||
- Add a New Model: [Issue #964](https://github.com/BerriAI/litellm/issues/964)
|
- Add a New Model: [Issue #964](https://github.com/BerriAI/litellm/issues/964)
|
||||||
|
|
||||||
Feedback on the beta endpoints is valuable and helps improve the API for all users.
|
Feedback on the beta endpoints is valuable and helps improve the API for all users.
|
||||||
|
|
||||||
|
|
||||||
|
## Add Additional Model Information
|
||||||
|
|
||||||
|
If you want the ability to add a display name, description, and labels for models, just use `model_info:`
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: "gpt-4"
|
||||||
|
litellm_params:
|
||||||
|
model: "gpt-4"
|
||||||
|
api_key: "os.environ/OPENAI_API_KEY"
|
||||||
|
model_info: # 👈 KEY CHANGE
|
||||||
|
my_custom_key: "my_custom_value"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
1. Add additional information to model
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: "gpt-4"
|
||||||
|
litellm_params:
|
||||||
|
model: "gpt-4"
|
||||||
|
api_key: "os.environ/OPENAI_API_KEY"
|
||||||
|
model_info: # 👈 KEY CHANGE
|
||||||
|
my_custom_key: "my_custom_value"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Call with `/model/info`
|
||||||
|
|
||||||
|
Use a key with access to the model `gpt-4`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X GET 'http://0.0.0.0:4000/v1/model/info' \
|
||||||
|
-H 'Authorization: Bearer LITELLM_KEY' \
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Expected Response**
|
||||||
|
|
||||||
|
Returned `model_info = Your custom model_info + (if exists) LITELLM MODEL INFO`
|
||||||
|
|
||||||
|
|
||||||
|
[**How LiteLLM Model Info is found**](https://github.com/BerriAI/litellm/blob/9b46ec05b02d36d6e4fb5c32321e51e7f56e4a6e/litellm/proxy/proxy_server.py#L7460)
|
||||||
|
|
||||||
|
[Tell us how this can be improved!](https://github.com/BerriAI/litellm/issues)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4"
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "e889baacd17f591cce4c63639275ba5e8dc60765d6c553e6ee5a504b19e50ddc",
|
||||||
|
"db_model": false,
|
||||||
|
"my_custom_key": "my_custom_value", # 👈 CUSTOM INFO
|
||||||
|
"key": "gpt-4", # 👈 KEY in LiteLLM MODEL INFO/COST MAP - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 8192,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 3e-05,
|
||||||
|
"input_cost_per_character": null,
|
||||||
|
"input_cost_per_token_above_128k_tokens": null,
|
||||||
|
"output_cost_per_token": 6e-05,
|
||||||
|
"output_cost_per_character": null,
|
||||||
|
"output_cost_per_token_above_128k_tokens": null,
|
||||||
|
"output_cost_per_character_above_128k_tokens": null,
|
||||||
|
"output_vector_size": null,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
|
@ -193,6 +193,53 @@ curl --request POST \
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Use Langfuse client sdk w/ LiteLLM Key
|
||||||
|
|
||||||
|
**Usage**
|
||||||
|
|
||||||
|
1. Set-up yaml to pass-through langfuse /api/public/ingestion
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
pass_through_endpoints:
|
||||||
|
- path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server
|
||||||
|
target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward
|
||||||
|
auth: true # 👈 KEY CHANGE
|
||||||
|
custom_auth_parser: "langfuse" # 👈 KEY CHANGE
|
||||||
|
headers:
|
||||||
|
LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_DEV_PUBLIC_KEY" # your langfuse account public key
|
||||||
|
LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY" # your langfuse account secret key
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test with langfuse sdk
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
from langfuse import Langfuse
|
||||||
|
|
||||||
|
langfuse = Langfuse(
|
||||||
|
host="http://localhost:4000", # your litellm proxy endpoint
|
||||||
|
public_key="sk-1234", # your litellm proxy api key
|
||||||
|
secret_key="anything", # no key required since this is a pass through
|
||||||
|
)
|
||||||
|
|
||||||
|
print("sending langfuse trace request")
|
||||||
|
trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
|
||||||
|
print("flushing langfuse request")
|
||||||
|
langfuse.flush()
|
||||||
|
|
||||||
|
print("flushed langfuse request")
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## `pass_through_endpoints` Spec on config.yaml
|
## `pass_through_endpoints` Spec on config.yaml
|
||||||
|
|
||||||
All possible values for `pass_through_endpoints` and what they mean
|
All possible values for `pass_through_endpoints` and what they mean
|
||||||
|
|
|
@ -72,15 +72,15 @@ http://localhost:4000/metrics
|
||||||
|
|
||||||
| Metric Name | Description |
|
| Metric Name | Description |
|
||||||
|----------------------|--------------------------------------|
|
|----------------------|--------------------------------------|
|
||||||
| `deployment_state` | The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage. |
|
| `litellm_deployment_state` | The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage. |
|
||||||
| `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment |
|
| `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment |
|
||||||
| `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment |
|
| `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment |
|
||||||
`llm_deployment_success_responses` | Total number of successful LLM API calls for deployment |
|
`litellm_deployment_success_responses` | Total number of successful LLM API calls for deployment |
|
||||||
| `llm_deployment_failure_responses` | Total number of failed LLM API calls for deployment |
|
| `litellm_deployment_failure_responses` | Total number of failed LLM API calls for deployment |
|
||||||
| `llm_deployment_total_requests` | Total number of LLM API calls for deployment - success + failure |
|
| `litellm_deployment_total_requests` | Total number of LLM API calls for deployment - success + failure |
|
||||||
| `llm_deployment_latency_per_output_token` | Latency per output token for deployment |
|
| `litellm_deployment_latency_per_output_token` | Latency per output token for deployment |
|
||||||
| `llm_deployment_successful_fallbacks` | Number of successful fallback requests from primary model -> fallback model |
|
| `litellm_deployment_successful_fallbacks` | Number of successful fallback requests from primary model -> fallback model |
|
||||||
| `llm_deployment_failed_fallbacks` | Number of failed fallback requests from primary model -> fallback model |
|
| `litellm_deployment_failed_fallbacks` | Number of failed fallback requests from primary model -> fallback model |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,9 @@ import Image from '@theme/IdealImage';
|
||||||
import Tabs from '@theme/Tabs';
|
import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# 👥📊 Team Based Logging
|
# 👥📊 Team/Key Based Logging
|
||||||
|
|
||||||
Allow each team to use their own Langfuse Project / custom callbacks
|
Allow each key/team to use their own Langfuse Project / custom callbacks
|
||||||
|
|
||||||
**This allows you to do the following**
|
**This allows you to do the following**
|
||||||
```
|
```
|
||||||
|
@ -189,3 +189,39 @@ curl -X GET 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/cal
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## [BETA] Key Based Logging
|
||||||
|
|
||||||
|
Use the `/key/generate` or `/key/update` endpoints to add logging callbacks to a specific key.
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
✨ This is an Enterprise only feature [Get Started with Enterprise here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"metadata": {
|
||||||
|
"logging": {
|
||||||
|
"callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
|
||||||
|
"callback_type": "success" # set, if required by integration - future improvement, have logging tools work for success + failure by default
|
||||||
|
"callback_vars": {
|
||||||
|
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY", # [RECOMMENDED] reference key in proxy environment
|
||||||
|
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", # [RECOMMENDED] reference key in proxy environment
|
||||||
|
"langfuse_host": "https://cloud.langfuse.com"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Help us improve this feature, by filing a [ticket here](https://github.com/BerriAI/litellm/issues)
|
||||||
|
|
||||||
|
|
|
@ -151,7 +151,7 @@ const sidebars = {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "Chat Completions (litellm.completion)",
|
label: "Chat Completions (litellm.completion + PROXY)",
|
||||||
link: {
|
link: {
|
||||||
type: "generated-index",
|
type: "generated-index",
|
||||||
title: "Chat Completions",
|
title: "Chat Completions",
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||||
|
|
||||||
|
@ -29,6 +30,8 @@ class GCSBucketPayload(TypedDict):
|
||||||
end_time: str
|
end_time: str
|
||||||
response_cost: Optional[float]
|
response_cost: Optional[float]
|
||||||
spend_log_metadata: str
|
spend_log_metadata: str
|
||||||
|
exception: Optional[str]
|
||||||
|
log_event_type: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class GCSBucketLogger(CustomLogger):
|
class GCSBucketLogger(CustomLogger):
|
||||||
|
@ -79,6 +82,7 @@ class GCSBucketLogger(CustomLogger):
|
||||||
logging_payload: GCSBucketPayload = await self.get_gcs_payload(
|
logging_payload: GCSBucketPayload = await self.get_gcs_payload(
|
||||||
kwargs, response_obj, start_time_str, end_time_str
|
kwargs, response_obj, start_time_str, end_time_str
|
||||||
)
|
)
|
||||||
|
logging_payload["log_event_type"] = "successful_api_call"
|
||||||
|
|
||||||
json_logged_payload = json.dumps(logging_payload)
|
json_logged_payload = json.dumps(logging_payload)
|
||||||
|
|
||||||
|
@ -103,7 +107,56 @@ class GCSBucketLogger(CustomLogger):
|
||||||
verbose_logger.error("GCS Bucket logging error: %s", str(e))
|
verbose_logger.error("GCS Bucket logging error: %s", str(e))
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
pass
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
|
if premium_user is not True:
|
||||||
|
raise ValueError(
|
||||||
|
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
|
||||||
|
kwargs,
|
||||||
|
response_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
end_time_str = end_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
headers = await self.construct_request_headers()
|
||||||
|
|
||||||
|
logging_payload: GCSBucketPayload = await self.get_gcs_payload(
|
||||||
|
kwargs, response_obj, start_time_str, end_time_str
|
||||||
|
)
|
||||||
|
logging_payload["log_event_type"] = "failed_api_call"
|
||||||
|
|
||||||
|
_litellm_params = kwargs.get("litellm_params") or {}
|
||||||
|
metadata = _litellm_params.get("metadata") or {}
|
||||||
|
|
||||||
|
json_logged_payload = json.dumps(logging_payload)
|
||||||
|
|
||||||
|
# Get the current date
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
# Modify the object_name to include the date-based folder
|
||||||
|
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
if "gcs_log_id" in metadata:
|
||||||
|
object_name = metadata["gcs_log_id"]
|
||||||
|
|
||||||
|
response = await self.async_httpx_client.post(
|
||||||
|
headers=headers,
|
||||||
|
url=f"https://storage.googleapis.com/upload/storage/v1/b/{self.BUCKET_NAME}/o?uploadType=media&name={object_name}",
|
||||||
|
data=json_logged_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
|
||||||
|
|
||||||
|
verbose_logger.debug("GCS Bucket response %s", response)
|
||||||
|
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
|
||||||
|
verbose_logger.debug("GCS Bucket response.text %s", response.text)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error("GCS Bucket logging error: %s", str(e))
|
||||||
|
|
||||||
async def construct_request_headers(self) -> Dict[str, str]:
|
async def construct_request_headers(self) -> Dict[str, str]:
|
||||||
from litellm import vertex_chat_completion
|
from litellm import vertex_chat_completion
|
||||||
|
@ -139,9 +192,18 @@ class GCSBucketLogger(CustomLogger):
|
||||||
optional_params=kwargs.get("optional_params", None),
|
optional_params=kwargs.get("optional_params", None),
|
||||||
)
|
)
|
||||||
response_dict = {}
|
response_dict = {}
|
||||||
response_dict = convert_litellm_response_object_to_dict(
|
if response_obj:
|
||||||
response_obj=response_obj
|
response_dict = convert_litellm_response_object_to_dict(
|
||||||
)
|
response_obj=response_obj
|
||||||
|
)
|
||||||
|
|
||||||
|
exception_str = None
|
||||||
|
|
||||||
|
# Handle logging exception attributes
|
||||||
|
if "exception" in kwargs:
|
||||||
|
exception_str = kwargs.get("exception", "")
|
||||||
|
if not isinstance(exception_str, str):
|
||||||
|
exception_str = str(exception_str)
|
||||||
|
|
||||||
_spend_log_payload: SpendLogsPayload = get_logging_payload(
|
_spend_log_payload: SpendLogsPayload = get_logging_payload(
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
@ -156,8 +218,10 @@ class GCSBucketLogger(CustomLogger):
|
||||||
response_obj=response_dict,
|
response_obj=response_dict,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
spend_log_metadata=_spend_log_payload["metadata"],
|
spend_log_metadata=_spend_log_payload.get("metadata", ""),
|
||||||
response_cost=kwargs.get("response_cost", None),
|
response_cost=kwargs.get("response_cost", None),
|
||||||
|
exception=exception_str,
|
||||||
|
log_event_type=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return gcs_payload
|
return gcs_payload
|
||||||
|
|
|
@ -605,6 +605,12 @@ class LangFuseLogger:
|
||||||
if "cache_key" in litellm.langfuse_default_tags:
|
if "cache_key" in litellm.langfuse_default_tags:
|
||||||
_hidden_params = metadata.get("hidden_params", {}) or {}
|
_hidden_params = metadata.get("hidden_params", {}) or {}
|
||||||
_cache_key = _hidden_params.get("cache_key", None)
|
_cache_key = _hidden_params.get("cache_key", None)
|
||||||
|
if _cache_key is None:
|
||||||
|
# fallback to using "preset_cache_key"
|
||||||
|
_preset_cache_key = kwargs.get("litellm_params", {}).get(
|
||||||
|
"preset_cache_key", None
|
||||||
|
)
|
||||||
|
_cache_key = _preset_cache_key
|
||||||
tags.append(f"cache_key:{_cache_key}")
|
tags.append(f"cache_key:{_cache_key}")
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
@ -676,7 +682,6 @@ def log_provider_specific_information_as_span(
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.proxy_server import premium_user
|
|
||||||
|
|
||||||
_hidden_params = clean_metadata.get("hidden_params", None)
|
_hidden_params = clean_metadata.get("hidden_params", None)
|
||||||
if _hidden_params is None:
|
if _hidden_params is None:
|
||||||
|
|
|
@ -141,42 +141,42 @@ class PrometheusLogger(CustomLogger):
|
||||||
]
|
]
|
||||||
|
|
||||||
# Metric for deployment state
|
# Metric for deployment state
|
||||||
self.deployment_state = Gauge(
|
self.litellm_deployment_state = Gauge(
|
||||||
"deployment_state",
|
"litellm_deployment_state",
|
||||||
"LLM Deployment Analytics - The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage",
|
"LLM Deployment Analytics - The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage",
|
||||||
labelnames=_logged_llm_labels,
|
labelnames=_logged_llm_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm_deployment_success_responses = Counter(
|
self.litellm_deployment_success_responses = Counter(
|
||||||
name="llm_deployment_success_responses",
|
name="litellm_deployment_success_responses",
|
||||||
documentation="LLM Deployment Analytics - Total number of successful LLM API calls via litellm",
|
documentation="LLM Deployment Analytics - Total number of successful LLM API calls via litellm",
|
||||||
labelnames=_logged_llm_labels,
|
labelnames=_logged_llm_labels,
|
||||||
)
|
)
|
||||||
self.llm_deployment_failure_responses = Counter(
|
self.litellm_deployment_failure_responses = Counter(
|
||||||
name="llm_deployment_failure_responses",
|
name="litellm_deployment_failure_responses",
|
||||||
documentation="LLM Deployment Analytics - Total number of failed LLM API calls via litellm",
|
documentation="LLM Deployment Analytics - Total number of failed LLM API calls via litellm",
|
||||||
labelnames=_logged_llm_labels,
|
labelnames=_logged_llm_labels,
|
||||||
)
|
)
|
||||||
self.llm_deployment_total_requests = Counter(
|
self.litellm_deployment_total_requests = Counter(
|
||||||
name="llm_deployment_total_requests",
|
name="litellm_deployment_total_requests",
|
||||||
documentation="LLM Deployment Analytics - Total number of LLM API calls via litellm - success + failure",
|
documentation="LLM Deployment Analytics - Total number of LLM API calls via litellm - success + failure",
|
||||||
labelnames=_logged_llm_labels,
|
labelnames=_logged_llm_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Deployment Latency tracking
|
# Deployment Latency tracking
|
||||||
self.llm_deployment_latency_per_output_token = Histogram(
|
self.litellm_deployment_latency_per_output_token = Histogram(
|
||||||
name="llm_deployment_latency_per_output_token",
|
name="litellm_deployment_latency_per_output_token",
|
||||||
documentation="LLM Deployment Analytics - Latency per output token",
|
documentation="LLM Deployment Analytics - Latency per output token",
|
||||||
labelnames=_logged_llm_labels,
|
labelnames=_logged_llm_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm_deployment_successful_fallbacks = Counter(
|
self.litellm_deployment_successful_fallbacks = Counter(
|
||||||
"llm_deployment_successful_fallbacks",
|
"litellm_deployment_successful_fallbacks",
|
||||||
"LLM Deployment Analytics - Number of successful fallback requests from primary model -> fallback model",
|
"LLM Deployment Analytics - Number of successful fallback requests from primary model -> fallback model",
|
||||||
["primary_model", "fallback_model"],
|
["primary_model", "fallback_model"],
|
||||||
)
|
)
|
||||||
self.llm_deployment_failed_fallbacks = Counter(
|
self.litellm_deployment_failed_fallbacks = Counter(
|
||||||
"llm_deployment_failed_fallbacks",
|
"litellm_deployment_failed_fallbacks",
|
||||||
"LLM Deployment Analytics - Number of failed fallback requests from primary model -> fallback model",
|
"LLM Deployment Analytics - Number of failed fallback requests from primary model -> fallback model",
|
||||||
["primary_model", "fallback_model"],
|
["primary_model", "fallback_model"],
|
||||||
)
|
)
|
||||||
|
@ -358,14 +358,14 @@ class PrometheusLogger(CustomLogger):
|
||||||
api_provider=llm_provider,
|
api_provider=llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm_deployment_failure_responses.labels(
|
self.litellm_deployment_failure_responses.labels(
|
||||||
litellm_model_name=litellm_model_name,
|
litellm_model_name=litellm_model_name,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_provider=llm_provider,
|
api_provider=llm_provider,
|
||||||
).inc()
|
).inc()
|
||||||
|
|
||||||
self.llm_deployment_total_requests.labels(
|
self.litellm_deployment_total_requests.labels(
|
||||||
litellm_model_name=litellm_model_name,
|
litellm_model_name=litellm_model_name,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -438,14 +438,14 @@ class PrometheusLogger(CustomLogger):
|
||||||
api_provider=llm_provider,
|
api_provider=llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm_deployment_success_responses.labels(
|
self.litellm_deployment_success_responses.labels(
|
||||||
litellm_model_name=litellm_model_name,
|
litellm_model_name=litellm_model_name,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_provider=llm_provider,
|
api_provider=llm_provider,
|
||||||
).inc()
|
).inc()
|
||||||
|
|
||||||
self.llm_deployment_total_requests.labels(
|
self.litellm_deployment_total_requests.labels(
|
||||||
litellm_model_name=litellm_model_name,
|
litellm_model_name=litellm_model_name,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -475,7 +475,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
latency_per_token = None
|
latency_per_token = None
|
||||||
if output_tokens is not None and output_tokens > 0:
|
if output_tokens is not None and output_tokens > 0:
|
||||||
latency_per_token = _latency_seconds / output_tokens
|
latency_per_token = _latency_seconds / output_tokens
|
||||||
self.llm_deployment_latency_per_output_token.labels(
|
self.litellm_deployment_latency_per_output_token.labels(
|
||||||
litellm_model_name=litellm_model_name,
|
litellm_model_name=litellm_model_name,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -497,7 +497,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
_new_model = kwargs.get("model")
|
_new_model = kwargs.get("model")
|
||||||
self.llm_deployment_successful_fallbacks.labels(
|
self.litellm_deployment_successful_fallbacks.labels(
|
||||||
primary_model=original_model_group, fallback_model=_new_model
|
primary_model=original_model_group, fallback_model=_new_model
|
||||||
).inc()
|
).inc()
|
||||||
|
|
||||||
|
@ -508,11 +508,11 @@ class PrometheusLogger(CustomLogger):
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
_new_model = kwargs.get("model")
|
_new_model = kwargs.get("model")
|
||||||
self.llm_deployment_failed_fallbacks.labels(
|
self.litellm_deployment_failed_fallbacks.labels(
|
||||||
primary_model=original_model_group, fallback_model=_new_model
|
primary_model=original_model_group, fallback_model=_new_model
|
||||||
).inc()
|
).inc()
|
||||||
|
|
||||||
def set_deployment_state(
|
def set_litellm_deployment_state(
|
||||||
self,
|
self,
|
||||||
state: int,
|
state: int,
|
||||||
litellm_model_name: str,
|
litellm_model_name: str,
|
||||||
|
@ -520,7 +520,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
api_base: str,
|
api_base: str,
|
||||||
api_provider: str,
|
api_provider: str,
|
||||||
):
|
):
|
||||||
self.deployment_state.labels(
|
self.litellm_deployment_state.labels(
|
||||||
litellm_model_name, model_id, api_base, api_provider
|
litellm_model_name, model_id, api_base, api_provider
|
||||||
).set(state)
|
).set(state)
|
||||||
|
|
||||||
|
@ -531,7 +531,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
api_base: str,
|
api_base: str,
|
||||||
api_provider: str,
|
api_provider: str,
|
||||||
):
|
):
|
||||||
self.set_deployment_state(
|
self.set_litellm_deployment_state(
|
||||||
0, litellm_model_name, model_id, api_base, api_provider
|
0, litellm_model_name, model_id, api_base, api_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -542,7 +542,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
api_base: str,
|
api_base: str,
|
||||||
api_provider: str,
|
api_provider: str,
|
||||||
):
|
):
|
||||||
self.set_deployment_state(
|
self.set_litellm_deployment_state(
|
||||||
1, litellm_model_name, model_id, api_base, api_provider
|
1, litellm_model_name, model_id, api_base, api_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -553,7 +553,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
api_base: str,
|
api_base: str,
|
||||||
api_provider: str,
|
api_provider: str,
|
||||||
):
|
):
|
||||||
self.set_deployment_state(
|
self.set_litellm_deployment_state(
|
||||||
2, litellm_model_name, model_id, api_base, api_provider
|
2, litellm_model_name, model_id, api_base, api_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -41,8 +41,8 @@ async def get_fallback_metric_from_prometheus():
|
||||||
"""
|
"""
|
||||||
response_message = ""
|
response_message = ""
|
||||||
relevant_metrics = [
|
relevant_metrics = [
|
||||||
"llm_deployment_successful_fallbacks_total",
|
"litellm_deployment_successful_fallbacks_total",
|
||||||
"llm_deployment_failed_fallbacks_total",
|
"litellm_deployment_failed_fallbacks_total",
|
||||||
]
|
]
|
||||||
for metric in relevant_metrics:
|
for metric in relevant_metrics:
|
||||||
response_json = await get_metric_from_prometheus(
|
response_json = await get_metric_from_prometheus(
|
||||||
|
|
|
@ -35,6 +35,7 @@ from litellm.types.llms.anthropic import (
|
||||||
AnthropicResponseContentBlockText,
|
AnthropicResponseContentBlockText,
|
||||||
AnthropicResponseContentBlockToolUse,
|
AnthropicResponseContentBlockToolUse,
|
||||||
AnthropicResponseUsageBlock,
|
AnthropicResponseUsageBlock,
|
||||||
|
AnthropicSystemMessageContent,
|
||||||
ContentBlockDelta,
|
ContentBlockDelta,
|
||||||
ContentBlockStart,
|
ContentBlockStart,
|
||||||
ContentBlockStop,
|
ContentBlockStop,
|
||||||
|
@ -759,6 +760,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||||
|
_usage = completion_response["usage"]
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
model_response.created = int(time.time())
|
model_response.created = int(time.time())
|
||||||
|
@ -768,6 +770,11 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "cache_creation_input_tokens" in _usage:
|
||||||
|
usage["cache_creation_input_tokens"] = _usage["cache_creation_input_tokens"]
|
||||||
|
if "cache_read_input_tokens" in _usage:
|
||||||
|
usage["cache_read_input_tokens"] = _usage["cache_read_input_tokens"]
|
||||||
setattr(model_response, "usage", usage) # type: ignore
|
setattr(model_response, "usage", usage) # type: ignore
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
@ -901,6 +908,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
# Separate system prompt from rest of message
|
# Separate system prompt from rest of message
|
||||||
system_prompt_indices = []
|
system_prompt_indices = []
|
||||||
system_prompt = ""
|
system_prompt = ""
|
||||||
|
anthropic_system_message_list = None
|
||||||
for idx, message in enumerate(messages):
|
for idx, message in enumerate(messages):
|
||||||
if message["role"] == "system":
|
if message["role"] == "system":
|
||||||
valid_content: bool = False
|
valid_content: bool = False
|
||||||
|
@ -908,8 +916,23 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
system_prompt += message["content"]
|
system_prompt += message["content"]
|
||||||
valid_content = True
|
valid_content = True
|
||||||
elif isinstance(message["content"], list):
|
elif isinstance(message["content"], list):
|
||||||
for content in message["content"]:
|
for _content in message["content"]:
|
||||||
system_prompt += content.get("text", "")
|
anthropic_system_message_content = (
|
||||||
|
AnthropicSystemMessageContent(
|
||||||
|
type=_content.get("type"),
|
||||||
|
text=_content.get("text"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "cache_control" in _content:
|
||||||
|
anthropic_system_message_content["cache_control"] = (
|
||||||
|
_content["cache_control"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if anthropic_system_message_list is None:
|
||||||
|
anthropic_system_message_list = []
|
||||||
|
anthropic_system_message_list.append(
|
||||||
|
anthropic_system_message_content
|
||||||
|
)
|
||||||
valid_content = True
|
valid_content = True
|
||||||
|
|
||||||
if valid_content:
|
if valid_content:
|
||||||
|
@ -919,6 +942,10 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
messages.pop(idx)
|
messages.pop(idx)
|
||||||
if len(system_prompt) > 0:
|
if len(system_prompt) > 0:
|
||||||
optional_params["system"] = system_prompt
|
optional_params["system"] = system_prompt
|
||||||
|
|
||||||
|
# Handling anthropic API Prompt Caching
|
||||||
|
if anthropic_system_message_list is not None:
|
||||||
|
optional_params["system"] = anthropic_system_message_list
|
||||||
# Format rest of message according to anthropic guidelines
|
# Format rest of message according to anthropic guidelines
|
||||||
try:
|
try:
|
||||||
messages = prompt_factory(
|
messages = prompt_factory(
|
||||||
|
@ -954,6 +981,8 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
else: # assume openai tool call
|
else: # assume openai tool call
|
||||||
new_tool = tool["function"]
|
new_tool = tool["function"]
|
||||||
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
||||||
|
if "cache_control" in tool:
|
||||||
|
new_tool["cache_control"] = tool["cache_control"]
|
||||||
anthropic_tools.append(new_tool)
|
anthropic_tools.append(new_tool)
|
||||||
|
|
||||||
optional_params["tools"] = anthropic_tools
|
optional_params["tools"] = anthropic_tools
|
||||||
|
|
218
litellm/llms/base_aws_llm.py
Normal file
218
litellm/llms/base_aws_llm.py
Normal file
|
@ -0,0 +1,218 @@
|
||||||
|
import json
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.caching import DualCache, InMemoryCache
|
||||||
|
from litellm.utils import get_secret
|
||||||
|
|
||||||
|
from .base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
class AwsAuthError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAWSLLM(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.iam_cache = DualCache()
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_credentials(
|
||||||
|
self,
|
||||||
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_session_token: Optional[str] = None,
|
||||||
|
aws_region_name: Optional[str] = None,
|
||||||
|
aws_session_name: Optional[str] = None,
|
||||||
|
aws_profile_name: Optional[str] = None,
|
||||||
|
aws_role_name: Optional[str] = None,
|
||||||
|
aws_web_identity_token: Optional[str] = None,
|
||||||
|
aws_sts_endpoint: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return a boto3.Credentials object
|
||||||
|
"""
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
## CHECK IS 'os.environ/' passed in
|
||||||
|
params_to_check: List[Optional[str]] = [
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_session_token,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
|
aws_sts_endpoint,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Iterate over parameters and update if needed
|
||||||
|
for i, param in enumerate(params_to_check):
|
||||||
|
if param and param.startswith("os.environ/"):
|
||||||
|
_v = get_secret(param)
|
||||||
|
if _v is not None and isinstance(_v, str):
|
||||||
|
params_to_check[i] = _v
|
||||||
|
# Assign updated values back to parameters
|
||||||
|
(
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_session_token,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
|
aws_sts_endpoint,
|
||||||
|
) = params_to_check
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
"in get credentials\n"
|
||||||
|
"aws_access_key_id=%s\n"
|
||||||
|
"aws_secret_access_key=%s\n"
|
||||||
|
"aws_session_token=%s\n"
|
||||||
|
"aws_region_name=%s\n"
|
||||||
|
"aws_session_name=%s\n"
|
||||||
|
"aws_profile_name=%s\n"
|
||||||
|
"aws_role_name=%s\n"
|
||||||
|
"aws_web_identity_token=%s\n"
|
||||||
|
"aws_sts_endpoint=%s",
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_session_token,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
|
aws_sts_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
### CHECK STS ###
|
||||||
|
if (
|
||||||
|
aws_web_identity_token is not None
|
||||||
|
and aws_role_name is not None
|
||||||
|
and aws_session_name is not None
|
||||||
|
):
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if aws_sts_endpoint is None:
|
||||||
|
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
|
||||||
|
else:
|
||||||
|
sts_endpoint = aws_sts_endpoint
|
||||||
|
|
||||||
|
iam_creds_cache_key = json.dumps(
|
||||||
|
{
|
||||||
|
"aws_web_identity_token": aws_web_identity_token,
|
||||||
|
"aws_role_name": aws_role_name,
|
||||||
|
"aws_session_name": aws_session_name,
|
||||||
|
"aws_region_name": aws_region_name,
|
||||||
|
"aws_sts_endpoint": sts_endpoint,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
iam_creds_dict = self.iam_cache.get_cache(iam_creds_cache_key)
|
||||||
|
if iam_creds_dict is None:
|
||||||
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise AwsAuthError(
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
region_name=aws_region_name,
|
||||||
|
endpoint_url=sts_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
iam_creds_dict = {
|
||||||
|
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
||||||
|
"aws_secret_access_key": sts_response["Credentials"][
|
||||||
|
"SecretAccessKey"
|
||||||
|
],
|
||||||
|
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
||||||
|
"region_name": aws_region_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.iam_cache.set_cache(
|
||||||
|
key=iam_creds_cache_key,
|
||||||
|
value=json.dumps(iam_creds_dict),
|
||||||
|
ttl=3600 - 60,
|
||||||
|
)
|
||||||
|
|
||||||
|
session = boto3.Session(**iam_creds_dict)
|
||||||
|
|
||||||
|
iam_creds = session.get_credentials()
|
||||||
|
|
||||||
|
return iam_creds
|
||||||
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||||
|
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_response = sts_client.assume_role(
|
||||||
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the credentials from the response and convert to Session Credentials
|
||||||
|
sts_credentials = sts_response["Credentials"]
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
|
credentials = Credentials(
|
||||||
|
access_key=sts_credentials["AccessKeyId"],
|
||||||
|
secret_key=sts_credentials["SecretAccessKey"],
|
||||||
|
token=sts_credentials["SessionToken"],
|
||||||
|
)
|
||||||
|
return credentials
|
||||||
|
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||||
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||||
|
client = boto3.Session(profile_name=aws_profile_name)
|
||||||
|
|
||||||
|
return client.get_credentials()
|
||||||
|
elif (
|
||||||
|
aws_access_key_id is not None
|
||||||
|
and aws_secret_access_key is not None
|
||||||
|
and aws_session_token is not None
|
||||||
|
): ### CHECK FOR AWS SESSION TOKEN ###
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
|
credentials = Credentials(
|
||||||
|
access_key=aws_access_key_id,
|
||||||
|
secret_key=aws_secret_access_key,
|
||||||
|
token=aws_session_token,
|
||||||
|
)
|
||||||
|
return credentials
|
||||||
|
else:
|
||||||
|
session = boto3.Session(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return session.get_credentials()
|
|
@ -57,6 +57,7 @@ from litellm.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
from .base_aws_llm import BaseAWSLLM
|
||||||
from .bedrock import BedrockError, ModelResponseIterator, convert_messages_to_prompt
|
from .bedrock import BedrockError, ModelResponseIterator, convert_messages_to_prompt
|
||||||
from .prompt_templates.factory import (
|
from .prompt_templates.factory import (
|
||||||
_bedrock_converse_messages_pt,
|
_bedrock_converse_messages_pt,
|
||||||
|
@ -87,7 +88,6 @@ BEDROCK_CONVERSE_MODELS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
iam_cache = DualCache()
|
|
||||||
_response_stream_shape_cache = None
|
_response_stream_shape_cache = None
|
||||||
bedrock_tool_name_mappings: InMemoryCache = InMemoryCache(
|
bedrock_tool_name_mappings: InMemoryCache = InMemoryCache(
|
||||||
max_size_in_memory=50, default_ttl=600
|
max_size_in_memory=50, default_ttl=600
|
||||||
|
@ -312,7 +312,7 @@ def make_sync_call(
|
||||||
return completion_stream
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class BedrockLLM(BaseLLM):
|
class BedrockLLM(BaseAWSLLM):
|
||||||
"""
|
"""
|
||||||
Example call
|
Example call
|
||||||
|
|
||||||
|
@ -380,183 +380,6 @@ class BedrockLLM(BaseLLM):
|
||||||
prompt += f"{message['content']}"
|
prompt += f"{message['content']}"
|
||||||
return prompt, chat_history # type: ignore
|
return prompt, chat_history # type: ignore
|
||||||
|
|
||||||
def get_credentials(
|
|
||||||
self,
|
|
||||||
aws_access_key_id: Optional[str] = None,
|
|
||||||
aws_secret_access_key: Optional[str] = None,
|
|
||||||
aws_session_token: Optional[str] = None,
|
|
||||||
aws_region_name: Optional[str] = None,
|
|
||||||
aws_session_name: Optional[str] = None,
|
|
||||||
aws_profile_name: Optional[str] = None,
|
|
||||||
aws_role_name: Optional[str] = None,
|
|
||||||
aws_web_identity_token: Optional[str] = None,
|
|
||||||
aws_sts_endpoint: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Return a boto3.Credentials object
|
|
||||||
"""
|
|
||||||
import boto3
|
|
||||||
|
|
||||||
print_verbose(
|
|
||||||
f"Boto3 get_credentials called variables passed to function {locals()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
## CHECK IS 'os.environ/' passed in
|
|
||||||
params_to_check: List[Optional[str]] = [
|
|
||||||
aws_access_key_id,
|
|
||||||
aws_secret_access_key,
|
|
||||||
aws_session_token,
|
|
||||||
aws_region_name,
|
|
||||||
aws_session_name,
|
|
||||||
aws_profile_name,
|
|
||||||
aws_role_name,
|
|
||||||
aws_web_identity_token,
|
|
||||||
aws_sts_endpoint,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Iterate over parameters and update if needed
|
|
||||||
for i, param in enumerate(params_to_check):
|
|
||||||
if param and param.startswith("os.environ/"):
|
|
||||||
_v = get_secret(param)
|
|
||||||
if _v is not None and isinstance(_v, str):
|
|
||||||
params_to_check[i] = _v
|
|
||||||
# Assign updated values back to parameters
|
|
||||||
(
|
|
||||||
aws_access_key_id,
|
|
||||||
aws_secret_access_key,
|
|
||||||
aws_session_token,
|
|
||||||
aws_region_name,
|
|
||||||
aws_session_name,
|
|
||||||
aws_profile_name,
|
|
||||||
aws_role_name,
|
|
||||||
aws_web_identity_token,
|
|
||||||
aws_sts_endpoint,
|
|
||||||
) = params_to_check
|
|
||||||
|
|
||||||
### CHECK STS ###
|
|
||||||
if (
|
|
||||||
aws_web_identity_token is not None
|
|
||||||
and aws_role_name is not None
|
|
||||||
and aws_session_name is not None
|
|
||||||
):
|
|
||||||
print_verbose(
|
|
||||||
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if aws_sts_endpoint is None:
|
|
||||||
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
|
|
||||||
else:
|
|
||||||
sts_endpoint = aws_sts_endpoint
|
|
||||||
|
|
||||||
iam_creds_cache_key = json.dumps(
|
|
||||||
{
|
|
||||||
"aws_web_identity_token": aws_web_identity_token,
|
|
||||||
"aws_role_name": aws_role_name,
|
|
||||||
"aws_session_name": aws_session_name,
|
|
||||||
"aws_region_name": aws_region_name,
|
|
||||||
"aws_sts_endpoint": sts_endpoint,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
|
|
||||||
if iam_creds_dict is None:
|
|
||||||
oidc_token = get_secret(aws_web_identity_token)
|
|
||||||
|
|
||||||
if oidc_token is None:
|
|
||||||
raise BedrockError(
|
|
||||||
message="OIDC token could not be retrieved from secret manager.",
|
|
||||||
status_code=401,
|
|
||||||
)
|
|
||||||
|
|
||||||
sts_client = boto3.client(
|
|
||||||
"sts",
|
|
||||||
region_name=aws_region_name,
|
|
||||||
endpoint_url=sts_endpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
|
||||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
|
||||||
sts_response = sts_client.assume_role_with_web_identity(
|
|
||||||
RoleArn=aws_role_name,
|
|
||||||
RoleSessionName=aws_session_name,
|
|
||||||
WebIdentityToken=oidc_token,
|
|
||||||
DurationSeconds=3600,
|
|
||||||
)
|
|
||||||
|
|
||||||
iam_creds_dict = {
|
|
||||||
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
|
||||||
"aws_secret_access_key": sts_response["Credentials"][
|
|
||||||
"SecretAccessKey"
|
|
||||||
],
|
|
||||||
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
|
||||||
"region_name": aws_region_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
iam_cache.set_cache(
|
|
||||||
key=iam_creds_cache_key,
|
|
||||||
value=json.dumps(iam_creds_dict),
|
|
||||||
ttl=3600 - 60,
|
|
||||||
)
|
|
||||||
|
|
||||||
session = boto3.Session(**iam_creds_dict)
|
|
||||||
|
|
||||||
iam_creds = session.get_credentials()
|
|
||||||
|
|
||||||
return iam_creds
|
|
||||||
elif aws_role_name is not None and aws_session_name is not None:
|
|
||||||
print_verbose(
|
|
||||||
f"Using STS Client AWS aws_role_name: {aws_role_name} aws_session_name: {aws_session_name}"
|
|
||||||
)
|
|
||||||
sts_client = boto3.client(
|
|
||||||
"sts",
|
|
||||||
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
|
||||||
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
|
||||||
)
|
|
||||||
|
|
||||||
sts_response = sts_client.assume_role(
|
|
||||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract the credentials from the response and convert to Session Credentials
|
|
||||||
sts_credentials = sts_response["Credentials"]
|
|
||||||
from botocore.credentials import Credentials
|
|
||||||
|
|
||||||
credentials = Credentials(
|
|
||||||
access_key=sts_credentials["AccessKeyId"],
|
|
||||||
secret_key=sts_credentials["SecretAccessKey"],
|
|
||||||
token=sts_credentials["SessionToken"],
|
|
||||||
)
|
|
||||||
return credentials
|
|
||||||
elif aws_profile_name is not None: ### CHECK SESSION ###
|
|
||||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
|
||||||
print_verbose(f"Using AWS profile: {aws_profile_name}")
|
|
||||||
client = boto3.Session(profile_name=aws_profile_name)
|
|
||||||
|
|
||||||
return client.get_credentials()
|
|
||||||
elif (
|
|
||||||
aws_access_key_id is not None
|
|
||||||
and aws_secret_access_key is not None
|
|
||||||
and aws_session_token is not None
|
|
||||||
): ### CHECK FOR AWS SESSION TOKEN ###
|
|
||||||
print_verbose(f"Using AWS Session Token: {aws_session_token}")
|
|
||||||
from botocore.credentials import Credentials
|
|
||||||
|
|
||||||
credentials = Credentials(
|
|
||||||
access_key=aws_access_key_id,
|
|
||||||
secret_key=aws_secret_access_key,
|
|
||||||
token=aws_session_token,
|
|
||||||
)
|
|
||||||
return credentials
|
|
||||||
else:
|
|
||||||
print_verbose("Using Default AWS Session")
|
|
||||||
session = boto3.Session(
|
|
||||||
aws_access_key_id=aws_access_key_id,
|
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
|
||||||
region_name=aws_region_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
return session.get_credentials()
|
|
||||||
|
|
||||||
def process_response(
|
def process_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -1055,8 +878,8 @@ class BedrockLLM(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
raise BedrockError(
|
raise BedrockError(
|
||||||
status_code=400,
|
status_code=404,
|
||||||
message="Bedrock HTTPX: Unsupported provider={}, model={}".format(
|
message="Bedrock HTTPX: Unknown provider={}, model={}".format(
|
||||||
provider, model
|
provider, model
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -1414,7 +1237,7 @@ class AmazonConverseConfig:
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
class BedrockConverseLLM(BaseLLM):
|
class BedrockConverseLLM(BaseAWSLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -1554,173 +1377,6 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
"""
|
"""
|
||||||
return urllib.parse.quote(model_id, safe="")
|
return urllib.parse.quote(model_id, safe="")
|
||||||
|
|
||||||
def get_credentials(
|
|
||||||
self,
|
|
||||||
aws_access_key_id: Optional[str] = None,
|
|
||||||
aws_secret_access_key: Optional[str] = None,
|
|
||||||
aws_session_token: Optional[str] = None,
|
|
||||||
aws_region_name: Optional[str] = None,
|
|
||||||
aws_session_name: Optional[str] = None,
|
|
||||||
aws_profile_name: Optional[str] = None,
|
|
||||||
aws_role_name: Optional[str] = None,
|
|
||||||
aws_web_identity_token: Optional[str] = None,
|
|
||||||
aws_sts_endpoint: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Return a boto3.Credentials object
|
|
||||||
"""
|
|
||||||
import boto3
|
|
||||||
|
|
||||||
## CHECK IS 'os.environ/' passed in
|
|
||||||
params_to_check: List[Optional[str]] = [
|
|
||||||
aws_access_key_id,
|
|
||||||
aws_secret_access_key,
|
|
||||||
aws_session_token,
|
|
||||||
aws_region_name,
|
|
||||||
aws_session_name,
|
|
||||||
aws_profile_name,
|
|
||||||
aws_role_name,
|
|
||||||
aws_web_identity_token,
|
|
||||||
aws_sts_endpoint,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Iterate over parameters and update if needed
|
|
||||||
for i, param in enumerate(params_to_check):
|
|
||||||
if param and param.startswith("os.environ/"):
|
|
||||||
_v = get_secret(param)
|
|
||||||
if _v is not None and isinstance(_v, str):
|
|
||||||
params_to_check[i] = _v
|
|
||||||
# Assign updated values back to parameters
|
|
||||||
(
|
|
||||||
aws_access_key_id,
|
|
||||||
aws_secret_access_key,
|
|
||||||
aws_session_token,
|
|
||||||
aws_region_name,
|
|
||||||
aws_session_name,
|
|
||||||
aws_profile_name,
|
|
||||||
aws_role_name,
|
|
||||||
aws_web_identity_token,
|
|
||||||
aws_sts_endpoint,
|
|
||||||
) = params_to_check
|
|
||||||
|
|
||||||
### CHECK STS ###
|
|
||||||
if (
|
|
||||||
aws_web_identity_token is not None
|
|
||||||
and aws_role_name is not None
|
|
||||||
and aws_session_name is not None
|
|
||||||
):
|
|
||||||
print_verbose(
|
|
||||||
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if aws_sts_endpoint is None:
|
|
||||||
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
|
|
||||||
else:
|
|
||||||
sts_endpoint = aws_sts_endpoint
|
|
||||||
|
|
||||||
iam_creds_cache_key = json.dumps(
|
|
||||||
{
|
|
||||||
"aws_web_identity_token": aws_web_identity_token,
|
|
||||||
"aws_role_name": aws_role_name,
|
|
||||||
"aws_session_name": aws_session_name,
|
|
||||||
"aws_region_name": aws_region_name,
|
|
||||||
"aws_sts_endpoint": sts_endpoint,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
|
|
||||||
if iam_creds_dict is None:
|
|
||||||
oidc_token = get_secret(aws_web_identity_token)
|
|
||||||
|
|
||||||
if oidc_token is None:
|
|
||||||
raise BedrockError(
|
|
||||||
message="OIDC token could not be retrieved from secret manager.",
|
|
||||||
status_code=401,
|
|
||||||
)
|
|
||||||
|
|
||||||
sts_client = boto3.client(
|
|
||||||
"sts",
|
|
||||||
region_name=aws_region_name,
|
|
||||||
endpoint_url=sts_endpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
|
||||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
|
||||||
sts_response = sts_client.assume_role_with_web_identity(
|
|
||||||
RoleArn=aws_role_name,
|
|
||||||
RoleSessionName=aws_session_name,
|
|
||||||
WebIdentityToken=oidc_token,
|
|
||||||
DurationSeconds=3600,
|
|
||||||
)
|
|
||||||
|
|
||||||
iam_creds_dict = {
|
|
||||||
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
|
||||||
"aws_secret_access_key": sts_response["Credentials"][
|
|
||||||
"SecretAccessKey"
|
|
||||||
],
|
|
||||||
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
|
||||||
"region_name": aws_region_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
iam_cache.set_cache(
|
|
||||||
key=iam_creds_cache_key,
|
|
||||||
value=json.dumps(iam_creds_dict),
|
|
||||||
ttl=3600 - 60,
|
|
||||||
)
|
|
||||||
|
|
||||||
session = boto3.Session(**iam_creds_dict)
|
|
||||||
|
|
||||||
iam_creds = session.get_credentials()
|
|
||||||
|
|
||||||
return iam_creds
|
|
||||||
elif aws_role_name is not None and aws_session_name is not None:
|
|
||||||
sts_client = boto3.client(
|
|
||||||
"sts",
|
|
||||||
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
|
||||||
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
|
||||||
)
|
|
||||||
|
|
||||||
sts_response = sts_client.assume_role(
|
|
||||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract the credentials from the response and convert to Session Credentials
|
|
||||||
sts_credentials = sts_response["Credentials"]
|
|
||||||
from botocore.credentials import Credentials
|
|
||||||
|
|
||||||
credentials = Credentials(
|
|
||||||
access_key=sts_credentials["AccessKeyId"],
|
|
||||||
secret_key=sts_credentials["SecretAccessKey"],
|
|
||||||
token=sts_credentials["SessionToken"],
|
|
||||||
)
|
|
||||||
return credentials
|
|
||||||
elif aws_profile_name is not None: ### CHECK SESSION ###
|
|
||||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
|
||||||
client = boto3.Session(profile_name=aws_profile_name)
|
|
||||||
|
|
||||||
return client.get_credentials()
|
|
||||||
elif (
|
|
||||||
aws_access_key_id is not None
|
|
||||||
and aws_secret_access_key is not None
|
|
||||||
and aws_session_token is not None
|
|
||||||
): ### CHECK FOR AWS SESSION TOKEN ###
|
|
||||||
from botocore.credentials import Credentials
|
|
||||||
|
|
||||||
credentials = Credentials(
|
|
||||||
access_key=aws_access_key_id,
|
|
||||||
secret_key=aws_secret_access_key,
|
|
||||||
token=aws_session_token,
|
|
||||||
)
|
|
||||||
return credentials
|
|
||||||
else:
|
|
||||||
session = boto3.Session(
|
|
||||||
aws_access_key_id=aws_access_key_id,
|
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
|
||||||
region_name=aws_region_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
return session.get_credentials()
|
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -601,12 +601,13 @@ def ollama_embeddings(
|
||||||
):
|
):
|
||||||
return asyncio.run(
|
return asyncio.run(
|
||||||
ollama_aembeddings(
|
ollama_aembeddings(
|
||||||
api_base,
|
api_base=api_base,
|
||||||
model,
|
model=model,
|
||||||
prompts,
|
prompts=prompts,
|
||||||
optional_params,
|
model_response=model_response,
|
||||||
logging_obj,
|
optional_params=optional_params,
|
||||||
model_response,
|
logging_obj=logging_obj,
|
||||||
encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -356,6 +356,7 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
|
||||||
"json": data,
|
"json": data,
|
||||||
"method": "POST",
|
"method": "POST",
|
||||||
"timeout": litellm.request_timeout,
|
"timeout": litellm.request_timeout,
|
||||||
|
"follow_redirects": True
|
||||||
}
|
}
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
_request["headers"] = {"Authorization": "Bearer {}".format(api_key)}
|
_request["headers"] = {"Authorization": "Bearer {}".format(api_key)}
|
||||||
|
|
|
@ -1224,6 +1224,19 @@ def convert_to_anthropic_tool_invoke(
|
||||||
return anthropic_tool_invoke
|
return anthropic_tool_invoke
|
||||||
|
|
||||||
|
|
||||||
|
def add_cache_control_to_content(
|
||||||
|
anthropic_content_element: Union[
|
||||||
|
dict, AnthropicMessagesImageParam, AnthropicMessagesTextParam
|
||||||
|
],
|
||||||
|
orignal_content_element: dict,
|
||||||
|
):
|
||||||
|
if "cache_control" in orignal_content_element:
|
||||||
|
anthropic_content_element["cache_control"] = orignal_content_element[
|
||||||
|
"cache_control"
|
||||||
|
]
|
||||||
|
return anthropic_content_element
|
||||||
|
|
||||||
|
|
||||||
def anthropic_messages_pt(
|
def anthropic_messages_pt(
|
||||||
messages: list,
|
messages: list,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -1264,18 +1277,31 @@ def anthropic_messages_pt(
|
||||||
image_chunk = convert_to_anthropic_image_obj(
|
image_chunk = convert_to_anthropic_image_obj(
|
||||||
m["image_url"]["url"]
|
m["image_url"]["url"]
|
||||||
)
|
)
|
||||||
user_content.append(
|
|
||||||
AnthropicMessagesImageParam(
|
_anthropic_content_element = AnthropicMessagesImageParam(
|
||||||
type="image",
|
type="image",
|
||||||
source=AnthropicImageParamSource(
|
source=AnthropicImageParamSource(
|
||||||
type="base64",
|
type="base64",
|
||||||
media_type=image_chunk["media_type"],
|
media_type=image_chunk["media_type"],
|
||||||
data=image_chunk["data"],
|
data=image_chunk["data"],
|
||||||
),
|
),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
anthropic_content_element = add_cache_control_to_content(
|
||||||
|
anthropic_content_element=_anthropic_content_element,
|
||||||
|
orignal_content_element=m,
|
||||||
|
)
|
||||||
|
user_content.append(anthropic_content_element)
|
||||||
elif m.get("type", "") == "text":
|
elif m.get("type", "") == "text":
|
||||||
user_content.append({"type": "text", "text": m["text"]})
|
_anthropic_text_content_element = {
|
||||||
|
"type": "text",
|
||||||
|
"text": m["text"],
|
||||||
|
}
|
||||||
|
anthropic_content_element = add_cache_control_to_content(
|
||||||
|
anthropic_content_element=_anthropic_text_content_element,
|
||||||
|
orignal_content_element=m,
|
||||||
|
)
|
||||||
|
user_content.append(anthropic_content_element)
|
||||||
elif (
|
elif (
|
||||||
messages[msg_i]["role"] == "tool"
|
messages[msg_i]["role"] == "tool"
|
||||||
or messages[msg_i]["role"] == "function"
|
or messages[msg_i]["role"] == "function"
|
||||||
|
@ -1306,6 +1332,10 @@ def anthropic_messages_pt(
|
||||||
anthropic_message = AnthropicMessagesTextParam(
|
anthropic_message = AnthropicMessagesTextParam(
|
||||||
type="text", text=m.get("text")
|
type="text", text=m.get("text")
|
||||||
)
|
)
|
||||||
|
anthropic_message = add_cache_control_to_content(
|
||||||
|
anthropic_content_element=anthropic_message,
|
||||||
|
orignal_content_element=m,
|
||||||
|
)
|
||||||
assistant_content.append(anthropic_message)
|
assistant_content.append(anthropic_message)
|
||||||
elif (
|
elif (
|
||||||
"content" in messages[msg_i]
|
"content" in messages[msg_i]
|
||||||
|
@ -1313,9 +1343,17 @@ def anthropic_messages_pt(
|
||||||
and len(messages[msg_i]["content"])
|
and len(messages[msg_i]["content"])
|
||||||
> 0 # don't pass empty text blocks. anthropic api raises errors.
|
> 0 # don't pass empty text blocks. anthropic api raises errors.
|
||||||
):
|
):
|
||||||
assistant_content.append(
|
|
||||||
{"type": "text", "text": messages[msg_i]["content"]}
|
_anthropic_text_content_element = {
|
||||||
|
"type": "text",
|
||||||
|
"text": messages[msg_i]["content"],
|
||||||
|
}
|
||||||
|
|
||||||
|
anthropic_content_element = add_cache_control_to_content(
|
||||||
|
anthropic_content_element=_anthropic_text_content_element,
|
||||||
|
orignal_content_element=messages[msg_i],
|
||||||
)
|
)
|
||||||
|
assistant_content.append(anthropic_content_element)
|
||||||
|
|
||||||
if messages[msg_i].get(
|
if messages[msg_i].get(
|
||||||
"tool_calls", []
|
"tool_calls", []
|
||||||
|
@ -1701,12 +1739,14 @@ def cohere_messages_pt_v2(
|
||||||
assistant_tool_calls: List[ToolCallObject] = []
|
assistant_tool_calls: List[ToolCallObject] = []
|
||||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||||
assistant_text = (
|
if isinstance(messages[msg_i]["content"], list):
|
||||||
messages[msg_i].get("content") or ""
|
for m in messages[msg_i]["content"]:
|
||||||
) # either string or none
|
if m.get("type", "") == "text":
|
||||||
if assistant_text:
|
assistant_content += m["text"]
|
||||||
assistant_content += assistant_text
|
elif messages[msg_i].get("content") is not None and isinstance(
|
||||||
|
messages[msg_i]["content"], str
|
||||||
|
):
|
||||||
|
assistant_content += messages[msg_i]["content"]
|
||||||
if messages[msg_i].get(
|
if messages[msg_i].get(
|
||||||
"tool_calls", []
|
"tool_calls", []
|
||||||
): # support assistant tool invoke conversion
|
): # support assistant tool invoke conversion
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -240,10 +240,10 @@ class TritonChatCompletion(BaseLLM):
|
||||||
handler = HTTPHandler()
|
handler = HTTPHandler()
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_stream(
|
return self._handle_stream(
|
||||||
handler, api_base, data_for_triton, model, logging_obj
|
handler, api_base, json_data_for_triton, model, logging_obj
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = handler.post(url=api_base, data=data_for_triton, headers=headers)
|
response = handler.post(url=api_base, data=json_data_for_triton, headers=headers)
|
||||||
return self._handle_response(
|
return self._handle_response(
|
||||||
response, model_response, logging_obj, type_of_model=type_of_model
|
response, model_response, logging_obj, type_of_model=type_of_model
|
||||||
)
|
)
|
||||||
|
|
|
@ -95,7 +95,6 @@ from .llms import (
|
||||||
palm,
|
palm,
|
||||||
petals,
|
petals,
|
||||||
replicate,
|
replicate,
|
||||||
sagemaker,
|
|
||||||
together_ai,
|
together_ai,
|
||||||
triton,
|
triton,
|
||||||
vertex_ai,
|
vertex_ai,
|
||||||
|
@ -120,6 +119,7 @@ from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
stringify_json_tool_call_content,
|
stringify_json_tool_call_content,
|
||||||
)
|
)
|
||||||
|
from .llms.sagemaker import SagemakerLLM
|
||||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.vertex_ai_partner import VertexAIPartnerModels
|
from .llms.vertex_ai_partner import VertexAIPartnerModels
|
||||||
|
@ -166,6 +166,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
vertex_chat_completion = VertexLLM()
|
vertex_chat_completion = VertexLLM()
|
||||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||||
watsonxai = IBMWatsonXAI()
|
watsonxai = IBMWatsonXAI()
|
||||||
|
sagemaker_llm = SagemakerLLM()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -2216,7 +2217,7 @@ def completion(
|
||||||
response = model_response
|
response = model_response
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
model_response = sagemaker.completion(
|
model_response = sagemaker_llm.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
@ -2230,26 +2231,13 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
)
|
)
|
||||||
if (
|
if optional_params.get("stream", False):
|
||||||
"stream" in optional_params and optional_params["stream"] == True
|
|
||||||
): ## [BETA]
|
|
||||||
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
|
||||||
from .llms.sagemaker import TokenIterator
|
|
||||||
|
|
||||||
tokenIterator = TokenIterator(model_response, acompletion=acompletion)
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
completion_stream=tokenIterator,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="sagemaker",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
original_response=response,
|
original_response=model_response,
|
||||||
)
|
)
|
||||||
return response
|
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
response = model_response
|
response = model_response
|
||||||
|
@ -3529,7 +3517,7 @@ def embedding(
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
response = sagemaker.embedding(
|
response = sagemaker_llm.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
@ -4898,7 +4886,6 @@ async def ahealth_check(
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
"litellm.ahealth_check(): Exception occured - {}".format(str(e))
|
"litellm.ahealth_check(): Exception occured - {}".format(str(e))
|
||||||
)
|
)
|
||||||
verbose_logger.debug(traceback.format_exc())
|
|
||||||
stack_trace = traceback.format_exc()
|
stack_trace = traceback.format_exc()
|
||||||
if isinstance(stack_trace, str):
|
if isinstance(stack_trace, str):
|
||||||
stack_trace = stack_trace[:1000]
|
stack_trace = stack_trace[:1000]
|
||||||
|
@ -4907,7 +4894,12 @@ async def ahealth_check(
|
||||||
"error": "Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models"
|
"error": "Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models"
|
||||||
}
|
}
|
||||||
|
|
||||||
error_to_return = str(e) + " stack trace: " + stack_trace
|
error_to_return = (
|
||||||
|
str(e)
|
||||||
|
+ "\nHave you set 'mode' - https://docs.litellm.ai/docs/proxy/health#embedding-models"
|
||||||
|
+ "\nstack trace: "
|
||||||
|
+ stack_trace
|
||||||
|
)
|
||||||
return {"error": error_to_return}
|
return {"error": error_to_return}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,18 @@
|
||||||
"supports_parallel_function_calling": true,
|
"supports_parallel_function_calling": true,
|
||||||
"supports_vision": true
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
|
"chatgpt-4o-latest": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000005,
|
||||||
|
"output_cost_per_token": 0.000015,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
|
},
|
||||||
"gpt-4o-2024-05-13": {
|
"gpt-4o-2024-05-13": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
@ -2062,7 +2074,8 @@
|
||||||
"litellm_provider": "vertex_ai-anthropic_models",
|
"litellm_provider": "vertex_ai-anthropic_models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true
|
"supports_vision": true,
|
||||||
|
"supports_assistant_prefill": true
|
||||||
},
|
},
|
||||||
"vertex_ai/claude-3-5-sonnet@20240620": {
|
"vertex_ai/claude-3-5-sonnet@20240620": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -2073,7 +2086,8 @@
|
||||||
"litellm_provider": "vertex_ai-anthropic_models",
|
"litellm_provider": "vertex_ai-anthropic_models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true
|
"supports_vision": true,
|
||||||
|
"supports_assistant_prefill": true
|
||||||
},
|
},
|
||||||
"vertex_ai/claude-3-haiku@20240307": {
|
"vertex_ai/claude-3-haiku@20240307": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -2084,7 +2098,8 @@
|
||||||
"litellm_provider": "vertex_ai-anthropic_models",
|
"litellm_provider": "vertex_ai-anthropic_models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true
|
"supports_vision": true,
|
||||||
|
"supports_assistant_prefill": true
|
||||||
},
|
},
|
||||||
"vertex_ai/claude-3-opus@20240229": {
|
"vertex_ai/claude-3-opus@20240229": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -2095,7 +2110,8 @@
|
||||||
"litellm_provider": "vertex_ai-anthropic_models",
|
"litellm_provider": "vertex_ai-anthropic_models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true
|
"supports_vision": true,
|
||||||
|
"supports_assistant_prefill": true
|
||||||
},
|
},
|
||||||
"vertex_ai/meta/llama3-405b-instruct-maas": {
|
"vertex_ai/meta/llama3-405b-instruct-maas": {
|
||||||
"max_tokens": 32000,
|
"max_tokens": 32000,
|
||||||
|
@ -4519,6 +4535,69 @@
|
||||||
"litellm_provider": "perplexity",
|
"litellm_provider": "perplexity",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"perplexity/llama-3.1-70b-instruct": {
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"max_input_tokens": 131072,
|
||||||
|
"max_output_tokens": 131072,
|
||||||
|
"input_cost_per_token": 0.000001,
|
||||||
|
"output_cost_per_token": 0.000001,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-3.1-8b-instruct": {
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"max_input_tokens": 131072,
|
||||||
|
"max_output_tokens": 131072,
|
||||||
|
"input_cost_per_token": 0.0000002,
|
||||||
|
"output_cost_per_token": 0.0000002,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-3.1-sonar-huge-128k-online": {
|
||||||
|
"max_tokens": 127072,
|
||||||
|
"max_input_tokens": 127072,
|
||||||
|
"max_output_tokens": 127072,
|
||||||
|
"input_cost_per_token": 0.000005,
|
||||||
|
"output_cost_per_token": 0.000005,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-3.1-sonar-large-128k-online": {
|
||||||
|
"max_tokens": 127072,
|
||||||
|
"max_input_tokens": 127072,
|
||||||
|
"max_output_tokens": 127072,
|
||||||
|
"input_cost_per_token": 0.000001,
|
||||||
|
"output_cost_per_token": 0.000001,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-3.1-sonar-large-128k-chat": {
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"max_input_tokens": 131072,
|
||||||
|
"max_output_tokens": 131072,
|
||||||
|
"input_cost_per_token": 0.000001,
|
||||||
|
"output_cost_per_token": 0.000001,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-3.1-sonar-small-128k-chat": {
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"max_input_tokens": 131072,
|
||||||
|
"max_output_tokens": 131072,
|
||||||
|
"input_cost_per_token": 0.0000002,
|
||||||
|
"output_cost_per_token": 0.0000002,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-3.1-sonar-small-128k-online": {
|
||||||
|
"max_tokens": 127072,
|
||||||
|
"max_input_tokens": 127072,
|
||||||
|
"max_output_tokens": 127072,
|
||||||
|
"input_cost_per_token": 0.0000002,
|
||||||
|
"output_cost_per_token": 0.0000002,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"perplexity/pplx-7b-chat": {
|
"perplexity/pplx-7b-chat": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 8192,
|
"max_input_tokens": 8192,
|
||||||
|
|
|
@ -1,13 +1,6 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "*"
|
- model_name: "gpt-4"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "*"
|
model: "gpt-4"
|
||||||
|
model_info:
|
||||||
# general_settings:
|
my_custom_key: "my_custom_value"
|
||||||
# master_key: sk-1234
|
|
||||||
# pass_through_endpoints:
|
|
||||||
# - path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server
|
|
||||||
# target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward
|
|
||||||
# headers:
|
|
||||||
# LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_PUBLIC_KEY" # your langfuse account public key
|
|
||||||
# LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_SECRET_KEY" # your langfuse account secret key
|
|
|
@ -12,7 +12,7 @@ import json
|
||||||
import secrets
|
import secrets
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
|
@ -125,7 +125,7 @@ async def user_api_key_auth(
|
||||||
# Check 2. FILTER IP ADDRESS
|
# Check 2. FILTER IP ADDRESS
|
||||||
await check_if_request_size_is_safe(request=request)
|
await check_if_request_size_is_safe(request=request)
|
||||||
|
|
||||||
is_valid_ip = _check_valid_ip(
|
is_valid_ip, passed_in_ip = _check_valid_ip(
|
||||||
allowed_ips=general_settings.get("allowed_ips", None),
|
allowed_ips=general_settings.get("allowed_ips", None),
|
||||||
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -134,7 +134,7 @@ async def user_api_key_auth(
|
||||||
if not is_valid_ip:
|
if not is_valid_ip:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Access forbidden: IP address not allowed.",
|
detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
|
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
|
||||||
|
@ -1251,12 +1251,12 @@ def _check_valid_ip(
|
||||||
allowed_ips: Optional[List[str]],
|
allowed_ips: Optional[List[str]],
|
||||||
request: Request,
|
request: Request,
|
||||||
use_x_forwarded_for: Optional[bool] = False,
|
use_x_forwarded_for: Optional[bool] = False,
|
||||||
) -> bool:
|
) -> Tuple[bool, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Returns if ip is allowed or not
|
Returns if ip is allowed or not
|
||||||
"""
|
"""
|
||||||
if allowed_ips is None: # if not set, assume true
|
if allowed_ips is None: # if not set, assume true
|
||||||
return True
|
return True, None
|
||||||
|
|
||||||
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
|
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
|
||||||
client_ip = None
|
client_ip = None
|
||||||
|
@ -1267,9 +1267,9 @@ def _check_valid_ip(
|
||||||
|
|
||||||
# Check if IP address is allowed
|
# Check if IP address is allowed
|
||||||
if client_ip not in allowed_ips:
|
if client_ip not in allowed_ips:
|
||||||
return False
|
return False, client_ip
|
||||||
|
|
||||||
return True
|
return True, client_ip
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_from_custom_header(
|
def get_api_key_from_custom_header(
|
||||||
|
|
56
litellm/proxy/common_utils/load_config_utils.py
Normal file
56
litellm/proxy/common_utils/load_config_utils.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_contents_from_s3(bucket_name, object_key):
|
||||||
|
try:
|
||||||
|
# v0 rely on boto3 for authentication - allowing boto3 to handle IAM credentials etc
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.config import Config
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
|
from litellm.main import bedrock_converse_chat_completion
|
||||||
|
|
||||||
|
credentials: Credentials = bedrock_converse_chat_completion.get_credentials()
|
||||||
|
s3_client = boto3.client(
|
||||||
|
"s3",
|
||||||
|
aws_access_key_id=credentials.access_key,
|
||||||
|
aws_secret_access_key=credentials.secret_key,
|
||||||
|
aws_session_token=credentials.token, # Optional, if using temporary credentials
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"Retrieving {object_key} from S3 bucket: {bucket_name}"
|
||||||
|
)
|
||||||
|
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
|
||||||
|
verbose_proxy_logger.debug(f"Response: {response}")
|
||||||
|
|
||||||
|
# Read the file contents
|
||||||
|
file_contents = response["Body"].read().decode("utf-8")
|
||||||
|
verbose_proxy_logger.debug(f"File contents retrieved from S3")
|
||||||
|
|
||||||
|
# Create a temporary file with YAML extension
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") as temp_file:
|
||||||
|
temp_file.write(file_contents.encode("utf-8"))
|
||||||
|
temp_file_path = temp_file.name
|
||||||
|
verbose_proxy_logger.debug(f"File stored temporarily at: {temp_file_path}")
|
||||||
|
|
||||||
|
# Load the YAML file content
|
||||||
|
with open(temp_file_path, "r") as yaml_file:
|
||||||
|
config = yaml.safe_load(yaml_file)
|
||||||
|
|
||||||
|
return config
|
||||||
|
except ImportError:
|
||||||
|
# this is most likely if a user is not using the litellm docker container
|
||||||
|
verbose_proxy_logger.error(f"ImportError: {str(e)}")
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# # Example usage
|
||||||
|
# bucket_name = 'litellm-proxy'
|
||||||
|
# object_key = 'litellm_proxy_config.yaml'
|
|
@ -5,7 +5,12 @@ from fastapi import Request
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
from litellm.proxy._types import CommonProxyErrors, TeamCallbackMetadata, UserAPIKeyAuth
|
from litellm.proxy._types import (
|
||||||
|
AddTeamCallback,
|
||||||
|
CommonProxyErrors,
|
||||||
|
TeamCallbackMetadata,
|
||||||
|
UserAPIKeyAuth,
|
||||||
|
)
|
||||||
from litellm.types.utils import SupportedCacheControls
|
from litellm.types.utils import SupportedCacheControls
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -59,6 +64,42 @@ def safe_add_api_version_from_query_params(data: dict, request: Request):
|
||||||
verbose_logger.error("error checking api version in query params: %s", str(e))
|
verbose_logger.error("error checking api version in query params: %s", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_key_logging_metadata_to_callback(
|
||||||
|
data: AddTeamCallback, team_callback_settings_obj: Optional[TeamCallbackMetadata]
|
||||||
|
) -> TeamCallbackMetadata:
|
||||||
|
if team_callback_settings_obj is None:
|
||||||
|
team_callback_settings_obj = TeamCallbackMetadata()
|
||||||
|
if data.callback_type == "success":
|
||||||
|
if team_callback_settings_obj.success_callback is None:
|
||||||
|
team_callback_settings_obj.success_callback = []
|
||||||
|
|
||||||
|
if data.callback_name not in team_callback_settings_obj.success_callback:
|
||||||
|
team_callback_settings_obj.success_callback.append(data.callback_name)
|
||||||
|
elif data.callback_type == "failure":
|
||||||
|
if team_callback_settings_obj.failure_callback is None:
|
||||||
|
team_callback_settings_obj.failure_callback = []
|
||||||
|
|
||||||
|
if data.callback_name not in team_callback_settings_obj.failure_callback:
|
||||||
|
team_callback_settings_obj.failure_callback.append(data.callback_name)
|
||||||
|
elif data.callback_type == "success_and_failure":
|
||||||
|
if team_callback_settings_obj.success_callback is None:
|
||||||
|
team_callback_settings_obj.success_callback = []
|
||||||
|
if team_callback_settings_obj.failure_callback is None:
|
||||||
|
team_callback_settings_obj.failure_callback = []
|
||||||
|
if data.callback_name not in team_callback_settings_obj.success_callback:
|
||||||
|
team_callback_settings_obj.success_callback.append(data.callback_name)
|
||||||
|
|
||||||
|
if data.callback_name in team_callback_settings_obj.failure_callback:
|
||||||
|
team_callback_settings_obj.failure_callback.append(data.callback_name)
|
||||||
|
|
||||||
|
for var, value in data.callback_vars.items():
|
||||||
|
if team_callback_settings_obj.callback_vars is None:
|
||||||
|
team_callback_settings_obj.callback_vars = {}
|
||||||
|
team_callback_settings_obj.callback_vars[var] = litellm.get_secret(value)
|
||||||
|
|
||||||
|
return team_callback_settings_obj
|
||||||
|
|
||||||
|
|
||||||
async def add_litellm_data_to_request(
|
async def add_litellm_data_to_request(
|
||||||
data: dict,
|
data: dict,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
@ -85,14 +126,19 @@ async def add_litellm_data_to_request(
|
||||||
|
|
||||||
safe_add_api_version_from_query_params(data, request)
|
safe_add_api_version_from_query_params(data, request)
|
||||||
|
|
||||||
|
_headers = dict(request.headers)
|
||||||
|
|
||||||
# Include original request and headers in the data
|
# Include original request and headers in the data
|
||||||
data["proxy_server_request"] = {
|
data["proxy_server_request"] = {
|
||||||
"url": str(request.url),
|
"url": str(request.url),
|
||||||
"method": request.method,
|
"method": request.method,
|
||||||
"headers": dict(request.headers),
|
"headers": _headers,
|
||||||
"body": copy.copy(data), # use copy instead of deepcopy
|
"body": copy.copy(data), # use copy instead of deepcopy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
## Forward any LLM API Provider specific headers in extra_headers
|
||||||
|
add_provider_specific_headers_to_request(data=data, headers=_headers)
|
||||||
|
|
||||||
## Cache Controls
|
## Cache Controls
|
||||||
headers = request.headers
|
headers = request.headers
|
||||||
verbose_proxy_logger.debug("Request Headers: %s", headers)
|
verbose_proxy_logger.debug("Request Headers: %s", headers)
|
||||||
|
@ -224,6 +270,7 @@ async def add_litellm_data_to_request(
|
||||||
} # add the team-specific configs to the completion call
|
} # add the team-specific configs to the completion call
|
||||||
|
|
||||||
# Team Callbacks controls
|
# Team Callbacks controls
|
||||||
|
callback_settings_obj: Optional[TeamCallbackMetadata] = None
|
||||||
if user_api_key_dict.team_metadata is not None:
|
if user_api_key_dict.team_metadata is not None:
|
||||||
team_metadata = user_api_key_dict.team_metadata
|
team_metadata = user_api_key_dict.team_metadata
|
||||||
if "callback_settings" in team_metadata:
|
if "callback_settings" in team_metadata:
|
||||||
|
@ -241,17 +288,54 @@ async def add_litellm_data_to_request(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
data["success_callback"] = callback_settings_obj.success_callback
|
elif (
|
||||||
data["failure_callback"] = callback_settings_obj.failure_callback
|
user_api_key_dict.metadata is not None
|
||||||
|
and "logging" in user_api_key_dict.metadata
|
||||||
|
):
|
||||||
|
for item in user_api_key_dict.metadata["logging"]:
|
||||||
|
|
||||||
if callback_settings_obj.callback_vars is not None:
|
callback_settings_obj = convert_key_logging_metadata_to_callback(
|
||||||
# unpack callback_vars in data
|
data=AddTeamCallback(**item),
|
||||||
for k, v in callback_settings_obj.callback_vars.items():
|
team_callback_settings_obj=callback_settings_obj,
|
||||||
data[k] = v
|
)
|
||||||
|
|
||||||
|
if callback_settings_obj is not None:
|
||||||
|
data["success_callback"] = callback_settings_obj.success_callback
|
||||||
|
data["failure_callback"] = callback_settings_obj.failure_callback
|
||||||
|
|
||||||
|
if callback_settings_obj.callback_vars is not None:
|
||||||
|
# unpack callback_vars in data
|
||||||
|
for k, v in callback_settings_obj.callback_vars.items():
|
||||||
|
data[k] = v
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def add_provider_specific_headers_to_request(
|
||||||
|
data: dict,
|
||||||
|
headers: dict,
|
||||||
|
):
|
||||||
|
ANTHROPIC_API_HEADERS = [
|
||||||
|
"anthropic-version",
|
||||||
|
"anthropic-beta",
|
||||||
|
]
|
||||||
|
|
||||||
|
extra_headers = data.get("extra_headers", {}) or {}
|
||||||
|
|
||||||
|
# boolean to indicate if a header was added
|
||||||
|
added_header = False
|
||||||
|
for header in ANTHROPIC_API_HEADERS:
|
||||||
|
if header in headers:
|
||||||
|
header_value = headers[header]
|
||||||
|
extra_headers.update({header: header_value})
|
||||||
|
added_header = True
|
||||||
|
|
||||||
|
if added_header is True:
|
||||||
|
data["extra_headers"] = extra_headers
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def _add_otel_traceparent_to_data(data: dict, request: Request):
|
def _add_otel_traceparent_to_data(data: dict, request: Request):
|
||||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,9 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: mistral/mistral-small-latest
|
model: mistral/mistral-small-latest
|
||||||
api_key: "os.environ/MISTRAL_API_KEY"
|
api_key: "os.environ/MISTRAL_API_KEY"
|
||||||
|
- model_name: bedrock-anthropic
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
|
||||||
- model_name: gemini-1.5-pro-001
|
- model_name: gemini-1.5-pro-001
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: vertex_ai_beta/gemini-1.5-pro-001
|
model: vertex_ai_beta/gemini-1.5-pro-001
|
||||||
|
@ -39,7 +42,7 @@ general_settings:
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}]
|
fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}]
|
||||||
success_callback: ["langfuse", "prometheus"]
|
callbacks: ["gcs_bucket"]
|
||||||
langfuse_default_tags: ["cache_hit", "cache_key", "proxy_base_url", "user_api_key_alias", "user_api_key_user_id", "user_api_key_user_email", "user_api_key_team_alias", "semantic-similarity", "proxy_base_url"]
|
success_callback: ["langfuse"]
|
||||||
failure_callback: ["prometheus"]
|
langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"]
|
||||||
cache: True
|
cache: True
|
||||||
|
|
|
@ -159,6 +159,7 @@ from litellm.proxy.common_utils.http_parsing_utils import (
|
||||||
check_file_size_under_limit,
|
check_file_size_under_limit,
|
||||||
)
|
)
|
||||||
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
|
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
|
||||||
|
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
|
||||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||||
remove_sensitive_info_from_deployment,
|
remove_sensitive_info_from_deployment,
|
||||||
)
|
)
|
||||||
|
@ -197,6 +198,8 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
router as pass_through_router,
|
router as pass_through_router,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.route_llm_request import route_request
|
||||||
|
|
||||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||||
load_aws_kms,
|
load_aws_kms,
|
||||||
load_aws_secret_manager,
|
load_aws_secret_manager,
|
||||||
|
@ -1444,7 +1447,18 @@ class ProxyConfig:
|
||||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details
|
||||||
|
|
||||||
# Load existing config
|
# Load existing config
|
||||||
config = await self.get_config(config_file_path=config_file_path)
|
if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
|
||||||
|
bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME")
|
||||||
|
object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY")
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"bucket_name: %s, object_key: %s", bucket_name, object_key
|
||||||
|
)
|
||||||
|
config = get_file_contents_from_s3(
|
||||||
|
bucket_name=bucket_name, object_key=object_key
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# default to file
|
||||||
|
config = await self.get_config(config_file_path=config_file_path)
|
||||||
## PRINT YAML FOR CONFIRMING IT WORKS
|
## PRINT YAML FOR CONFIRMING IT WORKS
|
||||||
printed_yaml = copy.deepcopy(config)
|
printed_yaml = copy.deepcopy(config)
|
||||||
printed_yaml.pop("environment_variables", None)
|
printed_yaml.pop("environment_variables", None)
|
||||||
|
@ -2652,6 +2666,15 @@ async def startup_event():
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await initialize(**worker_config)
|
await initialize(**worker_config)
|
||||||
|
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
|
||||||
|
(
|
||||||
|
llm_router,
|
||||||
|
llm_model_list,
|
||||||
|
general_settings,
|
||||||
|
) = await proxy_config.load_config(
|
||||||
|
router=llm_router, config_file_path=worker_config
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# if not, assume it's a json string
|
# if not, assume it's a json string
|
||||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||||
|
@ -3036,68 +3059,13 @@ async def chat_completion(
|
||||||
|
|
||||||
### ROUTE THE REQUEST ###
|
### ROUTE THE REQUEST ###
|
||||||
# Do not change this - it should be a constant time fetch - ALWAYS
|
# Do not change this - it should be a constant time fetch - ALWAYS
|
||||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
llm_call = await route_request(
|
||||||
# skip router if user passed their key
|
data=data,
|
||||||
if "api_key" in data:
|
route_type="acompletion",
|
||||||
tasks.append(litellm.acompletion(**data))
|
llm_router=llm_router,
|
||||||
elif "," in data["model"] and llm_router is not None:
|
user_model=user_model,
|
||||||
if (
|
)
|
||||||
data.get("fastest_response", None) is not None
|
tasks.append(llm_call)
|
||||||
and data["fastest_response"] == True
|
|
||||||
):
|
|
||||||
tasks.append(llm_router.abatch_completion_fastest_response(**data))
|
|
||||||
else:
|
|
||||||
_models_csv_string = data.pop("model")
|
|
||||||
_models = [model.strip() for model in _models_csv_string.split(",")]
|
|
||||||
tasks.append(llm_router.abatch_completion(models=_models, **data))
|
|
||||||
elif "user_config" in data:
|
|
||||||
# initialize a new router instance. make request using this Router
|
|
||||||
router_config = data.pop("user_config")
|
|
||||||
user_router = litellm.Router(**router_config)
|
|
||||||
tasks.append(user_router.acompletion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in router_model_names
|
|
||||||
): # model in router model list
|
|
||||||
tasks.append(llm_router.acompletion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
|
||||||
): # model in router model list
|
|
||||||
tasks.append(llm_router.acompletion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and llm_router.model_group_alias is not None
|
|
||||||
and data["model"] in llm_router.model_group_alias
|
|
||||||
): # model set in model_group_alias
|
|
||||||
tasks.append(llm_router.acompletion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
tasks.append(llm_router.acompletion(**data, specific_deployment=True))
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and data["model"] not in router_model_names
|
|
||||||
and llm_router.router_general_settings.pass_through_all_models is True
|
|
||||||
):
|
|
||||||
tasks.append(litellm.acompletion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and data["model"] not in router_model_names
|
|
||||||
and (
|
|
||||||
llm_router.default_deployment is not None
|
|
||||||
or len(llm_router.provider_default_deployments) > 0
|
|
||||||
)
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
tasks.append(llm_router.acompletion(**data))
|
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
||||||
tasks.append(litellm.acompletion(**data))
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"error": "chat_completion: Invalid model name passed in model="
|
|
||||||
+ data.get("model", "")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# wait for call to end
|
# wait for call to end
|
||||||
llm_responses = asyncio.gather(
|
llm_responses = asyncio.gather(
|
||||||
|
@ -3320,58 +3288,15 @@ async def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
### ROUTE THE REQUESTs ###
|
### ROUTE THE REQUESTs ###
|
||||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
llm_call = await route_request(
|
||||||
# skip router if user passed their key
|
data=data,
|
||||||
if "api_key" in data:
|
route_type="atext_completion",
|
||||||
llm_response = asyncio.create_task(litellm.atext_completion(**data))
|
llm_router=llm_router,
|
||||||
elif (
|
user_model=user_model,
|
||||||
llm_router is not None and data["model"] in router_model_names
|
)
|
||||||
): # model in router model list
|
|
||||||
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and llm_router.model_group_alias is not None
|
|
||||||
and data["model"] in llm_router.model_group_alias
|
|
||||||
): # model set in model_group_alias
|
|
||||||
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
llm_response = asyncio.create_task(
|
|
||||||
llm_router.atext_completion(**data, specific_deployment=True)
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
|
||||||
): # model in router model list
|
|
||||||
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and data["model"] not in router_model_names
|
|
||||||
and llm_router.router_general_settings.pass_through_all_models is True
|
|
||||||
):
|
|
||||||
llm_response = asyncio.create_task(litellm.atext_completion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and data["model"] not in router_model_names
|
|
||||||
and (
|
|
||||||
llm_router.default_deployment is not None
|
|
||||||
or len(llm_router.provider_default_deployments) > 0
|
|
||||||
)
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
||||||
llm_response = asyncio.create_task(litellm.atext_completion(**data))
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"error": "completion: Invalid model name passed in model="
|
|
||||||
+ data.get("model", "")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Await the llm_response task
|
# Await the llm_response task
|
||||||
response = await llm_response
|
response = await llm_call
|
||||||
|
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
model_id = hidden_params.get("model_id", None) or ""
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
|
@ -3585,59 +3510,13 @@ async def embeddings(
|
||||||
)
|
)
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
# skip router if user passed their key
|
llm_call = await route_request(
|
||||||
if "api_key" in data:
|
data=data,
|
||||||
tasks.append(litellm.aembedding(**data))
|
route_type="aembedding",
|
||||||
elif "user_config" in data:
|
llm_router=llm_router,
|
||||||
# initialize a new router instance. make request using this Router
|
user_model=user_model,
|
||||||
router_config = data.pop("user_config")
|
)
|
||||||
user_router = litellm.Router(**router_config)
|
tasks.append(llm_call)
|
||||||
tasks.append(user_router.aembedding(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in router_model_names
|
|
||||||
): # model in router model list
|
|
||||||
tasks.append(llm_router.aembedding(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and llm_router.model_group_alias is not None
|
|
||||||
and data["model"] in llm_router.model_group_alias
|
|
||||||
): # model set in model_group_alias
|
|
||||||
tasks.append(
|
|
||||||
llm_router.aembedding(**data)
|
|
||||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
tasks.append(llm_router.aembedding(**data, specific_deployment=True))
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
tasks.append(llm_router.aembedding(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and data["model"] not in router_model_names
|
|
||||||
and llm_router.router_general_settings.pass_through_all_models is True
|
|
||||||
):
|
|
||||||
tasks.append(litellm.aembedding(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and data["model"] not in router_model_names
|
|
||||||
and (
|
|
||||||
llm_router.default_deployment is not None
|
|
||||||
or len(llm_router.provider_default_deployments) > 0
|
|
||||||
)
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
tasks.append(llm_router.aembedding(**data))
|
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
||||||
tasks.append(litellm.aembedding(**data))
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"error": "embeddings: Invalid model name passed in model="
|
|
||||||
+ data.get("model", "")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# wait for call to end
|
# wait for call to end
|
||||||
llm_responses = asyncio.gather(
|
llm_responses = asyncio.gather(
|
||||||
|
@ -3768,46 +3647,13 @@ async def image_generation(
|
||||||
)
|
)
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
# skip router if user passed their key
|
llm_call = await route_request(
|
||||||
if "api_key" in data:
|
data=data,
|
||||||
response = await litellm.aimage_generation(**data)
|
route_type="aimage_generation",
|
||||||
elif (
|
llm_router=llm_router,
|
||||||
llm_router is not None and data["model"] in router_model_names
|
user_model=user_model,
|
||||||
): # model in router model list
|
)
|
||||||
response = await llm_router.aimage_generation(**data)
|
response = await llm_call
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
response = await llm_router.aimage_generation(
|
|
||||||
**data, specific_deployment=True
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and llm_router.model_group_alias is not None
|
|
||||||
and data["model"] in llm_router.model_group_alias
|
|
||||||
): # model set in model_group_alias
|
|
||||||
response = await llm_router.aimage_generation(
|
|
||||||
**data
|
|
||||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and data["model"] not in router_model_names
|
|
||||||
and (
|
|
||||||
llm_router.default_deployment is not None
|
|
||||||
or len(llm_router.provider_default_deployments) > 0
|
|
||||||
)
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
response = await llm_router.aimage_generation(**data)
|
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
||||||
response = await litellm.aimage_generation(**data)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"error": "image_generation: Invalid model name passed in model="
|
|
||||||
+ data.get("model", "")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
|
@ -3915,44 +3761,13 @@ async def audio_speech(
|
||||||
)
|
)
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
# skip router if user passed their key
|
llm_call = await route_request(
|
||||||
if "api_key" in data:
|
data=data,
|
||||||
response = await litellm.aspeech(**data)
|
route_type="aspeech",
|
||||||
elif (
|
llm_router=llm_router,
|
||||||
llm_router is not None and data["model"] in router_model_names
|
user_model=user_model,
|
||||||
): # model in router model list
|
)
|
||||||
response = await llm_router.aspeech(**data)
|
response = await llm_call
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
response = await llm_router.aspeech(**data, specific_deployment=True)
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and llm_router.model_group_alias is not None
|
|
||||||
and data["model"] in llm_router.model_group_alias
|
|
||||||
): # model set in model_group_alias
|
|
||||||
response = await llm_router.aspeech(
|
|
||||||
**data
|
|
||||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and data["model"] not in router_model_names
|
|
||||||
and (
|
|
||||||
llm_router.default_deployment is not None
|
|
||||||
or len(llm_router.provider_default_deployments) > 0
|
|
||||||
)
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
response = await llm_router.aspeech(**data)
|
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
||||||
response = await litellm.aspeech(**data)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"error": "audio_speech: Invalid model name passed in model="
|
|
||||||
+ data.get("model", "")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
|
@ -4085,47 +3900,13 @@ async def audio_transcriptions(
|
||||||
)
|
)
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
# skip router if user passed their key
|
llm_call = await route_request(
|
||||||
if "api_key" in data:
|
data=data,
|
||||||
response = await litellm.atranscription(**data)
|
route_type="atranscription",
|
||||||
elif (
|
llm_router=llm_router,
|
||||||
llm_router is not None and data["model"] in router_model_names
|
user_model=user_model,
|
||||||
): # model in router model list
|
)
|
||||||
response = await llm_router.atranscription(**data)
|
response = await llm_call
|
||||||
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
response = await llm_router.atranscription(
|
|
||||||
**data, specific_deployment=True
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and llm_router.model_group_alias is not None
|
|
||||||
and data["model"] in llm_router.model_group_alias
|
|
||||||
): # model set in model_group_alias
|
|
||||||
response = await llm_router.atranscription(
|
|
||||||
**data
|
|
||||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and data["model"] not in router_model_names
|
|
||||||
and (
|
|
||||||
llm_router.default_deployment is not None
|
|
||||||
or len(llm_router.provider_default_deployments) > 0
|
|
||||||
)
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
response = await llm_router.atranscription(**data)
|
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
||||||
response = await litellm.atranscription(**data)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"error": "audio_transcriptions: Invalid model name passed in model="
|
|
||||||
+ data.get("model", "")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
finally:
|
finally:
|
||||||
|
@ -5341,40 +5122,13 @@ async def moderations(
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
# skip router if user passed their key
|
llm_call = await route_request(
|
||||||
if "api_key" in data:
|
data=data,
|
||||||
response = await litellm.amoderation(**data)
|
route_type="amoderation",
|
||||||
elif (
|
llm_router=llm_router,
|
||||||
llm_router is not None and data.get("model") in router_model_names
|
user_model=user_model,
|
||||||
): # model in router model list
|
)
|
||||||
response = await llm_router.amoderation(**data)
|
response = await llm_call
|
||||||
elif (
|
|
||||||
llm_router is not None and data.get("model") in llm_router.deployment_names
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
response = await llm_router.amoderation(**data, specific_deployment=True)
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and llm_router.model_group_alias is not None
|
|
||||||
and data.get("model") in llm_router.model_group_alias
|
|
||||||
): # model set in model_group_alias
|
|
||||||
response = await llm_router.amoderation(
|
|
||||||
**data
|
|
||||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and data.get("model") not in router_model_names
|
|
||||||
and (
|
|
||||||
llm_router.default_deployment is not None
|
|
||||||
or len(llm_router.provider_default_deployments) > 0
|
|
||||||
)
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
response = await llm_router.amoderation(**data)
|
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
||||||
response = await litellm.amoderation(**data)
|
|
||||||
else:
|
|
||||||
# /moderations does not need a "model" passed
|
|
||||||
# see https://platform.openai.com/docs/api-reference/moderations
|
|
||||||
response = await litellm.amoderation(**data)
|
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
|
|
117
litellm/proxy/route_llm_request.py
Normal file
117
litellm/proxy/route_llm_request.py
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||||
|
|
||||||
|
from fastapi import (
|
||||||
|
Depends,
|
||||||
|
FastAPI,
|
||||||
|
File,
|
||||||
|
Form,
|
||||||
|
Header,
|
||||||
|
HTTPException,
|
||||||
|
Path,
|
||||||
|
Request,
|
||||||
|
Response,
|
||||||
|
UploadFile,
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.router import Router as _Router
|
||||||
|
|
||||||
|
LitellmRouter = _Router
|
||||||
|
else:
|
||||||
|
LitellmRouter = Any
|
||||||
|
|
||||||
|
|
||||||
|
ROUTE_ENDPOINT_MAPPING = {
|
||||||
|
"acompletion": "/chat/completions",
|
||||||
|
"atext_completion": "/completions",
|
||||||
|
"aembedding": "/embeddings",
|
||||||
|
"aimage_generation": "/image/generations",
|
||||||
|
"aspeech": "/audio/speech",
|
||||||
|
"atranscription": "/audio/transcriptions",
|
||||||
|
"amoderation": "/moderations",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def route_request(
|
||||||
|
data: dict,
|
||||||
|
llm_router: Optional[LitellmRouter],
|
||||||
|
user_model: Optional[str],
|
||||||
|
route_type: Literal[
|
||||||
|
"acompletion",
|
||||||
|
"atext_completion",
|
||||||
|
"aembedding",
|
||||||
|
"aimage_generation",
|
||||||
|
"aspeech",
|
||||||
|
"atranscription",
|
||||||
|
"amoderation",
|
||||||
|
],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Common helper to route the request
|
||||||
|
|
||||||
|
"""
|
||||||
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||||
|
|
||||||
|
if "api_key" in data:
|
||||||
|
return getattr(litellm, f"{route_type}")(**data)
|
||||||
|
|
||||||
|
elif "user_config" in data:
|
||||||
|
router_config = data.pop("user_config")
|
||||||
|
user_router = litellm.Router(**router_config)
|
||||||
|
return getattr(user_router, f"{route_type}")(**data)
|
||||||
|
|
||||||
|
elif (
|
||||||
|
route_type == "acompletion"
|
||||||
|
and data.get("model", "") is not None
|
||||||
|
and "," in data.get("model", "")
|
||||||
|
and llm_router is not None
|
||||||
|
):
|
||||||
|
if data.get("fastest_response", False):
|
||||||
|
return llm_router.abatch_completion_fastest_response(**data)
|
||||||
|
else:
|
||||||
|
models = [model.strip() for model in data.pop("model").split(",")]
|
||||||
|
return llm_router.abatch_completion(models=models, **data)
|
||||||
|
|
||||||
|
elif llm_router is not None:
|
||||||
|
if (
|
||||||
|
data["model"] in router_model_names
|
||||||
|
or data["model"] in llm_router.get_model_ids()
|
||||||
|
):
|
||||||
|
return getattr(llm_router, f"{route_type}")(**data)
|
||||||
|
|
||||||
|
elif (
|
||||||
|
llm_router.model_group_alias is not None
|
||||||
|
and data["model"] in llm_router.model_group_alias
|
||||||
|
):
|
||||||
|
return getattr(llm_router, f"{route_type}")(**data)
|
||||||
|
|
||||||
|
elif data["model"] in llm_router.deployment_names:
|
||||||
|
return getattr(llm_router, f"{route_type}")(
|
||||||
|
**data, specific_deployment=True
|
||||||
|
)
|
||||||
|
|
||||||
|
elif data["model"] not in router_model_names:
|
||||||
|
if llm_router.router_general_settings.pass_through_all_models:
|
||||||
|
return getattr(litellm, f"{route_type}")(**data)
|
||||||
|
elif (
|
||||||
|
llm_router.default_deployment is not None
|
||||||
|
or len(llm_router.provider_default_deployments) > 0
|
||||||
|
):
|
||||||
|
return getattr(llm_router, f"{route_type}")(**data)
|
||||||
|
|
||||||
|
elif user_model is not None:
|
||||||
|
return getattr(litellm, f"{route_type}")(**data)
|
||||||
|
|
||||||
|
# if no route found then it's a bad request
|
||||||
|
route_name = ROUTE_ENDPOINT_MAPPING.get(route_type, route_type)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={
|
||||||
|
"error": f"{route_name}: Invalid model name passed in model="
|
||||||
|
+ data.get("model", "")
|
||||||
|
},
|
||||||
|
)
|
|
@ -21,6 +21,8 @@ def get_logging_payload(
|
||||||
|
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
if response_obj is None:
|
||||||
|
response_obj = {}
|
||||||
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
metadata = (
|
metadata = (
|
||||||
|
|
37
litellm/proxy/tests/test_anthropic_context_caching.py
Normal file
37
litellm/proxy/tests/test_anthropic_context_caching.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234", # litellm proxy api key
|
||||||
|
base_url="http://0.0.0.0:4000", # litellm proxy base url
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
messages=[
|
||||||
|
{ # type: ignore
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are an AI assistant tasked with analyzing legal documents.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement" * 100,
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what are the key terms and conditions in this agreement?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
|
@ -190,7 +190,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
if api_version is None:
|
if api_version is None:
|
||||||
api_version = litellm.AZURE_DEFAULT_API_VERSION
|
api_version = os.getenv("AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION)
|
||||||
|
|
||||||
if "gateway.ai.cloudflare.com" in api_base:
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
if not api_base.endswith("/"):
|
if not api_base.endswith("/"):
|
||||||
|
|
321
litellm/tests/test_anthropic_prompt_caching.py
Normal file
321
litellm/tests/test_anthropic_prompt_caching.py
Normal file
|
@ -0,0 +1,321 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||||
|
|
||||||
|
# litellm.num_retries =3
|
||||||
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
user_message = "Write a short poem about the sky"
|
||||||
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
|
||||||
|
|
||||||
|
def logger_fn(user_model_dict):
|
||||||
|
print(f"user_model_dict: {user_model_dict}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_callbacks():
|
||||||
|
print("\npytest fixture - resetting callbacks")
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
|
litellm.failure_callback = []
|
||||||
|
litellm.callbacks = []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_anthropic_prompt_caching_tools():
|
||||||
|
# Arrange: Set up the MagicMock for the httpx.AsyncClient
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"id": "msg_01XFDUDYJgAACzvnptvVoYEL",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": "Hello!"}],
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
"usage": {"input_tokens": 12, "output_tokens": 6},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
# Act: Call the litellm.acompletion function
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
api_key="mock_api_key",
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "What's the weather like in Boston today?"}
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print what was called on the mock
|
||||||
|
print("call args=", mock_post.call_args)
|
||||||
|
|
||||||
|
expected_url = "https://api.anthropic.com/v1/messages"
|
||||||
|
expected_headers = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
"x-api-key": "mock_api_key",
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_json = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's the weather like in Boston today?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
expected_url, json=expected_json, headers=expected_headers, timeout=600.0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_anthropic_api_prompt_caching_basic():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
messages=[
|
||||||
|
# System Message
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement"
|
||||||
|
* 400,
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
|
||||||
|
},
|
||||||
|
# The final turn is marked with cache-control, for continuing in followups.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=10,
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("response=", response)
|
||||||
|
|
||||||
|
assert "cache_read_input_tokens" in response.usage
|
||||||
|
assert "cache_creation_input_tokens" in response.usage
|
||||||
|
|
||||||
|
# Assert either a cache entry was created or cache was read - changes depending on the anthropic api ttl
|
||||||
|
assert (response.usage.cache_read_input_tokens > 0) or (
|
||||||
|
response.usage.cache_creation_input_tokens > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_anthropic_prompt_caching_system():
|
||||||
|
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#prompt-caching-examples
|
||||||
|
# LArge Context Caching Example
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"id": "msg_01XFDUDYJgAACzvnptvVoYEL",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": "Hello!"}],
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
"usage": {"input_tokens": 12, "output_tokens": 6},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
# Act: Call the litellm.acompletion function
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
api_key="mock_api_key",
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are an AI assistant tasked with analyzing legal documents.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what are the key terms and conditions in this agreement?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print what was called on the mock
|
||||||
|
print("call args=", mock_post.call_args)
|
||||||
|
|
||||||
|
expected_url = "https://api.anthropic.com/v1/messages"
|
||||||
|
expected_headers = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
"x-api-key": "mock_api_key",
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_json = {
|
||||||
|
"system": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are an AI assistant tasked with analyzing legal documents.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "what are the key terms and conditions in this agreement?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
expected_url, json=expected_json, headers=expected_headers, timeout=600.0
|
||||||
|
)
|
|
@ -1159,8 +1159,8 @@ def test_bedrock_tools_pt_invalid_names():
|
||||||
assert result[1]["toolSpec"]["name"] == "another_invalid_name"
|
assert result[1]["toolSpec"]["name"] == "another_invalid_name"
|
||||||
|
|
||||||
|
|
||||||
def test_bad_request_error():
|
def test_not_found_error():
|
||||||
with pytest.raises(litellm.BadRequestError):
|
with pytest.raises(litellm.NotFoundError):
|
||||||
completion(
|
completion(
|
||||||
model="bedrock/bad_model",
|
model="bedrock/bad_model",
|
||||||
messages=[
|
messages=[
|
||||||
|
|
|
@ -14,7 +14,7 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||||
|
|
||||||
# litellm.num_retries = 3
|
# litellm.num_retries =3
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
user_message = "Write a short poem about the sky"
|
user_message = "Write a short poem about the sky"
|
||||||
|
@ -190,6 +190,31 @@ def test_completion_azure_command_r():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"api_base",
|
||||||
|
[
|
||||||
|
"https://litellm8397336933.openai.azure.com",
|
||||||
|
"https://litellm8397336933.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2023-03-15-preview",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_completion_azure_ai_gpt_4o(api_base):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="azure_ai/gpt-4o",
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=os.getenv("AZURE_AI_OPENAI_KEY"),
|
||||||
|
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completion_databricks(sync_mode):
|
async def test_completion_databricks(sync_mode):
|
||||||
|
@ -3312,108 +3337,6 @@ def test_customprompt_together_ai():
|
||||||
# test_customprompt_together_ai()
|
# test_customprompt_together_ai()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
|
||||||
def test_completion_sagemaker():
|
|
||||||
try:
|
|
||||||
litellm.set_verbose = True
|
|
||||||
print("testing sagemaker")
|
|
||||||
response = completion(
|
|
||||||
model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-ins-20240329-150233",
|
|
||||||
model_id="huggingface-llm-mistral-7b-instruct-20240329-150233",
|
|
||||||
messages=messages,
|
|
||||||
temperature=0.2,
|
|
||||||
max_tokens=80,
|
|
||||||
aws_region_name=os.getenv("AWS_REGION_NAME_2"),
|
|
||||||
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID_2"),
|
|
||||||
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY_2"),
|
|
||||||
input_cost_per_second=0.000420,
|
|
||||||
)
|
|
||||||
# Add any assertions here to check the response
|
|
||||||
print(response)
|
|
||||||
cost = completion_cost(completion_response=response)
|
|
||||||
print("calculated cost", cost)
|
|
||||||
assert (
|
|
||||||
cost > 0.0 and cost < 1.0
|
|
||||||
) # should never be > $1 for a single completion call
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_completion_sagemaker()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_acompletion_sagemaker():
|
|
||||||
try:
|
|
||||||
litellm.set_verbose = True
|
|
||||||
print("testing sagemaker")
|
|
||||||
response = await litellm.acompletion(
|
|
||||||
model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-ins-20240329-150233",
|
|
||||||
model_id="huggingface-llm-mistral-7b-instruct-20240329-150233",
|
|
||||||
messages=messages,
|
|
||||||
temperature=0.2,
|
|
||||||
max_tokens=80,
|
|
||||||
aws_region_name=os.getenv("AWS_REGION_NAME_2"),
|
|
||||||
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID_2"),
|
|
||||||
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY_2"),
|
|
||||||
input_cost_per_second=0.000420,
|
|
||||||
)
|
|
||||||
# Add any assertions here to check the response
|
|
||||||
print(response)
|
|
||||||
cost = completion_cost(completion_response=response)
|
|
||||||
print("calculated cost", cost)
|
|
||||||
assert (
|
|
||||||
cost > 0.0 and cost < 1.0
|
|
||||||
) # should never be > $1 for a single completion call
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
|
||||||
def test_completion_chat_sagemaker():
|
|
||||||
try:
|
|
||||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
|
||||||
litellm.set_verbose = True
|
|
||||||
response = completion(
|
|
||||||
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
|
||||||
messages=messages,
|
|
||||||
max_tokens=100,
|
|
||||||
temperature=0.7,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
# Add any assertions here to check the response
|
|
||||||
complete_response = ""
|
|
||||||
for chunk in response:
|
|
||||||
complete_response += chunk.choices[0].delta.content or ""
|
|
||||||
print(f"complete_response: {complete_response}")
|
|
||||||
assert len(complete_response) > 0
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_completion_chat_sagemaker()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
|
||||||
def test_completion_chat_sagemaker_mistral():
|
|
||||||
try:
|
|
||||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
|
||||||
|
|
||||||
response = completion(
|
|
||||||
model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-instruct",
|
|
||||||
messages=messages,
|
|
||||||
max_tokens=100,
|
|
||||||
)
|
|
||||||
# Add any assertions here to check the response
|
|
||||||
print(response)
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"An error occurred: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_completion_chat_sagemaker_mistral()
|
|
||||||
|
|
||||||
|
|
||||||
def response_format_tests(response: litellm.ModelResponse):
|
def response_format_tests(response: litellm.ModelResponse):
|
||||||
assert isinstance(response.id, str)
|
assert isinstance(response.id, str)
|
||||||
assert response.id != ""
|
assert response.id != ""
|
||||||
|
@ -3449,7 +3372,6 @@ def response_format_tests(response: litellm.ModelResponse):
|
||||||
assert isinstance(response.usage.total_tokens, int) # type: ignore
|
assert isinstance(response.usage.total_tokens, int) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
|
@ -3463,6 +3385,7 @@ def response_format_tests(response: litellm.ModelResponse):
|
||||||
"cohere.command-text-v14",
|
"cohere.command-text-v14",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completion_bedrock_httpx_models(sync_mode, model):
|
async def test_completion_bedrock_httpx_models(sync_mode, model):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -3705,19 +3628,21 @@ def test_completion_anyscale_api():
|
||||||
# test_completion_anyscale_api()
|
# test_completion_anyscale_api()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="flaky test, times out frequently")
|
# @pytest.mark.skip(reason="flaky test, times out frequently")
|
||||||
def test_completion_cohere():
|
def test_completion_cohere():
|
||||||
try:
|
try:
|
||||||
# litellm.set_verbose=True
|
# litellm.set_verbose=True
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You're a good bot"},
|
{"role": "system", "content": "You're a good bot"},
|
||||||
|
{"role": "assistant", "content": [{"text": "2", "type": "text"}]},
|
||||||
|
{"role": "assistant", "content": [{"text": "3", "type": "text"}]},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "Hey",
|
"content": "Hey",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
response = completion(
|
response = completion(
|
||||||
model="command-nightly",
|
model="command-r",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
|
@ -1,23 +1,27 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Test to make sure function call response always works with json.loads() -> no extra parsing required. Relevant issue - https://github.com/BerriAI/litellm/issues/2654
|
## Test to make sure function call response always works with json.loads() -> no extra parsing required. Relevant issue - https://github.com/BerriAI/litellm/issues/2654
|
||||||
import sys, os
|
import os
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os, io
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
|
||||||
import litellm
|
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from litellm import completion
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
|
||||||
# Just a stub to keep the sample code simple
|
# Just a stub to keep the sample code simple
|
||||||
class Trade:
|
class Trade:
|
||||||
|
@ -78,58 +82,60 @@ def trade(model_name: str) -> List[Trade]:
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
response = completion(
|
try:
|
||||||
model_name,
|
response = completion(
|
||||||
[
|
model_name,
|
||||||
{
|
[
|
||||||
"role": "system",
|
{
|
||||||
"content": """You are an expert asset manager, managing a portfolio.
|
"role": "system",
|
||||||
|
"content": """You are an expert asset manager, managing a portfolio.
|
||||||
|
|
||||||
Always use the `trade` function. Make sure that you call it correctly. For example, the following is a valid call:
|
Always use the `trade` function. Make sure that you call it correctly. For example, the following is a valid call:
|
||||||
|
```
|
||||||
|
trade({
|
||||||
|
"orders": [
|
||||||
|
{"action": "buy", "asset": "BTC", "amount": 0.1},
|
||||||
|
{"action": "sell", "asset": "ETH", "amount": 0.2}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
If there are no trades to make, call `trade` with an empty array:
|
||||||
|
```
|
||||||
|
trade({ "orders": [] })
|
||||||
|
```
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": """Manage the portfolio.
|
||||||
|
|
||||||
|
Don't jabber.
|
||||||
|
|
||||||
|
This is the current market data:
|
||||||
```
|
```
|
||||||
trade({
|
{market_data}
|
||||||
"orders": [
|
|
||||||
{"action": "buy", "asset": "BTC", "amount": 0.1},
|
|
||||||
{"action": "sell", "asset": "ETH", "amount": 0.2}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
```
|
```
|
||||||
|
|
||||||
If there are no trades to make, call `trade` with an empty array:
|
Your portfolio is as follows:
|
||||||
```
|
```
|
||||||
trade({ "orders": [] })
|
{portfolio}
|
||||||
```
|
```
|
||||||
""",
|
""".replace(
|
||||||
|
"{market_data}", "BTC: 64,000 USD\nETH: 3,500 USD"
|
||||||
|
).replace(
|
||||||
|
"{portfolio}", "USD: 1000, BTC: 0.1, ETH: 0.2"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools=[tool_spec],
|
||||||
|
tool_choice={
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": tool_spec["function"]["name"]}, # type: ignore
|
||||||
},
|
},
|
||||||
{
|
)
|
||||||
"role": "user",
|
except litellm.InternalServerError:
|
||||||
"content": """Manage the portfolio.
|
pass
|
||||||
|
|
||||||
Don't jabber.
|
|
||||||
|
|
||||||
This is the current market data:
|
|
||||||
```
|
|
||||||
{market_data}
|
|
||||||
```
|
|
||||||
|
|
||||||
Your portfolio is as follows:
|
|
||||||
```
|
|
||||||
{portfolio}
|
|
||||||
```
|
|
||||||
""".replace(
|
|
||||||
"{market_data}", "BTC: 64,000 USD\nETH: 3,500 USD"
|
|
||||||
).replace(
|
|
||||||
"{portfolio}", "USD: 1000, BTC: 0.1, ETH: 0.2"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
tools=[tool_spec],
|
|
||||||
tool_choice={
|
|
||||||
"type": "function",
|
|
||||||
"function": {"name": tool_spec["function"]["name"]}, # type: ignore
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
calls = response.choices[0].message.tool_calls
|
calls = response.choices[0].message.tool_calls
|
||||||
trades = [trade for call in calls for trade in parse_call(call)]
|
trades = [trade for call in calls for trade in parse_call(call)]
|
||||||
return trades
|
return trades
|
||||||
|
|
|
@ -147,6 +147,117 @@ async def test_basic_gcs_logger():
|
||||||
|
|
||||||
assert gcs_payload["response_cost"] > 0.0
|
assert gcs_payload["response_cost"] > 0.0
|
||||||
|
|
||||||
|
assert gcs_payload["log_event_type"] == "successful_api_call"
|
||||||
|
gcs_payload["spend_log_metadata"] = json.loads(gcs_payload["spend_log_metadata"])
|
||||||
|
|
||||||
|
assert (
|
||||||
|
gcs_payload["spend_log_metadata"]["user_api_key"]
|
||||||
|
== "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
gcs_payload["spend_log_metadata"]["user_api_key_user_id"]
|
||||||
|
== "116544810872468347480"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete Object from GCS
|
||||||
|
print("deleting object from GCS")
|
||||||
|
await gcs_logger.delete_gcs_object(object_name=object_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_gcs_logger_failure():
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
gcs_logger = GCSBucketLogger()
|
||||||
|
print("GCSBucketLogger", gcs_logger)
|
||||||
|
|
||||||
|
gcs_log_id = f"failure-test-{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
litellm.callbacks = [gcs_logger]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
temperature=0.7,
|
||||||
|
messages=[{"role": "user", "content": "This is a test"}],
|
||||||
|
max_tokens=10,
|
||||||
|
user="ishaan-2",
|
||||||
|
mock_response=litellm.BadRequestError(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
message="Error: 400: Bad Request: Invalid API key, please check your API key and try again.",
|
||||||
|
llm_provider="openai",
|
||||||
|
),
|
||||||
|
metadata={
|
||||||
|
"gcs_log_id": gcs_log_id,
|
||||||
|
"tags": ["model-anthropic-claude-v2.1", "app-ishaan-prod"],
|
||||||
|
"user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
|
||||||
|
"user_api_key_alias": None,
|
||||||
|
"user_api_end_user_max_budget": None,
|
||||||
|
"litellm_api_version": "0.0.0",
|
||||||
|
"global_max_parallel_requests": None,
|
||||||
|
"user_api_key_user_id": "116544810872468347480",
|
||||||
|
"user_api_key_org_id": None,
|
||||||
|
"user_api_key_team_id": None,
|
||||||
|
"user_api_key_team_alias": None,
|
||||||
|
"user_api_key_metadata": {},
|
||||||
|
"requester_ip_address": "127.0.0.1",
|
||||||
|
"spend_logs_metadata": {"hello": "world"},
|
||||||
|
"headers": {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"user-agent": "PostmanRuntime/7.32.3",
|
||||||
|
"accept": "*/*",
|
||||||
|
"postman-token": "92300061-eeaa-423b-a420-0b44896ecdc4",
|
||||||
|
"host": "localhost:4000",
|
||||||
|
"accept-encoding": "gzip, deflate, br",
|
||||||
|
"connection": "keep-alive",
|
||||||
|
"content-length": "163",
|
||||||
|
},
|
||||||
|
"endpoint": "http://localhost:4000/chat/completions",
|
||||||
|
"model_group": "gpt-3.5-turbo",
|
||||||
|
"deployment": "azure/chatgpt-v-2",
|
||||||
|
"model_info": {
|
||||||
|
"id": "4bad40a1eb6bebd1682800f16f44b9f06c52a6703444c99c7f9f32e9de3693b4",
|
||||||
|
"db_model": False,
|
||||||
|
},
|
||||||
|
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
|
||||||
|
"caching_groups": None,
|
||||||
|
"raw_request": "\n\nPOST Request Sent from LiteLLM:\ncurl -X POST \\\nhttps://openai-gpt-4-test-v-1.openai.azure.com//openai/ \\\n-H 'Authorization: *****' \\\n-d '{'model': 'chatgpt-v-2', 'messages': [{'role': 'system', 'content': 'you are a helpful assistant.\\n'}, {'role': 'user', 'content': 'bom dia'}], 'stream': False, 'max_tokens': 10, 'user': '116544810872468347480', 'extra_body': {}}'\n",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
# Get the current date
|
||||||
|
# Get the current date
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
# Modify the object_name to include the date-based folder
|
||||||
|
object_name = gcs_log_id
|
||||||
|
|
||||||
|
print("object_name", object_name)
|
||||||
|
|
||||||
|
# Check if object landed on GCS
|
||||||
|
object_from_gcs = await gcs_logger.download_gcs_object(object_name=object_name)
|
||||||
|
print("object from gcs=", object_from_gcs)
|
||||||
|
# convert object_from_gcs from bytes to DICT
|
||||||
|
parsed_data = json.loads(object_from_gcs)
|
||||||
|
print("object_from_gcs as dict", parsed_data)
|
||||||
|
|
||||||
|
print("type of object_from_gcs", type(parsed_data))
|
||||||
|
|
||||||
|
gcs_payload = GCSBucketPayload(**parsed_data)
|
||||||
|
|
||||||
|
print("gcs_payload", gcs_payload)
|
||||||
|
|
||||||
|
assert gcs_payload["request_kwargs"]["model"] == "gpt-3.5-turbo"
|
||||||
|
assert gcs_payload["request_kwargs"]["messages"] == [
|
||||||
|
{"role": "user", "content": "This is a test"}
|
||||||
|
]
|
||||||
|
|
||||||
|
assert gcs_payload["response_cost"] == 0
|
||||||
|
assert gcs_payload["log_event_type"] == "failed_api_call"
|
||||||
|
|
||||||
gcs_payload["spend_log_metadata"] = json.loads(gcs_payload["spend_log_metadata"])
|
gcs_payload["spend_log_metadata"] = json.loads(gcs_payload["spend_log_metadata"])
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
|
|
@ -76,6 +76,6 @@ async def test_async_prometheus_success_logging():
|
||||||
print("metrics from prometheus", metrics)
|
print("metrics from prometheus", metrics)
|
||||||
assert metrics["litellm_requests_metric_total"] == 1.0
|
assert metrics["litellm_requests_metric_total"] == 1.0
|
||||||
assert metrics["litellm_total_tokens_total"] == 30.0
|
assert metrics["litellm_total_tokens_total"] == 30.0
|
||||||
assert metrics["llm_deployment_success_responses_total"] == 1.0
|
assert metrics["litellm_deployment_success_responses_total"] == 1.0
|
||||||
assert metrics["llm_deployment_total_requests_total"] == 1.0
|
assert metrics["litellm_deployment_total_requests_total"] == 1.0
|
||||||
assert metrics["llm_deployment_latency_per_output_token_bucket"] == 1.0
|
assert metrics["litellm_deployment_latency_per_output_token_bucket"] == 1.0
|
||||||
|
|
|
@ -260,3 +260,56 @@ def test_anthropic_messages_tool_call():
|
||||||
translated_messages[-1]["content"][0]["tool_use_id"]
|
translated_messages[-1]["content"][0]["tool_use_id"]
|
||||||
== "bc8cb4b6-88c4-4138-8993-3a9d9cd51656"
|
== "bc8cb4b6-88c4-4138-8993-3a9d9cd51656"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_cache_controls_pt():
|
||||||
|
"see anthropic docs for this: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation"
|
||||||
|
messages = [
|
||||||
|
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
|
||||||
|
},
|
||||||
|
# The final turn is marked with cache-control, for continuing in followups.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
translated_messages = anthropic_messages_pt(
|
||||||
|
messages, model="claude-3-5-sonnet-20240620", llm_provider="anthropic"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, msg in enumerate(translated_messages):
|
||||||
|
if i == 0:
|
||||||
|
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||||
|
elif i == 1:
|
||||||
|
assert "cache_controls" not in msg["content"][0]
|
||||||
|
elif i == 2:
|
||||||
|
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||||
|
elif i == 3:
|
||||||
|
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||||
|
|
||||||
|
print("translated_messages: ", translated_messages)
|
||||||
|
|
|
@ -2,16 +2,19 @@
|
||||||
# This tests setting provider specific configs across providers
|
# This tests setting provider specific configs across providers
|
||||||
# There are 2 types of tests - changing config dynamically or by setting class variables
|
# There are 2 types of tests - changing config dynamically or by setting class variables
|
||||||
|
|
||||||
import sys, os
|
import os
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import completion
|
from litellm import RateLimitError, completion
|
||||||
from litellm import RateLimitError
|
|
||||||
|
|
||||||
# Huggingface - Expensive to deploy models and keep them running. Maybe we can try doing this via baseten??
|
# Huggingface - Expensive to deploy models and keep them running. Maybe we can try doing this via baseten??
|
||||||
# def hf_test_completion_tgi():
|
# def hf_test_completion_tgi():
|
||||||
|
@ -513,102 +516,165 @@ def sagemaker_test_completion():
|
||||||
# sagemaker_test_completion()
|
# sagemaker_test_completion()
|
||||||
|
|
||||||
|
|
||||||
def test_sagemaker_default_region(mocker):
|
def test_sagemaker_default_region():
|
||||||
"""
|
"""
|
||||||
If no regions are specified in config or in environment, the default region is us-west-2
|
If no regions are specified in config or in environment, the default region is us-west-2
|
||||||
"""
|
"""
|
||||||
mock_client = mocker.patch("boto3.client")
|
mock_response = MagicMock()
|
||||||
try:
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"generated_text": "This is a mock response from SageMaker.",
|
||||||
|
"id": "cmpl-mockid",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1629800000,
|
||||||
|
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This is a mock response from SageMaker.",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "length",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model="sagemaker/mock-endpoint",
|
model="sagemaker/mock-endpoint",
|
||||||
messages=[
|
messages=[{"content": "Hello, world!", "role": "user"}],
|
||||||
{
|
|
||||||
"content": "Hello, world!",
|
|
||||||
"role": "user"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
except Exception:
|
mock_post.assert_called_once()
|
||||||
pass # expected serialization exception because AWS client was replaced with a Mock
|
_, kwargs = mock_post.call_args
|
||||||
assert mock_client.call_args.kwargs["region_name"] == "us-west-2"
|
args_to_sagemaker = kwargs["json"]
|
||||||
|
print("Arguments passed to sagemaker=", args_to_sagemaker)
|
||||||
|
print("url=", kwargs["url"])
|
||||||
|
|
||||||
|
assert (
|
||||||
|
kwargs["url"]
|
||||||
|
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/mock-endpoint/invocations"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# test_sagemaker_default_region()
|
# test_sagemaker_default_region()
|
||||||
|
|
||||||
|
|
||||||
def test_sagemaker_environment_region(mocker):
|
def test_sagemaker_environment_region():
|
||||||
"""
|
"""
|
||||||
If a region is specified in the environment, use that region instead of us-west-2
|
If a region is specified in the environment, use that region instead of us-west-2
|
||||||
"""
|
"""
|
||||||
expected_region = "us-east-1"
|
expected_region = "us-east-1"
|
||||||
os.environ["AWS_REGION_NAME"] = expected_region
|
os.environ["AWS_REGION_NAME"] = expected_region
|
||||||
mock_client = mocker.patch("boto3.client")
|
mock_response = MagicMock()
|
||||||
try:
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"generated_text": "This is a mock response from SageMaker.",
|
||||||
|
"id": "cmpl-mockid",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1629800000,
|
||||||
|
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This is a mock response from SageMaker.",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "length",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model="sagemaker/mock-endpoint",
|
model="sagemaker/mock-endpoint",
|
||||||
messages=[
|
messages=[{"content": "Hello, world!", "role": "user"}],
|
||||||
{
|
|
||||||
"content": "Hello, world!",
|
|
||||||
"role": "user"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
except Exception:
|
mock_post.assert_called_once()
|
||||||
pass # expected serialization exception because AWS client was replaced with a Mock
|
_, kwargs = mock_post.call_args
|
||||||
del os.environ["AWS_REGION_NAME"] # cleanup
|
args_to_sagemaker = kwargs["json"]
|
||||||
assert mock_client.call_args.kwargs["region_name"] == expected_region
|
print("Arguments passed to sagemaker=", args_to_sagemaker)
|
||||||
|
print("url=", kwargs["url"])
|
||||||
|
|
||||||
|
assert (
|
||||||
|
kwargs["url"]
|
||||||
|
== f"https://runtime.sagemaker.{expected_region}.amazonaws.com/endpoints/mock-endpoint/invocations"
|
||||||
|
)
|
||||||
|
|
||||||
|
del os.environ["AWS_REGION_NAME"] # cleanup
|
||||||
|
|
||||||
|
|
||||||
# test_sagemaker_environment_region()
|
# test_sagemaker_environment_region()
|
||||||
|
|
||||||
|
|
||||||
def test_sagemaker_config_region(mocker):
|
def test_sagemaker_config_region():
|
||||||
"""
|
"""
|
||||||
If a region is specified as part of the optional parameters of the completion, including as
|
If a region is specified as part of the optional parameters of the completion, including as
|
||||||
part of the config file, then use that region instead of us-west-2
|
part of the config file, then use that region instead of us-west-2
|
||||||
"""
|
"""
|
||||||
expected_region = "us-east-1"
|
expected_region = "us-east-1"
|
||||||
mock_client = mocker.patch("boto3.client")
|
mock_response = MagicMock()
|
||||||
try:
|
|
||||||
response = litellm.completion(
|
def return_val():
|
||||||
model="sagemaker/mock-endpoint",
|
return {
|
||||||
messages=[
|
"generated_text": "This is a mock response from SageMaker.",
|
||||||
|
"id": "cmpl-mockid",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1629800000,
|
||||||
|
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
"choices": [
|
||||||
{
|
{
|
||||||
"content": "Hello, world!",
|
"text": "This is a mock response from SageMaker.",
|
||||||
"role": "user"
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "length",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
|
||||||
|
response = litellm.completion(
|
||||||
|
model="sagemaker/mock-endpoint",
|
||||||
|
messages=[{"content": "Hello, world!", "role": "user"}],
|
||||||
aws_region_name=expected_region,
|
aws_region_name=expected_region,
|
||||||
)
|
)
|
||||||
except Exception:
|
|
||||||
pass # expected serialization exception because AWS client was replaced with a Mock
|
mock_post.assert_called_once()
|
||||||
assert mock_client.call_args.kwargs["region_name"] == expected_region
|
_, kwargs = mock_post.call_args
|
||||||
|
args_to_sagemaker = kwargs["json"]
|
||||||
|
print("Arguments passed to sagemaker=", args_to_sagemaker)
|
||||||
|
print("url=", kwargs["url"])
|
||||||
|
|
||||||
|
assert (
|
||||||
|
kwargs["url"]
|
||||||
|
== f"https://runtime.sagemaker.{expected_region}.amazonaws.com/endpoints/mock-endpoint/invocations"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# test_sagemaker_config_region()
|
# test_sagemaker_config_region()
|
||||||
|
|
||||||
|
|
||||||
def test_sagemaker_config_and_environment_region(mocker):
|
|
||||||
"""
|
|
||||||
If both the environment and config file specify a region, the environment region is expected
|
|
||||||
"""
|
|
||||||
expected_region = "us-east-1"
|
|
||||||
unexpected_region = "us-east-2"
|
|
||||||
os.environ["AWS_REGION_NAME"] = expected_region
|
|
||||||
mock_client = mocker.patch("boto3.client")
|
|
||||||
try:
|
|
||||||
response = litellm.completion(
|
|
||||||
model="sagemaker/mock-endpoint",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"content": "Hello, world!",
|
|
||||||
"role": "user"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
aws_region_name=unexpected_region,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass # expected serialization exception because AWS client was replaced with a Mock
|
|
||||||
del os.environ["AWS_REGION_NAME"] # cleanup
|
|
||||||
assert mock_client.call_args.kwargs["region_name"] == expected_region
|
|
||||||
|
|
||||||
# test_sagemaker_config_and_environment_region()
|
# test_sagemaker_config_and_environment_region()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -229,8 +229,9 @@ def test_chat_completion_exception_any_model(client):
|
||||||
)
|
)
|
||||||
assert isinstance(openai_exception, openai.BadRequestError)
|
assert isinstance(openai_exception, openai.BadRequestError)
|
||||||
_error_message = openai_exception.message
|
_error_message = openai_exception.message
|
||||||
assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str(
|
assert (
|
||||||
_error_message
|
"/chat/completions: Invalid model name passed in model=Lite-GPT-12"
|
||||||
|
in str(_error_message)
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -259,7 +260,7 @@ def test_embedding_exception_any_model(client):
|
||||||
print("Exception raised=", openai_exception)
|
print("Exception raised=", openai_exception)
|
||||||
assert isinstance(openai_exception, openai.BadRequestError)
|
assert isinstance(openai_exception, openai.BadRequestError)
|
||||||
_error_message = openai_exception.message
|
_error_message = openai_exception.message
|
||||||
assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str(
|
assert "/embeddings: Invalid model name passed in model=Lite-GPT-12" in str(
|
||||||
_error_message
|
_error_message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -966,3 +966,203 @@ async def test_user_info_team_list(prisma_client):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
mock_client.assert_called()
|
mock_client.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Local test")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_callback_via_key(prisma_client):
|
||||||
|
"""
|
||||||
|
Test if callback specified in key, is used.
|
||||||
|
"""
|
||||||
|
global headers
|
||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request, Response
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import chat_completion
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "write 1 sentence poem"},
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"mock_response": "Hello world",
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
}
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
json_bytes = json.dumps(test_data).encode("utf-8")
|
||||||
|
|
||||||
|
request._body = json_bytes
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
litellm.litellm_core_utils.litellm_logging,
|
||||||
|
"LangFuseLogger",
|
||||||
|
new=MagicMock(),
|
||||||
|
) as mock_client:
|
||||||
|
resp = await chat_completion(
|
||||||
|
request=request,
|
||||||
|
fastapi_response=Response(),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
metadata={
|
||||||
|
"logging": [
|
||||||
|
{
|
||||||
|
"callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
|
||||||
|
"callback_type": "success", # set, if required by integration - future improvement, have logging tools work for success + failure by default
|
||||||
|
"callback_vars": {
|
||||||
|
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
|
||||||
|
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
|
||||||
|
"langfuse_host": "https://us.cloud.langfuse.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
print(resp)
|
||||||
|
mock_client.assert_called()
|
||||||
|
mock_client.return_value.log_event.assert_called()
|
||||||
|
args, kwargs = mock_client.return_value.log_event.call_args
|
||||||
|
kwargs = kwargs["kwargs"]
|
||||||
|
assert "user_api_key_metadata" in kwargs["litellm_params"]["metadata"]
|
||||||
|
assert (
|
||||||
|
"logging"
|
||||||
|
in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"]
|
||||||
|
)
|
||||||
|
checked_keys = False
|
||||||
|
for item in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"][
|
||||||
|
"logging"
|
||||||
|
]:
|
||||||
|
for k, v in item["callback_vars"].items():
|
||||||
|
print("k={}, v={}".format(k, v))
|
||||||
|
if "key" in k:
|
||||||
|
assert "os.environ" in v
|
||||||
|
checked_keys = True
|
||||||
|
|
||||||
|
assert checked_keys
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
|
||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request, Response
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
test_data = {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "write 1 sentence poem"},
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"mock_response": "Hello world",
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
}
|
||||||
|
|
||||||
|
json_bytes = json.dumps(test_data).encode("utf-8")
|
||||||
|
|
||||||
|
request._body = json_bytes
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"data": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"mock_response": "Hello world",
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
},
|
||||||
|
"request": request,
|
||||||
|
"user_api_key_dict": UserAPIKeyAuth(
|
||||||
|
token=None,
|
||||||
|
key_name=None,
|
||||||
|
key_alias=None,
|
||||||
|
spend=0.0,
|
||||||
|
max_budget=None,
|
||||||
|
expires=None,
|
||||||
|
models=[],
|
||||||
|
aliases={},
|
||||||
|
config={},
|
||||||
|
user_id=None,
|
||||||
|
team_id=None,
|
||||||
|
max_parallel_requests=None,
|
||||||
|
metadata={
|
||||||
|
"logging": [
|
||||||
|
{
|
||||||
|
"callback_name": "langfuse",
|
||||||
|
"callback_type": "success",
|
||||||
|
"callback_vars": {
|
||||||
|
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
|
||||||
|
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
|
||||||
|
"langfuse_host": "https://us.cloud.langfuse.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
tpm_limit=None,
|
||||||
|
rpm_limit=None,
|
||||||
|
budget_duration=None,
|
||||||
|
budget_reset_at=None,
|
||||||
|
allowed_cache_controls=[],
|
||||||
|
permissions={},
|
||||||
|
model_spend={},
|
||||||
|
model_max_budget={},
|
||||||
|
soft_budget_cooldown=False,
|
||||||
|
litellm_budget_table=None,
|
||||||
|
org_id=None,
|
||||||
|
team_spend=None,
|
||||||
|
team_alias=None,
|
||||||
|
team_tpm_limit=None,
|
||||||
|
team_rpm_limit=None,
|
||||||
|
team_max_budget=None,
|
||||||
|
team_models=[],
|
||||||
|
team_blocked=False,
|
||||||
|
soft_budget=None,
|
||||||
|
team_model_aliases=None,
|
||||||
|
team_member_spend=None,
|
||||||
|
team_metadata=None,
|
||||||
|
end_user_id=None,
|
||||||
|
end_user_tpm_limit=None,
|
||||||
|
end_user_rpm_limit=None,
|
||||||
|
end_user_max_budget=None,
|
||||||
|
last_refreshed_at=None,
|
||||||
|
api_key=None,
|
||||||
|
user_role=None,
|
||||||
|
allowed_model_region=None,
|
||||||
|
parent_otel_span=None,
|
||||||
|
),
|
||||||
|
"proxy_config": proxy_config,
|
||||||
|
"general_settings": {},
|
||||||
|
"version": "0.0.0",
|
||||||
|
}
|
||||||
|
|
||||||
|
new_data = await add_litellm_data_to_request(**data)
|
||||||
|
|
||||||
|
assert "success_callback" in new_data
|
||||||
|
assert new_data["success_callback"] == ["langfuse"]
|
||||||
|
assert "langfuse_public_key" in new_data
|
||||||
|
assert "langfuse_secret_key" in new_data
|
||||||
|
|
316
litellm/tests/test_sagemaker.py
Normal file
316
litellm/tests/test_sagemaker.py
Normal file
|
@ -0,0 +1,316 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||||
|
|
||||||
|
# litellm.num_retries =3
|
||||||
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
user_message = "Write a short poem about the sky"
|
||||||
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
|
|
||||||
|
def logger_fn(user_model_dict):
|
||||||
|
print(f"user_model_dict: {user_model_dict}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_callbacks():
|
||||||
|
print("\npytest fixture - resetting callbacks")
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
|
litellm.failure_callback = []
|
||||||
|
litellm.callbacks = []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
async def test_completion_sagemaker(sync_mode):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
print("testing sagemaker")
|
||||||
|
if sync_mode is True:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=80,
|
||||||
|
input_cost_per_second=0.000420,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=80,
|
||||||
|
input_cost_per_second=0.000420,
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
cost = completion_cost(completion_response=response)
|
||||||
|
print("calculated cost", cost)
|
||||||
|
assert (
|
||||||
|
cost > 0.0 and cost < 1.0
|
||||||
|
) # should never be > $1 for a single completion call
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||||
|
async def test_completion_sagemaker_stream(sync_mode):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = False
|
||||||
|
print("testing sagemaker")
|
||||||
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
full_text = ""
|
||||||
|
if sync_mode is True:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi - what is ur name"},
|
||||||
|
],
|
||||||
|
temperature=0.2,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=80,
|
||||||
|
input_cost_per_second=0.000420,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
full_text += chunk.choices[0].delta.content or ""
|
||||||
|
|
||||||
|
print("SYNC RESPONSE full text", full_text)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi - what is ur name"},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=80,
|
||||||
|
input_cost_per_second=0.000420,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("streaming response")
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
full_text += chunk.choices[0].delta.content or ""
|
||||||
|
|
||||||
|
print("ASYNC RESPONSE full text", full_text)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acompletion_sagemaker_non_stream():
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"generated_text": "This is a mock response from SageMaker.",
|
||||||
|
"id": "cmpl-mockid",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1629800000,
|
||||||
|
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This is a mock response from SageMaker.",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "length",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"inputs": "hi",
|
||||||
|
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
# Act: Call the litellm.acompletion function
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=80,
|
||||||
|
input_cost_per_second=0.000420,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print what was called on the mock
|
||||||
|
print("call args=", mock_post.call_args)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
_, kwargs = mock_post.call_args
|
||||||
|
args_to_sagemaker = kwargs["json"]
|
||||||
|
print("Arguments passed to sagemaker=", args_to_sagemaker)
|
||||||
|
assert args_to_sagemaker == expected_payload
|
||||||
|
assert (
|
||||||
|
kwargs["url"]
|
||||||
|
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_sagemaker_non_stream():
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"generated_text": "This is a mock response from SageMaker.",
|
||||||
|
"id": "cmpl-mockid",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1629800000,
|
||||||
|
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This is a mock response from SageMaker.",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "length",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"inputs": "hi",
|
||||||
|
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
# Act: Call the litellm.acompletion function
|
||||||
|
response = litellm.completion(
|
||||||
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=80,
|
||||||
|
input_cost_per_second=0.000420,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print what was called on the mock
|
||||||
|
print("call args=", mock_post.call_args)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
_, kwargs = mock_post.call_args
|
||||||
|
args_to_sagemaker = kwargs["json"]
|
||||||
|
print("Arguments passed to sagemaker=", args_to_sagemaker)
|
||||||
|
assert args_to_sagemaker == expected_payload
|
||||||
|
assert (
|
||||||
|
kwargs["url"]
|
||||||
|
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_sagemaker_non_stream_with_aws_params():
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"generated_text": "This is a mock response from SageMaker.",
|
||||||
|
"id": "cmpl-mockid",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1629800000,
|
||||||
|
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This is a mock response from SageMaker.",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "length",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"inputs": "hi",
|
||||||
|
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
# Act: Call the litellm.acompletion function
|
||||||
|
response = litellm.completion(
|
||||||
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=80,
|
||||||
|
input_cost_per_second=0.000420,
|
||||||
|
aws_access_key_id="gm",
|
||||||
|
aws_secret_access_key="s",
|
||||||
|
aws_region_name="us-west-5",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print what was called on the mock
|
||||||
|
print("call args=", mock_post.call_args)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
_, kwargs = mock_post.call_args
|
||||||
|
args_to_sagemaker = kwargs["json"]
|
||||||
|
print("Arguments passed to sagemaker=", args_to_sagemaker)
|
||||||
|
assert args_to_sagemaker == expected_payload
|
||||||
|
assert (
|
||||||
|
kwargs["url"]
|
||||||
|
== "https://runtime.sagemaker.us-west-5.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
|
||||||
|
)
|
|
@ -1683,6 +1683,7 @@ def test_completion_bedrock_mistral_stream():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="stopped using TokenIterator")
|
||||||
def test_sagemaker_weird_response():
|
def test_sagemaker_weird_response():
|
||||||
"""
|
"""
|
||||||
When the stream ends, flush any remaining holding chunks.
|
When the stream ends, flush any remaining holding chunks.
|
||||||
|
|
|
@ -44,7 +44,7 @@ def test_check_valid_ip(
|
||||||
|
|
||||||
request = Request(client_ip)
|
request = Request(client_ip)
|
||||||
|
|
||||||
assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore
|
assert _check_valid_ip(allowed_ips, request)[0] == expected_result # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# test x-forwarder for is used when user has opted in
|
# test x-forwarder for is used when user has opted in
|
||||||
|
@ -72,7 +72,7 @@ def test_check_valid_ip_sent_with_x_forwarded_for(
|
||||||
|
|
||||||
request = Request(client_ip, headers={"X-Forwarded-For": client_ip})
|
request = Request(client_ip, headers={"X-Forwarded-For": client_ip})
|
||||||
|
|
||||||
assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True) == expected_result # type: ignore
|
assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True)[0] == expected_result # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
@ -15,9 +15,10 @@ class AnthropicMessagesTool(TypedDict, total=False):
|
||||||
input_schema: Required[dict]
|
input_schema: Required[dict]
|
||||||
|
|
||||||
|
|
||||||
class AnthropicMessagesTextParam(TypedDict):
|
class AnthropicMessagesTextParam(TypedDict, total=False):
|
||||||
type: Literal["text"]
|
type: Literal["text"]
|
||||||
text: str
|
text: str
|
||||||
|
cache_control: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
class AnthropicMessagesToolUseParam(TypedDict):
|
class AnthropicMessagesToolUseParam(TypedDict):
|
||||||
|
@ -54,9 +55,10 @@ class AnthropicImageParamSource(TypedDict):
|
||||||
data: str
|
data: str
|
||||||
|
|
||||||
|
|
||||||
class AnthropicMessagesImageParam(TypedDict):
|
class AnthropicMessagesImageParam(TypedDict, total=False):
|
||||||
type: Literal["image"]
|
type: Literal["image"]
|
||||||
source: AnthropicImageParamSource
|
source: AnthropicImageParamSource
|
||||||
|
cache_control: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
class AnthropicMessagesToolResultContent(TypedDict):
|
class AnthropicMessagesToolResultContent(TypedDict):
|
||||||
|
@ -92,6 +94,12 @@ class AnthropicMetadata(TypedDict, total=False):
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicSystemMessageContent(TypedDict, total=False):
|
||||||
|
type: str
|
||||||
|
text: str
|
||||||
|
cache_control: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
class AnthropicMessagesRequest(TypedDict, total=False):
|
class AnthropicMessagesRequest(TypedDict, total=False):
|
||||||
model: Required[str]
|
model: Required[str]
|
||||||
messages: Required[
|
messages: Required[
|
||||||
|
@ -106,7 +114,7 @@ class AnthropicMessagesRequest(TypedDict, total=False):
|
||||||
metadata: AnthropicMetadata
|
metadata: AnthropicMetadata
|
||||||
stop_sequences: List[str]
|
stop_sequences: List[str]
|
||||||
stream: bool
|
stream: bool
|
||||||
system: str
|
system: Union[str, List]
|
||||||
temperature: float
|
temperature: float
|
||||||
tool_choice: AnthropicMessagesToolChoice
|
tool_choice: AnthropicMessagesToolChoice
|
||||||
tools: List[AnthropicMessagesTool]
|
tools: List[AnthropicMessagesTool]
|
||||||
|
|
|
@ -361,7 +361,7 @@ class ChatCompletionToolMessage(TypedDict):
|
||||||
|
|
||||||
class ChatCompletionSystemMessage(TypedDict, total=False):
|
class ChatCompletionSystemMessage(TypedDict, total=False):
|
||||||
role: Required[Literal["system"]]
|
role: Required[Literal["system"]]
|
||||||
content: Required[str]
|
content: Required[Union[str, List]]
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,7 @@ class ModelInfo(TypedDict, total=False):
|
||||||
supports_assistant_prefill: Optional[bool]
|
supports_assistant_prefill: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
class GenericStreamingChunk(TypedDict):
|
class GenericStreamingChunk(TypedDict, total=False):
|
||||||
text: Required[str]
|
text: Required[str]
|
||||||
tool_use: Optional[ChatCompletionToolCallChunk]
|
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||||
is_finished: Required[bool]
|
is_finished: Required[bool]
|
||||||
|
|
|
@ -4479,7 +4479,22 @@ def _is_non_openai_azure_model(model: str) -> bool:
|
||||||
or f"mistral/{model_name}" in litellm.mistral_chat_models
|
or f"mistral/{model_name}" in litellm.mistral_chat_models
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
except:
|
except Exception:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_azure_openai_model(model: str) -> bool:
|
||||||
|
try:
|
||||||
|
if "/" in model:
|
||||||
|
model = model.split("/", 1)[1]
|
||||||
|
if (
|
||||||
|
model in litellm.open_ai_chat_completion_models
|
||||||
|
or model in litellm.open_ai_text_completion_models
|
||||||
|
or model in litellm.open_ai_embedding_models
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
return False
|
return False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -4613,6 +4628,14 @@ def get_llm_provider(
|
||||||
elif custom_llm_provider == "azure_ai":
|
elif custom_llm_provider == "azure_ai":
|
||||||
api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
|
api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
|
||||||
dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY")
|
dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY")
|
||||||
|
|
||||||
|
if _is_azure_openai_model(model=model):
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
|
||||||
|
model
|
||||||
|
)
|
||||||
|
)
|
||||||
|
custom_llm_provider = "azure"
|
||||||
elif custom_llm_provider == "github":
|
elif custom_llm_provider == "github":
|
||||||
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
|
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
|
||||||
dynamic_api_key = api_key or get_secret("GITHUB_API_KEY")
|
dynamic_api_key = api_key or get_secret("GITHUB_API_KEY")
|
||||||
|
@ -9825,11 +9848,28 @@ class CustomStreamWrapper:
|
||||||
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||||
|
|
||||||
elif self.custom_llm_provider == "sagemaker":
|
elif self.custom_llm_provider == "sagemaker":
|
||||||
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
from litellm.types.llms.bedrock import GenericStreamingChunk
|
||||||
response_obj = self.handle_sagemaker_stream(chunk)
|
|
||||||
|
if self.received_finish_reason is not None:
|
||||||
|
raise StopIteration
|
||||||
|
response_obj: GenericStreamingChunk = chunk
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.stream_options
|
||||||
|
and self.stream_options.get("include_usage", False) is True
|
||||||
|
and response_obj["usage"] is not None
|
||||||
|
):
|
||||||
|
model_response.usage = litellm.Usage(
|
||||||
|
prompt_tokens=response_obj["usage"]["inputTokens"],
|
||||||
|
completion_tokens=response_obj["usage"]["outputTokens"],
|
||||||
|
total_tokens=response_obj["usage"]["totalTokens"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
|
||||||
|
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||||
elif self.custom_llm_provider == "petals":
|
elif self.custom_llm_provider == "petals":
|
||||||
if len(self.completion_stream) == 0:
|
if len(self.completion_stream) == 0:
|
||||||
if self.received_finish_reason is not None:
|
if self.received_finish_reason is not None:
|
||||||
|
|
|
@ -57,6 +57,18 @@
|
||||||
"supports_parallel_function_calling": true,
|
"supports_parallel_function_calling": true,
|
||||||
"supports_vision": true
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
|
"chatgpt-4o-latest": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000005,
|
||||||
|
"output_cost_per_token": 0.000015,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
|
},
|
||||||
"gpt-4o-2024-05-13": {
|
"gpt-4o-2024-05-13": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
@ -2062,7 +2074,8 @@
|
||||||
"litellm_provider": "vertex_ai-anthropic_models",
|
"litellm_provider": "vertex_ai-anthropic_models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true
|
"supports_vision": true,
|
||||||
|
"supports_assistant_prefill": true
|
||||||
},
|
},
|
||||||
"vertex_ai/claude-3-5-sonnet@20240620": {
|
"vertex_ai/claude-3-5-sonnet@20240620": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -2073,7 +2086,8 @@
|
||||||
"litellm_provider": "vertex_ai-anthropic_models",
|
"litellm_provider": "vertex_ai-anthropic_models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true
|
"supports_vision": true,
|
||||||
|
"supports_assistant_prefill": true
|
||||||
},
|
},
|
||||||
"vertex_ai/claude-3-haiku@20240307": {
|
"vertex_ai/claude-3-haiku@20240307": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -2084,7 +2098,8 @@
|
||||||
"litellm_provider": "vertex_ai-anthropic_models",
|
"litellm_provider": "vertex_ai-anthropic_models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true
|
"supports_vision": true,
|
||||||
|
"supports_assistant_prefill": true
|
||||||
},
|
},
|
||||||
"vertex_ai/claude-3-opus@20240229": {
|
"vertex_ai/claude-3-opus@20240229": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -2095,7 +2110,8 @@
|
||||||
"litellm_provider": "vertex_ai-anthropic_models",
|
"litellm_provider": "vertex_ai-anthropic_models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true
|
"supports_vision": true,
|
||||||
|
"supports_assistant_prefill": true
|
||||||
},
|
},
|
||||||
"vertex_ai/meta/llama3-405b-instruct-maas": {
|
"vertex_ai/meta/llama3-405b-instruct-maas": {
|
||||||
"max_tokens": 32000,
|
"max_tokens": 32000,
|
||||||
|
@ -4519,6 +4535,69 @@
|
||||||
"litellm_provider": "perplexity",
|
"litellm_provider": "perplexity",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"perplexity/llama-3.1-70b-instruct": {
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"max_input_tokens": 131072,
|
||||||
|
"max_output_tokens": 131072,
|
||||||
|
"input_cost_per_token": 0.000001,
|
||||||
|
"output_cost_per_token": 0.000001,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-3.1-8b-instruct": {
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"max_input_tokens": 131072,
|
||||||
|
"max_output_tokens": 131072,
|
||||||
|
"input_cost_per_token": 0.0000002,
|
||||||
|
"output_cost_per_token": 0.0000002,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-3.1-sonar-huge-128k-online": {
|
||||||
|
"max_tokens": 127072,
|
||||||
|
"max_input_tokens": 127072,
|
||||||
|
"max_output_tokens": 127072,
|
||||||
|
"input_cost_per_token": 0.000005,
|
||||||
|
"output_cost_per_token": 0.000005,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-3.1-sonar-large-128k-online": {
|
||||||
|
"max_tokens": 127072,
|
||||||
|
"max_input_tokens": 127072,
|
||||||
|
"max_output_tokens": 127072,
|
||||||
|
"input_cost_per_token": 0.000001,
|
||||||
|
"output_cost_per_token": 0.000001,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-3.1-sonar-large-128k-chat": {
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"max_input_tokens": 131072,
|
||||||
|
"max_output_tokens": 131072,
|
||||||
|
"input_cost_per_token": 0.000001,
|
||||||
|
"output_cost_per_token": 0.000001,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-3.1-sonar-small-128k-chat": {
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"max_input_tokens": 131072,
|
||||||
|
"max_output_tokens": 131072,
|
||||||
|
"input_cost_per_token": 0.0000002,
|
||||||
|
"output_cost_per_token": 0.0000002,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-3.1-sonar-small-128k-online": {
|
||||||
|
"max_tokens": 127072,
|
||||||
|
"max_input_tokens": 127072,
|
||||||
|
"max_output_tokens": 127072,
|
||||||
|
"input_cost_per_token": 0.0000002,
|
||||||
|
"output_cost_per_token": 0.0000002,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"perplexity/pplx-7b-chat": {
|
"perplexity/pplx-7b-chat": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 8192,
|
"max_input_tokens": 8192,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.43.9"
|
version = "1.43.15"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.43.9"
|
version = "1.43.15"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue