Compare commits

...
Sign in to create a new pull request.

4 commits

Author SHA1 Message Date
Ishaan Jaff
969dad19a1 docs model garden 2024-11-15 12:54:40 -08:00
Ishaan Jaff
0b8aa778bf test_vertexai_model_garden_model_completion 2024-11-15 12:45:15 -08:00
Ishaan Jaff
34c1dc675a VertexAIModelGardenModels 2024-11-15 12:36:49 -08:00
Ishaan Jaff
8d28003099 add VertexAIModelGardenModels 2024-11-15 10:00:23 -08:00
4 changed files with 356 additions and 3 deletions

View file

@ -1161,12 +1161,96 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
## Model Garden
| Model Name | Function Call |
|------------------|--------------------------------------|
| llama2 | `completion('vertex_ai/<endpoint_id>', messages)` |
:::tip
All OpenAI compatible models from Vertex Model Garden are supported.
:::
#### Using Model Garden
**Almost all Vertex Model Garden models are OpenAI compatible.**
<Tabs>
<TabItem value="openai" label="OpenAI Compatible Models">
| Property | Details |
|----------|---------|
| Provider Route | `vertex_ai/openai/{MODEL_ID}` |
| Vertex Documentation | [Vertex Model Garden - OpenAI Chat Completions](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gradio_streaming_chat_completions.ipynb), [Vertex Model Garden](https://cloud.google.com/model-garden?hl=en) |
| Supported Operations | `/chat/completions`, `/embeddings` |
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
## set ENV variables
os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811"
os.environ["VERTEXAI_LOCATION"] = "us-central1"
response = completion(
model="vertex_ai/openai/<your-endpoint-id>",
messages=[{ "content": "Hello, how are you?","role": "user"}]
)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**1. Add to config**
```yaml
model_list:
- model_name: llama3-1-8b-instruct
litellm_params:
model: vertex_ai/openai/5464397967697903616
vertex_ai_project: "my-test-project"
vertex_ai_location: "us-east-1"
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING at http://0.0.0.0:4000
```
**3. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "llama3-1-8b-instruct", # 👈 the 'model_name' in config
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="non-openai" label="Non-OpenAI Compatible Models">
```python
from litellm import completion
import os
@ -1181,6 +1265,11 @@ response = completion(
)
```
</TabItem>
</Tabs>
## Gemini Pro
| Model Name | Function Call |
|------------------|--------------------------------------|

View file

@ -0,0 +1,156 @@
"""
API Handler for calling Vertex AI Model Garden Models
Most Vertex Model Garden Models are OpenAI compatible - so this handler calls `openai_like_chat_completions`
Usage:
response = litellm.completion(
model="vertex_ai/openai/5464397967697903616",
messages=[{"role": "user", "content": "Hello, how are you?"}],
)
Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}`
Vertex Documentation for using the OpenAI /chat/completions endpoint: https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_llama3_deployment.ipynb
"""
import types
from enum import Enum
from typing import Callable, Literal, Optional, Union
import httpx # type: ignore
import litellm
from litellm.utils import ModelResponse
from ..common_utils import VertexAIError
from ..vertex_llm_base import VertexBase
def create_vertex_url(
vertex_location: str,
vertex_project: str,
stream: Optional[bool],
model: str,
api_base: Optional[str] = None,
) -> str:
"""Return the base url for the vertex garden models"""
# f"https://{self.endpoint.location}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{self.endpoint.location}"
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}"
class VertexAIModelGardenModels(VertexBase):
def __init__(self) -> None:
pass
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
api_base: Optional[str],
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
logger_fn=None,
acompletion: bool = False,
client=None,
):
"""
Handles calling Vertex AI Model Garden Models in OpenAI compatible format
Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}`
"""
try:
import vertexai
from google.cloud import aiplatform
from litellm.llms.anthropic.chat import AnthropicChatCompletion
from litellm.llms.databricks.chat import DatabricksChatCompletion
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.llms.text_completion_codestral import CodestralTextCompletion
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
except Exception:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""",
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
model = model.replace("openai/", "")
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
openai_like_chat_completions = DatabricksChatCompletion()
## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False
optional_params["stream"] = stream
default_api_base = create_vertex_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
stream=stream,
model=model,
)
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=stream,
auth_header=None,
url=default_api_base,
)
model = ""
return openai_like_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
client=client,
timeout=timeout,
encoding=encoding,
custom_llm_provider="vertex_ai",
)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))

View file

@ -158,6 +158,9 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
VertexEmbedding,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import (
VertexAIModelGardenModels,
)
from .llms.watsonx.chat.handler import WatsonXChatHandler
from .llms.watsonx.completion.handler import IBMWatsonXAI
from .types.llms.openai import (
@ -221,6 +224,7 @@ vertex_multimodal_embedding = VertexMultimodalEmbedding()
vertex_image_generation = VertexImageGeneration()
google_batch_embeddings = GoogleBatchEmbeddings()
vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_model_garden_chat_completion = VertexAIModelGardenModels()
vertex_text_to_speech = VertexTextToSpeechAPI()
watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM()
@ -2355,6 +2359,28 @@ def completion( # type: ignore # noqa: PLR0915
api_base=api_base,
extra_headers=extra_headers,
)
elif "openai" in model:
# Vertex Model Garden - OpenAI compatible models
model_response = vertex_model_garden_chat_completion.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
api_base=api_base,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
client=client,
)
else:
model_response = vertex_ai_non_gemini.completion(
model=model,

View file

@ -3123,3 +3123,85 @@ async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
assert isinstance(embedding["embedding"], list)
assert len(embedding["embedding"]) > 0
assert all(isinstance(x, float) for x in embedding["embedding"])
@pytest.mark.asyncio
@pytest.mark.respx
async def test_vertexai_model_garden_model_completion(respx_mock: MockRouter):
"""
Relevant issue: https://github.com/BerriAI/litellm/issues/6480
Using OpenAI compatible models from Vertex Model Garden
"""
load_vertex_ai_credentials()
litellm.set_verbose = True
# Test input
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
{
"role": "user",
"content": "Hello, what is your name and can you tell me the weather?",
},
]
# Expected request/response
expected_url = "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/633608382793/locations/us-central1/endpoints/5464397967697903616/chat/completions"
expected_request = {"model": "", "messages": messages, "stream": False}
mock_response = {
"id": "chat-09940d4e99e3488aa52a6f5e2ecf35b1",
"object": "chat.completion",
"created": 1731702782,
"model": "meta-llama/Llama-3.1-8B-Instruct",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello, my name is Litellm Bot. I'm a helpful assistant here to provide information and answer your questions.\n\nTo check the weather for you, I'll need to know your location. Could you please provide me with your city or zip code? That way, I can give you the most accurate and up-to-date weather information.\n\nIf you don't have your location handy, I can also suggest some popular weather websites or apps that you can use to check the weather for your area.\n\nLet me know how I can assist you!",
"tool_calls": [],
},
"logprobs": None,
"finish_reason": "stop",
"stop_reason": None,
}
],
"usage": {"prompt_tokens": 63, "total_tokens": 172, "completion_tokens": 109},
"prompt_logprobs": None,
}
# Setup mock request
mock_request = respx_mock.post(expected_url).mock(
return_value=httpx.Response(200, json=mock_response)
)
# Make request
response = await litellm.acompletion(
model="vertex_ai/openai/5464397967697903616",
messages=messages,
vertex_project="633608382793",
vertex_location="us-central1",
)
# Assert request was made correctly
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
assert request_body == expected_request
# Assert response structure
assert response.id == "chat-09940d4e99e3488aa52a6f5e2ecf35b1"
assert response.created == 1731702782
assert response.model == "vertex_ai/meta-llama/Llama-3.1-8B-Instruct"
assert len(response.choices) == 1
assert response.choices[0].message.role == "assistant"
assert response.choices[0].message.content.startswith(
"Hello, my name is Litellm Bot"
)
assert response.choices[0].finish_reason == "stop"
assert response.usage.completion_tokens == 109
assert response.usage.prompt_tokens == 63
assert response.usage.total_tokens == 172