mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #5478 from BerriAI/litellm_Add_ai21
[Feat] Add AI21 /chat API
This commit is contained in:
commit
dc1b0ec182
9 changed files with 412 additions and 8 deletions
|
@ -1,8 +1,17 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# AI21
|
||||
|
||||
LiteLLM supports j2-light, j2-mid and j2-ultra from [AI21](https://www.ai21.com/studio/pricing).
|
||||
LiteLLM supports j2-light, j2-mid and j2-ultra from [AI21](https://www.ai21.com/studio/pricing)
|
||||
|
||||
They're available to use without a waitlist.
|
||||
|
||||
:::tip
|
||||
|
||||
**We support ALL AI21 models, just set `model=ai21/<any-model-on-ai21>` as a prefix when sending litellm requests**.
|
||||
**See all litellm supported AI21 models [here](https://models.litellm.ai)**
|
||||
|
||||
:::
|
||||
|
||||
### API KEYS
|
||||
```python
|
||||
|
@ -10,6 +19,7 @@ import os
|
|||
os.environ["AI21_API_KEY"] = "your-api-key"
|
||||
```
|
||||
|
||||
## **LiteLLM Python SDK Usage**
|
||||
### Sample Usage
|
||||
|
||||
```python
|
||||
|
@ -23,10 +33,177 @@ messages = [{"role": "user", "content": "Write me a poem about the blue sky"}]
|
|||
completion(model="j2-light", messages=messages)
|
||||
```
|
||||
|
||||
### AI21 Models
|
||||
|
||||
|
||||
## **LiteLLM Proxy Server Usage**
|
||||
|
||||
Here's how to call a ai21 model with the LiteLLM Proxy Server
|
||||
|
||||
1. Modify the config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: my-model
|
||||
litellm_params:
|
||||
model: ai21/<your-model-name> # add ai21/ prefix to route as ai21 provider
|
||||
api_key: api-key # api key to send your model
|
||||
```
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Send Request to LiteLLM Proxy Server
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
|
||||
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="my-model",
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="curl">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "my-model",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
## Supported OpenAI Parameters
|
||||
|
||||
|
||||
| [param](../completion/input) | type | AI21 equivalent |
|
||||
|-------|-------------|------------------|
|
||||
| `tools` | **Optional[list]** | `tools` |
|
||||
| `response_format` | **Optional[dict]** | `response_format` |
|
||||
| `max_tokens` | **Optional[int]** | `max_tokens` |
|
||||
| `temperature` | **Optional[float]** | `temperature` |
|
||||
| `top_p` | **Optional[float]** | `top_p` |
|
||||
| `stop` | **Optional[Union[str, list]]** | `stop` |
|
||||
| `n` | **Optional[int]** | `n` |
|
||||
| `stream` | **Optional[bool]** | `stream` |
|
||||
| `seed` | **Optional[int]** | `seed` |
|
||||
| `tool_choice` | **Optional[str]** | `tool_choice` |
|
||||
| `user` | **Optional[str]** | `user` |
|
||||
|
||||
## Supported AI21 Parameters
|
||||
|
||||
|
||||
| param | type | [AI21 equivalent](https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters) |
|
||||
|-----------|------|-------------|
|
||||
| `documents` | **Optional[List[Dict]]** | `documents` |
|
||||
|
||||
|
||||
## Passing AI21 Specific Parameters - `documents`
|
||||
|
||||
LiteLLM allows you to pass all AI21 specific parameters to the `litellm.completion` function. Here is an example of how to pass the `documents` parameter to the `litellm.completion` function.
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="python" label="LiteLLM Python SDK">
|
||||
|
||||
```python
|
||||
response = await litellm.acompletion(
|
||||
model="jamba-1.5-large",
|
||||
messages=[{"role": "user", "content": "what does the document say"}],
|
||||
documents = [
|
||||
{
|
||||
"content": "hello world",
|
||||
"metadata": {
|
||||
"source": "google",
|
||||
"author": "ishaan"
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="LiteLLM Proxy Server">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
|
||||
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="my-model",
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
extra_body = {
|
||||
"documents": [
|
||||
{
|
||||
"content": "hello world",
|
||||
"metadata": {
|
||||
"source": "google",
|
||||
"author": "ishaan"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
:::tip
|
||||
|
||||
**We support ALL AI21 models, just set `model=ai21/<any-model-on-ai21>` as a prefix when sending litellm requests**
|
||||
**See all litellm supported AI21 models [here](https://models.litellm.ai)**
|
||||
:::
|
||||
|
||||
## AI21 Models
|
||||
|
||||
| Model Name | Function Call | Required OS Variables |
|
||||
|------------------|--------------------------------------------|--------------------------------------|
|
||||
| jamba-1.5-mini | `completion('jamba-1.5-mini', messages)` | `os.environ['AI21_API_KEY']` |
|
||||
| jamba-1.5-large | `completion('jamba-1.5-large', messages)` | `os.environ['AI21_API_KEY']` |
|
||||
| j2-light | `completion('j2-light', messages)` | `os.environ['AI21_API_KEY']` |
|
||||
| j2-mid | `completion('j2-mid', messages)` | `os.environ['AI21_API_KEY']` |
|
||||
| j2-ultra | `completion('j2-ultra', messages)` | `os.environ['AI21_API_KEY']` |
|
||||
|
||||
|
|
|
@ -364,6 +364,7 @@ vertex_llama3_models: List = []
|
|||
vertex_ai_ai21_models: List = []
|
||||
vertex_mistral_models: List = []
|
||||
ai21_models: List = []
|
||||
ai21_chat_models: List = []
|
||||
nlp_cloud_models: List = []
|
||||
aleph_alpha_models: List = []
|
||||
bedrock_models: List = []
|
||||
|
@ -416,6 +417,9 @@ for key, value in model_cost.items():
|
|||
key = key.replace("vertex_ai/", "")
|
||||
vertex_ai_ai21_models.append(key)
|
||||
elif value.get("litellm_provider") == "ai21":
|
||||
if value.get("mode") == "chat":
|
||||
ai21_chat_models.append(key)
|
||||
else:
|
||||
ai21_models.append(key)
|
||||
elif value.get("litellm_provider") == "nlp_cloud":
|
||||
nlp_cloud_models.append(key)
|
||||
|
@ -456,6 +460,7 @@ openai_compatible_providers: List = [
|
|||
"groq",
|
||||
"nvidia_nim",
|
||||
"cerebras",
|
||||
"ai21_chat",
|
||||
"volcengine",
|
||||
"codestral",
|
||||
"deepseek",
|
||||
|
@ -644,6 +649,7 @@ model_list = (
|
|||
+ vertex_chat_models
|
||||
+ vertex_text_models
|
||||
+ ai21_models
|
||||
+ ai21_chat_models
|
||||
+ together_ai_models
|
||||
+ baseten_models
|
||||
+ aleph_alpha_models
|
||||
|
@ -695,6 +701,7 @@ provider_list: List = [
|
|||
"groq",
|
||||
"nvidia_nim",
|
||||
"cerebras",
|
||||
"ai21_chat",
|
||||
"volcengine",
|
||||
"codestral",
|
||||
"text-completion-codestral",
|
||||
|
@ -853,7 +860,8 @@ from .llms.predibase import PredibaseConfig
|
|||
from .llms.replicate import ReplicateConfig
|
||||
from .llms.cohere.completion import CohereConfig
|
||||
from .llms.clarifai import ClarifaiConfig
|
||||
from .llms.ai21 import AI21Config
|
||||
from .llms.AI21.completion import AI21Config
|
||||
from .llms.AI21.chat import AI21ChatConfig
|
||||
from .llms.together_ai import TogetherAIConfig
|
||||
from .llms.cloudflare import CloudflareConfig
|
||||
from .llms.palm import PalmConfig
|
||||
|
@ -919,6 +927,7 @@ from .llms.openai import (
|
|||
)
|
||||
from .llms.nvidia_nim import NvidiaNimConfig
|
||||
from .llms.cerebras.chat import CerebrasConfig
|
||||
from .llms.AI21.chat import AI21ChatConfig
|
||||
from .llms.fireworks_ai import FireworksAIConfig
|
||||
from .llms.volcengine import VolcEngineConfig
|
||||
from .llms.text_completion_codestral import MistralTextCompletionConfig
|
||||
|
|
95
litellm/llms/AI21/chat.py
Normal file
95
litellm/llms/AI21/chat.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
AI21 Chat Completions API
|
||||
|
||||
this is OpenAI compatible - no translation needed / occurs
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class AI21ChatConfig:
|
||||
"""
|
||||
Reference: https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters
|
||||
|
||||
Below are the parameters:
|
||||
"""
|
||||
|
||||
tools: Optional[list] = None
|
||||
response_format: Optional[dict] = None
|
||||
documents: Optional[list] = None
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
seed: Optional[int] = None
|
||||
tool_choice: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: Optional[list] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
seed: Optional[int] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the supported OpenAI params for the given model
|
||||
|
||||
"""
|
||||
|
||||
return [
|
||||
"tools",
|
||||
"response_format",
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stop",
|
||||
"n",
|
||||
"stream",
|
||||
"seed",
|
||||
"tool_choice",
|
||||
"user",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self, model: str, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
|
@ -75,7 +75,6 @@ from litellm.utils import (
|
|||
from ._logging import verbose_logger
|
||||
from .caching import disable_cache, enable_cache, update_cache
|
||||
from .llms import (
|
||||
ai21,
|
||||
aleph_alpha,
|
||||
baseten,
|
||||
clarifai,
|
||||
|
@ -91,6 +90,7 @@ from .llms import (
|
|||
replicate,
|
||||
vllm,
|
||||
)
|
||||
from .llms.AI21 import completion as ai21
|
||||
from .llms.anthropic.chat import AnthropicChatCompletion
|
||||
from .llms.anthropic.completion import AnthropicTextCompletion
|
||||
from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||
|
@ -387,6 +387,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "nvidia_nim"
|
||||
or custom_llm_provider == "cerebras"
|
||||
or custom_llm_provider == "ai21_chat"
|
||||
or custom_llm_provider == "volcengine"
|
||||
or custom_llm_provider == "codestral"
|
||||
or custom_llm_provider == "text-completion-codestral"
|
||||
|
@ -1293,6 +1294,7 @@ def completion(
|
|||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "nvidia_nim"
|
||||
or custom_llm_provider == "cerebras"
|
||||
or custom_llm_provider == "ai21_chat"
|
||||
or custom_llm_provider == "volcengine"
|
||||
or custom_llm_provider == "codestral"
|
||||
or custom_llm_provider == "deepseek"
|
||||
|
@ -3143,6 +3145,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "nvidia_nim"
|
||||
or custom_llm_provider == "cerebras"
|
||||
or custom_llm_provider == "ai21_chat"
|
||||
or custom_llm_provider == "volcengine"
|
||||
or custom_llm_provider == "deepseek"
|
||||
or custom_llm_provider == "fireworks_ai"
|
||||
|
@ -3807,6 +3810,7 @@ async def atext_completion(
|
|||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "nvidia_nim"
|
||||
or custom_llm_provider == "cerebras"
|
||||
or custom_llm_provider == "ai21_chat"
|
||||
or custom_llm_provider == "volcengine"
|
||||
or custom_llm_provider == "text-completion-codestral"
|
||||
or custom_llm_provider == "deepseek"
|
||||
|
|
|
@ -4481,3 +4481,23 @@ async def test_dynamic_azure_params(stream, sync_mode):
|
|||
except Exception as e:
|
||||
traceback.print_stack()
|
||||
raise e
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_completion_ai21_chat():
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.acompletion(
|
||||
model="jamba-1.5-large",
|
||||
user="ishaan",
|
||||
tool_choice="auto",
|
||||
seed=123,
|
||||
messages=[{"role": "user", "content": "what does the document say"}],
|
||||
documents=[
|
||||
{
|
||||
"content": "hello world",
|
||||
"metadata": {"source": "google", "author": "ishaan"},
|
||||
}
|
||||
],
|
||||
)
|
||||
pass
|
||||
|
|
|
@ -68,3 +68,28 @@ def test_get_llm_provider_deepseek_custom_api_base():
|
|||
assert api_base == "MY-FAKE-BASE"
|
||||
|
||||
os.environ.pop("DEEPSEEK_API_BASE")
|
||||
|
||||
|
||||
def test_get_llm_provider_ai21_chat():
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
model="jamba-1.5-large",
|
||||
)
|
||||
assert custom_llm_provider == "ai21_chat"
|
||||
assert model == "jamba-1.5-large"
|
||||
assert api_base == "https://api.ai21.com/studio/v1"
|
||||
|
||||
|
||||
def test_get_llm_provider_ai21_chat_test2():
|
||||
"""
|
||||
if user prefix with ai21/ but calls jamba-1.5-large then it should be ai21_chat provider
|
||||
"""
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
model="ai21/jamba-1.5-large",
|
||||
)
|
||||
|
||||
print("model=", model)
|
||||
print("custom_llm_provider=", custom_llm_provider)
|
||||
print("api_base=", api_base)
|
||||
assert custom_llm_provider == "ai21_chat"
|
||||
assert model == "jamba-1.5-large"
|
||||
assert api_base == "https://api.ai21.com/studio/v1"
|
||||
|
|
|
@ -586,6 +586,37 @@ async def test_completion_predibase_streaming(sync_mode):
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_completion_ai21_stream():
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.acompletion(
|
||||
model="ai21_chat/jamba-1.5-large",
|
||||
user="ishaan",
|
||||
stream=True,
|
||||
seed=123,
|
||||
messages=[{"role": "user", "content": "hi my name is ishaan"}],
|
||||
)
|
||||
complete_response = ""
|
||||
idx = 0
|
||||
async for init_chunk in response:
|
||||
chunk, finished = streaming_format_tests(idx, init_chunk)
|
||||
complete_response += chunk
|
||||
custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"]
|
||||
print(f"custom_llm_provider: {custom_llm_provider}")
|
||||
assert custom_llm_provider == "ai21_chat"
|
||||
idx += 1
|
||||
if finished:
|
||||
assert isinstance(init_chunk.choices[0], litellm.utils.StreamingChoices)
|
||||
break
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
|
||||
print(f"complete_response: {complete_response}")
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def test_completion_azure_function_calling_stream():
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
|
|
|
@ -2887,6 +2887,7 @@ def get_optional_params(
|
|||
and custom_llm_provider != "groq"
|
||||
and custom_llm_provider != "nvidia_nim"
|
||||
and custom_llm_provider != "cerebras"
|
||||
and custom_llm_provider != "ai21_chat"
|
||||
and custom_llm_provider != "volcengine"
|
||||
and custom_llm_provider != "deepseek"
|
||||
and custom_llm_provider != "codestral"
|
||||
|
@ -3656,6 +3657,16 @@ def get_optional_params(
|
|||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
||||
elif custom_llm_provider == "ai21_chat":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.AI21ChatConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
|
@ -4283,6 +4294,8 @@ def get_supported_openai_params(
|
|||
return litellm.NvidiaNimConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "cerebras":
|
||||
return litellm.CerebrasConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "ai21_chat":
|
||||
return litellm.AI21ChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "volcengine":
|
||||
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "groq":
|
||||
|
@ -4671,6 +4684,7 @@ def get_llm_provider(
|
|||
):
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
if custom_llm_provider == "perplexity":
|
||||
# perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai
|
||||
api_base = api_base or get_secret("PERPLEXITY_API_BASE") or "https://api.perplexity.ai" # type: ignore
|
||||
|
@ -4717,6 +4731,16 @@ def get_llm_provider(
|
|||
or "https://api.cerebras.ai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("CEREBRAS_API_KEY")
|
||||
elif (custom_llm_provider == "ai21_chat") or (
|
||||
custom_llm_provider == "ai21" and model in litellm.ai21_chat_models
|
||||
):
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("AI21_API_BASE")
|
||||
or "https://api.ai21.com/studio/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("AI21_API_KEY")
|
||||
custom_llm_provider = "ai21_chat"
|
||||
elif custom_llm_provider == "volcengine":
|
||||
# volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||
api_base = (
|
||||
|
@ -4870,6 +4894,9 @@ def get_llm_provider(
|
|||
elif endpoint == "https://api.cerebras.ai/v1":
|
||||
custom_llm_provider = "cerebras"
|
||||
dynamic_api_key = get_secret("CEREBRAS_API_KEY")
|
||||
elif endpoint == "https://api.ai21.com/studio/v1":
|
||||
custom_llm_provider = "ai21_chat"
|
||||
dynamic_api_key = get_secret("AI21_API_KEY")
|
||||
elif endpoint == "https://codestral.mistral.ai/v1":
|
||||
custom_llm_provider = "codestral"
|
||||
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
|
||||
|
@ -4953,6 +4980,14 @@ def get_llm_provider(
|
|||
## ai21
|
||||
elif model in litellm.ai21_models:
|
||||
custom_llm_provider = "ai21"
|
||||
elif model in litellm.ai21_chat_models:
|
||||
custom_llm_provider = "ai21_chat"
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("AI21_API_BASE")
|
||||
or "https://api.ai21.com/studio/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("AI21_API_KEY")
|
||||
## aleph_alpha
|
||||
elif model in litellm.aleph_alpha_models:
|
||||
custom_llm_provider = "aleph_alpha"
|
||||
|
@ -5800,6 +5835,11 @@ def validate_environment(
|
|||
keys_in_environment = True
|
||||
else:
|
||||
missing_keys.append("CEREBRAS_API_KEY")
|
||||
elif custom_llm_provider == "ai21_chat":
|
||||
if "AI21_API_KEY" in os.environ:
|
||||
keys_in_environment = True
|
||||
else:
|
||||
missing_keys.append("AI21_API_KEY")
|
||||
elif custom_llm_provider == "volcengine":
|
||||
if "VOLCENGINE_API_KEY" in os.environ:
|
||||
keys_in_environment = True
|
||||
|
@ -6211,7 +6251,10 @@ def convert_to_model_response_object(
|
|||
if "model" in response_object:
|
||||
if model_response_object.model is None:
|
||||
model_response_object.model = response_object["model"]
|
||||
elif "/" in model_response_object.model:
|
||||
elif (
|
||||
"/" in model_response_object.model
|
||||
and response_object["model"] is not None
|
||||
):
|
||||
openai_compatible_provider = model_response_object.model.split("/")[
|
||||
0
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue