forked from phoenix/litellm-mirror
Merge pull request #5391 from BerriAI/litellm_add_ai21_support
[Feat] Add Vertex AI21 support
This commit is contained in:
commit
6ab8cbc105
10 changed files with 390 additions and 64 deletions
|
@ -983,6 +983,85 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
</Tabs>
|
||||
|
||||
|
||||
## AI21 Models
|
||||
|
||||
| Model Name | Function Call |
|
||||
|------------------|--------------------------------------|
|
||||
| jamba-1.5-mini@001 | `completion(model='vertex_ai/jamba-1.5-mini@001', messages)` |
|
||||
| jamba-1.5-large@001 | `completion(model='vertex_ai/jamba-1.5-large@001', messages)` |
|
||||
|
||||
### Usage
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ""
|
||||
|
||||
model = "meta/jamba-1.5-mini@001"
|
||||
|
||||
vertex_ai_project = "your-vertex-project" # can also set this as os.environ["VERTEXAI_PROJECT"]
|
||||
vertex_ai_location = "your-vertex-location" # can also set this as os.environ["VERTEXAI_LOCATION"]
|
||||
|
||||
response = completion(
|
||||
model="vertex_ai/" + model,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
vertex_ai_project=vertex_ai_project,
|
||||
vertex_ai_location=vertex_ai_location,
|
||||
)
|
||||
print("\nModel Response", response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
**1. Add to config**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: jamba-1.5-mini
|
||||
litellm_params:
|
||||
model: vertex_ai/jamba-1.5-mini@001
|
||||
vertex_ai_project: "my-test-project"
|
||||
vertex_ai_location: "us-east-1"
|
||||
- model_name: jamba-1.5-large
|
||||
litellm_params:
|
||||
model: vertex_ai/jamba-1.5-large@001
|
||||
vertex_ai_project: "my-test-project"
|
||||
vertex_ai_location: "us-west-1"
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING at http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "jamba-1.5-large",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
||||
### Usage - Codestral FIM
|
||||
|
||||
|
|
|
@ -859,9 +859,12 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
|
|||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
||||
VertexAIAnthropicConfig,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
|
||||
VertexAILlama3Config,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import (
|
||||
VertexAIAi21Config,
|
||||
)
|
||||
from .llms.sagemaker.sagemaker import SagemakerConfig
|
||||
from .llms.ollama import OllamaConfig
|
||||
from .llms.ollama_chat import OllamaChatConfig
|
||||
|
|
|
@ -48,6 +48,7 @@ def cost_router(
|
|||
"claude" in model
|
||||
or "llama" in model
|
||||
or "mistral" in model
|
||||
or "jamba" in model
|
||||
or "codestral" in model
|
||||
):
|
||||
return "cost_per_token"
|
||||
|
|
|
@ -363,6 +363,12 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
except Exception as e:
|
||||
raise DatabricksError(status_code=500, message=str(e))
|
||||
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=response_json,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
response = ModelResponse(**response_json)
|
||||
|
||||
if base_model is not None:
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
import types
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class VertexAIAi21Config:
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
|
||||
|
||||
The class `VertexAIAi21Config` provides configuration for the VertexAI's AI21 API interface
|
||||
|
||||
-> Supports all OpenAI parameters
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
):
|
||||
return litellm.OpenAIConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
|
@ -0,0 +1,59 @@
|
|||
import types
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class VertexAILlama3Config:
|
||||
"""
|
||||
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
|
||||
|
||||
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens,
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key == "max_tokens" and value is None:
|
||||
value = self.max_tokens
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
):
|
||||
return litellm.OpenAIConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
|
@ -1,6 +1,7 @@
|
|||
# What is this?
|
||||
## Handler for calling llama 3.1 API on Vertex AI
|
||||
## API Handler for calling Vertex AI Partner Models
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
|
@ -8,7 +9,13 @@ import httpx # type: ignore
|
|||
import litellm
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
from ..base import BaseLLM
|
||||
from ...base import BaseLLM
|
||||
|
||||
|
||||
class VertexPartnerProvider(str, Enum):
|
||||
mistralai = "mistralai"
|
||||
llama = "llama"
|
||||
ai21 = "ai21"
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -24,61 +31,6 @@ class VertexAIError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class VertexAILlama3Config:
|
||||
"""
|
||||
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
|
||||
|
||||
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens,
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key == "max_tokens" and value is None:
|
||||
value = self.max_tokens
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
):
|
||||
return litellm.OpenAIConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIPartnerModels(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
@ -87,17 +39,22 @@ class VertexAIPartnerModels(BaseLLM):
|
|||
self,
|
||||
vertex_location: str,
|
||||
vertex_project: str,
|
||||
partner: Literal["llama", "mistralai"],
|
||||
partner: VertexPartnerProvider,
|
||||
stream: Optional[bool],
|
||||
model: str,
|
||||
) -> str:
|
||||
if partner == "llama":
|
||||
if partner == VertexPartnerProvider.llama:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
|
||||
elif partner == "mistralai":
|
||||
elif partner == VertexPartnerProvider.mistralai:
|
||||
if stream:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
|
||||
elif partner == VertexPartnerProvider.ai21:
|
||||
if stream:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
|
||||
|
||||
def completion(
|
||||
self,
|
||||
|
@ -160,9 +117,12 @@ class VertexAIPartnerModels(BaseLLM):
|
|||
optional_params["stream"] = stream
|
||||
|
||||
if "llama" in model:
|
||||
partner = "llama"
|
||||
partner = VertexPartnerProvider.llama
|
||||
elif "mistral" in model or "codestral" in model:
|
||||
partner = "mistralai"
|
||||
partner = VertexPartnerProvider.mistralai
|
||||
optional_params["custom_endpoint"] = True
|
||||
elif "jamba" in model:
|
||||
partner = VertexPartnerProvider.ai21
|
||||
optional_params["custom_endpoint"] = True
|
||||
|
||||
api_base = self.create_vertex_url(
|
|
@ -126,7 +126,7 @@ from .llms.vertex_ai_and_google_ai_studio import (
|
|||
vertex_ai_anthropic,
|
||||
vertex_ai_non_gemini,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
||||
VertexAIPartnerModels,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||
|
@ -2080,6 +2080,7 @@ def completion(
|
|||
model.startswith("meta/")
|
||||
or model.startswith("mistral")
|
||||
or model.startswith("codestral")
|
||||
or model.startswith("jamba")
|
||||
):
|
||||
model_response = vertex_partner_models_chat_completion.completion(
|
||||
model=model,
|
||||
|
|
|
@ -2356,3 +2356,157 @@ async def test_gemini_context_caching_anthropic_format():
|
|||
|
||||
check_cache_mock.assert_called_once()
|
||||
assert mock_client.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partner_models_httpx_ai21():
|
||||
litellm.set_verbose = True
|
||||
model = "vertex_ai/jamba-1.5-mini@001"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, can you tell me the weather in San Francisco?",
|
||||
},
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"top_p": 0.5,
|
||||
}
|
||||
|
||||
mock_response = AsyncMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"id": "chat-3d11cf95eb224966937b216d9494fe73",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": " Sure, let me check that for you.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "b5cef16b-5946-4937-b9d5-beeaea871e77",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 158,
|
||||
"completion_tokens": 36,
|
||||
"total_tokens": 194,
|
||||
},
|
||||
"meta": {"requestDurationMillis": 501},
|
||||
"model": "jamba-1.5",
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = await litellm.acompletion(**data)
|
||||
|
||||
# Assert
|
||||
mock_post.assert_called_once()
|
||||
url, kwargs = mock_post.call_args
|
||||
print("url = ", url)
|
||||
print("call args = ", kwargs)
|
||||
|
||||
print(kwargs["data"])
|
||||
|
||||
assert (
|
||||
url[0]
|
||||
== "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/adroit-crow-413218/locations/us-central1/publishers/ai21/models/jamba-1.5-mini@001:rawPredict"
|
||||
)
|
||||
|
||||
# json loads kwargs
|
||||
kwargs["data"] = json.loads(kwargs["data"])
|
||||
|
||||
assert kwargs["data"] == {
|
||||
"model": "jamba-1.5-mini",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, can you tell me the weather in San Francisco?",
|
||||
},
|
||||
],
|
||||
"top_p": 0.5,
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
assert response.id == "chat-3d11cf95eb224966937b216d9494fe73"
|
||||
assert len(response.choices) == 1
|
||||
assert (
|
||||
response.choices[0].message.content == " Sure, let me check that for you."
|
||||
)
|
||||
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
|
||||
assert (
|
||||
response.choices[0].message.tool_calls[0].function.arguments
|
||||
== '{"location": "San Francisco"}'
|
||||
)
|
||||
assert response.usage.prompt_tokens == 158
|
||||
assert response.usage.completion_tokens == 36
|
||||
assert response.usage.total_tokens == 194
|
||||
|
||||
print(f"response: {response}")
|
||||
|
|
|
@ -3267,6 +3267,16 @@ def get_optional_params(
|
|||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai" and model in litellm.ai21_models:
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.VertexAIAi21Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue