Merge pull request #5391 from BerriAI/litellm_add_ai21_support

[Feat] Add Vertex  AI21 support
This commit is contained in:
Ishaan Jaff 2024-08-27 15:06:26 -07:00 committed by GitHub
commit 6ab8cbc105
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 390 additions and 64 deletions

View file

@ -983,6 +983,85 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
</Tabs>
## AI21 Models
| Model Name | Function Call |
|------------------|--------------------------------------|
| jamba-1.5-mini@001 | `completion(model='vertex_ai/jamba-1.5-mini@001', messages)` |
| jamba-1.5-large@001 | `completion(model='vertex_ai/jamba-1.5-large@001', messages)` |
### Usage
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ""
model = "meta/jamba-1.5-mini@001"
vertex_ai_project = "your-vertex-project" # can also set this as os.environ["VERTEXAI_PROJECT"]
vertex_ai_location = "your-vertex-location" # can also set this as os.environ["VERTEXAI_LOCATION"]
response = completion(
model="vertex_ai/" + model,
messages=[{"role": "user", "content": "hi"}],
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
)
print("\nModel Response", response)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**1. Add to config**
```yaml
model_list:
- model_name: jamba-1.5-mini
litellm_params:
model: vertex_ai/jamba-1.5-mini@001
vertex_ai_project: "my-test-project"
vertex_ai_location: "us-east-1"
- model_name: jamba-1.5-large
litellm_params:
model: vertex_ai/jamba-1.5-large@001
vertex_ai_project: "my-test-project"
vertex_ai_location: "us-west-1"
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING at http://0.0.0.0:4000
```
**3. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "jamba-1.5-large",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>
### Usage - Codestral FIM

View file

@ -859,9 +859,12 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
VertexAIAnthropicConfig,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
VertexAILlama3Config,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import (
VertexAIAi21Config,
)
from .llms.sagemaker.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig

View file

@ -48,6 +48,7 @@ def cost_router(
"claude" in model
or "llama" in model
or "mistral" in model
or "jamba" in model
or "codestral" in model
):
return "cost_per_token"

View file

@ -363,6 +363,12 @@ class DatabricksChatCompletion(BaseLLM):
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
logging_obj.post_call(
input=messages,
api_key="",
original_response=response_json,
additional_args={"complete_input_dict": data},
)
response = ModelResponse(**response_json)
if base_model is not None:

View file

@ -0,0 +1,53 @@
import types
from typing import Callable, Literal, Optional, Union
import litellm
class VertexAIAi21Config:
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
The class `VertexAIAi21Config` provides configuration for the VertexAI's AI21 API interface
-> Supports all OpenAI parameters
"""
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)

View file

@ -0,0 +1,59 @@
import types
from typing import Callable, Literal, Optional, Union
import litellm
class VertexAILlama3Config:
"""
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)

View file

@ -1,6 +1,7 @@
# What is this?
## Handler for calling llama 3.1 API on Vertex AI
## API Handler for calling Vertex AI Partner Models
import types
from enum import Enum
from typing import Callable, Literal, Optional, Union
import httpx # type: ignore
@ -8,7 +9,13 @@ import httpx # type: ignore
import litellm
from litellm.utils import ModelResponse
from ..base import BaseLLM
from ...base import BaseLLM
class VertexPartnerProvider(str, Enum):
mistralai = "mistralai"
llama = "llama"
ai21 = "ai21"
class VertexAIError(Exception):
@ -24,61 +31,6 @@ class VertexAIError(Exception):
) # Call the base class constructor with the parameters it needs
class VertexAILlama3Config:
"""
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)
class VertexAIPartnerModels(BaseLLM):
def __init__(self) -> None:
pass
@ -87,17 +39,22 @@ class VertexAIPartnerModels(BaseLLM):
self,
vertex_location: str,
vertex_project: str,
partner: Literal["llama", "mistralai"],
partner: VertexPartnerProvider,
stream: Optional[bool],
model: str,
) -> str:
if partner == "llama":
if partner == VertexPartnerProvider.llama:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
elif partner == "mistralai":
elif partner == VertexPartnerProvider.mistralai:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.ai21:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
def completion(
self,
@ -160,9 +117,12 @@ class VertexAIPartnerModels(BaseLLM):
optional_params["stream"] = stream
if "llama" in model:
partner = "llama"
partner = VertexPartnerProvider.llama
elif "mistral" in model or "codestral" in model:
partner = "mistralai"
partner = VertexPartnerProvider.mistralai
optional_params["custom_endpoint"] = True
elif "jamba" in model:
partner = VertexPartnerProvider.ai21
optional_params["custom_endpoint"] = True
api_base = self.create_vertex_url(

View file

@ -126,7 +126,7 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic,
vertex_ai_non_gemini,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
@ -2080,6 +2080,7 @@ def completion(
model.startswith("meta/")
or model.startswith("mistral")
or model.startswith("codestral")
or model.startswith("jamba")
):
model_response = vertex_partner_models_chat_completion.completion(
model=model,

View file

@ -2356,3 +2356,157 @@ async def test_gemini_context_caching_anthropic_format():
check_cache_mock.assert_called_once()
assert mock_client.call_count == 3
@pytest.mark.asyncio
async def test_partner_models_httpx_ai21():
litellm.set_verbose = True
model = "vertex_ai/jamba-1.5-mini@001"
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
{
"role": "user",
"content": "Hello, can you tell me the weather in San Francisco?",
},
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location"],
},
},
}
]
data = {
"model": model,
"messages": messages,
"tools": tools,
"top_p": 0.5,
}
mock_response = AsyncMock()
def return_val():
return {
"id": "chat-3d11cf95eb224966937b216d9494fe73",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": " Sure, let me check that for you.",
"tool_calls": [
{
"id": "b5cef16b-5946-4937-b9d5-beeaea871e77",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
},
}
],
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 158,
"completion_tokens": 36,
"total_tokens": 194,
},
"meta": {"requestDurationMillis": 501},
"model": "jamba-1.5",
}
mock_response.json = return_val
mock_response.status_code = 200
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.acompletion(**data)
# Assert
mock_post.assert_called_once()
url, kwargs = mock_post.call_args
print("url = ", url)
print("call args = ", kwargs)
print(kwargs["data"])
assert (
url[0]
== "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/adroit-crow-413218/locations/us-central1/publishers/ai21/models/jamba-1.5-mini@001:rawPredict"
)
# json loads kwargs
kwargs["data"] = json.loads(kwargs["data"])
assert kwargs["data"] == {
"model": "jamba-1.5-mini",
"messages": [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
{
"role": "user",
"content": "Hello, can you tell me the weather in San Francisco?",
},
],
"top_p": 0.5,
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location"],
},
},
}
],
"stream": False,
}
assert response.id == "chat-3d11cf95eb224966937b216d9494fe73"
assert len(response.choices) == 1
assert (
response.choices[0].message.content == " Sure, let me check that for you."
)
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
assert (
response.choices[0].message.tool_calls[0].function.arguments
== '{"location": "San Francisco"}'
)
assert response.usage.prompt_tokens == 158
assert response.usage.completion_tokens == 36
assert response.usage.total_tokens == 194
print(f"response: {response}")

View file

@ -3267,6 +3267,16 @@ def get_optional_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "vertex_ai" and model in litellm.ai21_models:
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VertexAIAi21Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = get_supported_openai_params(