Merge pull request #5391 from BerriAI/litellm_add_ai21_support

[Feat] Add Vertex  AI21 support
This commit is contained in:
Ishaan Jaff 2024-08-27 15:06:26 -07:00 committed by GitHub
commit 6ab8cbc105
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 390 additions and 64 deletions

View file

@ -983,6 +983,85 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
</Tabs> </Tabs>
## AI21 Models
| Model Name | Function Call |
|------------------|--------------------------------------|
| jamba-1.5-mini@001 | `completion(model='vertex_ai/jamba-1.5-mini@001', messages)` |
| jamba-1.5-large@001 | `completion(model='vertex_ai/jamba-1.5-large@001', messages)` |
### Usage
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ""
model = "meta/jamba-1.5-mini@001"
vertex_ai_project = "your-vertex-project" # can also set this as os.environ["VERTEXAI_PROJECT"]
vertex_ai_location = "your-vertex-location" # can also set this as os.environ["VERTEXAI_LOCATION"]
response = completion(
model="vertex_ai/" + model,
messages=[{"role": "user", "content": "hi"}],
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
)
print("\nModel Response", response)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**1. Add to config**
```yaml
model_list:
- model_name: jamba-1.5-mini
litellm_params:
model: vertex_ai/jamba-1.5-mini@001
vertex_ai_project: "my-test-project"
vertex_ai_location: "us-east-1"
- model_name: jamba-1.5-large
litellm_params:
model: vertex_ai/jamba-1.5-large@001
vertex_ai_project: "my-test-project"
vertex_ai_location: "us-west-1"
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING at http://0.0.0.0:4000
```
**3. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "jamba-1.5-large",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>
### Usage - Codestral FIM ### Usage - Codestral FIM

View file

@ -859,9 +859,12 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
VertexAIAnthropicConfig, VertexAIAnthropicConfig,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
VertexAILlama3Config, VertexAILlama3Config,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import (
VertexAIAi21Config,
)
from .llms.sagemaker.sagemaker import SagemakerConfig from .llms.sagemaker.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig

View file

@ -48,6 +48,7 @@ def cost_router(
"claude" in model "claude" in model
or "llama" in model or "llama" in model
or "mistral" in model or "mistral" in model
or "jamba" in model
or "codestral" in model or "codestral" in model
): ):
return "cost_per_token" return "cost_per_token"

View file

@ -363,6 +363,12 @@ class DatabricksChatCompletion(BaseLLM):
except Exception as e: except Exception as e:
raise DatabricksError(status_code=500, message=str(e)) raise DatabricksError(status_code=500, message=str(e))
logging_obj.post_call(
input=messages,
api_key="",
original_response=response_json,
additional_args={"complete_input_dict": data},
)
response = ModelResponse(**response_json) response = ModelResponse(**response_json)
if base_model is not None: if base_model is not None:

View file

@ -0,0 +1,53 @@
import types
from typing import Callable, Literal, Optional, Union
import litellm
class VertexAIAi21Config:
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
The class `VertexAIAi21Config` provides configuration for the VertexAI's AI21 API interface
-> Supports all OpenAI parameters
"""
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)

View file

@ -0,0 +1,59 @@
import types
from typing import Callable, Literal, Optional, Union
import litellm
class VertexAILlama3Config:
"""
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)

View file

@ -1,6 +1,7 @@
# What is this? # What is this?
## Handler for calling llama 3.1 API on Vertex AI ## API Handler for calling Vertex AI Partner Models
import types import types
from enum import Enum
from typing import Callable, Literal, Optional, Union from typing import Callable, Literal, Optional, Union
import httpx # type: ignore import httpx # type: ignore
@ -8,7 +9,13 @@ import httpx # type: ignore
import litellm import litellm
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
from ..base import BaseLLM from ...base import BaseLLM
class VertexPartnerProvider(str, Enum):
mistralai = "mistralai"
llama = "llama"
ai21 = "ai21"
class VertexAIError(Exception): class VertexAIError(Exception):
@ -24,61 +31,6 @@ class VertexAIError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class VertexAILlama3Config:
"""
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)
class VertexAIPartnerModels(BaseLLM): class VertexAIPartnerModels(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
pass pass
@ -87,17 +39,22 @@ class VertexAIPartnerModels(BaseLLM):
self, self,
vertex_location: str, vertex_location: str,
vertex_project: str, vertex_project: str,
partner: Literal["llama", "mistralai"], partner: VertexPartnerProvider,
stream: Optional[bool], stream: Optional[bool],
model: str, model: str,
) -> str: ) -> str:
if partner == "llama": if partner == VertexPartnerProvider.llama:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi" return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
elif partner == "mistralai": elif partner == VertexPartnerProvider.mistralai:
if stream: if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict" return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
else: else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict" return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.ai21:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
def completion( def completion(
self, self,
@ -160,9 +117,12 @@ class VertexAIPartnerModels(BaseLLM):
optional_params["stream"] = stream optional_params["stream"] = stream
if "llama" in model: if "llama" in model:
partner = "llama" partner = VertexPartnerProvider.llama
elif "mistral" in model or "codestral" in model: elif "mistral" in model or "codestral" in model:
partner = "mistralai" partner = VertexPartnerProvider.mistralai
optional_params["custom_endpoint"] = True
elif "jamba" in model:
partner = VertexPartnerProvider.ai21
optional_params["custom_endpoint"] = True optional_params["custom_endpoint"] = True
api_base = self.create_vertex_url( api_base = self.create_vertex_url(

View file

@ -126,7 +126,7 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic, vertex_ai_anthropic,
vertex_ai_non_gemini, vertex_ai_non_gemini,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels, VertexAIPartnerModels,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
@ -2080,6 +2080,7 @@ def completion(
model.startswith("meta/") model.startswith("meta/")
or model.startswith("mistral") or model.startswith("mistral")
or model.startswith("codestral") or model.startswith("codestral")
or model.startswith("jamba")
): ):
model_response = vertex_partner_models_chat_completion.completion( model_response = vertex_partner_models_chat_completion.completion(
model=model, model=model,

View file

@ -2356,3 +2356,157 @@ async def test_gemini_context_caching_anthropic_format():
check_cache_mock.assert_called_once() check_cache_mock.assert_called_once()
assert mock_client.call_count == 3 assert mock_client.call_count == 3
@pytest.mark.asyncio
async def test_partner_models_httpx_ai21():
litellm.set_verbose = True
model = "vertex_ai/jamba-1.5-mini@001"
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
{
"role": "user",
"content": "Hello, can you tell me the weather in San Francisco?",
},
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location"],
},
},
}
]
data = {
"model": model,
"messages": messages,
"tools": tools,
"top_p": 0.5,
}
mock_response = AsyncMock()
def return_val():
return {
"id": "chat-3d11cf95eb224966937b216d9494fe73",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": " Sure, let me check that for you.",
"tool_calls": [
{
"id": "b5cef16b-5946-4937-b9d5-beeaea871e77",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
},
}
],
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 158,
"completion_tokens": 36,
"total_tokens": 194,
},
"meta": {"requestDurationMillis": 501},
"model": "jamba-1.5",
}
mock_response.json = return_val
mock_response.status_code = 200
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.acompletion(**data)
# Assert
mock_post.assert_called_once()
url, kwargs = mock_post.call_args
print("url = ", url)
print("call args = ", kwargs)
print(kwargs["data"])
assert (
url[0]
== "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/adroit-crow-413218/locations/us-central1/publishers/ai21/models/jamba-1.5-mini@001:rawPredict"
)
# json loads kwargs
kwargs["data"] = json.loads(kwargs["data"])
assert kwargs["data"] == {
"model": "jamba-1.5-mini",
"messages": [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
{
"role": "user",
"content": "Hello, can you tell me the weather in San Francisco?",
},
],
"top_p": 0.5,
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location"],
},
},
}
],
"stream": False,
}
assert response.id == "chat-3d11cf95eb224966937b216d9494fe73"
assert len(response.choices) == 1
assert (
response.choices[0].message.content == " Sure, let me check that for you."
)
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
assert (
response.choices[0].message.tool_calls[0].function.arguments
== '{"location": "San Francisco"}'
)
assert response.usage.prompt_tokens == 158
assert response.usage.completion_tokens == 36
assert response.usage.total_tokens == 194
print(f"response: {response}")

View file

@ -3267,6 +3267,16 @@ def get_optional_params(
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
) )
elif custom_llm_provider == "vertex_ai" and model in litellm.ai21_models:
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VertexAIAi21Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(