Merge branch 'main' into litellm_default_router_retries

This commit is contained in:
Krish Dholakia 2024-04-27 11:21:57 -07:00 committed by GitHub
commit 1a06f009d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 1663 additions and 44 deletions

View file

@ -227,6 +227,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ | | [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ | | [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
| [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ | | [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ |
| [IBM - watsonx.ai](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | ✅
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ | | [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ |
| [xinference [Xorbits Inference]](https://docs.litellm.ai/docs/providers/xinference) | | | | | ✅ | | [xinference [Xorbits Inference]](https://docs.litellm.ai/docs/providers/xinference) | | | | | ✅ |

300
cookbook/liteLLM_IBM_Watsonx.ipynb vendored Normal file

File diff suppressed because one or more lines are too long

View file

@ -53,6 +53,50 @@ All models listed here https://docs.mistral.ai/platform/endpoints are supported.
| open-mixtral-8x22b | `completion(model="mistral/open-mixtral-8x22b", messages)` | | open-mixtral-8x22b | `completion(model="mistral/open-mixtral-8x22b", messages)` |
## Function Calling
```python
from litellm import completion
# set env
os.environ["MISTRAL_API_KEY"] = "your-api-key"
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion(
model="mistral/mistral-large-latest",
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
```
## Sample Usage - Embedding ## Sample Usage - Embedding
```python ```python
from litellm import embedding from litellm import embedding

View file

@ -0,0 +1,284 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# IBM watsonx.ai
LiteLLM supports all IBM [watsonx.ai](https://watsonx.ai/) foundational models and embeddings.
## Environment Variables
```python
os.environ["WATSONX_URL"] = "" # (required) Base URL of your WatsonX instance
# (required) either one of the following:
os.environ["WATSONX_APIKEY"] = "" # IBM cloud API key
os.environ["WATSONX_TOKEN"] = "" # IAM auth token
# optional - can also be passed as params to completion() or embedding()
os.environ["WATSONX_PROJECT_ID"] = "" # Project ID of your WatsonX instance
os.environ["WATSONX_DEPLOYMENT_SPACE_ID"] = "" # ID of your deployment space to use deployed models
```
See [here](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information on how to get an access token to authenticate to watsonx.ai.
## Usage
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_IBM_Watsonx.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
```python
import os
from litellm import completion
os.environ["WATSONX_URL"] = ""
os.environ["WATSONX_APIKEY"] = ""
response = completion(
model="watsonx/ibm/granite-13b-chat-v2",
messages=[{ "content": "what is your favorite colour?","role": "user"}],
project_id="<my-project-id>" # or pass with os.environ["WATSONX_PROJECT_ID"]
)
response = completion(
model="watsonx/meta-llama/llama-3-8b-instruct",
messages=[{ "content": "what is your favorite colour?","role": "user"}],
project_id="<my-project-id>"
)
```
## Usage - Streaming
```python
import os
from litellm import completion
os.environ["WATSONX_URL"] = ""
os.environ["WATSONX_APIKEY"] = ""
os.environ["WATSONX_PROJECT_ID"] = ""
response = completion(
model="watsonx/ibm/granite-13b-chat-v2",
messages=[{ "content": "what is your favorite colour?","role": "user"}],
stream=True
)
for chunk in response:
print(chunk)
```
#### Example Streaming Output Chunk
```json
{
"choices": [
{
"finish_reason": null,
"index": 0,
"delta": {
"content": "I don't have a favorite color, but I do like the color blue. What's your favorite color?"
}
}
],
"created": null,
"model": "watsonx/ibm/granite-13b-chat-v2",
"usage": {
"prompt_tokens": null,
"completion_tokens": null,
"total_tokens": null
}
}
```
## Usage - Models in deployment spaces
Models that have been deployed to a deployment space (e.g.: tuned models) can be called using the `deployment/<deployment_id>` format (where `<deployment_id>` is the ID of the deployed model in your deployment space).
The ID of your deployment space must also be set in the environment variable `WATSONX_DEPLOYMENT_SPACE_ID` or passed to the function as `space_id=<deployment_space_id>`.
```python
import litellm
response = litellm.completion(
model="watsonx/deployment/<deployment_id>",
messages=[{"content": "Hello, how are you?", "role": "user"}],
space_id="<deployment_space_id>"
)
```
## Usage - Embeddings
LiteLLM also supports making requests to IBM watsonx.ai embedding models. The credential needed for this is the same as for completion.
```python
from litellm import embedding
response = embedding(
model="watsonx/ibm/slate-30m-english-rtrvr",
input=["What is the capital of France?"],
project_id="<my-project-id>"
)
print(response)
# EmbeddingResponse(model='ibm/slate-30m-english-rtrvr', data=[{'object': 'embedding', 'index': 0, 'embedding': [-0.037463713, -0.02141933, -0.02851813, 0.015519324, ..., -0.0021367231, -0.01704561, -0.001425816, 0.0035238306]}], object='list', usage=Usage(prompt_tokens=8, total_tokens=8))
```
## OpenAI Proxy Usage
Here's how to call IBM watsonx.ai with the LiteLLM Proxy Server
### 1. Save keys in your environment
```bash
export WATSONX_URL=""
export WATSONX_APIKEY=""
export WATSONX_PROJECT_ID=""
```
### 2. Start the proxy
<Tabs>
<TabItem value="cli" label="CLI">
```bash
$ litellm --model watsonx/meta-llama/llama-3-8b-instruct
# Server running on http://0.0.0.0:4000
```
</TabItem>
<TabItem value="config" label="config.yaml">
```yaml
model_list:
- model_name: llama-3-8b
litellm_params:
# all params accepted by litellm.completion()
model: watsonx/meta-llama/llama-3-8b-instruct
api_key: "os.environ/WATSONX_API_KEY" # does os.getenv("WATSONX_API_KEY")
```
</TabItem>
</Tabs>
### 3. Test it
<Tabs>
<TabItem value="Curl" label="Curl Request">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "llama-3-8b",
"messages": [
{
"role": "user",
"content": "what is your favorite colour?"
}
]
}
'
```
</TabItem>
<TabItem value="openai" label="OpenAI v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="llama-3-8b", messages=[
{
"role": "user",
"content": "what is your favorite colour?"
}
])
print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000", # set openai_api_base to the LiteLLM Proxy
model = "llama-3-8b",
temperature=0.1
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
</Tabs>
## Authentication
### Passing credentials as parameters
You can also pass the credentials as parameters to the completion and embedding functions.
```python
import os
from litellm import completion
response = completion(
model="watsonx/ibm/granite-13b-chat-v2",
messages=[{ "content": "What is your favorite color?","role": "user"}],
url="",
api_key="",
project_id=""
)
```
## Supported IBM watsonx.ai Models
Here are some examples of models available in IBM watsonx.ai that you can use with LiteLLM:
| Mode Name | Command |
| ---------- | --------- |
| Flan T5 XXL | `completion(model=watsonx/google/flan-t5-xxl, messages=messages)` |
| Flan Ul2 | `completion(model=watsonx/google/flan-ul2, messages=messages)` |
| Mt0 XXL | `completion(model=watsonx/bigscience/mt0-xxl, messages=messages)` |
| Gpt Neox | `completion(model=watsonx/eleutherai/gpt-neox-20b, messages=messages)` |
| Mpt 7B Instruct2 | `completion(model=watsonx/ibm/mpt-7b-instruct2, messages=messages)` |
| Starcoder | `completion(model=watsonx/bigcode/starcoder, messages=messages)` |
| Llama 2 70B Chat | `completion(model=watsonx/meta-llama/llama-2-70b-chat, messages=messages)` |
| Llama 2 13B Chat | `completion(model=watsonx/meta-llama/llama-2-13b-chat, messages=messages)` |
| Granite 13B Instruct | `completion(model=watsonx/ibm/granite-13b-instruct-v1, messages=messages)` |
| Granite 13B Chat | `completion(model=watsonx/ibm/granite-13b-chat-v1, messages=messages)` |
| Flan T5 XL | `completion(model=watsonx/google/flan-t5-xl, messages=messages)` |
| Granite 13B Chat V2 | `completion(model=watsonx/ibm/granite-13b-chat-v2, messages=messages)` |
| Granite 13B Instruct V2 | `completion(model=watsonx/ibm/granite-13b-instruct-v2, messages=messages)` |
| Elyza Japanese Llama 2 7B Instruct | `completion(model=watsonx/elyza/elyza-japanese-llama-2-7b-instruct, messages=messages)` |
| Mixtral 8X7B Instruct V01 Q | `completion(model=watsonx/ibm-mistralai/mixtral-8x7b-instruct-v01-q, messages=messages)` |
For a list of all available models in watsonx.ai, see [here](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx&locale=en&audience=wdp).
## Supported IBM watsonx.ai Embedding Models
| Model Name | Function Call |
|----------------------|---------------------------------------------|
| Slate 30m | `embedding(model="watsonx/ibm/slate-30m-english-rtrvr", input=input)` |
| Slate 125m | `embedding(model="watsonx/ibm/slate-125m-english-rtrvr", input=input)` |
For a list of all available embedding models in watsonx.ai, see [here](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models-embed.html?context=wx).

View file

@ -148,6 +148,7 @@ const sidebars = {
"providers/openrouter", "providers/openrouter",
"providers/custom_openai_proxy", "providers/custom_openai_proxy",
"providers/petals", "providers/petals",
"providers/watsonx",
], ],
}, },
"proxy/custom_pricing", "proxy/custom_pricing",

View file

@ -77,6 +77,7 @@ baseten_key: Optional[str] = None
aleph_alpha_key: Optional[str] = None aleph_alpha_key: Optional[str] = None
nlp_cloud_key: Optional[str] = None nlp_cloud_key: Optional[str] = None
use_client: bool = False use_client: bool = False
ssl_verify: bool = True
disable_streaming_logging: bool = False disable_streaming_logging: bool = False
### GUARDRAILS ### ### GUARDRAILS ###
llamaguard_model_name: Optional[str] = None llamaguard_model_name: Optional[str] = None
@ -298,6 +299,7 @@ aleph_alpha_models: List = []
bedrock_models: List = [] bedrock_models: List = []
deepinfra_models: List = [] deepinfra_models: List = []
perplexity_models: List = [] perplexity_models: List = []
watsonx_models: List = []
for key, value in model_cost.items(): for key, value in model_cost.items():
if value.get("litellm_provider") == "openai": if value.get("litellm_provider") == "openai":
open_ai_chat_completion_models.append(key) open_ai_chat_completion_models.append(key)
@ -342,6 +344,8 @@ for key, value in model_cost.items():
deepinfra_models.append(key) deepinfra_models.append(key)
elif value.get("litellm_provider") == "perplexity": elif value.get("litellm_provider") == "perplexity":
perplexity_models.append(key) perplexity_models.append(key)
elif value.get("litellm_provider") == "watsonx":
watsonx_models.append(key)
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
openai_compatible_endpoints: List = [ openai_compatible_endpoints: List = [
@ -478,6 +482,7 @@ model_list = (
+ perplexity_models + perplexity_models
+ maritalk_models + maritalk_models
+ vertex_language_models + vertex_language_models
+ watsonx_models
) )
provider_list: List = [ provider_list: List = [
@ -516,6 +521,7 @@ provider_list: List = [
"cloudflare", "cloudflare",
"xinference", "xinference",
"fireworks_ai", "fireworks_ai",
"watsonx",
"custom", # custom apis "custom", # custom apis
] ]
@ -537,6 +543,7 @@ models_by_provider: dict = {
"deepinfra": deepinfra_models, "deepinfra": deepinfra_models,
"perplexity": perplexity_models, "perplexity": perplexity_models,
"maritalk": maritalk_models, "maritalk": maritalk_models,
"watsonx": watsonx_models,
} }
# mapping for those models which have larger equivalents # mapping for those models which have larger equivalents
@ -650,6 +657,7 @@ from .llms.bedrock import (
) )
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
from .llms.watsonx import IBMWatsonXAIConfig
from .main import * # type: ignore from .main import * # type: ignore
from .integrations import * from .integrations import *
from .exceptions import ( from .exceptions import (

View file

@ -430,6 +430,32 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
prompt = default_pt(messages) prompt = default_pt(messages)
return prompt return prompt
### IBM Granite
def ibm_granite_pt(messages: list):
"""
IBM's Granite models uses the template:
<|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message}
See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models
"""
return custom_prompt(
messages=messages,
role_dict={
'system': {
'pre_message': '<|system|>\n',
'post_message': '\n',
},
'user': {
'pre_message': '<|user|>\n',
'post_message': '\n',
},
'assistant': {
'pre_message': '<|assistant|>\n',
'post_message': '\n',
}
}
).strip()
### ANTHROPIC ### ### ANTHROPIC ###
@ -1359,6 +1385,25 @@ def prompt_factory(
return messages return messages
elif custom_llm_provider == "azure_text": elif custom_llm_provider == "azure_text":
return azure_text_pt(messages=messages) return azure_text_pt(messages=messages)
elif custom_llm_provider == "watsonx":
if "granite" in model and "chat" in model:
# granite-13b-chat-v1 and granite-13b-chat-v2 use a specific prompt template
return ibm_granite_pt(messages=messages)
elif "ibm-mistral" in model and "instruct" in model:
# models like ibm-mistral/mixtral-8x7b-instruct-v01-q use the mistral instruct prompt template
return mistral_instruct_pt(messages=messages)
elif "meta-llama/llama-3" in model and "instruct" in model:
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
return custom_prompt(
role_dict={
"system": {"pre_message": "<|start_header_id|>system<|end_header_id|>\n", "post_message": "<|eot_id|>"},
"user": {"pre_message": "<|start_header_id|>user<|end_header_id|>\n", "post_message": "<|eot_id|>"},
"assistant": {"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", "post_message": "<|eot_id|>"},
},
messages=messages,
initial_prompt_value="<|begin_of_text|>",
final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n",
)
try: try:
if "meta-llama/llama-2" in model and "chat" in model: if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages) return llama_2_chat_pt(messages=messages)

View file

@ -112,10 +112,16 @@ def start_prediction(
} }
initial_prediction_data = { initial_prediction_data = {
"version": version_id,
"input": input_data, "input": input_data,
} }
if ":" in version_id and len(version_id) > 64:
model_parts = version_id.split(":")
if (
len(model_parts) > 1 and len(model_parts[1]) == 64
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
initial_prediction_data["version"] = model_parts[1]
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input_data["prompt"], input=input_data["prompt"],

View file

@ -529,6 +529,7 @@ def completion(
"instances": instances, "instances": instances,
"vertex_location": vertex_location, "vertex_location": vertex_location,
"vertex_project": vertex_project, "vertex_project": vertex_project,
"safety_settings":safety_settings,
**optional_params, **optional_params,
} }
if optional_params.get("stream", False) is True: if optional_params.get("stream", False) is True:
@ -813,6 +814,7 @@ async def async_completion(
instances=None, instances=None,
vertex_project=None, vertex_project=None,
vertex_location=None, vertex_location=None,
safety_settings=None,
**optional_params, **optional_params,
): ):
""" """
@ -844,6 +846,7 @@ async def async_completion(
response = await llm_model._generate_content_async( response = await llm_model._generate_content_async(
contents=content, contents=content,
generation_config=optional_params, generation_config=optional_params,
safety_settings=safety_settings,
tools=tools, tools=tools,
) )

591
litellm/llms/watsonx.py Normal file
View file

@ -0,0 +1,591 @@
from enum import Enum
import json, types, time # noqa: E401
from contextlib import contextmanager
from typing import Callable, Dict, Optional, Any, Union, List
import httpx
import requests
import litellm
from litellm.utils import ModelResponse, get_secret, Usage
from .base import BaseLLM
from .prompt_templates import factory as ptf
class WatsonXAIError(Exception):
def __init__(self, status_code, message, url: Optional[str] = None):
self.status_code = status_code
self.message = message
url = url or "https://https://us-south.ml.cloud.ibm.com"
self.request = httpx.Request(method="POST", url=url)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class IBMWatsonXAIConfig:
"""
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
Supported params for all available watsonx.ai foundational models.
- `decoding_method` (str): One of "greedy" or "sample"
- `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'.
- `max_new_tokens` (integer): Maximum length of the generated tokens.
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
- `stop_sequences` (string[]): list of strings to use as stop sequences.
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
- `repetition_penalty` (float): token repetition penalty during text generation.
- `truncate_input_tokens` (integer): Truncate input tokens to this length.
- `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match.
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean.
- `random_seed` (integer): Random seed for text generation.
- `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering.
- `stream` (bool): If True, the model will return a stream of responses.
"""
decoding_method: Optional[str] = "sample"
temperature: Optional[float] = None
max_new_tokens: Optional[int] = None # litellm.max_tokens
min_new_tokens: Optional[int] = None
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
top_k: Optional[int] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
truncate_input_tokens: Optional[int] = None
include_stop_sequences: Optional[bool] = False
return_options: Optional[Dict[str, bool]] = None
random_seed: Optional[int] = None # e.g 42
moderations: Optional[dict] = None
stream: Optional[bool] = False
def __init__(
self,
decoding_method: Optional[str] = None,
temperature: Optional[float] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
length_penalty: Optional[dict] = None,
stop_sequences: Optional[List[str]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
truncate_input_tokens: Optional[int] = None,
include_stop_sequences: Optional[bool] = None,
return_options: Optional[dict] = None,
random_seed: Optional[int] = None,
moderations: Optional[dict] = None,
stream: Optional[bool] = None,
**kwargs,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"temperature", # equivalent to temperature
"max_tokens", # equivalent to max_new_tokens
"top_p", # equivalent to top_p
"frequency_penalty", # equivalent to repetition_penalty
"stop", # equivalent to stop_sequences
"seed", # equivalent to random_seed
"stream", # equivalent to stream
]
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# handle anthropic prompts and amazon titan prompts
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_dict = custom_prompt_dict[model]
prompt = ptf.custom_prompt(
messages=messages,
role_dict=model_prompt_dict.get(
"role_dict", model_prompt_dict.get("roles")
),
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
bos_token=model_prompt_dict.get("bos_token", ""),
eos_token=model_prompt_dict.get("eos_token", ""),
)
return prompt
elif provider == "ibm":
prompt = ptf.prompt_factory(
model=model, messages=messages, custom_llm_provider="watsonx"
)
elif provider == "ibm-mistralai":
prompt = ptf.mistral_instruct_pt(messages=messages)
else:
prompt = ptf.prompt_factory(
model=model, messages=messages, custom_llm_provider="watsonx"
)
return prompt
class WatsonXAIEndpoint(str, Enum):
TEXT_GENERATION = "/ml/v1/text/generation"
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
DEPLOYMENT_TEXT_GENERATION = "/ml/v1/deployments/{deployment_id}/text/generation"
DEPLOYMENT_TEXT_GENERATION_STREAM = (
"/ml/v1/deployments/{deployment_id}/text/generation_stream"
)
EMBEDDINGS = "/ml/v1/text/embeddings"
PROMPTS = "/ml/v1/prompts"
class IBMWatsonXAI(BaseLLM):
"""
Class to interface with IBM Watsonx.ai API for text generation and embeddings.
Reference: https://cloud.ibm.com/apidocs/watsonx-ai
"""
api_version = "2024-03-13"
def __init__(self) -> None:
super().__init__()
def _prepare_text_generation_req(
self,
model_id: str,
prompt: str,
stream: bool,
optional_params: dict,
print_verbose: Optional[Callable] = None,
) -> dict:
"""
Get the request parameters for text generation.
"""
api_params = self._get_api_params(optional_params, print_verbose=print_verbose)
# build auth headers
api_token = api_params.get("token")
headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json",
"Accept": "application/json",
}
extra_body_params = optional_params.pop("extra_body", {})
optional_params.update(extra_body_params)
# init the payload to the text generation call
payload = {
"input": prompt,
"moderations": optional_params.pop("moderations", {}),
"parameters": optional_params,
}
request_params = dict(version=api_params["api_version"])
# text generation endpoint deployment or model / stream or not
if model_id.startswith("deployment/"):
# deployment models are passed in as 'deployment/<deployment_id>'
if api_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
url=api_params["url"],
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model_id.split("/")[1:])
endpoint = (
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
if stream
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
)
endpoint = endpoint.format(deployment_id=deployment_id)
else:
payload["model_id"] = model_id
payload["project_id"] = api_params["project_id"]
endpoint = (
WatsonXAIEndpoint.TEXT_GENERATION_STREAM
if stream
else WatsonXAIEndpoint.TEXT_GENERATION
)
url = api_params["url"].rstrip("/") + endpoint
return dict(
method="POST", url=url, headers=headers, json=payload, params=request_params
)
def _get_api_params(
self, params: dict, print_verbose: Optional[Callable] = None
) -> dict:
"""
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
"""
# Load auth variables from params
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
api_key = params.pop("apikey", None)
token = params.pop("token", None)
project_id = params.pop(
"project_id", params.pop("watsonx_project", None)
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
region_name = params.pop("region_name", params.pop("region", None))
if region_name is None:
region_name = params.pop(
"watsonx_region_name", params.pop("watsonx_region", None)
) # consistent with how vertex ai + aws regions are accepted
wx_credentials = params.pop(
"wx_credentials",
params.pop(
"watsonx_credentials", None
), # follow {provider}_credentials, same as vertex ai
)
api_version = params.pop("api_version", IBMWatsonXAI.api_version)
# Load auth variables from environment variables
if url is None:
url = (
get_secret("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret("WATSONX_URL")
or get_secret("WX_URL")
or get_secret("WML_URL")
)
if api_key is None:
api_key = (
get_secret("WATSONX_APIKEY")
or get_secret("WATSONX_API_KEY")
or get_secret("WX_API_KEY")
)
if token is None:
token = get_secret("WATSONX_TOKEN") or get_secret("WX_TOKEN")
if project_id is None:
project_id = (
get_secret("WATSONX_PROJECT_ID")
or get_secret("WX_PROJECT_ID")
or get_secret("PROJECT_ID")
)
if region_name is None:
region_name = (
get_secret("WATSONX_REGION")
or get_secret("WX_REGION")
or get_secret("REGION")
)
if space_id is None:
space_id = (
get_secret("WATSONX_DEPLOYMENT_SPACE_ID")
or get_secret("WATSONX_SPACE_ID")
or get_secret("WX_SPACE_ID")
or get_secret("SPACE_ID")
)
# credentials parsing
if wx_credentials is not None:
url = wx_credentials.get("url", url)
api_key = wx_credentials.get(
"apikey", wx_credentials.get("api_key", api_key)
)
token = wx_credentials.get(
"token",
wx_credentials.get(
"watsonx_token", token
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
)
# verify that all required credentials are present
if url is None:
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
)
if token is None and api_key is not None:
# generate the auth token
if print_verbose:
print_verbose("Generating IAM token for Watsonx.ai")
token = self.generate_iam_token(api_key)
elif token is None and api_key is None:
raise WatsonXAIError(
status_code=401,
url=url,
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
)
if project_id is None:
raise WatsonXAIError(
status_code=401,
url=url,
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
)
return {
"url": url,
"api_key": api_key,
"token": token,
"project_id": project_id,
"space_id": space_id,
"region_name": region_name,
"api_version": api_version,
}
def completion(
self,
model: str,
messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
litellm_params: Optional[dict] = None,
logger_fn=None,
timeout: Optional[float] = None,
):
"""
Send a text generation request to the IBM Watsonx.ai API.
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
"""
stream = optional_params.pop("stream", False)
# Load default configs
config = IBMWatsonXAIConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
# Make prompt to send to model
provider = model.split("/")[0]
# model_name = "/".join(model.split("/")[1:])
prompt = convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
def process_text_request(request_params: dict) -> ModelResponse:
with self._manage_response(
request_params, logging_obj=logging_obj, input=prompt, timeout=timeout
) as resp:
json_resp = resp.json()
generated_text = json_resp["results"][0]["generated_text"]
prompt_tokens = json_resp["results"][0]["input_token_count"]
completion_tokens = json_resp["results"][0]["generated_token_count"]
model_response["choices"][0]["message"]["content"] = generated_text
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
model_response["created"] = int(time.time())
model_response["model"] = model
setattr(
model_response,
"usage",
Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return model_response
def process_stream_request(
request_params: dict,
) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with self._manage_response(
request_params,
logging_obj=logging_obj,
stream=True,
input=prompt,
timeout=timeout,
) as resp:
response = litellm.CustomStreamWrapper(
resp.iter_lines(),
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return response
try:
## Get the response from the model
req_params = self._prepare_text_generation_req(
model_id=model,
prompt=prompt,
stream=stream,
optional_params=optional_params,
print_verbose=print_verbose,
)
if stream:
return process_stream_request(req_params)
else:
return process_text_request(req_params)
except WatsonXAIError as e:
raise e
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def embedding(
self,
model: str,
input: Union[list, str],
api_key: Optional[str] = None,
logging_obj=None,
model_response=None,
optional_params=None,
encoding=None,
):
"""
Send a text embedding request to the IBM Watsonx.ai API.
"""
if optional_params is None:
optional_params = {}
# Load default configs
config = IBMWatsonXAIConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
# Load auth variables from environment variables
if isinstance(input, str):
input = [input]
if api_key is not None:
optional_params["api_key"] = api_key
api_params = self._get_api_params(optional_params)
# build auth headers
api_token = api_params.get("token")
headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json",
"Accept": "application/json",
}
# init the payload to the text generation call
payload = {
"inputs": input,
"model_id": model,
"project_id": api_params["project_id"],
"parameters": optional_params,
}
request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
# request = httpx.Request(
# "POST", url, headers=headers, json=payload, params=request_params
# )
req_params = {
"method": "POST",
"url": url,
"headers": headers,
"json": payload,
"params": request_params,
}
with self._manage_response(
req_params, logging_obj=logging_obj, input=input
) as resp:
json_resp = resp.json()
results = json_resp.get("results", [])
embedding_response = []
for idx, result in enumerate(results):
embedding_response.append(
{"object": "embedding", "index": idx, "embedding": result["embedding"]}
)
model_response["object"] = "list"
model_response["data"] = embedding_response
model_response["model"] = model
input_tokens = json_resp.get("input_token_count", 0)
model_response.usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
return model_response
def generate_iam_token(self, api_key=None, **params):
headers = {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
if api_key is None:
api_key = get_secret("WX_API_KEY") or get_secret("WATSONX_API_KEY")
if api_key is None:
raise ValueError("API key is required")
headers["Accept"] = "application/json"
data = {
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
"apikey": api_key,
}
response = httpx.post(
"https://iam.cloud.ibm.com/identity/token", data=data, headers=headers
)
response.raise_for_status()
json_data = response.json()
iam_access_token = json_data["access_token"]
self.token = iam_access_token
return iam_access_token
@contextmanager
def _manage_response(
self,
request_params: dict,
logging_obj: Any,
stream: bool = False,
input: Optional[Any] = None,
timeout: Optional[float] = None,
):
request_str = (
f"response = {request_params['method']}(\n"
f"\turl={request_params['url']},\n"
f"\tjson={request_params['json']},\n"
f")"
)
logging_obj.pre_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
additional_args={
"complete_input_dict": request_params["json"],
"request_str": request_str,
},
)
if timeout:
request_params["timeout"] = timeout
try:
if stream:
resp = requests.request(
**request_params,
stream=True,
)
resp.raise_for_status()
yield resp
else:
resp = requests.request(**request_params)
resp.raise_for_status()
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
logging_obj.post_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params["json"],
},
)

View file

@ -62,6 +62,7 @@ from .llms import (
vertex_ai, vertex_ai,
vertex_ai_anthropic, vertex_ai_anthropic,
maritalk, maritalk,
watsonx,
) )
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion from .llms.azure import AzureChatCompletion
@ -1862,6 +1863,43 @@ def completion(
## RESPONSE OBJECT ## RESPONSE OBJECT
response = response response = response
elif custom_llm_provider == "watsonx":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonx.IBMWatsonXAI().completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
timeout=timeout,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
response = CustomStreamWrapper(
iter(response),
model,
custom_llm_provider="watsonx",
logging_obj=logging,
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
)
## RESPONSE OBJECT
response = response
elif custom_llm_provider == "vllm": elif custom_llm_provider == "vllm":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = vllm.completion( model_response = vllm.completion(
@ -2941,6 +2979,15 @@ def embedding(
client=client, client=client,
aembedding=aembedding, aembedding=aembedding,
) )
elif custom_llm_provider == "watsonx":
response = watsonx.IBMWatsonXAI().embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
)
else: else:
args = locals() args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}") raise ValueError(f"No valid embedding model args passed in - {args}")

View file

@ -6,4 +6,3 @@ model_list:
model_name: fake-openai-endpoint model_name: fake-openai-endpoint
router_settings: router_settings:
num_retries: 0 num_retries: 0

View file

@ -1937,6 +1937,7 @@ class Router:
) )
default_api_base = api_base default_api_base = api_base
default_api_key = api_key default_api_key = api_key
if ( if (
model_name in litellm.open_ai_chat_completion_models model_name in litellm.open_ai_chat_completion_models
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
@ -1972,6 +1973,23 @@ class Router:
api_base = litellm.get_secret(api_base_env_name) api_base = litellm.get_secret(api_base_env_name)
litellm_params["api_base"] = api_base litellm_params["api_base"] = api_base
## AZURE AI STUDIO MISTRAL CHECK ##
"""
Make sure api base ends in /v1/
if not, add it - https://github.com/BerriAI/litellm/issues/2279
"""
if (
custom_llm_provider == "openai"
and api_base is not None
and not api_base.endswith("/v1/")
):
# check if it ends with a trailing slash
if api_base.endswith("/"):
api_base += "v1/"
else:
api_base += "/v1/"
api_version = litellm_params.get("api_version") api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"): if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "") api_version_env_name = api_version.replace("os.environ/", "")
@ -2062,9 +2080,11 @@ class Router:
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), transport=AsyncCustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), ),
mounts=async_proxy_mounts, mounts=async_proxy_mounts,
), # type: ignore ), # type: ignore
@ -2084,9 +2104,11 @@ class Router:
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport(), transport=CustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), ),
mounts=sync_proxy_mounts, mounts=sync_proxy_mounts,
), # type: ignore ), # type: ignore
@ -2106,9 +2128,11 @@ class Router:
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), transport=AsyncCustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), ),
mounts=async_proxy_mounts, mounts=async_proxy_mounts,
), # type: ignore ), # type: ignore
@ -2128,9 +2152,11 @@ class Router:
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport(), transport=CustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), ),
mounts=sync_proxy_mounts, mounts=sync_proxy_mounts,
), # type: ignore ), # type: ignore
@ -2168,9 +2194,11 @@ class Router:
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), transport=AsyncCustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), ),
mounts=async_proxy_mounts, mounts=async_proxy_mounts,
), # type: ignore ), # type: ignore
@ -2188,9 +2216,11 @@ class Router:
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport(), transport=CustomHTTPTransport(
limits=httpx.Limits( verify=litellm.ssl_verify,
max_connections=1000, max_keepalive_connections=100 limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
), ),
mounts=sync_proxy_mounts, mounts=sync_proxy_mounts,
), # type: ignore ), # type: ignore
@ -2209,9 +2239,11 @@ class Router:
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), transport=AsyncCustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), ),
mounts=async_proxy_mounts, mounts=async_proxy_mounts,
), ),
@ -2229,9 +2261,11 @@ class Router:
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport(), transport=CustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), ),
mounts=sync_proxy_mounts, mounts=sync_proxy_mounts,
), ),
@ -2259,9 +2293,11 @@ class Router:
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), transport=AsyncCustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), ),
mounts=async_proxy_mounts, mounts=async_proxy_mounts,
), # type: ignore ), # type: ignore
@ -2281,9 +2317,11 @@ class Router:
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport(), transport=CustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), ),
mounts=sync_proxy_mounts, mounts=sync_proxy_mounts,
), # type: ignore ), # type: ignore
@ -2304,9 +2342,11 @@ class Router:
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), transport=AsyncCustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), ),
mounts=async_proxy_mounts, mounts=async_proxy_mounts,
), # type: ignore ), # type: ignore
@ -2327,9 +2367,11 @@ class Router:
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport(), transport=CustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), ),
mounts=sync_proxy_mounts, mounts=sync_proxy_mounts,
), # type: ignore ), # type: ignore

View file

@ -2655,6 +2655,42 @@ def test_completion_palm_stream():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_watsonx():
litellm.set_verbose = True
model_name = "watsonx/ibm/granite-13b-chat-v2"
try:
response = completion(
model=model_name,
messages=messages,
stop=["stop"],
max_tokens=20,
)
# Add any assertions here to check the response
print(response)
except litellm.APIError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_watsonx():
litellm.set_verbose = True
model_name = "watsonx/ibm/granite-13b-chat-v2"
print("testing watsonx")
try:
response = await litellm.acompletion(
model=model_name,
messages=messages,
temperature=0.2,
max_tokens=80,
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_palm_stream() # test_completion_palm_stream()
# test_completion_deep_infra() # test_completion_deep_infra()

View file

@ -483,6 +483,18 @@ def test_mistral_embeddings():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_watsonx_embeddings():
try:
litellm.set_verbose = True
response = litellm.embedding(
model="watsonx/ibm/slate-30m-english-rtrvr",
input=["good morning from litellm"],
)
print(f"response: {response}")
assert isinstance(response.usage, litellm.Usage)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_mistral_embeddings() # test_mistral_embeddings()

View file

@ -14,6 +14,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict from collections import defaultdict
from dotenv import load_dotenv from dotenv import load_dotenv
import os, httpx
load_dotenv() load_dotenv()
@ -56,6 +57,87 @@ def test_router_num_retries_init(num_retries, max_retries):
else: else:
assert getattr(model_client, "max_retries") == 0 assert getattr(model_client, "max_retries") == 0
@pytest.mark.parametrize(
"timeout", [10, 1.0, httpx.Timeout(timeout=300.0, connect=20.0)]
)
@pytest.mark.parametrize("ssl_verify", [True, False])
def test_router_timeout_init(timeout, ssl_verify):
"""
Allow user to pass httpx.Timeout
related issue - https://github.com/BerriAI/litellm/issues/3162
"""
litellm.ssl_verify = ssl_verify
router = Router(
model_list=[
{
"model_name": "test-model",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
"timeout": timeout,
},
"model_info": {"id": 1234},
}
]
)
model_client = router._get_client(
deployment={"model_info": {"id": 1234}}, client_type="sync_client", kwargs={}
)
assert getattr(model_client, "timeout") == timeout
print(f"vars model_client: {vars(model_client)}")
http_client = getattr(model_client, "_client")
print(f"http client: {vars(http_client)}, ssl_Verify={ssl_verify}")
if ssl_verify == False:
assert http_client._transport._pool._ssl_context.verify_mode.name == "CERT_NONE"
else:
assert (
http_client._transport._pool._ssl_context.verify_mode.name
== "CERT_REQUIRED"
)
@pytest.mark.parametrize(
"mistral_api_base",
[
"os.environ/AZURE_MISTRAL_API_BASE",
"https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1/",
"https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1",
"https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/",
"https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com",
],
)
def test_router_azure_ai_studio_init(mistral_api_base):
router = Router(
model_list=[
{
"model_name": "test-model",
"litellm_params": {
"model": "azure/mistral-large-latest",
"api_key": "os.environ/AZURE_MISTRAL_API_KEY",
"api_base": mistral_api_base,
},
"model_info": {"id": 1234},
}
]
)
model_client = router._get_client(
deployment={"model_info": {"id": 1234}}, client_type="sync_client", kwargs={}
)
url = getattr(model_client, "_base_url")
uri_reference = str(getattr(url, "_uri_reference"))
print(f"uri_reference: {uri_reference}")
assert "/v1/" in uri_reference
def test_exception_raising(): def test_exception_raising():
# this tests if the router raises an exception when invalid params are set # this tests if the router raises an exception when invalid params are set

View file

@ -1271,6 +1271,32 @@ def test_completion_sagemaker_stream():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_watsonx_stream():
litellm.set_verbose = True
try:
response = completion(
model="watsonx/ibm/granite-13b-chat-v2",
messages=messages,
temperature=0.5,
max_tokens=20,
stream=True,
)
complete_response = ""
has_finish_reason = False
# Add any assertions here to check the response
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
if finished:
break
complete_response += chunk
if has_finish_reason is False:
raise Exception("finish reason not set for last chunk")
if complete_response.strip() == "":
raise Exception("Empty response received")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_sagemaker_stream() # test_completion_sagemaker_stream()

View file

@ -1,5 +1,5 @@
from typing import List, Optional, Union, Dict, Tuple, Literal from typing import List, Optional, Union, Dict, Tuple, Literal
import httpx
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from .completion import CompletionRequest from .completion import CompletionRequest
from .embedding import EmbeddingRequest from .embedding import EmbeddingRequest
@ -104,7 +104,9 @@ class LiteLLM_Params(BaseModel):
api_key: Optional[str] = None api_key: Optional[str] = None
api_base: Optional[str] = None api_base: Optional[str] = None
api_version: Optional[str] = None api_version: Optional[str] = None
timeout: Optional[Union[float, str]] = None # if str, pass in as os.environ/ timeout: Optional[Union[float, str, httpx.Timeout]] = (
None # if str, pass in as os.environ/
)
stream_timeout: Optional[Union[float, str]] = ( stream_timeout: Optional[Union[float, str]] = (
None # timeout when making stream=True calls, if str, pass in as os.environ/ None # timeout when making stream=True calls, if str, pass in as os.environ/
) )
@ -152,6 +154,7 @@ class LiteLLM_Params(BaseModel):
class Config: class Config:
extra = "allow" extra = "allow"
arbitrary_types_allowed = True
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator

View file

@ -5427,6 +5427,45 @@ def get_optional_params(
optional_params["extra_body"] = ( optional_params["extra_body"] = (
extra_body # openai client supports `extra_body` param extra_body # openai client supports `extra_body` param
) )
elif custom_llm_provider == "watsonx":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if max_tokens is not None:
optional_params["max_new_tokens"] = max_tokens
if stream:
optional_params["stream"] = stream
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if frequency_penalty is not None:
optional_params["repetition_penalty"] = frequency_penalty
if seed is not None:
optional_params["random_seed"] = seed
if stop is not None:
optional_params["stop_sequences"] = stop
# WatsonX-only parameters
extra_body = {}
if "decoding_method" in passed_params:
extra_body["decoding_method"] = passed_params.pop("decoding_method")
if "min_tokens" in passed_params or "min_new_tokens" in passed_params:
extra_body["min_new_tokens"] = passed_params.pop("min_tokens", passed_params.pop("min_new_tokens"))
if "top_k" in passed_params:
extra_body["top_k"] = passed_params.pop("top_k")
if "truncate_input_tokens" in passed_params:
extra_body["truncate_input_tokens"] = passed_params.pop("truncate_input_tokens")
if "length_penalty" in passed_params:
extra_body["length_penalty"] = passed_params.pop("length_penalty")
if "time_limit" in passed_params:
extra_body["time_limit"] = passed_params.pop("time_limit")
if "return_options" in passed_params:
extra_body["return_options"] = passed_params.pop("return_options")
optional_params["extra_body"] = (
extra_body # openai client supports `extra_body` param
)
else: # assume passing in params for openai/azure openai else: # assume passing in params for openai/azure openai
print_verbose( print_verbose(
f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}" f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}"
@ -5829,6 +5868,8 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"frequency_penalty", "frequency_penalty",
"presence_penalty", "presence_penalty",
] ]
elif custom_llm_provider == "watsonx":
return litellm.IBMWatsonXAIConfig().get_supported_openai_params()
def get_formatted_prompt( def get_formatted_prompt(
@ -6056,6 +6097,8 @@ def get_llm_provider(
model in litellm.bedrock_models or model in litellm.bedrock_embedding_models model in litellm.bedrock_models or model in litellm.bedrock_embedding_models
): ):
custom_llm_provider = "bedrock" custom_llm_provider = "bedrock"
elif model in litellm.watsonx_models:
custom_llm_provider = "watsonx"
# openai embeddings # openai embeddings
elif model in litellm.open_ai_embedding_models: elif model in litellm.open_ai_embedding_models:
custom_llm_provider = "openai" custom_llm_provider = "openai"
@ -6520,7 +6563,7 @@ def validate_environment(model: Optional[str] = None) -> dict:
if "VERTEXAI_PROJECT" in os.environ and "VERTEXAI_LOCATION" in os.environ: if "VERTEXAI_PROJECT" in os.environ and "VERTEXAI_LOCATION" in os.environ:
keys_in_environment = True keys_in_environment = True
else: else:
missing_keys.extend(["VERTEXAI_PROJECT", "VERTEXAI_PROJECT"]) missing_keys.extend(["VERTEXAI_PROJECT", "VERTEXAI_LOCATION"])
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
if "HUGGINGFACE_API_KEY" in os.environ: if "HUGGINGFACE_API_KEY" in os.environ:
keys_in_environment = True keys_in_environment = True
@ -9751,6 +9794,37 @@ class CustomStreamWrapper:
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
def handle_watsonx_stream(self, chunk):
try:
if isinstance(chunk, dict):
parsed_response = chunk
elif isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
if 'generated_text' in chunk:
response = chunk.replace('data: ', '').strip()
parsed_response = json.loads(response)
else:
return {"text": "", "is_finished": False}
else:
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
raise ValueError(f"Unable to parse response. Original response: {chunk}")
results = parsed_response.get("results", [])
if len(results) > 0:
text = results[0].get("generated_text", "")
finish_reason = results[0].get("stop_reason")
is_finished = finish_reason != 'not_finished'
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
"prompt_tokens": results[0].get("input_token_count", None),
"completion_tokens": results[0].get("generated_token_count", None),
}
return {"text": "", "is_finished": False}
except Exception as e:
raise e
def model_response_creator(self): def model_response_creator(self):
model_response = ModelResponse(stream=True, model=self.model) model_response = ModelResponse(stream=True, model=self.model)
if self.response_id is not None: if self.response_id is not None:
@ -10006,6 +10080,21 @@ class CustomStreamWrapper:
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "watsonx":
response_obj = self.handle_watsonx_stream(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj.get("prompt_tokens") is not None:
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
if response_obj.get("completion_tokens") is not None:
model_response.usage.completion_tokens = response_obj["completion_tokens"]
model_response.usage.total_tokens = (
getattr(model_response.usage, "prompt_tokens", 0)
+ getattr(model_response.usage, "completion_tokens", 0)
)
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai": elif self.custom_llm_provider == "text-completion-openai":
response_obj = self.handle_openai_text_completion_chunk(chunk) response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.35.29" version = "1.35.30"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.35.29" version = "1.35.30"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]