mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(Feat) Add Vertex Model Garden llama 3.1 models (#6763)
* add VertexAIModelGardenModels * VertexAIModelGardenModels * test_vertexai_model_garden_model_completion * docs model garden
This commit is contained in:
parent
0f7ea14992
commit
9ba8f40bd1
4 changed files with 356 additions and 3 deletions
|
@ -1161,12 +1161,96 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
|
||||||
|
|
||||||
## Model Garden
|
## Model Garden
|
||||||
| Model Name | Function Call |
|
|
||||||
|------------------|--------------------------------------|
|
:::tip
|
||||||
| llama2 | `completion('vertex_ai/<endpoint_id>', messages)` |
|
|
||||||
|
All OpenAI compatible models from Vertex Model Garden are supported.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
#### Using Model Garden
|
#### Using Model Garden
|
||||||
|
|
||||||
|
**Almost all Vertex Model Garden models are OpenAI compatible.**
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Compatible Models">
|
||||||
|
|
||||||
|
| Property | Details |
|
||||||
|
|----------|---------|
|
||||||
|
| Provider Route | `vertex_ai/openai/{MODEL_ID}` |
|
||||||
|
| Vertex Documentation | [Vertex Model Garden - OpenAI Chat Completions](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gradio_streaming_chat_completions.ipynb), [Vertex Model Garden](https://cloud.google.com/model-garden?hl=en) |
|
||||||
|
| Supported Operations | `/chat/completions`, `/embeddings` |
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811"
|
||||||
|
os.environ["VERTEXAI_LOCATION"] = "us-central1"
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai/openai/<your-endpoint-id>",
|
||||||
|
messages=[{ "content": "Hello, how are you?","role": "user"}]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="proxy" label="Proxy">
|
||||||
|
|
||||||
|
|
||||||
|
**1. Add to config**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: llama3-1-8b-instruct
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/openai/5464397967697903616
|
||||||
|
vertex_ai_project: "my-test-project"
|
||||||
|
vertex_ai_location: "us-east-1"
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Start proxy**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
|
||||||
|
# RUNNING at http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Test it!**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "llama3-1-8b-instruct", # 👈 the 'model_name' in config
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="non-openai" label="Non-OpenAI Compatible Models">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
import os
|
import os
|
||||||
|
@ -1181,6 +1265,11 @@ response = completion(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
## Gemini Pro
|
## Gemini Pro
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|------------------|--------------------------------------|
|
|------------------|--------------------------------------|
|
||||||
|
|
|
@ -0,0 +1,156 @@
|
||||||
|
"""
|
||||||
|
API Handler for calling Vertex AI Model Garden Models
|
||||||
|
|
||||||
|
Most Vertex Model Garden Models are OpenAI compatible - so this handler calls `openai_like_chat_completions`
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
response = litellm.completion(
|
||||||
|
model="vertex_ai/openai/5464397967697903616",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}`
|
||||||
|
|
||||||
|
|
||||||
|
Vertex Documentation for using the OpenAI /chat/completions endpoint: https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_llama3_deployment.ipynb
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Callable, Literal, Optional, Union
|
||||||
|
|
||||||
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.utils import ModelResponse
|
||||||
|
|
||||||
|
from ..common_utils import VertexAIError
|
||||||
|
from ..vertex_llm_base import VertexBase
|
||||||
|
|
||||||
|
|
||||||
|
def create_vertex_url(
|
||||||
|
vertex_location: str,
|
||||||
|
vertex_project: str,
|
||||||
|
stream: Optional[bool],
|
||||||
|
model: str,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Return the base url for the vertex garden models"""
|
||||||
|
# f"https://{self.endpoint.location}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{self.endpoint.location}"
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIModelGardenModels(VertexBase):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
api_base: Optional[str],
|
||||||
|
optional_params: dict,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
headers: Optional[dict],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
litellm_params: dict,
|
||||||
|
vertex_project=None,
|
||||||
|
vertex_location=None,
|
||||||
|
vertex_credentials=None,
|
||||||
|
logger_fn=None,
|
||||||
|
acompletion: bool = False,
|
||||||
|
client=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Handles calling Vertex AI Model Garden Models in OpenAI compatible format
|
||||||
|
|
||||||
|
Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}`
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import vertexai
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
|
from litellm.llms.anthropic.chat import AnthropicChatCompletion
|
||||||
|
from litellm.llms.databricks.chat import DatabricksChatCompletion
|
||||||
|
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
||||||
|
from litellm.llms.text_completion_codestral import CodestralTextCompletion
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
|
VertexLLM,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
|
||||||
|
raise VertexAIError(
|
||||||
|
status_code=400,
|
||||||
|
message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (
|
||||||
|
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
|
||||||
|
):
|
||||||
|
raise VertexAIError(
|
||||||
|
status_code=400,
|
||||||
|
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
model = model.replace("openai/", "")
|
||||||
|
vertex_httpx_logic = VertexLLM()
|
||||||
|
|
||||||
|
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
||||||
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_like_chat_completions = DatabricksChatCompletion()
|
||||||
|
|
||||||
|
## CONSTRUCT API BASE
|
||||||
|
stream: bool = optional_params.get("stream", False) or False
|
||||||
|
optional_params["stream"] = stream
|
||||||
|
default_api_base = create_vertex_url(
|
||||||
|
vertex_location=vertex_location or "us-central1",
|
||||||
|
vertex_project=vertex_project or project_id,
|
||||||
|
stream=stream,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(default_api_base.split(":")) > 1:
|
||||||
|
endpoint = default_api_base.split(":")[-1]
|
||||||
|
else:
|
||||||
|
endpoint = ""
|
||||||
|
|
||||||
|
_, api_base = self._check_custom_proxy(
|
||||||
|
api_base=api_base,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
gemini_api_key=None,
|
||||||
|
endpoint=endpoint,
|
||||||
|
stream=stream,
|
||||||
|
auth_header=None,
|
||||||
|
url=default_api_base,
|
||||||
|
)
|
||||||
|
model = ""
|
||||||
|
return openai_like_chat_completions.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=access_token,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
acompletion=acompletion,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
client=client,
|
||||||
|
timeout=timeout,
|
||||||
|
encoding=encoding,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(status_code=500, message=str(e))
|
|
@ -158,6 +158,9 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
|
||||||
VertexEmbedding,
|
VertexEmbedding,
|
||||||
)
|
)
|
||||||
|
from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import (
|
||||||
|
VertexAIModelGardenModels,
|
||||||
|
)
|
||||||
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
||||||
from .llms.watsonx.completion.handler import IBMWatsonXAI
|
from .llms.watsonx.completion.handler import IBMWatsonXAI
|
||||||
from .types.llms.openai import (
|
from .types.llms.openai import (
|
||||||
|
@ -221,6 +224,7 @@ vertex_multimodal_embedding = VertexMultimodalEmbedding()
|
||||||
vertex_image_generation = VertexImageGeneration()
|
vertex_image_generation = VertexImageGeneration()
|
||||||
google_batch_embeddings = GoogleBatchEmbeddings()
|
google_batch_embeddings = GoogleBatchEmbeddings()
|
||||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||||
|
vertex_model_garden_chat_completion = VertexAIModelGardenModels()
|
||||||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||||
watsonxai = IBMWatsonXAI()
|
watsonxai = IBMWatsonXAI()
|
||||||
sagemaker_llm = SagemakerLLM()
|
sagemaker_llm = SagemakerLLM()
|
||||||
|
@ -2355,6 +2359,28 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
)
|
)
|
||||||
|
elif "openai" in model:
|
||||||
|
# Vertex Model Garden - OpenAI compatible models
|
||||||
|
model_response = vertex_model_garden_chat_completion.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=new_params,
|
||||||
|
litellm_params=litellm_params, # type: ignore
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
api_base=api_base,
|
||||||
|
vertex_location=vertex_ai_location,
|
||||||
|
vertex_project=vertex_ai_project,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
logging_obj=logging,
|
||||||
|
acompletion=acompletion,
|
||||||
|
headers=headers,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model_response = vertex_ai_non_gemini.completion(
|
model_response = vertex_ai_non_gemini.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -3123,3 +3123,85 @@ async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
|
||||||
assert isinstance(embedding["embedding"], list)
|
assert isinstance(embedding["embedding"], list)
|
||||||
assert len(embedding["embedding"]) > 0
|
assert len(embedding["embedding"]) > 0
|
||||||
assert all(isinstance(x, float) for x in embedding["embedding"])
|
assert all(isinstance(x, float) for x in embedding["embedding"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.respx
|
||||||
|
async def test_vertexai_model_garden_model_completion(respx_mock: MockRouter):
|
||||||
|
"""
|
||||||
|
Relevant issue: https://github.com/BerriAI/litellm/issues/6480
|
||||||
|
|
||||||
|
Using OpenAI compatible models from Vertex Model Garden
|
||||||
|
"""
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
# Test input
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, what is your name and can you tell me the weather?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Expected request/response
|
||||||
|
expected_url = "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/633608382793/locations/us-central1/endpoints/5464397967697903616/chat/completions"
|
||||||
|
expected_request = {"model": "", "messages": messages, "stream": False}
|
||||||
|
|
||||||
|
mock_response = {
|
||||||
|
"id": "chat-09940d4e99e3488aa52a6f5e2ecf35b1",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1731702782,
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello, my name is Litellm Bot. I'm a helpful assistant here to provide information and answer your questions.\n\nTo check the weather for you, I'll need to know your location. Could you please provide me with your city or zip code? That way, I can give you the most accurate and up-to-date weather information.\n\nIf you don't have your location handy, I can also suggest some popular weather websites or apps that you can use to check the weather for your area.\n\nLet me know how I can assist you!",
|
||||||
|
"tool_calls": [],
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"stop_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 63, "total_tokens": 172, "completion_tokens": 109},
|
||||||
|
"prompt_logprobs": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = respx_mock.post(expected_url).mock(
|
||||||
|
return_value=httpx.Response(200, json=mock_response)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make request
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="vertex_ai/openai/5464397967697903616",
|
||||||
|
messages=messages,
|
||||||
|
vertex_project="633608382793",
|
||||||
|
vertex_location="us-central1",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert request was made correctly
|
||||||
|
assert mock_request.called
|
||||||
|
request_body = json.loads(mock_request.calls[0].request.content)
|
||||||
|
assert request_body == expected_request
|
||||||
|
|
||||||
|
# Assert response structure
|
||||||
|
assert response.id == "chat-09940d4e99e3488aa52a6f5e2ecf35b1"
|
||||||
|
assert response.created == 1731702782
|
||||||
|
assert response.model == "vertex_ai/meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
assert len(response.choices) == 1
|
||||||
|
assert response.choices[0].message.role == "assistant"
|
||||||
|
assert response.choices[0].message.content.startswith(
|
||||||
|
"Hello, my name is Litellm Bot"
|
||||||
|
)
|
||||||
|
assert response.choices[0].finish_reason == "stop"
|
||||||
|
assert response.usage.completion_tokens == 109
|
||||||
|
assert response.usage.prompt_tokens == 63
|
||||||
|
assert response.usage.total_tokens == 172
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue