forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_track_imagen_spend_logs
This commit is contained in:
commit
c1adb0b7f2
32 changed files with 1384 additions and 226 deletions
|
@ -1,12 +1,12 @@
|
||||||
repos:
|
repos:
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
# - id: mypy
|
- id: mypy
|
||||||
# name: mypy
|
name: mypy
|
||||||
# entry: python3 -m mypy --ignore-missing-imports
|
entry: python3 -m mypy --ignore-missing-imports
|
||||||
# language: system
|
language: system
|
||||||
# types: [python]
|
types: [python]
|
||||||
# files: ^litellm/
|
files: ^litellm/
|
||||||
- id: isort
|
- id: isort
|
||||||
name: isort
|
name: isort
|
||||||
entry: isort
|
entry: isort
|
||||||
|
|
|
@ -1,8 +1,17 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# AI21
|
# AI21
|
||||||
|
|
||||||
LiteLLM supports j2-light, j2-mid and j2-ultra from [AI21](https://www.ai21.com/studio/pricing).
|
LiteLLM supports j2-light, j2-mid and j2-ultra from [AI21](https://www.ai21.com/studio/pricing)
|
||||||
|
|
||||||
They're available to use without a waitlist.
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
**We support ALL AI21 models, just set `model=ai21/<any-model-on-ai21>` as a prefix when sending litellm requests**.
|
||||||
|
**See all litellm supported AI21 models [here](https://models.litellm.ai)**
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
### API KEYS
|
### API KEYS
|
||||||
```python
|
```python
|
||||||
|
@ -10,6 +19,7 @@ import os
|
||||||
os.environ["AI21_API_KEY"] = "your-api-key"
|
os.environ["AI21_API_KEY"] = "your-api-key"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## **LiteLLM Python SDK Usage**
|
||||||
### Sample Usage
|
### Sample Usage
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -23,10 +33,177 @@ messages = [{"role": "user", "content": "Write me a poem about the blue sky"}]
|
||||||
completion(model="j2-light", messages=messages)
|
completion(model="j2-light", messages=messages)
|
||||||
```
|
```
|
||||||
|
|
||||||
### AI21 Models
|
|
||||||
|
|
||||||
|
## **LiteLLM Proxy Server Usage**
|
||||||
|
|
||||||
|
Here's how to call a ai21 model with the LiteLLM Proxy Server
|
||||||
|
|
||||||
|
1. Modify the config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: my-model
|
||||||
|
litellm_params:
|
||||||
|
model: ai21/<your-model-name> # add ai21/ prefix to route as ai21 provider
|
||||||
|
api_key: api-key # api key to send your model
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Send Request to LiteLLM Proxy Server
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
|
||||||
|
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="my-model",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "my-model",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## Supported OpenAI Parameters
|
||||||
|
|
||||||
|
|
||||||
|
| [param](../completion/input) | type | AI21 equivalent |
|
||||||
|
|-------|-------------|------------------|
|
||||||
|
| `tools` | **Optional[list]** | `tools` |
|
||||||
|
| `response_format` | **Optional[dict]** | `response_format` |
|
||||||
|
| `max_tokens` | **Optional[int]** | `max_tokens` |
|
||||||
|
| `temperature` | **Optional[float]** | `temperature` |
|
||||||
|
| `top_p` | **Optional[float]** | `top_p` |
|
||||||
|
| `stop` | **Optional[Union[str, list]]** | `stop` |
|
||||||
|
| `n` | **Optional[int]** | `n` |
|
||||||
|
| `stream` | **Optional[bool]** | `stream` |
|
||||||
|
| `seed` | **Optional[int]** | `seed` |
|
||||||
|
| `tool_choice` | **Optional[str]** | `tool_choice` |
|
||||||
|
| `user` | **Optional[str]** | `user` |
|
||||||
|
|
||||||
|
## Supported AI21 Parameters
|
||||||
|
|
||||||
|
|
||||||
|
| param | type | [AI21 equivalent](https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters) |
|
||||||
|
|-----------|------|-------------|
|
||||||
|
| `documents` | **Optional[List[Dict]]** | `documents` |
|
||||||
|
|
||||||
|
|
||||||
|
## Passing AI21 Specific Parameters - `documents`
|
||||||
|
|
||||||
|
LiteLLM allows you to pass all AI21 specific parameters to the `litellm.completion` function. Here is an example of how to pass the `documents` parameter to the `litellm.completion` function.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="python" label="LiteLLM Python SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="jamba-1.5-large",
|
||||||
|
messages=[{"role": "user", "content": "what does the document say"}],
|
||||||
|
documents = [
|
||||||
|
{
|
||||||
|
"content": "hello world",
|
||||||
|
"metadata": {
|
||||||
|
"source": "google",
|
||||||
|
"author": "ishaan"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="proxy" label="LiteLLM Proxy Server">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
|
||||||
|
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="my-model",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_body = {
|
||||||
|
"documents": [
|
||||||
|
{
|
||||||
|
"content": "hello world",
|
||||||
|
"metadata": {
|
||||||
|
"source": "google",
|
||||||
|
"author": "ishaan"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
**We support ALL AI21 models, just set `model=ai21/<any-model-on-ai21>` as a prefix when sending litellm requests**
|
||||||
|
**See all litellm supported AI21 models [here](https://models.litellm.ai)**
|
||||||
|
:::
|
||||||
|
|
||||||
|
## AI21 Models
|
||||||
|
|
||||||
| Model Name | Function Call | Required OS Variables |
|
| Model Name | Function Call | Required OS Variables |
|
||||||
|------------------|--------------------------------------------|--------------------------------------|
|
|------------------|--------------------------------------------|--------------------------------------|
|
||||||
|
| jamba-1.5-mini | `completion('jamba-1.5-mini', messages)` | `os.environ['AI21_API_KEY']` |
|
||||||
|
| jamba-1.5-large | `completion('jamba-1.5-large', messages)` | `os.environ['AI21_API_KEY']` |
|
||||||
| j2-light | `completion('j2-light', messages)` | `os.environ['AI21_API_KEY']` |
|
| j2-light | `completion('j2-light', messages)` | `os.environ['AI21_API_KEY']` |
|
||||||
| j2-mid | `completion('j2-mid', messages)` | `os.environ['AI21_API_KEY']` |
|
| j2-mid | `completion('j2-mid', messages)` | `os.environ['AI21_API_KEY']` |
|
||||||
| j2-ultra | `completion('j2-ultra', messages)` | `os.environ['AI21_API_KEY']` |
|
| j2-ultra | `completion('j2-ultra', messages)` | `os.environ['AI21_API_KEY']` |
|
||||||
|
|
||||||
|
|
|
@ -190,6 +190,36 @@ curl -i http://localhost:4000/v1/chat/completions \
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Advanced - provide multiple slack channels for a given alert type
|
||||||
|
|
||||||
|
Just add it like this - `alert_type: [<hook_url_channel_1>, <hook_url_channel_2>]`.
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
alerting: ["slack"]
|
||||||
|
alert_to_webhook_url: {
|
||||||
|
"spend_reports": ["https://webhook.site/7843a980-a494-4967-80fb-d502dbc16886", "https://webhook.site/28cfb179-f4fb-4408-8129-729ff55cf213"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X GET 'http://0.0.0.0:4000/health/services?service=slack' \
|
||||||
|
-H 'Authorization: Bearer sk-1234'
|
||||||
|
```
|
||||||
|
|
||||||
|
In case of error, check server logs for the error message!
|
||||||
|
|
||||||
## Advanced - Using MS Teams Webhooks
|
## Advanced - Using MS Teams Webhooks
|
||||||
|
|
||||||
MS Teams provides a slack compatible webhook url that you can use for alerting
|
MS Teams provides a slack compatible webhook url that you can use for alerting
|
||||||
|
|
|
@ -116,6 +116,7 @@ ssl_certificate: Optional[str] = None
|
||||||
disable_streaming_logging: bool = False
|
disable_streaming_logging: bool = False
|
||||||
in_memory_llm_clients_cache: dict = {}
|
in_memory_llm_clients_cache: dict = {}
|
||||||
safe_memory_mode: bool = False
|
safe_memory_mode: bool = False
|
||||||
|
enable_azure_ad_token_refresh: Optional[bool] = False
|
||||||
### DEFAULT AZURE API VERSION ###
|
### DEFAULT AZURE API VERSION ###
|
||||||
AZURE_DEFAULT_API_VERSION = "2024-07-01-preview" # this is updated to the latest
|
AZURE_DEFAULT_API_VERSION = "2024-07-01-preview" # this is updated to the latest
|
||||||
### COHERE EMBEDDINGS DEFAULT TYPE ###
|
### COHERE EMBEDDINGS DEFAULT TYPE ###
|
||||||
|
@ -364,6 +365,7 @@ vertex_llama3_models: List = []
|
||||||
vertex_ai_ai21_models: List = []
|
vertex_ai_ai21_models: List = []
|
||||||
vertex_mistral_models: List = []
|
vertex_mistral_models: List = []
|
||||||
ai21_models: List = []
|
ai21_models: List = []
|
||||||
|
ai21_chat_models: List = []
|
||||||
nlp_cloud_models: List = []
|
nlp_cloud_models: List = []
|
||||||
aleph_alpha_models: List = []
|
aleph_alpha_models: List = []
|
||||||
bedrock_models: List = []
|
bedrock_models: List = []
|
||||||
|
@ -419,6 +421,9 @@ for key, value in model_cost.items():
|
||||||
key = key.replace("vertex_ai/", "")
|
key = key.replace("vertex_ai/", "")
|
||||||
vertex_ai_image_models.append(key)
|
vertex_ai_image_models.append(key)
|
||||||
elif value.get("litellm_provider") == "ai21":
|
elif value.get("litellm_provider") == "ai21":
|
||||||
|
if value.get("mode") == "chat":
|
||||||
|
ai21_chat_models.append(key)
|
||||||
|
else:
|
||||||
ai21_models.append(key)
|
ai21_models.append(key)
|
||||||
elif value.get("litellm_provider") == "nlp_cloud":
|
elif value.get("litellm_provider") == "nlp_cloud":
|
||||||
nlp_cloud_models.append(key)
|
nlp_cloud_models.append(key)
|
||||||
|
@ -459,6 +464,7 @@ openai_compatible_providers: List = [
|
||||||
"groq",
|
"groq",
|
||||||
"nvidia_nim",
|
"nvidia_nim",
|
||||||
"cerebras",
|
"cerebras",
|
||||||
|
"ai21_chat",
|
||||||
"volcengine",
|
"volcengine",
|
||||||
"codestral",
|
"codestral",
|
||||||
"deepseek",
|
"deepseek",
|
||||||
|
@ -647,6 +653,7 @@ model_list = (
|
||||||
+ vertex_chat_models
|
+ vertex_chat_models
|
||||||
+ vertex_text_models
|
+ vertex_text_models
|
||||||
+ ai21_models
|
+ ai21_models
|
||||||
|
+ ai21_chat_models
|
||||||
+ together_ai_models
|
+ together_ai_models
|
||||||
+ baseten_models
|
+ baseten_models
|
||||||
+ aleph_alpha_models
|
+ aleph_alpha_models
|
||||||
|
@ -698,6 +705,7 @@ provider_list: List = [
|
||||||
"groq",
|
"groq",
|
||||||
"nvidia_nim",
|
"nvidia_nim",
|
||||||
"cerebras",
|
"cerebras",
|
||||||
|
"ai21_chat",
|
||||||
"volcengine",
|
"volcengine",
|
||||||
"codestral",
|
"codestral",
|
||||||
"text-completion-codestral",
|
"text-completion-codestral",
|
||||||
|
@ -856,7 +864,8 @@ from .llms.predibase import PredibaseConfig
|
||||||
from .llms.replicate import ReplicateConfig
|
from .llms.replicate import ReplicateConfig
|
||||||
from .llms.cohere.completion import CohereConfig
|
from .llms.cohere.completion import CohereConfig
|
||||||
from .llms.clarifai import ClarifaiConfig
|
from .llms.clarifai import ClarifaiConfig
|
||||||
from .llms.ai21 import AI21Config
|
from .llms.AI21.completion import AI21Config
|
||||||
|
from .llms.AI21.chat import AI21ChatConfig
|
||||||
from .llms.together_ai import TogetherAIConfig
|
from .llms.together_ai import TogetherAIConfig
|
||||||
from .llms.cloudflare import CloudflareConfig
|
from .llms.cloudflare import CloudflareConfig
|
||||||
from .llms.palm import PalmConfig
|
from .llms.palm import PalmConfig
|
||||||
|
@ -903,6 +912,14 @@ from .llms.bedrock.common_utils import (
|
||||||
AmazonMistralConfig,
|
AmazonMistralConfig,
|
||||||
AmazonBedrockGlobalConfig,
|
AmazonBedrockGlobalConfig,
|
||||||
)
|
)
|
||||||
|
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
|
||||||
|
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
|
||||||
|
AmazonTitanMultimodalEmbeddingG1Config,
|
||||||
|
)
|
||||||
|
from .llms.bedrock.embed.amazon_titan_v2_transformation import (
|
||||||
|
AmazonTitanV2Config,
|
||||||
|
)
|
||||||
|
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
|
||||||
from .llms.openai import (
|
from .llms.openai import (
|
||||||
OpenAIConfig,
|
OpenAIConfig,
|
||||||
OpenAITextCompletionConfig,
|
OpenAITextCompletionConfig,
|
||||||
|
@ -914,6 +931,7 @@ from .llms.openai import (
|
||||||
)
|
)
|
||||||
from .llms.nvidia_nim import NvidiaNimConfig
|
from .llms.nvidia_nim import NvidiaNimConfig
|
||||||
from .llms.cerebras.chat import CerebrasConfig
|
from .llms.cerebras.chat import CerebrasConfig
|
||||||
|
from .llms.AI21.chat import AI21ChatConfig
|
||||||
from .llms.fireworks_ai import FireworksAIConfig
|
from .llms.fireworks_ai import FireworksAIConfig
|
||||||
from .llms.volcengine import VolcEngineConfig
|
from .llms.volcengine import VolcEngineConfig
|
||||||
from .llms.text_completion_codestral import MistralTextCompletionConfig
|
from .llms.text_completion_codestral import MistralTextCompletionConfig
|
||||||
|
|
|
@ -1514,7 +1514,9 @@ Model Info:
|
||||||
self.alert_to_webhook_url is not None
|
self.alert_to_webhook_url is not None
|
||||||
and alert_type in self.alert_to_webhook_url
|
and alert_type in self.alert_to_webhook_url
|
||||||
):
|
):
|
||||||
slack_webhook_url = self.alert_to_webhook_url[alert_type]
|
slack_webhook_url: Optional[Union[str, List[str]]] = (
|
||||||
|
self.alert_to_webhook_url[alert_type]
|
||||||
|
)
|
||||||
elif self.default_webhook_url is not None:
|
elif self.default_webhook_url is not None:
|
||||||
slack_webhook_url = self.default_webhook_url
|
slack_webhook_url = self.default_webhook_url
|
||||||
else:
|
else:
|
||||||
|
@ -1525,11 +1527,32 @@ Model Info:
|
||||||
payload = {"text": formatted_message}
|
payload = {"text": formatted_message}
|
||||||
headers = {"Content-type": "application/json"}
|
headers = {"Content-type": "application/json"}
|
||||||
|
|
||||||
response = await self.async_http_handler.post(
|
async def send_to_webhook(url: str):
|
||||||
url=slack_webhook_url,
|
return await self.async_http_handler.post(
|
||||||
|
url=url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(payload),
|
data=json.dumps(payload),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(slack_webhook_url, list):
|
||||||
|
# Parallelize the calls if it's a list of URLs
|
||||||
|
responses = await asyncio.gather(
|
||||||
|
*[send_to_webhook(url) for url in slack_webhook_url]
|
||||||
|
)
|
||||||
|
|
||||||
|
for response, url in zip(responses, slack_webhook_url):
|
||||||
|
if response.status_code == 200:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Error sending slack alert to url={}. Error={}".format(
|
||||||
|
url, response.text
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Single call if it's a single URL
|
||||||
|
response = await send_to_webhook(slack_webhook_url)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
@ -1718,7 +1741,9 @@ Model Info:
|
||||||
try:
|
try:
|
||||||
from calendar import monthrange
|
from calendar import monthrange
|
||||||
|
|
||||||
from litellm.proxy.proxy_server import _get_spend_report_for_time_range
|
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
||||||
|
_get_spend_report_for_time_range,
|
||||||
|
)
|
||||||
|
|
||||||
todays_date = datetime.datetime.now().date()
|
todays_date = datetime.datetime.now().date()
|
||||||
first_day_of_month = todays_date.replace(day=1)
|
first_day_of_month = todays_date.replace(day=1)
|
||||||
|
@ -1763,7 +1788,7 @@ Model Info:
|
||||||
alerting_metadata={},
|
alerting_metadata={},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error("Error sending weekly spend report %s", e)
|
verbose_proxy_logger.exception("Error sending weekly spend report %s", e)
|
||||||
|
|
||||||
async def send_fallback_stats_from_prometheus(self):
|
async def send_fallback_stats_from_prometheus(self):
|
||||||
"""
|
"""
|
||||||
|
|
95
litellm/llms/AI21/chat.py
Normal file
95
litellm/llms/AI21/chat.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
"""
|
||||||
|
AI21 Chat Completions API
|
||||||
|
|
||||||
|
this is OpenAI compatible - no translation needed / occurs
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
class AI21ChatConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters
|
||||||
|
|
||||||
|
Below are the parameters:
|
||||||
|
"""
|
||||||
|
|
||||||
|
tools: Optional[list] = None
|
||||||
|
response_format: Optional[dict] = None
|
||||||
|
documents: Optional[list] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
stop: Optional[Union[str, list]] = None
|
||||||
|
n: Optional[int] = None
|
||||||
|
stream: Optional[bool] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
tool_choice: Optional[str] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
response_format: Optional[dict] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
stop: Optional[Union[str, list]] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals().copy()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
|
"""
|
||||||
|
Get the supported OpenAI params for the given model
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
return [
|
||||||
|
"tools",
|
||||||
|
"response_format",
|
||||||
|
"max_tokens",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"stop",
|
||||||
|
"n",
|
||||||
|
"stream",
|
||||||
|
"seed",
|
||||||
|
"tool_choice",
|
||||||
|
"user",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, model: str, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in supported_openai_params:
|
||||||
|
optional_params[param] = value
|
||||||
|
return optional_params
|
|
@ -30,6 +30,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
)
|
)
|
||||||
from litellm.types.llms.anthropic import (
|
from litellm.types.llms.anthropic import (
|
||||||
AnthopicMessagesAssistantMessageParam,
|
AnthopicMessagesAssistantMessageParam,
|
||||||
|
AnthropicChatCompletionUsageBlock,
|
||||||
AnthropicFinishReason,
|
AnthropicFinishReason,
|
||||||
AnthropicMessagesRequest,
|
AnthropicMessagesRequest,
|
||||||
AnthropicMessagesTool,
|
AnthropicMessagesTool,
|
||||||
|
@ -1177,6 +1178,30 @@ class ModelResponseIterator:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _handle_usage(
|
||||||
|
self, anthropic_usage_chunk: dict
|
||||||
|
) -> AnthropicChatCompletionUsageBlock:
|
||||||
|
special_fields = ["input_tokens", "output_tokens"]
|
||||||
|
|
||||||
|
usage_block = AnthropicChatCompletionUsageBlock(
|
||||||
|
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
||||||
|
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
||||||
|
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
|
||||||
|
+ anthropic_usage_chunk.get("output_tokens", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
if "cache_creation_input_tokens" in anthropic_usage_chunk:
|
||||||
|
usage_block["cache_creation_input_tokens"] = anthropic_usage_chunk[
|
||||||
|
"cache_creation_input_tokens"
|
||||||
|
]
|
||||||
|
|
||||||
|
if "cache_read_input_tokens" in anthropic_usage_chunk:
|
||||||
|
usage_block["cache_read_input_tokens"] = anthropic_usage_chunk[
|
||||||
|
"cache_read_input_tokens"
|
||||||
|
]
|
||||||
|
|
||||||
|
return usage_block
|
||||||
|
|
||||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
try:
|
try:
|
||||||
type_chunk = chunk.get("type", "") or ""
|
type_chunk = chunk.get("type", "") or ""
|
||||||
|
@ -1252,12 +1277,7 @@ class ModelResponseIterator:
|
||||||
finish_reason=message_delta["delta"].get("stop_reason", "stop")
|
finish_reason=message_delta["delta"].get("stop_reason", "stop")
|
||||||
or "stop"
|
or "stop"
|
||||||
)
|
)
|
||||||
usage = ChatCompletionUsageBlock(
|
usage = self._handle_usage(anthropic_usage_chunk=message_delta["usage"])
|
||||||
prompt_tokens=message_delta["usage"].get("input_tokens", 0),
|
|
||||||
completion_tokens=message_delta["usage"].get("output_tokens", 0),
|
|
||||||
total_tokens=message_delta["usage"].get("input_tokens", 0)
|
|
||||||
+ message_delta["usage"].get("output_tokens", 0),
|
|
||||||
)
|
|
||||||
is_finished = True
|
is_finished = True
|
||||||
elif type_chunk == "message_start":
|
elif type_chunk == "message_start":
|
||||||
"""
|
"""
|
||||||
|
@ -1280,19 +1300,8 @@ class ModelResponseIterator:
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
message_start_block = MessageStartBlock(**chunk) # type: ignore
|
message_start_block = MessageStartBlock(**chunk) # type: ignore
|
||||||
usage = ChatCompletionUsageBlock(
|
usage = self._handle_usage(
|
||||||
prompt_tokens=message_start_block["message"]
|
anthropic_usage_chunk=message_start_block["message"]["usage"]
|
||||||
.get("usage", {})
|
|
||||||
.get("input_tokens", 0),
|
|
||||||
completion_tokens=message_start_block["message"]
|
|
||||||
.get("usage", {})
|
|
||||||
.get("output_tokens", 0),
|
|
||||||
total_tokens=message_start_block["message"]
|
|
||||||
.get("usage", {})
|
|
||||||
.get("input_tokens", 0)
|
|
||||||
+ message_start_block["message"]
|
|
||||||
.get("usage", {})
|
|
||||||
.get("output_tokens", 0),
|
|
||||||
)
|
)
|
||||||
elif type_chunk == "error":
|
elif type_chunk == "error":
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -43,6 +43,10 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionResponseMessage,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
ChatCompletionToolChoiceFunctionParam,
|
||||||
|
ChatCompletionToolChoiceObjectParam,
|
||||||
|
ChatCompletionToolParam,
|
||||||
|
ChatCompletionToolParamFunctionChunk,
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
|
@ -1152,6 +1156,7 @@ class AmazonConverseConfig:
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
"extra_headers",
|
"extra_headers",
|
||||||
|
"response_format",
|
||||||
]
|
]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -1210,6 +1215,48 @@ class AmazonConverseConfig:
|
||||||
drop_params: bool,
|
drop_params: bool,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
|
if param == "response_format":
|
||||||
|
json_schema: Optional[dict] = None
|
||||||
|
schema_name: str = ""
|
||||||
|
if "response_schema" in value:
|
||||||
|
json_schema = value["response_schema"]
|
||||||
|
schema_name = "json_tool_call"
|
||||||
|
elif "json_schema" in value:
|
||||||
|
json_schema = value["json_schema"]["schema"]
|
||||||
|
schema_name = value["json_schema"]["name"]
|
||||||
|
"""
|
||||||
|
Follow similar approach to anthropic - translate to a single tool call.
|
||||||
|
|
||||||
|
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||||
|
- You usually want to provide a single tool
|
||||||
|
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||||
|
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||||
|
"""
|
||||||
|
if json_schema is not None:
|
||||||
|
_tool_choice = self.map_tool_choice_values(
|
||||||
|
model=model, tool_choice="required", drop_params=drop_params # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
_tool = ChatCompletionToolParam(
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolParamFunctionChunk(
|
||||||
|
name=schema_name, parameters=json_schema
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
optional_params["tools"] = [_tool]
|
||||||
|
optional_params["tool_choice"] = _tool_choice
|
||||||
|
optional_params["json_mode"] = True
|
||||||
|
else:
|
||||||
|
if litellm.drop_params is True or drop_params is True:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="Bedrock doesn't support response_format={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||||
|
value
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
if param == "max_tokens":
|
if param == "max_tokens":
|
||||||
optional_params["maxTokens"] = value
|
optional_params["maxTokens"] = value
|
||||||
if param == "stream":
|
if param == "stream":
|
||||||
|
@ -1263,7 +1310,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
json_mode: Optional[bool] = optional_params.pop("json_mode", None)
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
try:
|
try:
|
||||||
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
|
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
|
||||||
|
@ -1332,6 +1379,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
name=response_tool_name,
|
name=response_tool_name,
|
||||||
arguments=json.dumps(content["toolUse"]["input"]),
|
arguments=json.dumps(content["toolUse"]["input"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
_tool_response_chunk = ChatCompletionToolCallChunk(
|
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||||
id=content["toolUse"]["toolUseId"],
|
id=content["toolUse"]["toolUseId"],
|
||||||
type="function",
|
type="function",
|
||||||
|
@ -1340,6 +1388,13 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
)
|
)
|
||||||
tools.append(_tool_response_chunk)
|
tools.append(_tool_response_chunk)
|
||||||
chat_completion_message["content"] = content_str
|
chat_completion_message["content"] = content_str
|
||||||
|
|
||||||
|
if json_mode is True and tools is not None and len(tools) == 1:
|
||||||
|
# to support 'json_schema' logic on bedrock models
|
||||||
|
json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments")
|
||||||
|
if json_mode_content_str is not None:
|
||||||
|
chat_completion_message["content"] = json_mode_content_str
|
||||||
|
else:
|
||||||
chat_completion_message["tool_calls"] = tools
|
chat_completion_message["tool_calls"] = tools
|
||||||
|
|
||||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||||
|
@ -1586,6 +1641,9 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
|
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
|
||||||
supported_tool_call_params = ["tools", "tool_choice"]
|
supported_tool_call_params = ["tools", "tool_choice"]
|
||||||
supported_guardrail_params = ["guardrailConfig"]
|
supported_guardrail_params = ["guardrailConfig"]
|
||||||
|
json_mode: Optional[bool] = inference_params.pop(
|
||||||
|
"json_mode", None
|
||||||
|
) # used for handling json_schema
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
|
|
||||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||||
|
@ -2028,8 +2086,14 @@ class MockResponseIterator: # for returning ai21 streaming responses
|
||||||
text=chunk_data.choices[0].message.content or "", # type: ignore
|
text=chunk_data.choices[0].message.content or "", # type: ignore
|
||||||
tool_use=None,
|
tool_use=None,
|
||||||
is_finished=True,
|
is_finished=True,
|
||||||
finish_reason=chunk_data.choices[0].finish_reason, # type: ignore
|
finish_reason=map_finish_reason(
|
||||||
usage=chunk_usage, # type: ignore
|
finish_reason=chunk_data.choices[0].finish_reason or ""
|
||||||
|
),
|
||||||
|
usage=ChatCompletionUsageBlock(
|
||||||
|
prompt_tokens=chunk_usage.prompt_tokens,
|
||||||
|
completion_tokens=chunk_usage.completion_tokens,
|
||||||
|
total_tokens=chunk_usage.total_tokens,
|
||||||
|
),
|
||||||
index=0,
|
index=0,
|
||||||
)
|
)
|
||||||
return processed_chunk
|
return processed_chunk
|
||||||
|
|
|
@ -15,8 +15,6 @@ from typing import List, Optional
|
||||||
from litellm.types.llms.bedrock import (
|
from litellm.types.llms.bedrock import (
|
||||||
AmazonTitanG1EmbeddingRequest,
|
AmazonTitanG1EmbeddingRequest,
|
||||||
AmazonTitanG1EmbeddingResponse,
|
AmazonTitanG1EmbeddingResponse,
|
||||||
AmazonTitanV2EmbeddingRequest,
|
|
||||||
AmazonTitanV2EmbeddingResponse,
|
|
||||||
)
|
)
|
||||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||||
|
|
||||||
|
@ -52,6 +50,14 @@ class AmazonTitanG1Config:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
return optional_params
|
||||||
|
|
||||||
def _transform_request(
|
def _transform_request(
|
||||||
self, input: str, inference_params: dict
|
self, input: str, inference_params: dict
|
||||||
) -> AmazonTitanG1EmbeddingRequest:
|
) -> AmazonTitanG1EmbeddingRequest:
|
||||||
|
@ -80,70 +86,3 @@ class AmazonTitanG1Config:
|
||||||
total_tokens=total_prompt_tokens,
|
total_tokens=total_prompt_tokens,
|
||||||
)
|
)
|
||||||
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
||||||
|
|
||||||
|
|
||||||
class AmazonTitanV2Config:
|
|
||||||
"""
|
|
||||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
|
||||||
|
|
||||||
normalize: boolean - flag indicating whether or not to normalize the output embeddings. Defaults to true
|
|
||||||
dimensions: int - The number of dimensions the output embeddings should have. The following values are accepted: 1024 (default), 512, 256.
|
|
||||||
"""
|
|
||||||
|
|
||||||
normalize: Optional[bool] = None
|
|
||||||
dimensions: Optional[int] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def _transform_request(
|
|
||||||
self, input: str, inference_params: dict
|
|
||||||
) -> AmazonTitanV2EmbeddingRequest:
|
|
||||||
return AmazonTitanV2EmbeddingRequest(inputText=input, **inference_params) # type: ignore
|
|
||||||
|
|
||||||
def _transform_response(
|
|
||||||
self, response_list: List[dict], model: str
|
|
||||||
) -> EmbeddingResponse:
|
|
||||||
total_prompt_tokens = 0
|
|
||||||
|
|
||||||
transformed_responses: List[Embedding] = []
|
|
||||||
for index, response in enumerate(response_list):
|
|
||||||
_parsed_response = AmazonTitanV2EmbeddingResponse(**response) # type: ignore
|
|
||||||
transformed_responses.append(
|
|
||||||
Embedding(
|
|
||||||
embedding=_parsed_response["embedding"],
|
|
||||||
index=index,
|
|
||||||
object="embedding",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
|
||||||
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=total_prompt_tokens,
|
|
||||||
completion_tokens=0,
|
|
||||||
total_tokens=total_prompt_tokens,
|
|
||||||
)
|
|
||||||
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
|
||||||
|
|
|
@ -17,13 +17,36 @@ from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||||
from litellm.utils import is_base64_encoded
|
from litellm.utils import is_base64_encoded
|
||||||
|
|
||||||
|
|
||||||
def _transform_request(
|
class AmazonTitanMultimodalEmbeddingG1Config:
|
||||||
input: str, inference_params: dict
|
"""
|
||||||
) -> AmazonTitanMultimodalEmbeddingRequest:
|
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-mm.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_supported_openai_params(self) -> List[str]:
|
||||||
|
return ["dimensions"]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k == "dimensions":
|
||||||
|
optional_params["embeddingConfig"] = (
|
||||||
|
AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
|
||||||
|
)
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def _transform_request(
|
||||||
|
self, input: str, inference_params: dict
|
||||||
|
) -> AmazonTitanMultimodalEmbeddingRequest:
|
||||||
## check if b64 encoded str or not ##
|
## check if b64 encoded str or not ##
|
||||||
is_encoded = is_base64_encoded(input)
|
is_encoded = is_base64_encoded(input)
|
||||||
if is_encoded: # check if string is b64 encoded image or not
|
if is_encoded: # check if string is b64 encoded image or not
|
||||||
transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputImage=input)
|
transformed_request = AmazonTitanMultimodalEmbeddingRequest(
|
||||||
|
inputImage=input
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputText=input)
|
transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputText=input)
|
||||||
|
|
||||||
|
@ -32,8 +55,9 @@ def _transform_request(
|
||||||
|
|
||||||
return transformed_request
|
return transformed_request
|
||||||
|
|
||||||
|
def _transform_response(
|
||||||
def _transform_response(response_list: List[dict], model: str) -> EmbeddingResponse:
|
self, response_list: List[dict], model: str
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
|
||||||
total_prompt_tokens = 0
|
total_prompt_tokens = 0
|
||||||
transformed_responses: List[Embedding] = []
|
transformed_responses: List[Embedding] = []
|
||||||
|
@ -41,7 +65,9 @@ def _transform_response(response_list: List[dict], model: str) -> EmbeddingRespo
|
||||||
_parsed_response = AmazonTitanMultimodalEmbeddingResponse(**response) # type: ignore
|
_parsed_response = AmazonTitanMultimodalEmbeddingResponse(**response) # type: ignore
|
||||||
transformed_responses.append(
|
transformed_responses.append(
|
||||||
Embedding(
|
Embedding(
|
||||||
embedding=_parsed_response["embedding"], index=index, object="embedding"
|
embedding=_parsed_response["embedding"],
|
||||||
|
index=index,
|
||||||
|
object="embedding",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||||
|
|
|
@ -56,6 +56,17 @@ class AmazonTitanV2Config:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self) -> List[str]:
|
||||||
|
return ["dimensions"]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k == "dimensions":
|
||||||
|
optional_params["dimensions"] = v
|
||||||
|
return optional_params
|
||||||
|
|
||||||
def _transform_request(
|
def _transform_request(
|
||||||
self, input: str, inference_params: dict
|
self, input: str, inference_params: dict
|
||||||
) -> AmazonTitanV2EmbeddingRequest:
|
) -> AmazonTitanV2EmbeddingRequest:
|
||||||
|
|
|
@ -11,9 +11,24 @@ from litellm.types.llms.bedrock import CohereEmbeddingRequest, CohereEmbeddingRe
|
||||||
from litellm.types.utils import Embedding, EmbeddingResponse
|
from litellm.types.utils import Embedding, EmbeddingResponse
|
||||||
|
|
||||||
|
|
||||||
def _transform_request(
|
class BedrockCohereEmbeddingConfig:
|
||||||
input: List[str], inference_params: dict
|
def __init__(self) -> None:
|
||||||
) -> CohereEmbeddingRequest:
|
pass
|
||||||
|
|
||||||
|
def get_supported_openai_params(self) -> List[str]:
|
||||||
|
return ["encoding_format"]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k == "encoding_format":
|
||||||
|
optional_params["embedding_types"] = v
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def _transform_request(
|
||||||
|
self, input: List[str], inference_params: dict
|
||||||
|
) -> CohereEmbeddingRequest:
|
||||||
transformed_request = CohereEmbeddingRequest(
|
transformed_request = CohereEmbeddingRequest(
|
||||||
texts=input,
|
texts=input,
|
||||||
input_type=litellm.COHERE_DEFAULT_EMBEDDING_INPUT_TYPE, # type: ignore
|
input_type=litellm.COHERE_DEFAULT_EMBEDDING_INPUT_TYPE, # type: ignore
|
||||||
|
|
|
@ -16,6 +16,7 @@ from litellm.llms.cohere.embed import embedding as cohere_embedding
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
|
_get_async_httpx_client,
|
||||||
_get_httpx_client,
|
_get_httpx_client,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
|
from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
|
||||||
|
@ -25,13 +26,10 @@ from ...base_aws_llm import BaseAWSLLM
|
||||||
from ..common_utils import BedrockError, get_runtime_endpoint
|
from ..common_utils import BedrockError, get_runtime_endpoint
|
||||||
from .amazon_titan_g1_transformation import AmazonTitanG1Config
|
from .amazon_titan_g1_transformation import AmazonTitanG1Config
|
||||||
from .amazon_titan_multimodal_transformation import (
|
from .amazon_titan_multimodal_transformation import (
|
||||||
_transform_request as amazon_multimodal_transform_request,
|
AmazonTitanMultimodalEmbeddingG1Config,
|
||||||
)
|
|
||||||
from .amazon_titan_multimodal_transformation import (
|
|
||||||
_transform_response as amazon_multimodal_transform_response,
|
|
||||||
)
|
)
|
||||||
from .amazon_titan_v2_transformation import AmazonTitanV2Config
|
from .amazon_titan_v2_transformation import AmazonTitanV2Config
|
||||||
from .cohere_transformation import _transform_request as cohere_transform_request
|
from .cohere_transformation import BedrockCohereEmbeddingConfig
|
||||||
|
|
||||||
|
|
||||||
class BedrockEmbedding(BaseAWSLLM):
|
class BedrockEmbedding(BaseAWSLLM):
|
||||||
|
@ -118,6 +116,35 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
async def _make_async_call(
|
||||||
|
self,
|
||||||
|
client: Optional[AsyncHTTPHandler],
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: dict,
|
||||||
|
) -> dict:
|
||||||
|
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = _get_async_httpx_client(_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
def _single_func_embeddings(
|
def _single_func_embeddings(
|
||||||
self,
|
self,
|
||||||
client: Optional[HTTPHandler],
|
client: Optional[HTTPHandler],
|
||||||
|
@ -186,9 +213,102 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
|
|
||||||
## TRANSFORM RESPONSE ##
|
## TRANSFORM RESPONSE ##
|
||||||
if model == "amazon.titan-embed-image-v1":
|
if model == "amazon.titan-embed-image-v1":
|
||||||
returned_response = amazon_multimodal_transform_response(
|
returned_response = (
|
||||||
|
AmazonTitanMultimodalEmbeddingG1Config()._transform_response(
|
||||||
response_list=responses, model=model
|
response_list=responses, model=model
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
elif model == "amazon.titan-embed-text-v1":
|
||||||
|
returned_response = AmazonTitanG1Config()._transform_response(
|
||||||
|
response_list=responses, model=model
|
||||||
|
)
|
||||||
|
elif model == "amazon.titan-embed-text-v2:0":
|
||||||
|
returned_response = AmazonTitanV2Config()._transform_response(
|
||||||
|
response_list=responses, model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
if returned_response is None:
|
||||||
|
raise Exception(
|
||||||
|
"Unable to map model response to known provider format. model={}".format(
|
||||||
|
model
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return returned_response
|
||||||
|
|
||||||
|
async def _async_single_func_embeddings(
|
||||||
|
self,
|
||||||
|
client: Optional[AsyncHTTPHandler],
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
batch_data: List[dict],
|
||||||
|
credentials: Any,
|
||||||
|
extra_headers: Optional[dict],
|
||||||
|
endpoint_url: str,
|
||||||
|
aws_region_name: str,
|
||||||
|
model: str,
|
||||||
|
logging_obj: Any,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
|
responses: List[dict] = []
|
||||||
|
for data in batch_data:
|
||||||
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
|
||||||
|
)
|
||||||
|
sigv4.add_auth(request)
|
||||||
|
if (
|
||||||
|
extra_headers is not None and "Authorization" in extra_headers
|
||||||
|
): # prevent sigv4 from overwriting the auth header
|
||||||
|
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||||
|
prepped = request.prepare()
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=data,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": prepped.url,
|
||||||
|
"headers": prepped.headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response = await self._make_async_call(
|
||||||
|
client=client,
|
||||||
|
timeout=timeout,
|
||||||
|
api_base=prepped.url,
|
||||||
|
headers=prepped.headers,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=data,
|
||||||
|
api_key="",
|
||||||
|
original_response=response,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
returned_response: Optional[EmbeddingResponse] = None
|
||||||
|
|
||||||
|
## TRANSFORM RESPONSE ##
|
||||||
|
if model == "amazon.titan-embed-image-v1":
|
||||||
|
returned_response = (
|
||||||
|
AmazonTitanMultimodalEmbeddingG1Config()._transform_response(
|
||||||
|
response_list=responses, model=model
|
||||||
|
)
|
||||||
|
)
|
||||||
elif model == "amazon.titan-embed-text-v1":
|
elif model == "amazon.titan-embed-text-v1":
|
||||||
returned_response = AmazonTitanG1Config()._transform_response(
|
returned_response = AmazonTitanG1Config()._transform_response(
|
||||||
response_list=responses, model=model
|
response_list=responses, model=model
|
||||||
|
@ -246,7 +366,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
data: Optional[CohereEmbeddingRequest] = None
|
data: Optional[CohereEmbeddingRequest] = None
|
||||||
batch_data: Optional[List] = None
|
batch_data: Optional[List] = None
|
||||||
if provider == "cohere":
|
if provider == "cohere":
|
||||||
data = cohere_transform_request(
|
data = BedrockCohereEmbeddingConfig()._transform_request(
|
||||||
input=input, inference_params=inference_params
|
input=input, inference_params=inference_params
|
||||||
)
|
)
|
||||||
elif provider == "amazon" and model in [
|
elif provider == "amazon" and model in [
|
||||||
|
@ -257,11 +377,11 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
batch_data = []
|
batch_data = []
|
||||||
for i in input:
|
for i in input:
|
||||||
if model == "amazon.titan-embed-image-v1":
|
if model == "amazon.titan-embed-image-v1":
|
||||||
transformed_request: AmazonEmbeddingRequest = (
|
transformed_request: (
|
||||||
amazon_multimodal_transform_request(
|
AmazonEmbeddingRequest
|
||||||
|
) = AmazonTitanMultimodalEmbeddingG1Config()._transform_request(
|
||||||
input=i, inference_params=inference_params
|
input=i, inference_params=inference_params
|
||||||
)
|
)
|
||||||
)
|
|
||||||
elif model == "amazon.titan-embed-text-v1":
|
elif model == "amazon.titan-embed-text-v1":
|
||||||
transformed_request = AmazonTitanG1Config()._transform_request(
|
transformed_request = AmazonTitanG1Config()._transform_request(
|
||||||
input=i, inference_params=inference_params
|
input=i, inference_params=inference_params
|
||||||
|
@ -283,6 +403,22 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
||||||
|
|
||||||
if batch_data is not None:
|
if batch_data is not None:
|
||||||
|
if aembedding:
|
||||||
|
return self._async_single_func_embeddings( # type: ignore
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
timeout=timeout,
|
||||||
|
batch_data=batch_data,
|
||||||
|
credentials=credentials,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
model=model,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
return self._single_func_embeddings(
|
return self._single_func_embeddings(
|
||||||
client=(
|
client=(
|
||||||
client
|
client
|
||||||
|
|
|
@ -703,8 +703,16 @@ class ModelResponseIterator:
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = processed_chunk.choices[0].finish_reason
|
finish_reason = processed_chunk.choices[0].finish_reason
|
||||||
|
|
||||||
if hasattr(processed_chunk, "usage"):
|
if hasattr(processed_chunk, "usage") and isinstance(
|
||||||
usage = processed_chunk.usage # type: ignore
|
processed_chunk.usage, litellm.Usage
|
||||||
|
):
|
||||||
|
usage_chunk: litellm.Usage = processed_chunk.usage
|
||||||
|
|
||||||
|
usage = ChatCompletionUsageBlock(
|
||||||
|
prompt_tokens=usage_chunk.prompt_tokens,
|
||||||
|
completion_tokens=usage_chunk.completion_tokens,
|
||||||
|
total_tokens=usage_chunk.total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
return GenericStreamingChunk(
|
return GenericStreamingChunk(
|
||||||
text=text,
|
text=text,
|
||||||
|
|
|
@ -75,7 +75,6 @@ from litellm.utils import (
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from .caching import disable_cache, enable_cache, update_cache
|
from .caching import disable_cache, enable_cache, update_cache
|
||||||
from .llms import (
|
from .llms import (
|
||||||
ai21,
|
|
||||||
aleph_alpha,
|
aleph_alpha,
|
||||||
baseten,
|
baseten,
|
||||||
clarifai,
|
clarifai,
|
||||||
|
@ -91,6 +90,7 @@ from .llms import (
|
||||||
replicate,
|
replicate,
|
||||||
vllm,
|
vllm,
|
||||||
)
|
)
|
||||||
|
from .llms.AI21 import completion as ai21
|
||||||
from .llms.anthropic.chat import AnthropicChatCompletion
|
from .llms.anthropic.chat import AnthropicChatCompletion
|
||||||
from .llms.anthropic.completion import AnthropicTextCompletion
|
from .llms.anthropic.completion import AnthropicTextCompletion
|
||||||
from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params
|
from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||||
|
@ -391,6 +391,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
or custom_llm_provider == "nvidia_nim"
|
or custom_llm_provider == "nvidia_nim"
|
||||||
or custom_llm_provider == "cerebras"
|
or custom_llm_provider == "cerebras"
|
||||||
|
or custom_llm_provider == "ai21_chat"
|
||||||
or custom_llm_provider == "volcengine"
|
or custom_llm_provider == "volcengine"
|
||||||
or custom_llm_provider == "codestral"
|
or custom_llm_provider == "codestral"
|
||||||
or custom_llm_provider == "text-completion-codestral"
|
or custom_llm_provider == "text-completion-codestral"
|
||||||
|
@ -1297,6 +1298,7 @@ def completion(
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
or custom_llm_provider == "nvidia_nim"
|
or custom_llm_provider == "nvidia_nim"
|
||||||
or custom_llm_provider == "cerebras"
|
or custom_llm_provider == "cerebras"
|
||||||
|
or custom_llm_provider == "ai21_chat"
|
||||||
or custom_llm_provider == "volcengine"
|
or custom_llm_provider == "volcengine"
|
||||||
or custom_llm_provider == "codestral"
|
or custom_llm_provider == "codestral"
|
||||||
or custom_llm_provider == "deepseek"
|
or custom_llm_provider == "deepseek"
|
||||||
|
@ -3147,6 +3149,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
or custom_llm_provider == "nvidia_nim"
|
or custom_llm_provider == "nvidia_nim"
|
||||||
or custom_llm_provider == "cerebras"
|
or custom_llm_provider == "cerebras"
|
||||||
|
or custom_llm_provider == "ai21_chat"
|
||||||
or custom_llm_provider == "volcengine"
|
or custom_llm_provider == "volcengine"
|
||||||
or custom_llm_provider == "deepseek"
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "fireworks_ai"
|
or custom_llm_provider == "fireworks_ai"
|
||||||
|
@ -3811,6 +3814,7 @@ async def atext_completion(
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
or custom_llm_provider == "nvidia_nim"
|
or custom_llm_provider == "nvidia_nim"
|
||||||
or custom_llm_provider == "cerebras"
|
or custom_llm_provider == "cerebras"
|
||||||
|
or custom_llm_provider == "ai21_chat"
|
||||||
or custom_llm_provider == "volcengine"
|
or custom_llm_provider == "volcengine"
|
||||||
or custom_llm_provider == "text-completion-codestral"
|
or custom_llm_provider == "text-completion-codestral"
|
||||||
or custom_llm_provider == "deepseek"
|
or custom_llm_provider == "deepseek"
|
||||||
|
@ -5435,6 +5439,9 @@ def stream_chunk_builder(
|
||||||
# # Update usage information if needed
|
# # Update usage information if needed
|
||||||
prompt_tokens = 0
|
prompt_tokens = 0
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
|
## anthropic prompt caching information ##
|
||||||
|
cache_creation_input_tokens: Optional[int] = None
|
||||||
|
cache_read_input_tokens: Optional[int] = None
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
usage_chunk: Optional[Usage] = None
|
usage_chunk: Optional[Usage] = None
|
||||||
if "usage" in chunk:
|
if "usage" in chunk:
|
||||||
|
@ -5446,6 +5453,13 @@ def stream_chunk_builder(
|
||||||
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
|
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
|
||||||
if "completion_tokens" in usage_chunk:
|
if "completion_tokens" in usage_chunk:
|
||||||
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
|
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
|
||||||
|
if "cache_creation_input_tokens" in usage_chunk:
|
||||||
|
cache_creation_input_tokens = usage_chunk.get(
|
||||||
|
"cache_creation_input_tokens"
|
||||||
|
)
|
||||||
|
if "cache_read_input_tokens" in usage_chunk:
|
||||||
|
cache_read_input_tokens = usage_chunk.get("cache_read_input_tokens")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
|
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
|
||||||
model=model, messages=messages
|
model=model, messages=messages
|
||||||
|
@ -5464,6 +5478,13 @@ def stream_chunk_builder(
|
||||||
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
|
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cache_creation_input_tokens is not None:
|
||||||
|
response["usage"][
|
||||||
|
"cache_creation_input_tokens"
|
||||||
|
] = cache_creation_input_tokens
|
||||||
|
if cache_read_input_tokens is not None:
|
||||||
|
response["usage"]["cache_read_input_tokens"] = cache_read_input_tokens
|
||||||
|
|
||||||
return convert_to_model_response_object(
|
return convert_to_model_response_object(
|
||||||
response_object=response,
|
response_object=response,
|
||||||
model_response_object=model_response,
|
model_response_object=model_response,
|
||||||
|
|
|
@ -286,8 +286,35 @@
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
"ft:gpt-3.5-turbo": {
|
"ft:gpt-3.5-turbo": {
|
||||||
"max_tokens": 4097,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 4097,
|
"max_input_tokens": 16385,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000006,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"ft:gpt-3.5-turbo-0125": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 16385,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000006,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"ft:gpt-3.5-turbo-1106": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 16385,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000006,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"ft:gpt-3.5-turbo-0613": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 4096,
|
||||||
"max_output_tokens": 4096,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000003,
|
"input_cost_per_token": 0.000003,
|
||||||
"output_cost_per_token": 0.000006,
|
"output_cost_per_token": 0.000006,
|
||||||
|
|
|
@ -2,3 +2,16 @@ model_list:
|
||||||
- model_name: "gpt-3.5-turbo"
|
- model_name: "gpt-3.5-turbo"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "gpt-3.5-turbo"
|
model: "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
max_internal_user_budget: 0.02 # amount in USD
|
||||||
|
internal_user_budget_duration: "1s" # reset every second
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
alerting: ["slack"]
|
||||||
|
alerting_threshold: 0.0001 # (Seconds) set an artifically low threshold for testing alerting
|
||||||
|
alert_to_webhook_url: {
|
||||||
|
"spend_reports": ["https://webhook.site/7843a980-a494-4967-80fb-d502dbc16886", "https://webhook.site/28cfb179-f4fb-4408-8129-729ff55cf213"]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -1632,6 +1632,16 @@ class AllCallbacks(LiteLLMBase):
|
||||||
ui_callback_name="Langsmith",
|
ui_callback_name="Langsmith",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
lago: CallbackOnUI = CallbackOnUI(
|
||||||
|
litellm_callback_name="lago",
|
||||||
|
litellm_callback_params=[
|
||||||
|
"LAGO_API_BASE",
|
||||||
|
"LAGO_API_KEY",
|
||||||
|
"LAGO_API_EVENT_CODE",
|
||||||
|
],
|
||||||
|
ui_callback_name="Lago Billing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SpendLogsMetadata(TypedDict):
|
class SpendLogsMetadata(TypedDict):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -505,6 +505,10 @@ async def user_update(
|
||||||
): # models default to [], spend defaults to 0, we should not reset these values
|
): # models default to [], spend defaults to 0, we should not reset these values
|
||||||
non_default_values[k] = v
|
non_default_values[k] = v
|
||||||
|
|
||||||
|
is_internal_user = False
|
||||||
|
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
||||||
|
is_internal_user = True
|
||||||
|
|
||||||
if "budget_duration" in non_default_values:
|
if "budget_duration" in non_default_values:
|
||||||
duration_s = _duration_in_seconds(
|
duration_s = _duration_in_seconds(
|
||||||
duration=non_default_values["budget_duration"]
|
duration=non_default_values["budget_duration"]
|
||||||
|
@ -512,6 +516,20 @@ async def user_update(
|
||||||
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||||
non_default_values["budget_reset_at"] = user_reset_at
|
non_default_values["budget_reset_at"] = user_reset_at
|
||||||
|
|
||||||
|
if "max_budget" not in non_default_values:
|
||||||
|
if (
|
||||||
|
is_internal_user and litellm.max_internal_user_budget is not None
|
||||||
|
): # applies internal user limits, if user role updated
|
||||||
|
non_default_values["max_budget"] = litellm.max_internal_user_budget
|
||||||
|
|
||||||
|
if (
|
||||||
|
"budget_duration" not in non_default_values
|
||||||
|
): # applies internal user limits, if user role updated
|
||||||
|
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
||||||
|
non_default_values["budget_duration"] = (
|
||||||
|
litellm.internal_user_budget_duration
|
||||||
|
)
|
||||||
|
|
||||||
## ADD USER, IF NEW ##
|
## ADD USER, IF NEW ##
|
||||||
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
|
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
|
||||||
if data.user_id is not None and len(data.user_id) > 0:
|
if data.user_id is not None and len(data.user_id) > 0:
|
||||||
|
|
32
litellm/proxy/secret_managers/get_azure_ad_token_provider.py
Normal file
32
litellm/proxy/secret_managers/get_azure_ad_token_provider.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import os
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
|
def get_azure_ad_token_provider() -> Callable[[], str]:
|
||||||
|
"""
|
||||||
|
Get Azure AD token provider based on Service Principal with Secret workflow.
|
||||||
|
|
||||||
|
Based on: https://github.com/openai/openai-python/blob/main/examples/azure_ad.py
|
||||||
|
See Also:
|
||||||
|
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#service-principal-with-secret;
|
||||||
|
https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable that returns a temporary authentication token.
|
||||||
|
"""
|
||||||
|
from azure.identity import ClientSecretCredential
|
||||||
|
from azure.identity import get_bearer_token_provider
|
||||||
|
|
||||||
|
try:
|
||||||
|
credential = ClientSecretCredential(
|
||||||
|
client_id=os.environ["AZURE_CLIENT_ID"],
|
||||||
|
client_secret=os.environ["AZURE_CLIENT_SECRET"],
|
||||||
|
tenant_id=os.environ["AZURE_TENANT_ID"],
|
||||||
|
)
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError("Missing environment variable required by Azure AD workflow.") from e
|
||||||
|
|
||||||
|
return get_bearer_token_provider(
|
||||||
|
credential,
|
||||||
|
"https://cognitiveservices.azure.com/.default",
|
||||||
|
)
|
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
from typing import TYPE_CHECKING, Any, Callable
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
|
@ -9,6 +9,9 @@ import openai
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
|
from litellm.proxy.secret_managers.get_azure_ad_token_provider import (
|
||||||
|
get_azure_ad_token_provider,
|
||||||
|
)
|
||||||
from litellm.utils import calculate_max_parallel_requests
|
from litellm.utils import calculate_max_parallel_requests
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -172,7 +175,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
organization_env_name = organization.replace("os.environ/", "")
|
organization_env_name = organization.replace("os.environ/", "")
|
||||||
organization = litellm.get_secret(organization_env_name)
|
organization = litellm.get_secret(organization_env_name)
|
||||||
litellm_params["organization"] = organization
|
litellm_params["organization"] = organization
|
||||||
azure_ad_token_provider = None
|
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||||
if litellm_params.get("tenant_id"):
|
if litellm_params.get("tenant_id"):
|
||||||
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
||||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
||||||
|
@ -197,6 +200,16 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
if azure_ad_token is not None:
|
if azure_ad_token is not None:
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
|
elif (
|
||||||
|
azure_ad_token_provider is None
|
||||||
|
and litellm.enable_azure_ad_token_refresh is True
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
azure_ad_token_provider = get_azure_ad_token_provider()
|
||||||
|
except ValueError:
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
"Azure AD Token Provider could not be used."
|
||||||
|
)
|
||||||
if api_version is None:
|
if api_version is None:
|
||||||
api_version = os.getenv(
|
api_version = os.getenv(
|
||||||
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
|
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
|
||||||
|
@ -211,6 +224,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
_client = openai.AsyncAzureOpenAI(
|
_client = openai.AsyncAzureOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -236,6 +250,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
_client = openai.AzureOpenAI( # type: ignore
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -258,6 +273,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
|
@ -283,6 +299,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
_client = openai.AzureOpenAI( # type: ignore
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
|
@ -313,6 +330,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
"azure_endpoint": api_base,
|
"azure_endpoint": api_base,
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
"azure_ad_token": azure_ad_token,
|
"azure_ad_token": azure_ad_token,
|
||||||
|
"azure_ad_token_provider": azure_ad_token_provider,
|
||||||
}
|
}
|
||||||
|
|
||||||
if azure_ad_token_provider is not None:
|
if azure_ad_token_provider is not None:
|
||||||
|
|
|
@ -282,6 +282,83 @@ async def test_anthropic_api_prompt_caching_no_headers():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
|
async def test_anthropic_api_prompt_caching_streaming():
|
||||||
|
from litellm.tests.test_streaming import streaming_format_tests
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
messages=[
|
||||||
|
# System Message
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement"
|
||||||
|
* 400,
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
|
||||||
|
},
|
||||||
|
# The final turn is marked with cache-control, for continuing in followups.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=10,
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
is_cache_read_input_tokens_in_usage = False
|
||||||
|
is_cache_creation_input_tokens_in_usage = False
|
||||||
|
async for chunk in response:
|
||||||
|
streaming_format_tests(idx=idx, chunk=chunk)
|
||||||
|
# Assert either a cache entry was created or cache was read - changes depending on the anthropic api ttl
|
||||||
|
if hasattr(chunk, "usage"):
|
||||||
|
print("Received final usage - {}".format(chunk.usage))
|
||||||
|
if hasattr(chunk, "usage") and hasattr(chunk.usage, "cache_read_input_tokens"):
|
||||||
|
is_cache_read_input_tokens_in_usage = True
|
||||||
|
if hasattr(chunk, "usage") and hasattr(
|
||||||
|
chunk.usage, "cache_creation_input_tokens"
|
||||||
|
):
|
||||||
|
is_cache_creation_input_tokens_in_usage = True
|
||||||
|
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
print("response=", response)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
is_cache_read_input_tokens_in_usage and is_cache_creation_input_tokens_in_usage
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_litellm_anthropic_prompt_caching_system():
|
async def test_litellm_anthropic_prompt_caching_system():
|
||||||
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#prompt-caching-examples
|
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#prompt-caching-examples
|
||||||
|
|
|
@ -2172,7 +2172,14 @@ def test_completion_openai():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["gpt-4o-2024-08-06", "azure/chatgpt-v-2"])
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"gpt-4o-2024-08-06",
|
||||||
|
"azure/chatgpt-v-2",
|
||||||
|
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_completion_openai_pydantic(model):
|
def test_completion_openai_pydantic(model):
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -2201,7 +2208,7 @@ def test_completion_openai_pydantic(model):
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except litellm.JSONSchemaValidationError:
|
except litellm.JSONSchemaValidationError:
|
||||||
print("ERROR OCCURRED! INVALID JSON")
|
pytest.fail("ERROR OCCURRED! INVALID JSON")
|
||||||
|
|
||||||
print("This is the response object\n", response)
|
print("This is the response object\n", response)
|
||||||
|
|
||||||
|
@ -4474,3 +4481,23 @@ async def test_dynamic_azure_params(stream, sync_mode):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_stack()
|
traceback.print_stack()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
|
async def test_completion_ai21_chat():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="jamba-1.5-large",
|
||||||
|
user="ishaan",
|
||||||
|
tool_choice="auto",
|
||||||
|
seed=123,
|
||||||
|
messages=[{"role": "user", "content": "what does the document say"}],
|
||||||
|
documents=[
|
||||||
|
{
|
||||||
|
"content": "hello world",
|
||||||
|
"metadata": {"source": "google", "author": "ishaan"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
|
@ -319,9 +319,52 @@ async def test_cohere_embedding3(custom_llm_provider):
|
||||||
"bedrock/amazon.titan-embed-text-v2:0",
|
"bedrock/amazon.titan-embed-text-v2:0",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("sync_mode", [True])
|
@pytest.mark.parametrize("sync_mode", [True, False]) # ,
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bedrock_embedding_titan(model, sync_mode):
|
async def test_bedrock_embedding_titan(model, sync_mode):
|
||||||
|
try:
|
||||||
|
# this tests if we support str input for bedrock embedding
|
||||||
|
litellm.set_verbose = True
|
||||||
|
litellm.enable_cache()
|
||||||
|
import time
|
||||||
|
|
||||||
|
current_time = str(time.time())
|
||||||
|
# DO NOT MAKE THE INPUT A LIST in this test
|
||||||
|
if sync_mode:
|
||||||
|
response = embedding(
|
||||||
|
model=model,
|
||||||
|
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||||
|
aws_region_name="us-west-2",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model=model,
|
||||||
|
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||||
|
aws_region_name="us-west-2",
|
||||||
|
)
|
||||||
|
print("response:", response)
|
||||||
|
assert isinstance(
|
||||||
|
response["data"][0]["embedding"], list
|
||||||
|
), "Expected response to be a list"
|
||||||
|
print("type of first embedding:", type(response["data"][0]["embedding"][0]))
|
||||||
|
assert all(
|
||||||
|
isinstance(x, float) for x in response["data"][0]["embedding"]
|
||||||
|
), "Expected response to be a list of floats"
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"bedrock/amazon.titan-embed-text-v1",
|
||||||
|
"bedrock/amazon.titan-embed-image-v1",
|
||||||
|
"bedrock/amazon.titan-embed-text-v2:0",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True]) # True,
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bedrock_embedding_titan_caching(model, sync_mode):
|
||||||
try:
|
try:
|
||||||
# this tests if we support str input for bedrock embedding
|
# this tests if we support str input for bedrock embedding
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -75,3 +75,28 @@ def test_get_llm_provider_vertex_ai_image_models():
|
||||||
model="imagegeneration@006", custom_llm_provider=None
|
model="imagegeneration@006", custom_llm_provider=None
|
||||||
)
|
)
|
||||||
assert custom_llm_provider == "vertex_ai"
|
assert custom_llm_provider == "vertex_ai"
|
||||||
|
|
||||||
|
def test_get_llm_provider_ai21_chat():
|
||||||
|
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||||
|
model="jamba-1.5-large",
|
||||||
|
)
|
||||||
|
assert custom_llm_provider == "ai21_chat"
|
||||||
|
assert model == "jamba-1.5-large"
|
||||||
|
assert api_base == "https://api.ai21.com/studio/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llm_provider_ai21_chat_test2():
|
||||||
|
"""
|
||||||
|
if user prefix with ai21/ but calls jamba-1.5-large then it should be ai21_chat provider
|
||||||
|
"""
|
||||||
|
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||||
|
model="ai21/jamba-1.5-large",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("model=", model)
|
||||||
|
print("custom_llm_provider=", custom_llm_provider)
|
||||||
|
print("api_base=", api_base)
|
||||||
|
assert custom_llm_provider == "ai21_chat"
|
||||||
|
assert model == "jamba-1.5-large"
|
||||||
|
assert api_base == "https://api.ai21.com/studio/v1"
|
||||||
|
|
||||||
|
|
|
@ -70,13 +70,43 @@ def test_anthropic_optional_params(stop_sequence, expected_count):
|
||||||
def test_bedrock_optional_params_embeddings():
|
def test_bedrock_optional_params_embeddings():
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
optional_params = get_optional_params_embeddings(
|
optional_params = get_optional_params_embeddings(
|
||||||
user="John", encoding_format=None, custom_llm_provider="bedrock"
|
model="", user="John", encoding_format=None, custom_llm_provider="bedrock"
|
||||||
)
|
)
|
||||||
assert len(optional_params) == 0
|
assert len(optional_params) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, expected_dimensions, dimensions_kwarg",
|
||||||
|
[
|
||||||
|
("bedrock/amazon.titan-embed-text-v1", False, None),
|
||||||
|
("bedrock/amazon.titan-embed-image-v1", True, "embeddingConfig"),
|
||||||
|
("bedrock/amazon.titan-embed-text-v2:0", True, "dimensions"),
|
||||||
|
("bedrock/cohere.embed-multilingual-v3", False, None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_bedrock_optional_params_embeddings_dimension(
|
||||||
|
model, expected_dimensions, dimensions_kwarg
|
||||||
|
):
|
||||||
|
litellm.drop_params = True
|
||||||
|
optional_params = get_optional_params_embeddings(
|
||||||
|
model=model,
|
||||||
|
user="John",
|
||||||
|
encoding_format=None,
|
||||||
|
dimensions=20,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
)
|
||||||
|
if expected_dimensions:
|
||||||
|
assert len(optional_params) == 1
|
||||||
|
else:
|
||||||
|
assert len(optional_params) == 0
|
||||||
|
|
||||||
|
if dimensions_kwarg is not None:
|
||||||
|
assert dimensions_kwarg in optional_params
|
||||||
|
|
||||||
|
|
||||||
def test_google_ai_studio_optional_params_embeddings():
|
def test_google_ai_studio_optional_params_embeddings():
|
||||||
optional_params = get_optional_params_embeddings(
|
optional_params = get_optional_params_embeddings(
|
||||||
|
model="",
|
||||||
user="John",
|
user="John",
|
||||||
encoding_format=None,
|
encoding_format=None,
|
||||||
custom_llm_provider="gemini",
|
custom_llm_provider="gemini",
|
||||||
|
@ -88,7 +118,7 @@ def test_google_ai_studio_optional_params_embeddings():
|
||||||
def test_openai_optional_params_embeddings():
|
def test_openai_optional_params_embeddings():
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
optional_params = get_optional_params_embeddings(
|
optional_params = get_optional_params_embeddings(
|
||||||
user="John", encoding_format=None, custom_llm_provider="openai"
|
model="", user="John", encoding_format=None, custom_llm_provider="openai"
|
||||||
)
|
)
|
||||||
assert len(optional_params) == 1
|
assert len(optional_params) == 1
|
||||||
assert optional_params["user"] == "John"
|
assert optional_params["user"] == "John"
|
||||||
|
@ -97,7 +127,10 @@ def test_openai_optional_params_embeddings():
|
||||||
def test_azure_optional_params_embeddings():
|
def test_azure_optional_params_embeddings():
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
optional_params = get_optional_params_embeddings(
|
optional_params = get_optional_params_embeddings(
|
||||||
user="John", encoding_format=None, custom_llm_provider="azure"
|
model="chatgpt-v-2",
|
||||||
|
user="John",
|
||||||
|
encoding_format=None,
|
||||||
|
custom_llm_provider="azure",
|
||||||
)
|
)
|
||||||
assert len(optional_params) == 1
|
assert len(optional_params) == 1
|
||||||
assert optional_params["user"] == "John"
|
assert optional_params["user"] == "John"
|
||||||
|
@ -455,6 +488,7 @@ def test_get_optional_params_image_gen():
|
||||||
|
|
||||||
def test_bedrock_optional_params_embeddings_provider_specific_params():
|
def test_bedrock_optional_params_embeddings_provider_specific_params():
|
||||||
optional_params = get_optional_params_embeddings(
|
optional_params = get_optional_params_embeddings(
|
||||||
|
model="my-custom-model",
|
||||||
custom_llm_provider="huggingface",
|
custom_llm_provider="huggingface",
|
||||||
wait_for_model=True,
|
wait_for_model=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,17 +1,25 @@
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
# This tests client initialization + reinitialization on the router
|
# This tests client initialization + reinitialization on the router
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
# This tests caching on the router
|
# This tests caching on the router
|
||||||
import sys, os, time
|
import sys
|
||||||
import traceback, asyncio
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Dict
|
||||||
|
from unittest.mock import MagicMock, PropertyMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from openai.lib.azure import OpenAIError
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router
|
from litellm import APIConnectionError, Router
|
||||||
|
|
||||||
|
|
||||||
async def test_router_init():
|
async def test_router_init():
|
||||||
|
@ -75,4 +83,133 @@ async def test_router_init():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os")
|
||||||
|
def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret(
|
||||||
|
mocked_os_lib: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test router initialization with neither API key nor using Azure Service Principal with Secret authentication
|
||||||
|
workflow (having not provided environment variables).
|
||||||
|
"""
|
||||||
|
litellm.enable_azure_ad_token_refresh = True
|
||||||
|
# mock EMPTY environment variables
|
||||||
|
environment_variables_expected_to_use: Dict = {}
|
||||||
|
mocked_environ = PropertyMock(return_value=environment_variables_expected_to_use)
|
||||||
|
# Because of the way mock attributes are stored you can’t directly attach a PropertyMock to a mock object.
|
||||||
|
# https://docs.python.org/3.11/library/unittest.mock.html#unittest.mock.PropertyMock
|
||||||
|
type(mocked_os_lib).environ = mocked_environ
|
||||||
|
|
||||||
|
# define the model list
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
# test case for Azure Service Principal with Secret authentication
|
||||||
|
"model_name": "gpt-4o",
|
||||||
|
"litellm_params": {
|
||||||
|
# checkout there is no api_key here -
|
||||||
|
# AZURE_CLIENT_ID, AZURE_CLIENT_SECRET and AZURE_TENANT_ID environment variables should be used instead
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"base_model": "gpt-4o",
|
||||||
|
"api_base": "test_api_base",
|
||||||
|
"api_version": "2024-01-01-preview",
|
||||||
|
"custom_llm_provider": "azure",
|
||||||
|
},
|
||||||
|
"model_info": {"mode": "completion"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# initialize the router
|
||||||
|
with pytest.raises(OpenAIError):
|
||||||
|
# it would raise an error, because environment variables were not provided => azure_ad_token_provider is None
|
||||||
|
Router(model_list=model_list)
|
||||||
|
|
||||||
|
# check if the mocked environment variables were reached
|
||||||
|
mocked_environ.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("azure.identity.get_bearer_token_provider")
|
||||||
|
@patch("azure.identity.ClientSecretCredential")
|
||||||
|
@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os")
|
||||||
|
def test_router_init_azure_service_principal_with_secret_with_environment_variables(
|
||||||
|
mocked_os_lib: MagicMock,
|
||||||
|
mocked_credential: MagicMock,
|
||||||
|
mocked_get_bearer_token_provider: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test router initialization and sample completion using Azure Service Principal with Secret authentication workflow,
|
||||||
|
having provided the (mocked) credentials in environment variables and not provided any API key.
|
||||||
|
|
||||||
|
To allow for local testing without real credentials, first must mock Azure SDK authentication functions
|
||||||
|
and environment variables.
|
||||||
|
"""
|
||||||
|
litellm.enable_azure_ad_token_refresh = True
|
||||||
|
# mock the token provider function
|
||||||
|
mocked_func_generating_token = MagicMock(return_value="test_token")
|
||||||
|
mocked_get_bearer_token_provider.return_value = mocked_func_generating_token
|
||||||
|
|
||||||
|
# mock the environment variables with mocked credentials
|
||||||
|
environment_variables_expected_to_use = {
|
||||||
|
"AZURE_CLIENT_ID": "test_client_id",
|
||||||
|
"AZURE_CLIENT_SECRET": "test_client_secret",
|
||||||
|
"AZURE_TENANT_ID": "test_tenant_id",
|
||||||
|
}
|
||||||
|
mocked_environ = PropertyMock(return_value=environment_variables_expected_to_use)
|
||||||
|
# Because of the way mock attributes are stored you can’t directly attach a PropertyMock to a mock object.
|
||||||
|
# https://docs.python.org/3.11/library/unittest.mock.html#unittest.mock.PropertyMock
|
||||||
|
type(mocked_os_lib).environ = mocked_environ
|
||||||
|
|
||||||
|
# define the model list
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
# test case for Azure Service Principal with Secret authentication
|
||||||
|
"model_name": "gpt-4o",
|
||||||
|
"litellm_params": {
|
||||||
|
# checkout there is no api_key here -
|
||||||
|
# AZURE_CLIENT_ID, AZURE_CLIENT_SECRET and AZURE_TENANT_ID environment variables should be used instead
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"base_model": "gpt-4o",
|
||||||
|
"api_base": "test_api_base",
|
||||||
|
"api_version": "2024-01-01-preview",
|
||||||
|
"custom_llm_provider": "azure",
|
||||||
|
},
|
||||||
|
"model_info": {"mode": "completion"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# initialize the router
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
|
||||||
|
# first check if environment variables were used at all
|
||||||
|
mocked_environ.assert_called()
|
||||||
|
# then check if the client was initialized with the correct environment variables
|
||||||
|
mocked_credential.assert_called_with(
|
||||||
|
**{
|
||||||
|
"client_id": environment_variables_expected_to_use["AZURE_CLIENT_ID"],
|
||||||
|
"client_secret": environment_variables_expected_to_use[
|
||||||
|
"AZURE_CLIENT_SECRET"
|
||||||
|
],
|
||||||
|
"tenant_id": environment_variables_expected_to_use["AZURE_TENANT_ID"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# check if the token provider was called at all
|
||||||
|
mocked_get_bearer_token_provider.assert_called()
|
||||||
|
# then check if the token provider was initialized with the mocked credential
|
||||||
|
for call_args in mocked_get_bearer_token_provider.call_args_list:
|
||||||
|
assert call_args.args[0] == mocked_credential.return_value
|
||||||
|
# however, at this point token should not be fetched yet
|
||||||
|
mocked_func_generating_token.assert_not_called()
|
||||||
|
|
||||||
|
# now let's try to make a completion call
|
||||||
|
deployment = model_list[0]
|
||||||
|
model = deployment["model_name"]
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||||
|
]
|
||||||
|
with pytest.raises(APIConnectionError):
|
||||||
|
# of course, it will raise an error, because URL is mocked
|
||||||
|
router.completion(model=model, messages=messages, temperature=1) # type: ignore
|
||||||
|
|
||||||
|
# finally verify if the mocked token was used by Azure SDK
|
||||||
|
mocked_func_generating_token.assert_called()
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(test_router_init())
|
# asyncio.run(test_router_init())
|
||||||
|
|
|
@ -586,6 +586,37 @@ async def test_completion_predibase_streaming(sync_mode):
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
|
async def test_completion_ai21_stream():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="ai21_chat/jamba-1.5-large",
|
||||||
|
user="ishaan",
|
||||||
|
stream=True,
|
||||||
|
seed=123,
|
||||||
|
messages=[{"role": "user", "content": "hi my name is ishaan"}],
|
||||||
|
)
|
||||||
|
complete_response = ""
|
||||||
|
idx = 0
|
||||||
|
async for init_chunk in response:
|
||||||
|
chunk, finished = streaming_format_tests(idx, init_chunk)
|
||||||
|
complete_response += chunk
|
||||||
|
custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"]
|
||||||
|
print(f"custom_llm_provider: {custom_llm_provider}")
|
||||||
|
assert custom_llm_provider == "ai21_chat"
|
||||||
|
idx += 1
|
||||||
|
if finished:
|
||||||
|
assert isinstance(init_chunk.choices[0], litellm.utils.StreamingChoices)
|
||||||
|
break
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
|
||||||
|
print(f"complete_response: {complete_response}")
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_completion_azure_function_calling_stream():
|
def test_completion_azure_function_calling_stream():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
|
|
@ -287,3 +287,11 @@ class AnthropicResponse(BaseModel):
|
||||||
|
|
||||||
usage: AnthropicResponseUsageBlock
|
usage: AnthropicResponseUsageBlock
|
||||||
"""Billing and rate-limit usage."""
|
"""Billing and rate-limit usage."""
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicChatCompletionUsageBlock(TypedDict, total=False):
|
||||||
|
prompt_tokens: Required[int]
|
||||||
|
completion_tokens: Required[int]
|
||||||
|
total_tokens: Required[int]
|
||||||
|
cache_creation_input_tokens: int
|
||||||
|
cache_read_input_tokens: int
|
||||||
|
|
137
litellm/utils.py
137
litellm/utils.py
|
@ -2550,7 +2550,7 @@ def get_optional_params_image_gen(
|
||||||
|
|
||||||
def get_optional_params_embeddings(
|
def get_optional_params_embeddings(
|
||||||
# 2 optional params
|
# 2 optional params
|
||||||
model=None,
|
model: str,
|
||||||
user=None,
|
user=None,
|
||||||
encoding_format=None,
|
encoding_format=None,
|
||||||
dimensions=None,
|
dimensions=None,
|
||||||
|
@ -2606,7 +2606,7 @@ def get_optional_params_embeddings(
|
||||||
):
|
):
|
||||||
raise UnsupportedParamsError(
|
raise UnsupportedParamsError(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
message=f"Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
message="Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
||||||
)
|
)
|
||||||
if custom_llm_provider == "triton":
|
if custom_llm_provider == "triton":
|
||||||
keys = list(non_default_params.keys())
|
keys = list(non_default_params.keys())
|
||||||
|
@ -2641,39 +2641,57 @@ def get_optional_params_embeddings(
|
||||||
)
|
)
|
||||||
final_params = {**optional_params, **kwargs}
|
final_params = {**optional_params, **kwargs}
|
||||||
return final_params
|
return final_params
|
||||||
if custom_llm_provider == "vertex_ai":
|
|
||||||
if len(non_default_params.keys()) > 0:
|
|
||||||
if litellm.drop_params is True: # drop the unsupported non-default values
|
|
||||||
keys = list(non_default_params.keys())
|
|
||||||
for k in keys:
|
|
||||||
non_default_params.pop(k, None)
|
|
||||||
final_params = {**non_default_params, **kwargs}
|
|
||||||
return final_params
|
|
||||||
raise UnsupportedParamsError(
|
|
||||||
status_code=500,
|
|
||||||
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
|
||||||
)
|
|
||||||
if custom_llm_provider == "bedrock":
|
if custom_llm_provider == "bedrock":
|
||||||
# if dimensions is in non_default_params -> pass it for model=bedrock/amazon.titan-embed-text-v2
|
# if dimensions is in non_default_params -> pass it for model=bedrock/amazon.titan-embed-text-v2
|
||||||
if (
|
if "amazon.titan-embed-text-v1" in model:
|
||||||
"dimensions" in non_default_params.keys()
|
object: Any = litellm.AmazonTitanG1Config()
|
||||||
and "amazon.titan-embed-text-v2" in model
|
elif "amazon.titan-embed-image-v1" in model:
|
||||||
):
|
object = litellm.AmazonTitanMultimodalEmbeddingG1Config()
|
||||||
kwargs["dimensions"] = non_default_params["dimensions"]
|
elif "amazon.titan-embed-text-v2:0" in model:
|
||||||
non_default_params.pop("dimensions", None)
|
object = litellm.AmazonTitanV2Config()
|
||||||
|
elif "cohere.embed-multilingual-v3" in model:
|
||||||
if len(non_default_params.keys()) > 0:
|
object = litellm.BedrockCohereEmbeddingConfig()
|
||||||
if litellm.drop_params is True: # drop the unsupported non-default values
|
else: # unmapped model
|
||||||
keys = list(non_default_params.keys())
|
supported_params = []
|
||||||
for k in keys:
|
_check_valid_arg(supported_params=supported_params)
|
||||||
non_default_params.pop(k, None)
|
final_params = {**kwargs}
|
||||||
final_params = {**non_default_params, **kwargs}
|
|
||||||
return final_params
|
return final_params
|
||||||
raise UnsupportedParamsError(
|
|
||||||
status_code=500,
|
supported_params = object.get_supported_openai_params()
|
||||||
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = object.map_openai_params(
|
||||||
|
non_default_params=non_default_params, optional_params={}
|
||||||
)
|
)
|
||||||
return {**non_default_params, **kwargs}
|
final_params = {**optional_params, **kwargs}
|
||||||
|
return final_params
|
||||||
|
# elif model == "amazon.titan-embed-image-v1":
|
||||||
|
# supported_params = litellm.AmazonTitanG1Config().get_supported_openai_params()
|
||||||
|
# _check_valid_arg(supported_params=supported_params)
|
||||||
|
# optional_params = litellm.AmazonTitanG1Config().map_openai_params(
|
||||||
|
# non_default_params=non_default_params, optional_params={}
|
||||||
|
# )
|
||||||
|
# final_params = {**optional_params, **kwargs}
|
||||||
|
# return final_params
|
||||||
|
|
||||||
|
# if (
|
||||||
|
# "dimensions" in non_default_params.keys()
|
||||||
|
# and "amazon.titan-embed-text-v2" in model
|
||||||
|
# ):
|
||||||
|
# kwargs["dimensions"] = non_default_params["dimensions"]
|
||||||
|
# non_default_params.pop("dimensions", None)
|
||||||
|
|
||||||
|
# if len(non_default_params.keys()) > 0:
|
||||||
|
# if litellm.drop_params is True: # drop the unsupported non-default values
|
||||||
|
# keys = list(non_default_params.keys())
|
||||||
|
# for k in keys:
|
||||||
|
# non_default_params.pop(k, None)
|
||||||
|
# final_params = {**non_default_params, **kwargs}
|
||||||
|
# return final_params
|
||||||
|
# raise UnsupportedParamsError(
|
||||||
|
# status_code=500,
|
||||||
|
# message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||||
|
# )
|
||||||
|
# return {**non_default_params, **kwargs}
|
||||||
if custom_llm_provider == "mistral":
|
if custom_llm_provider == "mistral":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -2869,6 +2887,7 @@ def get_optional_params(
|
||||||
and custom_llm_provider != "groq"
|
and custom_llm_provider != "groq"
|
||||||
and custom_llm_provider != "nvidia_nim"
|
and custom_llm_provider != "nvidia_nim"
|
||||||
and custom_llm_provider != "cerebras"
|
and custom_llm_provider != "cerebras"
|
||||||
|
and custom_llm_provider != "ai21_chat"
|
||||||
and custom_llm_provider != "volcengine"
|
and custom_llm_provider != "volcengine"
|
||||||
and custom_llm_provider != "deepseek"
|
and custom_llm_provider != "deepseek"
|
||||||
and custom_llm_provider != "codestral"
|
and custom_llm_provider != "codestral"
|
||||||
|
@ -3638,6 +3657,16 @@ def get_optional_params(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "ai21_chat":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.AI21ChatConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
elif custom_llm_provider == "fireworks_ai":
|
elif custom_llm_provider == "fireworks_ai":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -4265,6 +4294,8 @@ def get_supported_openai_params(
|
||||||
return litellm.NvidiaNimConfig().get_supported_openai_params(model=model)
|
return litellm.NvidiaNimConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "cerebras":
|
elif custom_llm_provider == "cerebras":
|
||||||
return litellm.CerebrasConfig().get_supported_openai_params(model=model)
|
return litellm.CerebrasConfig().get_supported_openai_params(model=model)
|
||||||
|
elif custom_llm_provider == "ai21_chat":
|
||||||
|
return litellm.AI21ChatConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "volcengine":
|
elif custom_llm_provider == "volcengine":
|
||||||
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "groq":
|
elif custom_llm_provider == "groq":
|
||||||
|
@ -4653,6 +4684,7 @@ def get_llm_provider(
|
||||||
):
|
):
|
||||||
custom_llm_provider = model.split("/", 1)[0]
|
custom_llm_provider = model.split("/", 1)[0]
|
||||||
model = model.split("/", 1)[1]
|
model = model.split("/", 1)[1]
|
||||||
|
|
||||||
if custom_llm_provider == "perplexity":
|
if custom_llm_provider == "perplexity":
|
||||||
# perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai
|
# perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai
|
||||||
api_base = api_base or get_secret("PERPLEXITY_API_BASE") or "https://api.perplexity.ai" # type: ignore
|
api_base = api_base or get_secret("PERPLEXITY_API_BASE") or "https://api.perplexity.ai" # type: ignore
|
||||||
|
@ -4699,6 +4731,16 @@ def get_llm_provider(
|
||||||
or "https://api.cerebras.ai/v1"
|
or "https://api.cerebras.ai/v1"
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
dynamic_api_key = api_key or get_secret("CEREBRAS_API_KEY")
|
dynamic_api_key = api_key or get_secret("CEREBRAS_API_KEY")
|
||||||
|
elif (custom_llm_provider == "ai21_chat") or (
|
||||||
|
custom_llm_provider == "ai21" and model in litellm.ai21_chat_models
|
||||||
|
):
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or get_secret("AI21_API_BASE")
|
||||||
|
or "https://api.ai21.com/studio/v1"
|
||||||
|
) # type: ignore
|
||||||
|
dynamic_api_key = api_key or get_secret("AI21_API_KEY")
|
||||||
|
custom_llm_provider = "ai21_chat"
|
||||||
elif custom_llm_provider == "volcengine":
|
elif custom_llm_provider == "volcengine":
|
||||||
# volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
# volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -4852,6 +4894,9 @@ def get_llm_provider(
|
||||||
elif endpoint == "https://api.cerebras.ai/v1":
|
elif endpoint == "https://api.cerebras.ai/v1":
|
||||||
custom_llm_provider = "cerebras"
|
custom_llm_provider = "cerebras"
|
||||||
dynamic_api_key = get_secret("CEREBRAS_API_KEY")
|
dynamic_api_key = get_secret("CEREBRAS_API_KEY")
|
||||||
|
elif endpoint == "https://api.ai21.com/studio/v1":
|
||||||
|
custom_llm_provider = "ai21_chat"
|
||||||
|
dynamic_api_key = get_secret("AI21_API_KEY")
|
||||||
elif endpoint == "https://codestral.mistral.ai/v1":
|
elif endpoint == "https://codestral.mistral.ai/v1":
|
||||||
custom_llm_provider = "codestral"
|
custom_llm_provider = "codestral"
|
||||||
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
|
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
|
||||||
|
@ -4936,6 +4981,14 @@ def get_llm_provider(
|
||||||
## ai21
|
## ai21
|
||||||
elif model in litellm.ai21_models:
|
elif model in litellm.ai21_models:
|
||||||
custom_llm_provider = "ai21"
|
custom_llm_provider = "ai21"
|
||||||
|
elif model in litellm.ai21_chat_models:
|
||||||
|
custom_llm_provider = "ai21_chat"
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or get_secret("AI21_API_BASE")
|
||||||
|
or "https://api.ai21.com/studio/v1"
|
||||||
|
) # type: ignore
|
||||||
|
dynamic_api_key = api_key or get_secret("AI21_API_KEY")
|
||||||
## aleph_alpha
|
## aleph_alpha
|
||||||
elif model in litellm.aleph_alpha_models:
|
elif model in litellm.aleph_alpha_models:
|
||||||
custom_llm_provider = "aleph_alpha"
|
custom_llm_provider = "aleph_alpha"
|
||||||
|
@ -5783,6 +5836,11 @@ def validate_environment(
|
||||||
keys_in_environment = True
|
keys_in_environment = True
|
||||||
else:
|
else:
|
||||||
missing_keys.append("CEREBRAS_API_KEY")
|
missing_keys.append("CEREBRAS_API_KEY")
|
||||||
|
elif custom_llm_provider == "ai21_chat":
|
||||||
|
if "AI21_API_KEY" in os.environ:
|
||||||
|
keys_in_environment = True
|
||||||
|
else:
|
||||||
|
missing_keys.append("AI21_API_KEY")
|
||||||
elif custom_llm_provider == "volcengine":
|
elif custom_llm_provider == "volcengine":
|
||||||
if "VOLCENGINE_API_KEY" in os.environ:
|
if "VOLCENGINE_API_KEY" in os.environ:
|
||||||
keys_in_environment = True
|
keys_in_environment = True
|
||||||
|
@ -6194,7 +6252,10 @@ def convert_to_model_response_object(
|
||||||
if "model" in response_object:
|
if "model" in response_object:
|
||||||
if model_response_object.model is None:
|
if model_response_object.model is None:
|
||||||
model_response_object.model = response_object["model"]
|
model_response_object.model = response_object["model"]
|
||||||
elif "/" in model_response_object.model:
|
elif (
|
||||||
|
"/" in model_response_object.model
|
||||||
|
and response_object["model"] is not None
|
||||||
|
):
|
||||||
openai_compatible_provider = model_response_object.model.split("/")[
|
openai_compatible_provider = model_response_object.model.split("/")[
|
||||||
0
|
0
|
||||||
]
|
]
|
||||||
|
@ -9889,11 +9950,7 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
if anthropic_response_obj["usage"] is not None:
|
if anthropic_response_obj["usage"] is not None:
|
||||||
model_response.usage = litellm.Usage(
|
model_response.usage = litellm.Usage(
|
||||||
prompt_tokens=anthropic_response_obj["usage"]["prompt_tokens"],
|
**anthropic_response_obj["usage"]
|
||||||
completion_tokens=anthropic_response_obj["usage"][
|
|
||||||
"completion_tokens"
|
|
||||||
],
|
|
||||||
total_tokens=anthropic_response_obj["usage"]["total_tokens"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -10508,10 +10565,10 @@ class CustomStreamWrapper:
|
||||||
original_chunk.system_fingerprint
|
original_chunk.system_fingerprint
|
||||||
)
|
)
|
||||||
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
||||||
if self.sent_first_chunk == False:
|
if self.sent_first_chunk is False:
|
||||||
model_response.choices[0].delta["role"] = "assistant"
|
model_response.choices[0].delta["role"] = "assistant"
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
elif self.sent_first_chunk == True and hasattr(
|
elif self.sent_first_chunk is True and hasattr(
|
||||||
model_response.choices[0].delta, "role"
|
model_response.choices[0].delta, "role"
|
||||||
):
|
):
|
||||||
_initial_delta = model_response.choices[
|
_initial_delta = model_response.choices[
|
||||||
|
@ -10576,7 +10633,7 @@ class CustomStreamWrapper:
|
||||||
model_response.choices[0].delta.tool_calls is not None
|
model_response.choices[0].delta.tool_calls is not None
|
||||||
or model_response.choices[0].delta.function_call is not None
|
or model_response.choices[0].delta.function_call is not None
|
||||||
):
|
):
|
||||||
if self.sent_first_chunk == False:
|
if self.sent_first_chunk is False:
|
||||||
model_response.choices[0].delta["role"] = "assistant"
|
model_response.choices[0].delta["role"] = "assistant"
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
return model_response
|
return model_response
|
||||||
|
|
|
@ -286,8 +286,35 @@
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
"ft:gpt-3.5-turbo": {
|
"ft:gpt-3.5-turbo": {
|
||||||
"max_tokens": 4097,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 4097,
|
"max_input_tokens": 16385,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000006,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"ft:gpt-3.5-turbo-0125": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 16385,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000006,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"ft:gpt-3.5-turbo-1106": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 16385,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000006,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"ft:gpt-3.5-turbo-0613": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 4096,
|
||||||
"max_output_tokens": 4096,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000003,
|
"input_cost_per_token": 0.000003,
|
||||||
"output_cost_per_token": 0.000006,
|
"output_cost_per_token": 0.000006,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue