forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_default_router_retries
This commit is contained in:
commit
1a06f009d1
20 changed files with 1663 additions and 44 deletions
|
@ -227,6 +227,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
|
||||||
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ |
|
| [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
| [IBM - watsonx.ai](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | ✅
|
||||||
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ |
|
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ |
|
||||||
| [xinference [Xorbits Inference]](https://docs.litellm.ai/docs/providers/xinference) | | | | | ✅ |
|
| [xinference [Xorbits Inference]](https://docs.litellm.ai/docs/providers/xinference) | | | | | ✅ |
|
||||||
|
|
||||||
|
|
300
cookbook/liteLLM_IBM_Watsonx.ipynb
vendored
Normal file
300
cookbook/liteLLM_IBM_Watsonx.ipynb
vendored
Normal file
File diff suppressed because one or more lines are too long
|
@ -53,6 +53,50 @@ All models listed here https://docs.mistral.ai/platform/endpoints are supported.
|
||||||
| open-mixtral-8x22b | `completion(model="mistral/open-mixtral-8x22b", messages)` |
|
| open-mixtral-8x22b | `completion(model="mistral/open-mixtral-8x22b", messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
## Function Calling
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
# set env
|
||||||
|
os.environ["MISTRAL_API_KEY"] = "your-api-key"
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="mistral/mistral-large-latest",
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
)
|
||||||
|
# Add any assertions, here to check response args
|
||||||
|
print(response)
|
||||||
|
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||||
|
assert isinstance(
|
||||||
|
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Sample Usage - Embedding
|
## Sample Usage - Embedding
|
||||||
```python
|
```python
|
||||||
from litellm import embedding
|
from litellm import embedding
|
||||||
|
|
284
docs/my-website/docs/providers/watsonx.md
Normal file
284
docs/my-website/docs/providers/watsonx.md
Normal file
|
@ -0,0 +1,284 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# IBM watsonx.ai
|
||||||
|
|
||||||
|
LiteLLM supports all IBM [watsonx.ai](https://watsonx.ai/) foundational models and embeddings.
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
```python
|
||||||
|
os.environ["WATSONX_URL"] = "" # (required) Base URL of your WatsonX instance
|
||||||
|
# (required) either one of the following:
|
||||||
|
os.environ["WATSONX_APIKEY"] = "" # IBM cloud API key
|
||||||
|
os.environ["WATSONX_TOKEN"] = "" # IAM auth token
|
||||||
|
# optional - can also be passed as params to completion() or embedding()
|
||||||
|
os.environ["WATSONX_PROJECT_ID"] = "" # Project ID of your WatsonX instance
|
||||||
|
os.environ["WATSONX_DEPLOYMENT_SPACE_ID"] = "" # ID of your deployment space to use deployed models
|
||||||
|
```
|
||||||
|
|
||||||
|
See [here](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information on how to get an access token to authenticate to watsonx.ai.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_IBM_Watsonx.ipynb">
|
||||||
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||||
|
</a>
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
os.environ["WATSONX_URL"] = ""
|
||||||
|
os.environ["WATSONX_APIKEY"] = ""
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="watsonx/ibm/granite-13b-chat-v2",
|
||||||
|
messages=[{ "content": "what is your favorite colour?","role": "user"}],
|
||||||
|
project_id="<my-project-id>" # or pass with os.environ["WATSONX_PROJECT_ID"]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="watsonx/meta-llama/llama-3-8b-instruct",
|
||||||
|
messages=[{ "content": "what is your favorite colour?","role": "user"}],
|
||||||
|
project_id="<my-project-id>"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage - Streaming
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
os.environ["WATSONX_URL"] = ""
|
||||||
|
os.environ["WATSONX_APIKEY"] = ""
|
||||||
|
os.environ["WATSONX_PROJECT_ID"] = ""
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="watsonx/ibm/granite-13b-chat-v2",
|
||||||
|
messages=[{ "content": "what is your favorite colour?","role": "user"}],
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example Streaming Output Chunk
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": "I don't have a favorite color, but I do like the color blue. What's your favorite color?"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": null,
|
||||||
|
"model": "watsonx/ibm/granite-13b-chat-v2",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": null,
|
||||||
|
"completion_tokens": null,
|
||||||
|
"total_tokens": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage - Models in deployment spaces
|
||||||
|
|
||||||
|
Models that have been deployed to a deployment space (e.g.: tuned models) can be called using the `deployment/<deployment_id>` format (where `<deployment_id>` is the ID of the deployed model in your deployment space).
|
||||||
|
|
||||||
|
The ID of your deployment space must also be set in the environment variable `WATSONX_DEPLOYMENT_SPACE_ID` or passed to the function as `space_id=<deployment_space_id>`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
response = litellm.completion(
|
||||||
|
model="watsonx/deployment/<deployment_id>",
|
||||||
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
|
space_id="<deployment_space_id>"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage - Embeddings
|
||||||
|
|
||||||
|
LiteLLM also supports making requests to IBM watsonx.ai embedding models. The credential needed for this is the same as for completion.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import embedding
|
||||||
|
|
||||||
|
response = embedding(
|
||||||
|
model="watsonx/ibm/slate-30m-english-rtrvr",
|
||||||
|
input=["What is the capital of France?"],
|
||||||
|
project_id="<my-project-id>"
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
# EmbeddingResponse(model='ibm/slate-30m-english-rtrvr', data=[{'object': 'embedding', 'index': 0, 'embedding': [-0.037463713, -0.02141933, -0.02851813, 0.015519324, ..., -0.0021367231, -0.01704561, -0.001425816, 0.0035238306]}], object='list', usage=Usage(prompt_tokens=8, total_tokens=8))
|
||||||
|
```
|
||||||
|
|
||||||
|
## OpenAI Proxy Usage
|
||||||
|
|
||||||
|
Here's how to call IBM watsonx.ai with the LiteLLM Proxy Server
|
||||||
|
|
||||||
|
### 1. Save keys in your environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export WATSONX_URL=""
|
||||||
|
export WATSONX_APIKEY=""
|
||||||
|
export WATSONX_PROJECT_ID=""
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Start the proxy
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="cli" label="CLI">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --model watsonx/meta-llama/llama-3-8b-instruct
|
||||||
|
|
||||||
|
# Server running on http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="config" label="config.yaml">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: llama-3-8b
|
||||||
|
litellm_params:
|
||||||
|
# all params accepted by litellm.completion()
|
||||||
|
model: watsonx/meta-llama/llama-3-8b-instruct
|
||||||
|
api_key: "os.environ/WATSONX_API_KEY" # does os.getenv("WATSONX_API_KEY")
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
### 3. Test it
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="Curl" label="Curl Request">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "llama-3-8b",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is your favorite colour?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="openai" label="OpenAI v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="anything",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(model="llama-3-8b", messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is your favorite colour?"
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="langchain" label="Langchain">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.prompts.chat import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
SystemMessagePromptTemplate,
|
||||||
|
)
|
||||||
|
from langchain.schema import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
chat = ChatOpenAI(
|
||||||
|
openai_api_base="http://0.0.0.0:4000", # set openai_api_base to the LiteLLM Proxy
|
||||||
|
model = "llama-3-8b",
|
||||||
|
temperature=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
SystemMessage(
|
||||||
|
content="You are a helpful assistant that im using to make a test request to."
|
||||||
|
),
|
||||||
|
HumanMessage(
|
||||||
|
content="test from litellm. tell me why it's amazing in 1 sentence"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
response = chat(messages)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
### Passing credentials as parameters
|
||||||
|
|
||||||
|
You can also pass the credentials as parameters to the completion and embedding functions.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="watsonx/ibm/granite-13b-chat-v2",
|
||||||
|
messages=[{ "content": "What is your favorite color?","role": "user"}],
|
||||||
|
url="",
|
||||||
|
api_key="",
|
||||||
|
project_id=""
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Supported IBM watsonx.ai Models
|
||||||
|
|
||||||
|
Here are some examples of models available in IBM watsonx.ai that you can use with LiteLLM:
|
||||||
|
|
||||||
|
| Mode Name | Command |
|
||||||
|
| ---------- | --------- |
|
||||||
|
| Flan T5 XXL | `completion(model=watsonx/google/flan-t5-xxl, messages=messages)` |
|
||||||
|
| Flan Ul2 | `completion(model=watsonx/google/flan-ul2, messages=messages)` |
|
||||||
|
| Mt0 XXL | `completion(model=watsonx/bigscience/mt0-xxl, messages=messages)` |
|
||||||
|
| Gpt Neox | `completion(model=watsonx/eleutherai/gpt-neox-20b, messages=messages)` |
|
||||||
|
| Mpt 7B Instruct2 | `completion(model=watsonx/ibm/mpt-7b-instruct2, messages=messages)` |
|
||||||
|
| Starcoder | `completion(model=watsonx/bigcode/starcoder, messages=messages)` |
|
||||||
|
| Llama 2 70B Chat | `completion(model=watsonx/meta-llama/llama-2-70b-chat, messages=messages)` |
|
||||||
|
| Llama 2 13B Chat | `completion(model=watsonx/meta-llama/llama-2-13b-chat, messages=messages)` |
|
||||||
|
| Granite 13B Instruct | `completion(model=watsonx/ibm/granite-13b-instruct-v1, messages=messages)` |
|
||||||
|
| Granite 13B Chat | `completion(model=watsonx/ibm/granite-13b-chat-v1, messages=messages)` |
|
||||||
|
| Flan T5 XL | `completion(model=watsonx/google/flan-t5-xl, messages=messages)` |
|
||||||
|
| Granite 13B Chat V2 | `completion(model=watsonx/ibm/granite-13b-chat-v2, messages=messages)` |
|
||||||
|
| Granite 13B Instruct V2 | `completion(model=watsonx/ibm/granite-13b-instruct-v2, messages=messages)` |
|
||||||
|
| Elyza Japanese Llama 2 7B Instruct | `completion(model=watsonx/elyza/elyza-japanese-llama-2-7b-instruct, messages=messages)` |
|
||||||
|
| Mixtral 8X7B Instruct V01 Q | `completion(model=watsonx/ibm-mistralai/mixtral-8x7b-instruct-v01-q, messages=messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
For a list of all available models in watsonx.ai, see [here](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx&locale=en&audience=wdp).
|
||||||
|
|
||||||
|
|
||||||
|
## Supported IBM watsonx.ai Embedding Models
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|----------------------|---------------------------------------------|
|
||||||
|
| Slate 30m | `embedding(model="watsonx/ibm/slate-30m-english-rtrvr", input=input)` |
|
||||||
|
| Slate 125m | `embedding(model="watsonx/ibm/slate-125m-english-rtrvr", input=input)` |
|
||||||
|
|
||||||
|
|
||||||
|
For a list of all available embedding models in watsonx.ai, see [here](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models-embed.html?context=wx).
|
|
@ -148,6 +148,7 @@ const sidebars = {
|
||||||
"providers/openrouter",
|
"providers/openrouter",
|
||||||
"providers/custom_openai_proxy",
|
"providers/custom_openai_proxy",
|
||||||
"providers/petals",
|
"providers/petals",
|
||||||
|
"providers/watsonx",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"proxy/custom_pricing",
|
"proxy/custom_pricing",
|
||||||
|
|
|
@ -77,6 +77,7 @@ baseten_key: Optional[str] = None
|
||||||
aleph_alpha_key: Optional[str] = None
|
aleph_alpha_key: Optional[str] = None
|
||||||
nlp_cloud_key: Optional[str] = None
|
nlp_cloud_key: Optional[str] = None
|
||||||
use_client: bool = False
|
use_client: bool = False
|
||||||
|
ssl_verify: bool = True
|
||||||
disable_streaming_logging: bool = False
|
disable_streaming_logging: bool = False
|
||||||
### GUARDRAILS ###
|
### GUARDRAILS ###
|
||||||
llamaguard_model_name: Optional[str] = None
|
llamaguard_model_name: Optional[str] = None
|
||||||
|
@ -298,6 +299,7 @@ aleph_alpha_models: List = []
|
||||||
bedrock_models: List = []
|
bedrock_models: List = []
|
||||||
deepinfra_models: List = []
|
deepinfra_models: List = []
|
||||||
perplexity_models: List = []
|
perplexity_models: List = []
|
||||||
|
watsonx_models: List = []
|
||||||
for key, value in model_cost.items():
|
for key, value in model_cost.items():
|
||||||
if value.get("litellm_provider") == "openai":
|
if value.get("litellm_provider") == "openai":
|
||||||
open_ai_chat_completion_models.append(key)
|
open_ai_chat_completion_models.append(key)
|
||||||
|
@ -342,6 +344,8 @@ for key, value in model_cost.items():
|
||||||
deepinfra_models.append(key)
|
deepinfra_models.append(key)
|
||||||
elif value.get("litellm_provider") == "perplexity":
|
elif value.get("litellm_provider") == "perplexity":
|
||||||
perplexity_models.append(key)
|
perplexity_models.append(key)
|
||||||
|
elif value.get("litellm_provider") == "watsonx":
|
||||||
|
watsonx_models.append(key)
|
||||||
|
|
||||||
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
|
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
|
||||||
openai_compatible_endpoints: List = [
|
openai_compatible_endpoints: List = [
|
||||||
|
@ -478,6 +482,7 @@ model_list = (
|
||||||
+ perplexity_models
|
+ perplexity_models
|
||||||
+ maritalk_models
|
+ maritalk_models
|
||||||
+ vertex_language_models
|
+ vertex_language_models
|
||||||
|
+ watsonx_models
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_list: List = [
|
provider_list: List = [
|
||||||
|
@ -516,6 +521,7 @@ provider_list: List = [
|
||||||
"cloudflare",
|
"cloudflare",
|
||||||
"xinference",
|
"xinference",
|
||||||
"fireworks_ai",
|
"fireworks_ai",
|
||||||
|
"watsonx",
|
||||||
"custom", # custom apis
|
"custom", # custom apis
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -537,6 +543,7 @@ models_by_provider: dict = {
|
||||||
"deepinfra": deepinfra_models,
|
"deepinfra": deepinfra_models,
|
||||||
"perplexity": perplexity_models,
|
"perplexity": perplexity_models,
|
||||||
"maritalk": maritalk_models,
|
"maritalk": maritalk_models,
|
||||||
|
"watsonx": watsonx_models,
|
||||||
}
|
}
|
||||||
|
|
||||||
# mapping for those models which have larger equivalents
|
# mapping for those models which have larger equivalents
|
||||||
|
@ -650,6 +657,7 @@ from .llms.bedrock import (
|
||||||
)
|
)
|
||||||
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
|
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
|
||||||
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
||||||
|
from .llms.watsonx import IBMWatsonXAIConfig
|
||||||
from .main import * # type: ignore
|
from .main import * # type: ignore
|
||||||
from .integrations import *
|
from .integrations import *
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
|
|
|
@ -430,6 +430,32 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
|
||||||
prompt = default_pt(messages)
|
prompt = default_pt(messages)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
### IBM Granite
|
||||||
|
|
||||||
|
def ibm_granite_pt(messages: list):
|
||||||
|
"""
|
||||||
|
IBM's Granite models uses the template:
|
||||||
|
<|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message}
|
||||||
|
|
||||||
|
See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models
|
||||||
|
"""
|
||||||
|
return custom_prompt(
|
||||||
|
messages=messages,
|
||||||
|
role_dict={
|
||||||
|
'system': {
|
||||||
|
'pre_message': '<|system|>\n',
|
||||||
|
'post_message': '\n',
|
||||||
|
},
|
||||||
|
'user': {
|
||||||
|
'pre_message': '<|user|>\n',
|
||||||
|
'post_message': '\n',
|
||||||
|
},
|
||||||
|
'assistant': {
|
||||||
|
'pre_message': '<|assistant|>\n',
|
||||||
|
'post_message': '\n',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
).strip()
|
||||||
|
|
||||||
### ANTHROPIC ###
|
### ANTHROPIC ###
|
||||||
|
|
||||||
|
@ -1359,6 +1385,25 @@ def prompt_factory(
|
||||||
return messages
|
return messages
|
||||||
elif custom_llm_provider == "azure_text":
|
elif custom_llm_provider == "azure_text":
|
||||||
return azure_text_pt(messages=messages)
|
return azure_text_pt(messages=messages)
|
||||||
|
elif custom_llm_provider == "watsonx":
|
||||||
|
if "granite" in model and "chat" in model:
|
||||||
|
# granite-13b-chat-v1 and granite-13b-chat-v2 use a specific prompt template
|
||||||
|
return ibm_granite_pt(messages=messages)
|
||||||
|
elif "ibm-mistral" in model and "instruct" in model:
|
||||||
|
# models like ibm-mistral/mixtral-8x7b-instruct-v01-q use the mistral instruct prompt template
|
||||||
|
return mistral_instruct_pt(messages=messages)
|
||||||
|
elif "meta-llama/llama-3" in model and "instruct" in model:
|
||||||
|
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
|
||||||
|
return custom_prompt(
|
||||||
|
role_dict={
|
||||||
|
"system": {"pre_message": "<|start_header_id|>system<|end_header_id|>\n", "post_message": "<|eot_id|>"},
|
||||||
|
"user": {"pre_message": "<|start_header_id|>user<|end_header_id|>\n", "post_message": "<|eot_id|>"},
|
||||||
|
"assistant": {"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", "post_message": "<|eot_id|>"},
|
||||||
|
},
|
||||||
|
messages=messages,
|
||||||
|
initial_prompt_value="<|begin_of_text|>",
|
||||||
|
final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n",
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if "meta-llama/llama-2" in model and "chat" in model:
|
if "meta-llama/llama-2" in model and "chat" in model:
|
||||||
return llama_2_chat_pt(messages=messages)
|
return llama_2_chat_pt(messages=messages)
|
||||||
|
|
|
@ -112,10 +112,16 @@ def start_prediction(
|
||||||
}
|
}
|
||||||
|
|
||||||
initial_prediction_data = {
|
initial_prediction_data = {
|
||||||
"version": version_id,
|
|
||||||
"input": input_data,
|
"input": input_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ":" in version_id and len(version_id) > 64:
|
||||||
|
model_parts = version_id.split(":")
|
||||||
|
if (
|
||||||
|
len(model_parts) > 1 and len(model_parts[1]) == 64
|
||||||
|
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
||||||
|
initial_prediction_data["version"] = model_parts[1]
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input_data["prompt"],
|
input=input_data["prompt"],
|
||||||
|
|
|
@ -529,6 +529,7 @@ def completion(
|
||||||
"instances": instances,
|
"instances": instances,
|
||||||
"vertex_location": vertex_location,
|
"vertex_location": vertex_location,
|
||||||
"vertex_project": vertex_project,
|
"vertex_project": vertex_project,
|
||||||
|
"safety_settings":safety_settings,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
}
|
}
|
||||||
if optional_params.get("stream", False) is True:
|
if optional_params.get("stream", False) is True:
|
||||||
|
@ -813,6 +814,7 @@ async def async_completion(
|
||||||
instances=None,
|
instances=None,
|
||||||
vertex_project=None,
|
vertex_project=None,
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
|
safety_settings=None,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -844,6 +846,7 @@ async def async_completion(
|
||||||
response = await llm_model._generate_content_async(
|
response = await llm_model._generate_content_async(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=optional_params,
|
generation_config=optional_params,
|
||||||
|
safety_settings=safety_settings,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
591
litellm/llms/watsonx.py
Normal file
591
litellm/llms/watsonx.py
Normal file
|
@ -0,0 +1,591 @@
|
||||||
|
from enum import Enum
|
||||||
|
import json, types, time # noqa: E401
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Callable, Dict, Optional, Any, Union, List
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import requests
|
||||||
|
import litellm
|
||||||
|
from litellm.utils import ModelResponse, get_secret, Usage
|
||||||
|
|
||||||
|
from .base import BaseLLM
|
||||||
|
from .prompt_templates import factory as ptf
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonXAIError(Exception):
|
||||||
|
def __init__(self, status_code, message, url: Optional[str] = None):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
url = url or "https://https://us-south.ml.cloud.ibm.com"
|
||||||
|
self.request = httpx.Request(method="POST", url=url)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class IBMWatsonXAIConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
||||||
|
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
|
||||||
|
|
||||||
|
Supported params for all available watsonx.ai foundational models.
|
||||||
|
|
||||||
|
- `decoding_method` (str): One of "greedy" or "sample"
|
||||||
|
|
||||||
|
- `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'.
|
||||||
|
|
||||||
|
- `max_new_tokens` (integer): Maximum length of the generated tokens.
|
||||||
|
|
||||||
|
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
|
||||||
|
|
||||||
|
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
|
||||||
|
|
||||||
|
- `stop_sequences` (string[]): list of strings to use as stop sequences.
|
||||||
|
|
||||||
|
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
|
||||||
|
|
||||||
|
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
|
||||||
|
|
||||||
|
- `repetition_penalty` (float): token repetition penalty during text generation.
|
||||||
|
|
||||||
|
- `truncate_input_tokens` (integer): Truncate input tokens to this length.
|
||||||
|
|
||||||
|
- `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match.
|
||||||
|
|
||||||
|
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean.
|
||||||
|
|
||||||
|
- `random_seed` (integer): Random seed for text generation.
|
||||||
|
|
||||||
|
- `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering.
|
||||||
|
|
||||||
|
- `stream` (bool): If True, the model will return a stream of responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
decoding_method: Optional[str] = "sample"
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_new_tokens: Optional[int] = None # litellm.max_tokens
|
||||||
|
min_new_tokens: Optional[int] = None
|
||||||
|
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
|
||||||
|
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
truncate_input_tokens: Optional[int] = None
|
||||||
|
include_stop_sequences: Optional[bool] = False
|
||||||
|
return_options: Optional[Dict[str, bool]] = None
|
||||||
|
random_seed: Optional[int] = None # e.g 42
|
||||||
|
moderations: Optional[dict] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
decoding_method: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
|
min_new_tokens: Optional[int] = None,
|
||||||
|
length_penalty: Optional[dict] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
truncate_input_tokens: Optional[int] = None,
|
||||||
|
include_stop_sequences: Optional[bool] = None,
|
||||||
|
return_options: Optional[dict] = None,
|
||||||
|
random_seed: Optional[int] = None,
|
||||||
|
moderations: Optional[dict] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"temperature", # equivalent to temperature
|
||||||
|
"max_tokens", # equivalent to max_new_tokens
|
||||||
|
"top_p", # equivalent to top_p
|
||||||
|
"frequency_penalty", # equivalent to repetition_penalty
|
||||||
|
"stop", # equivalent to stop_sequences
|
||||||
|
"seed", # equivalent to random_seed
|
||||||
|
"stream", # equivalent to stream
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
|
# handle anthropic prompts and amazon titan prompts
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_dict = custom_prompt_dict[model]
|
||||||
|
prompt = ptf.custom_prompt(
|
||||||
|
messages=messages,
|
||||||
|
role_dict=model_prompt_dict.get(
|
||||||
|
"role_dict", model_prompt_dict.get("roles")
|
||||||
|
),
|
||||||
|
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
|
||||||
|
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
|
||||||
|
bos_token=model_prompt_dict.get("bos_token", ""),
|
||||||
|
eos_token=model_prompt_dict.get("eos_token", ""),
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
elif provider == "ibm":
|
||||||
|
prompt = ptf.prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="watsonx"
|
||||||
|
)
|
||||||
|
elif provider == "ibm-mistralai":
|
||||||
|
prompt = ptf.mistral_instruct_pt(messages=messages)
|
||||||
|
else:
|
||||||
|
prompt = ptf.prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="watsonx"
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonXAIEndpoint(str, Enum):
|
||||||
|
TEXT_GENERATION = "/ml/v1/text/generation"
|
||||||
|
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
|
||||||
|
DEPLOYMENT_TEXT_GENERATION = "/ml/v1/deployments/{deployment_id}/text/generation"
|
||||||
|
DEPLOYMENT_TEXT_GENERATION_STREAM = (
|
||||||
|
"/ml/v1/deployments/{deployment_id}/text/generation_stream"
|
||||||
|
)
|
||||||
|
EMBEDDINGS = "/ml/v1/text/embeddings"
|
||||||
|
PROMPTS = "/ml/v1/prompts"
|
||||||
|
|
||||||
|
|
||||||
|
class IBMWatsonXAI(BaseLLM):
|
||||||
|
"""
|
||||||
|
Class to interface with IBM Watsonx.ai API for text generation and embeddings.
|
||||||
|
|
||||||
|
Reference: https://cloud.ibm.com/apidocs/watsonx-ai
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_version = "2024-03-13"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _prepare_text_generation_req(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
prompt: str,
|
||||||
|
stream: bool,
|
||||||
|
optional_params: dict,
|
||||||
|
print_verbose: Optional[Callable] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Get the request parameters for text generation.
|
||||||
|
"""
|
||||||
|
api_params = self._get_api_params(optional_params, print_verbose=print_verbose)
|
||||||
|
# build auth headers
|
||||||
|
api_token = api_params.get("token")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
extra_body_params = optional_params.pop("extra_body", {})
|
||||||
|
optional_params.update(extra_body_params)
|
||||||
|
# init the payload to the text generation call
|
||||||
|
payload = {
|
||||||
|
"input": prompt,
|
||||||
|
"moderations": optional_params.pop("moderations", {}),
|
||||||
|
"parameters": optional_params,
|
||||||
|
}
|
||||||
|
request_params = dict(version=api_params["api_version"])
|
||||||
|
# text generation endpoint deployment or model / stream or not
|
||||||
|
if model_id.startswith("deployment/"):
|
||||||
|
# deployment models are passed in as 'deployment/<deployment_id>'
|
||||||
|
if api_params.get("space_id") is None:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
url=api_params["url"],
|
||||||
|
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||||
|
)
|
||||||
|
deployment_id = "/".join(model_id.split("/")[1:])
|
||||||
|
endpoint = (
|
||||||
|
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
|
||||||
|
if stream
|
||||||
|
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
|
||||||
|
)
|
||||||
|
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||||
|
else:
|
||||||
|
payload["model_id"] = model_id
|
||||||
|
payload["project_id"] = api_params["project_id"]
|
||||||
|
endpoint = (
|
||||||
|
WatsonXAIEndpoint.TEXT_GENERATION_STREAM
|
||||||
|
if stream
|
||||||
|
else WatsonXAIEndpoint.TEXT_GENERATION
|
||||||
|
)
|
||||||
|
url = api_params["url"].rstrip("/") + endpoint
|
||||||
|
return dict(
|
||||||
|
method="POST", url=url, headers=headers, json=payload, params=request_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_api_params(
|
||||||
|
self, params: dict, print_verbose: Optional[Callable] = None
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||||
|
"""
|
||||||
|
# Load auth variables from params
|
||||||
|
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
|
||||||
|
api_key = params.pop("apikey", None)
|
||||||
|
token = params.pop("token", None)
|
||||||
|
project_id = params.pop(
|
||||||
|
"project_id", params.pop("watsonx_project", None)
|
||||||
|
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
||||||
|
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
|
||||||
|
region_name = params.pop("region_name", params.pop("region", None))
|
||||||
|
if region_name is None:
|
||||||
|
region_name = params.pop(
|
||||||
|
"watsonx_region_name", params.pop("watsonx_region", None)
|
||||||
|
) # consistent with how vertex ai + aws regions are accepted
|
||||||
|
wx_credentials = params.pop(
|
||||||
|
"wx_credentials",
|
||||||
|
params.pop(
|
||||||
|
"watsonx_credentials", None
|
||||||
|
), # follow {provider}_credentials, same as vertex ai
|
||||||
|
)
|
||||||
|
api_version = params.pop("api_version", IBMWatsonXAI.api_version)
|
||||||
|
# Load auth variables from environment variables
|
||||||
|
if url is None:
|
||||||
|
url = (
|
||||||
|
get_secret("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||||
|
or get_secret("WATSONX_URL")
|
||||||
|
or get_secret("WX_URL")
|
||||||
|
or get_secret("WML_URL")
|
||||||
|
)
|
||||||
|
if api_key is None:
|
||||||
|
api_key = (
|
||||||
|
get_secret("WATSONX_APIKEY")
|
||||||
|
or get_secret("WATSONX_API_KEY")
|
||||||
|
or get_secret("WX_API_KEY")
|
||||||
|
)
|
||||||
|
if token is None:
|
||||||
|
token = get_secret("WATSONX_TOKEN") or get_secret("WX_TOKEN")
|
||||||
|
if project_id is None:
|
||||||
|
project_id = (
|
||||||
|
get_secret("WATSONX_PROJECT_ID")
|
||||||
|
or get_secret("WX_PROJECT_ID")
|
||||||
|
or get_secret("PROJECT_ID")
|
||||||
|
)
|
||||||
|
if region_name is None:
|
||||||
|
region_name = (
|
||||||
|
get_secret("WATSONX_REGION")
|
||||||
|
or get_secret("WX_REGION")
|
||||||
|
or get_secret("REGION")
|
||||||
|
)
|
||||||
|
if space_id is None:
|
||||||
|
space_id = (
|
||||||
|
get_secret("WATSONX_DEPLOYMENT_SPACE_ID")
|
||||||
|
or get_secret("WATSONX_SPACE_ID")
|
||||||
|
or get_secret("WX_SPACE_ID")
|
||||||
|
or get_secret("SPACE_ID")
|
||||||
|
)
|
||||||
|
|
||||||
|
# credentials parsing
|
||||||
|
if wx_credentials is not None:
|
||||||
|
url = wx_credentials.get("url", url)
|
||||||
|
api_key = wx_credentials.get(
|
||||||
|
"apikey", wx_credentials.get("api_key", api_key)
|
||||||
|
)
|
||||||
|
token = wx_credentials.get(
|
||||||
|
"token",
|
||||||
|
wx_credentials.get(
|
||||||
|
"watsonx_token", token
|
||||||
|
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||||
|
)
|
||||||
|
|
||||||
|
# verify that all required credentials are present
|
||||||
|
if url is None:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
|
||||||
|
)
|
||||||
|
if token is None and api_key is not None:
|
||||||
|
# generate the auth token
|
||||||
|
if print_verbose:
|
||||||
|
print_verbose("Generating IAM token for Watsonx.ai")
|
||||||
|
token = self.generate_iam_token(api_key)
|
||||||
|
elif token is None and api_key is None:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
url=url,
|
||||||
|
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
|
||||||
|
)
|
||||||
|
if project_id is None:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
url=url,
|
||||||
|
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"api_key": api_key,
|
||||||
|
"token": token,
|
||||||
|
"project_id": project_id,
|
||||||
|
"space_id": space_id,
|
||||||
|
"region_name": region_name,
|
||||||
|
"api_version": api_version,
|
||||||
|
}
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
|
logger_fn=None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Send a text generation request to the IBM Watsonx.ai API.
|
||||||
|
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
||||||
|
"""
|
||||||
|
stream = optional_params.pop("stream", False)
|
||||||
|
|
||||||
|
# Load default configs
|
||||||
|
config = IBMWatsonXAIConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if k not in optional_params:
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
# Make prompt to send to model
|
||||||
|
provider = model.split("/")[0]
|
||||||
|
# model_name = "/".join(model.split("/")[1:])
|
||||||
|
prompt = convert_messages_to_prompt(
|
||||||
|
model, messages, provider, custom_prompt_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_text_request(request_params: dict) -> ModelResponse:
|
||||||
|
with self._manage_response(
|
||||||
|
request_params, logging_obj=logging_obj, input=prompt, timeout=timeout
|
||||||
|
) as resp:
|
||||||
|
json_resp = resp.json()
|
||||||
|
|
||||||
|
generated_text = json_resp["results"][0]["generated_text"]
|
||||||
|
prompt_tokens = json_resp["results"][0]["input_token_count"]
|
||||||
|
completion_tokens = json_resp["results"][0]["generated_token_count"]
|
||||||
|
model_response["choices"][0]["message"]["content"] = generated_text
|
||||||
|
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
setattr(
|
||||||
|
model_response,
|
||||||
|
"usage",
|
||||||
|
Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def process_stream_request(
|
||||||
|
request_params: dict,
|
||||||
|
) -> litellm.CustomStreamWrapper:
|
||||||
|
# stream the response - generated chunks will be handled
|
||||||
|
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||||
|
with self._manage_response(
|
||||||
|
request_params,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
stream=True,
|
||||||
|
input=prompt,
|
||||||
|
timeout=timeout,
|
||||||
|
) as resp:
|
||||||
|
response = litellm.CustomStreamWrapper(
|
||||||
|
resp.iter_lines(),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="watsonx",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
try:
|
||||||
|
## Get the response from the model
|
||||||
|
req_params = self._prepare_text_generation_req(
|
||||||
|
model_id=model,
|
||||||
|
prompt=prompt,
|
||||||
|
stream=stream,
|
||||||
|
optional_params=optional_params,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return process_stream_request(req_params)
|
||||||
|
else:
|
||||||
|
return process_text_request(req_params)
|
||||||
|
except WatsonXAIError as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise WatsonXAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
def embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: Union[list, str],
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
model_response=None,
|
||||||
|
optional_params=None,
|
||||||
|
encoding=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Send a text embedding request to the IBM Watsonx.ai API.
|
||||||
|
"""
|
||||||
|
if optional_params is None:
|
||||||
|
optional_params = {}
|
||||||
|
# Load default configs
|
||||||
|
config = IBMWatsonXAIConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if k not in optional_params:
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
# Load auth variables from environment variables
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input]
|
||||||
|
if api_key is not None:
|
||||||
|
optional_params["api_key"] = api_key
|
||||||
|
api_params = self._get_api_params(optional_params)
|
||||||
|
# build auth headers
|
||||||
|
api_token = api_params.get("token")
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
# init the payload to the text generation call
|
||||||
|
payload = {
|
||||||
|
"inputs": input,
|
||||||
|
"model_id": model,
|
||||||
|
"project_id": api_params["project_id"],
|
||||||
|
"parameters": optional_params,
|
||||||
|
}
|
||||||
|
request_params = dict(version=api_params["api_version"])
|
||||||
|
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
|
||||||
|
# request = httpx.Request(
|
||||||
|
# "POST", url, headers=headers, json=payload, params=request_params
|
||||||
|
# )
|
||||||
|
req_params = {
|
||||||
|
"method": "POST",
|
||||||
|
"url": url,
|
||||||
|
"headers": headers,
|
||||||
|
"json": payload,
|
||||||
|
"params": request_params,
|
||||||
|
}
|
||||||
|
with self._manage_response(
|
||||||
|
req_params, logging_obj=logging_obj, input=input
|
||||||
|
) as resp:
|
||||||
|
json_resp = resp.json()
|
||||||
|
|
||||||
|
results = json_resp.get("results", [])
|
||||||
|
embedding_response = []
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
embedding_response.append(
|
||||||
|
{"object": "embedding", "index": idx, "embedding": result["embedding"]}
|
||||||
|
)
|
||||||
|
model_response["object"] = "list"
|
||||||
|
model_response["data"] = embedding_response
|
||||||
|
model_response["model"] = model
|
||||||
|
input_tokens = json_resp.get("input_token_count", 0)
|
||||||
|
model_response.usage = Usage(
|
||||||
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||||
|
)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def generate_iam_token(self, api_key=None, **params):
|
||||||
|
headers = {}
|
||||||
|
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||||
|
if api_key is None:
|
||||||
|
api_key = get_secret("WX_API_KEY") or get_secret("WATSONX_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError("API key is required")
|
||||||
|
headers["Accept"] = "application/json"
|
||||||
|
data = {
|
||||||
|
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||||
|
"apikey": api_key,
|
||||||
|
}
|
||||||
|
response = httpx.post(
|
||||||
|
"https://iam.cloud.ibm.com/identity/token", data=data, headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
json_data = response.json()
|
||||||
|
iam_access_token = json_data["access_token"]
|
||||||
|
self.token = iam_access_token
|
||||||
|
return iam_access_token
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _manage_response(
|
||||||
|
self,
|
||||||
|
request_params: dict,
|
||||||
|
logging_obj: Any,
|
||||||
|
stream: bool = False,
|
||||||
|
input: Optional[Any] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
):
|
||||||
|
request_str = (
|
||||||
|
f"response = {request_params['method']}(\n"
|
||||||
|
f"\turl={request_params['url']},\n"
|
||||||
|
f"\tjson={request_params['json']},\n"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=request_params["headers"].get("Authorization"),
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": request_params["json"],
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if timeout:
|
||||||
|
request_params["timeout"] = timeout
|
||||||
|
try:
|
||||||
|
if stream:
|
||||||
|
resp = requests.request(
|
||||||
|
**request_params,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
yield resp
|
||||||
|
else:
|
||||||
|
resp = requests.request(**request_params)
|
||||||
|
resp.raise_for_status()
|
||||||
|
yield resp
|
||||||
|
except Exception as e:
|
||||||
|
raise WatsonXAIError(status_code=500, message=str(e))
|
||||||
|
if not stream:
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=request_params["headers"].get("Authorization"),
|
||||||
|
original_response=json.dumps(resp.json()),
|
||||||
|
additional_args={
|
||||||
|
"status_code": resp.status_code,
|
||||||
|
"complete_input_dict": request_params["json"],
|
||||||
|
},
|
||||||
|
)
|
|
@ -62,6 +62,7 @@ from .llms import (
|
||||||
vertex_ai,
|
vertex_ai,
|
||||||
vertex_ai_anthropic,
|
vertex_ai_anthropic,
|
||||||
maritalk,
|
maritalk,
|
||||||
|
watsonx,
|
||||||
)
|
)
|
||||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||||
from .llms.azure import AzureChatCompletion
|
from .llms.azure import AzureChatCompletion
|
||||||
|
@ -1862,6 +1863,43 @@ def completion(
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
response = response
|
response = response
|
||||||
|
elif custom_llm_provider == "watsonx":
|
||||||
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
|
response = watsonx.IBMWatsonXAI().completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params, # type: ignore
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"stream" in optional_params
|
||||||
|
and optional_params["stream"] == True
|
||||||
|
and not isinstance(response, CustomStreamWrapper)
|
||||||
|
):
|
||||||
|
# don't try to access stream object,
|
||||||
|
response = CustomStreamWrapper(
|
||||||
|
iter(response),
|
||||||
|
model,
|
||||||
|
custom_llm_provider="watsonx",
|
||||||
|
logging_obj=logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
if optional_params.get("stream", False):
|
||||||
|
## LOGGING
|
||||||
|
logging.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=None,
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
response = response
|
||||||
elif custom_llm_provider == "vllm":
|
elif custom_llm_provider == "vllm":
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
model_response = vllm.completion(
|
model_response = vllm.completion(
|
||||||
|
@ -2941,6 +2979,15 @@ def embedding(
|
||||||
client=client,
|
client=client,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "watsonx":
|
||||||
|
response = watsonx.IBMWatsonXAI().embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=EmbeddingResponse(),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
args = locals()
|
args = locals()
|
||||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||||
|
|
|
@ -6,4 +6,3 @@ model_list:
|
||||||
model_name: fake-openai-endpoint
|
model_name: fake-openai-endpoint
|
||||||
router_settings:
|
router_settings:
|
||||||
num_retries: 0
|
num_retries: 0
|
||||||
|
|
||||||
|
|
|
@ -1937,6 +1937,7 @@ class Router:
|
||||||
)
|
)
|
||||||
default_api_base = api_base
|
default_api_base = api_base
|
||||||
default_api_key = api_key
|
default_api_key = api_key
|
||||||
|
|
||||||
if (
|
if (
|
||||||
model_name in litellm.open_ai_chat_completion_models
|
model_name in litellm.open_ai_chat_completion_models
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
|
@ -1972,6 +1973,23 @@ class Router:
|
||||||
api_base = litellm.get_secret(api_base_env_name)
|
api_base = litellm.get_secret(api_base_env_name)
|
||||||
litellm_params["api_base"] = api_base
|
litellm_params["api_base"] = api_base
|
||||||
|
|
||||||
|
## AZURE AI STUDIO MISTRAL CHECK ##
|
||||||
|
"""
|
||||||
|
Make sure api base ends in /v1/
|
||||||
|
|
||||||
|
if not, add it - https://github.com/BerriAI/litellm/issues/2279
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
custom_llm_provider == "openai"
|
||||||
|
and api_base is not None
|
||||||
|
and not api_base.endswith("/v1/")
|
||||||
|
):
|
||||||
|
# check if it ends with a trailing slash
|
||||||
|
if api_base.endswith("/"):
|
||||||
|
api_base += "v1/"
|
||||||
|
else:
|
||||||
|
api_base += "/v1/"
|
||||||
|
|
||||||
api_version = litellm_params.get("api_version")
|
api_version = litellm_params.get("api_version")
|
||||||
if api_version and api_version.startswith("os.environ/"):
|
if api_version and api_version.startswith("os.environ/"):
|
||||||
api_version_env_name = api_version.replace("os.environ/", "")
|
api_version_env_name = api_version.replace("os.environ/", "")
|
||||||
|
@ -2062,9 +2080,11 @@ class Router:
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(),
|
transport=AsyncCustomHTTPTransport(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
mounts=async_proxy_mounts,
|
mounts=async_proxy_mounts,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
|
@ -2084,9 +2104,11 @@ class Router:
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(),
|
transport=CustomHTTPTransport(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
mounts=sync_proxy_mounts,
|
mounts=sync_proxy_mounts,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
|
@ -2106,9 +2128,11 @@ class Router:
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(),
|
transport=AsyncCustomHTTPTransport(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
mounts=async_proxy_mounts,
|
mounts=async_proxy_mounts,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
|
@ -2128,9 +2152,11 @@ class Router:
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(),
|
transport=CustomHTTPTransport(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
mounts=sync_proxy_mounts,
|
mounts=sync_proxy_mounts,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
|
@ -2168,9 +2194,11 @@ class Router:
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(),
|
transport=AsyncCustomHTTPTransport(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
mounts=async_proxy_mounts,
|
mounts=async_proxy_mounts,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
|
@ -2188,9 +2216,11 @@ class Router:
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(),
|
transport=CustomHTTPTransport(
|
||||||
limits=httpx.Limits(
|
verify=litellm.ssl_verify,
|
||||||
max_connections=1000, max_keepalive_connections=100
|
limits=httpx.Limits(
|
||||||
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
),
|
),
|
||||||
mounts=sync_proxy_mounts,
|
mounts=sync_proxy_mounts,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
|
@ -2209,9 +2239,11 @@ class Router:
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(),
|
transport=AsyncCustomHTTPTransport(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
mounts=async_proxy_mounts,
|
mounts=async_proxy_mounts,
|
||||||
),
|
),
|
||||||
|
@ -2229,9 +2261,11 @@ class Router:
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(),
|
transport=CustomHTTPTransport(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
mounts=sync_proxy_mounts,
|
mounts=sync_proxy_mounts,
|
||||||
),
|
),
|
||||||
|
@ -2259,9 +2293,11 @@ class Router:
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(),
|
transport=AsyncCustomHTTPTransport(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
mounts=async_proxy_mounts,
|
mounts=async_proxy_mounts,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
|
@ -2281,9 +2317,11 @@ class Router:
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(),
|
transport=CustomHTTPTransport(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
mounts=sync_proxy_mounts,
|
mounts=sync_proxy_mounts,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
|
@ -2304,9 +2342,11 @@ class Router:
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(),
|
transport=AsyncCustomHTTPTransport(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
mounts=async_proxy_mounts,
|
mounts=async_proxy_mounts,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
|
@ -2327,9 +2367,11 @@ class Router:
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(),
|
transport=CustomHTTPTransport(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
mounts=sync_proxy_mounts,
|
mounts=sync_proxy_mounts,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
|
|
|
@ -2655,6 +2655,42 @@ def test_completion_palm_stream():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_watsonx():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
stop=["stop"],
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except litellm.APIError as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acompletion_watsonx():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||||
|
print("testing watsonx")
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=80,
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_completion_palm_stream()
|
# test_completion_palm_stream()
|
||||||
|
|
||||||
# test_completion_deep_infra()
|
# test_completion_deep_infra()
|
||||||
|
|
|
@ -483,6 +483,18 @@ def test_mistral_embeddings():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
def test_watsonx_embeddings():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = litellm.embedding(
|
||||||
|
model="watsonx/ibm/slate-30m-english-rtrvr",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
assert isinstance(response.usage, litellm.Usage)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_mistral_embeddings()
|
# test_mistral_embeddings()
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
import os, httpx
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
@ -56,6 +57,87 @@ def test_router_num_retries_init(num_retries, max_retries):
|
||||||
else:
|
else:
|
||||||
assert getattr(model_client, "max_retries") == 0
|
assert getattr(model_client, "max_retries") == 0
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"timeout", [10, 1.0, httpx.Timeout(timeout=300.0, connect=20.0)]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("ssl_verify", [True, False])
|
||||||
|
def test_router_timeout_init(timeout, ssl_verify):
|
||||||
|
"""
|
||||||
|
Allow user to pass httpx.Timeout
|
||||||
|
|
||||||
|
related issue - https://github.com/BerriAI/litellm/issues/3162
|
||||||
|
"""
|
||||||
|
litellm.ssl_verify = ssl_verify
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "test-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"timeout": timeout,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1234},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
model_client = router._get_client(
|
||||||
|
deployment={"model_info": {"id": 1234}}, client_type="sync_client", kwargs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert getattr(model_client, "timeout") == timeout
|
||||||
|
|
||||||
|
print(f"vars model_client: {vars(model_client)}")
|
||||||
|
http_client = getattr(model_client, "_client")
|
||||||
|
print(f"http client: {vars(http_client)}, ssl_Verify={ssl_verify}")
|
||||||
|
if ssl_verify == False:
|
||||||
|
assert http_client._transport._pool._ssl_context.verify_mode.name == "CERT_NONE"
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
http_client._transport._pool._ssl_context.verify_mode.name
|
||||||
|
== "CERT_REQUIRED"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mistral_api_base",
|
||||||
|
[
|
||||||
|
"os.environ/AZURE_MISTRAL_API_BASE",
|
||||||
|
"https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1/",
|
||||||
|
"https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1",
|
||||||
|
"https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/",
|
||||||
|
"https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_router_azure_ai_studio_init(mistral_api_base):
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "test-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/mistral-large-latest",
|
||||||
|
"api_key": "os.environ/AZURE_MISTRAL_API_KEY",
|
||||||
|
"api_base": mistral_api_base,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1234},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
model_client = router._get_client(
|
||||||
|
deployment={"model_info": {"id": 1234}}, client_type="sync_client", kwargs={}
|
||||||
|
)
|
||||||
|
url = getattr(model_client, "_base_url")
|
||||||
|
uri_reference = str(getattr(url, "_uri_reference"))
|
||||||
|
|
||||||
|
print(f"uri_reference: {uri_reference}")
|
||||||
|
|
||||||
|
assert "/v1/" in uri_reference
|
||||||
|
|
||||||
|
|
||||||
def test_exception_raising():
|
def test_exception_raising():
|
||||||
# this tests if the router raises an exception when invalid params are set
|
# this tests if the router raises an exception when invalid params are set
|
||||||
|
|
|
@ -1271,6 +1271,32 @@ def test_completion_sagemaker_stream():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_watsonx_stream():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model="watsonx/ibm/granite-13b-chat-v2",
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.5,
|
||||||
|
max_tokens=20,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
complete_response = ""
|
||||||
|
has_finish_reason = False
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
for idx, chunk in enumerate(response):
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
has_finish_reason = finished
|
||||||
|
if finished:
|
||||||
|
break
|
||||||
|
complete_response += chunk
|
||||||
|
if has_finish_reason is False:
|
||||||
|
raise Exception("finish reason not set for last chunk")
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# test_completion_sagemaker_stream()
|
# test_completion_sagemaker_stream()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import List, Optional, Union, Dict, Tuple, Literal
|
from typing import List, Optional, Union, Dict, Tuple, Literal
|
||||||
|
import httpx
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
from .completion import CompletionRequest
|
from .completion import CompletionRequest
|
||||||
from .embedding import EmbeddingRequest
|
from .embedding import EmbeddingRequest
|
||||||
|
@ -104,7 +104,9 @@ class LiteLLM_Params(BaseModel):
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
api_base: Optional[str] = None
|
api_base: Optional[str] = None
|
||||||
api_version: Optional[str] = None
|
api_version: Optional[str] = None
|
||||||
timeout: Optional[Union[float, str]] = None # if str, pass in as os.environ/
|
timeout: Optional[Union[float, str, httpx.Timeout]] = (
|
||||||
|
None # if str, pass in as os.environ/
|
||||||
|
)
|
||||||
stream_timeout: Optional[Union[float, str]] = (
|
stream_timeout: Optional[Union[float, str]] = (
|
||||||
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
||||||
)
|
)
|
||||||
|
@ -152,6 +154,7 @@ class LiteLLM_Params(BaseModel):
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
# Define custom behavior for the 'in' operator
|
# Define custom behavior for the 'in' operator
|
||||||
|
|
|
@ -5427,6 +5427,45 @@ def get_optional_params(
|
||||||
optional_params["extra_body"] = (
|
optional_params["extra_body"] = (
|
||||||
extra_body # openai client supports `extra_body` param
|
extra_body # openai client supports `extra_body` param
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "watsonx":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
if max_tokens is not None:
|
||||||
|
optional_params["max_new_tokens"] = max_tokens
|
||||||
|
if stream:
|
||||||
|
optional_params["stream"] = stream
|
||||||
|
if temperature is not None:
|
||||||
|
optional_params["temperature"] = temperature
|
||||||
|
if top_p is not None:
|
||||||
|
optional_params["top_p"] = top_p
|
||||||
|
if frequency_penalty is not None:
|
||||||
|
optional_params["repetition_penalty"] = frequency_penalty
|
||||||
|
if seed is not None:
|
||||||
|
optional_params["random_seed"] = seed
|
||||||
|
if stop is not None:
|
||||||
|
optional_params["stop_sequences"] = stop
|
||||||
|
|
||||||
|
# WatsonX-only parameters
|
||||||
|
extra_body = {}
|
||||||
|
if "decoding_method" in passed_params:
|
||||||
|
extra_body["decoding_method"] = passed_params.pop("decoding_method")
|
||||||
|
if "min_tokens" in passed_params or "min_new_tokens" in passed_params:
|
||||||
|
extra_body["min_new_tokens"] = passed_params.pop("min_tokens", passed_params.pop("min_new_tokens"))
|
||||||
|
if "top_k" in passed_params:
|
||||||
|
extra_body["top_k"] = passed_params.pop("top_k")
|
||||||
|
if "truncate_input_tokens" in passed_params:
|
||||||
|
extra_body["truncate_input_tokens"] = passed_params.pop("truncate_input_tokens")
|
||||||
|
if "length_penalty" in passed_params:
|
||||||
|
extra_body["length_penalty"] = passed_params.pop("length_penalty")
|
||||||
|
if "time_limit" in passed_params:
|
||||||
|
extra_body["time_limit"] = passed_params.pop("time_limit")
|
||||||
|
if "return_options" in passed_params:
|
||||||
|
extra_body["return_options"] = passed_params.pop("return_options")
|
||||||
|
optional_params["extra_body"] = (
|
||||||
|
extra_body # openai client supports `extra_body` param
|
||||||
|
)
|
||||||
else: # assume passing in params for openai/azure openai
|
else: # assume passing in params for openai/azure openai
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}"
|
f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}"
|
||||||
|
@ -5829,6 +5868,8 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
"frequency_penalty",
|
"frequency_penalty",
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
]
|
]
|
||||||
|
elif custom_llm_provider == "watsonx":
|
||||||
|
return litellm.IBMWatsonXAIConfig().get_supported_openai_params()
|
||||||
|
|
||||||
|
|
||||||
def get_formatted_prompt(
|
def get_formatted_prompt(
|
||||||
|
@ -6056,6 +6097,8 @@ def get_llm_provider(
|
||||||
model in litellm.bedrock_models or model in litellm.bedrock_embedding_models
|
model in litellm.bedrock_models or model in litellm.bedrock_embedding_models
|
||||||
):
|
):
|
||||||
custom_llm_provider = "bedrock"
|
custom_llm_provider = "bedrock"
|
||||||
|
elif model in litellm.watsonx_models:
|
||||||
|
custom_llm_provider = "watsonx"
|
||||||
# openai embeddings
|
# openai embeddings
|
||||||
elif model in litellm.open_ai_embedding_models:
|
elif model in litellm.open_ai_embedding_models:
|
||||||
custom_llm_provider = "openai"
|
custom_llm_provider = "openai"
|
||||||
|
@ -6520,7 +6563,7 @@ def validate_environment(model: Optional[str] = None) -> dict:
|
||||||
if "VERTEXAI_PROJECT" in os.environ and "VERTEXAI_LOCATION" in os.environ:
|
if "VERTEXAI_PROJECT" in os.environ and "VERTEXAI_LOCATION" in os.environ:
|
||||||
keys_in_environment = True
|
keys_in_environment = True
|
||||||
else:
|
else:
|
||||||
missing_keys.extend(["VERTEXAI_PROJECT", "VERTEXAI_PROJECT"])
|
missing_keys.extend(["VERTEXAI_PROJECT", "VERTEXAI_LOCATION"])
|
||||||
elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
if "HUGGINGFACE_API_KEY" in os.environ:
|
if "HUGGINGFACE_API_KEY" in os.environ:
|
||||||
keys_in_environment = True
|
keys_in_environment = True
|
||||||
|
@ -9751,6 +9794,37 @@ class CustomStreamWrapper:
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def handle_watsonx_stream(self, chunk):
|
||||||
|
try:
|
||||||
|
if isinstance(chunk, dict):
|
||||||
|
parsed_response = chunk
|
||||||
|
elif isinstance(chunk, (str, bytes)):
|
||||||
|
if isinstance(chunk, bytes):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
if 'generated_text' in chunk:
|
||||||
|
response = chunk.replace('data: ', '').strip()
|
||||||
|
parsed_response = json.loads(response)
|
||||||
|
else:
|
||||||
|
return {"text": "", "is_finished": False}
|
||||||
|
else:
|
||||||
|
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
|
||||||
|
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||||
|
results = parsed_response.get("results", [])
|
||||||
|
if len(results) > 0:
|
||||||
|
text = results[0].get("generated_text", "")
|
||||||
|
finish_reason = results[0].get("stop_reason")
|
||||||
|
is_finished = finish_reason != 'not_finished'
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"is_finished": is_finished,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
"prompt_tokens": results[0].get("input_token_count", None),
|
||||||
|
"completion_tokens": results[0].get("generated_token_count", None),
|
||||||
|
}
|
||||||
|
return {"text": "", "is_finished": False}
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def model_response_creator(self):
|
def model_response_creator(self):
|
||||||
model_response = ModelResponse(stream=True, model=self.model)
|
model_response = ModelResponse(stream=True, model=self.model)
|
||||||
if self.response_id is not None:
|
if self.response_id is not None:
|
||||||
|
@ -10006,6 +10080,21 @@ class CustomStreamWrapper:
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
|
elif self.custom_llm_provider == "watsonx":
|
||||||
|
response_obj = self.handle_watsonx_stream(chunk)
|
||||||
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
|
if response_obj.get("prompt_tokens") is not None:
|
||||||
|
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
|
||||||
|
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
|
||||||
|
if response_obj.get("completion_tokens") is not None:
|
||||||
|
model_response.usage.completion_tokens = response_obj["completion_tokens"]
|
||||||
|
model_response.usage.total_tokens = (
|
||||||
|
getattr(model_response.usage, "prompt_tokens", 0)
|
||||||
|
+ getattr(model_response.usage, "completion_tokens", 0)
|
||||||
|
)
|
||||||
|
if response_obj["is_finished"]:
|
||||||
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider == "text-completion-openai":
|
elif self.custom_llm_provider == "text-completion-openai":
|
||||||
response_obj = self.handle_openai_text_completion_chunk(chunk)
|
response_obj = self.handle_openai_text_completion_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.35.29"
|
version = "1.35.30"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.35.29"
|
version = "1.35.30"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue