(feat) Vertex AI - add support for fine tuned embedding models (#6749)

* fix use fine tuned vertex embedding models

* test_vertex_embedding_url

* add _transform_openai_request_to_fine_tuned_embedding_request

* add _transform_openai_request_to_fine_tuned_embedding_request

* add transform_openai_request_to_vertex_embedding_request

* add _transform_vertex_response_to_openai_for_fine_tuned_models

* test_vertexai_embedding for ft models

* fix test_vertexai_embedding_finetuned

* doc fine tuned / custom embedding models

* fix test test_partner_models_httpx
This commit is contained in:
Ishaan Jaff 2024-11-14 20:37:55 -08:00 committed by GitHub
parent c03351328f
commit c119bad5f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 261 additions and 5 deletions

View file

@ -1562,6 +1562,10 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## **Embedding Models**
#### Usage - Embedding
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import litellm
from litellm import embedding
@ -1574,6 +1578,49 @@ response = embedding(
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="LiteLLM PROXY">
1. Add model to config.yaml
```yaml
model_list:
- model_name: snowflake-arctic-embed-m-long-1731622468876
litellm_params:
model: vertex_ai/<your-model-id>
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
litellm_settings:
drop_params: True
```
2. Start Proxy
```
$ litellm --config /path/to/config.yaml
```
3. Make Request using OpenAI Python SDK, Langchain Python SDK
```python
import openai
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
response = client.embeddings.create(
model="snowflake-arctic-embed-m-long-1731622468876",
input = ["good morning from litellm", "this is another item"],
)
print(response)
```
</TabItem>
</Tabs>
#### Supported Embedding Models
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
@ -1589,6 +1636,7 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
| Fine-tuned OR Custom Embedding models | `embedding(model="vertex_ai/<your-model-id>", input)` |
### Supported OpenAI (Unified) Params

View file

@ -89,6 +89,9 @@ def _get_vertex_url(
elif mode == "embedding":
endpoint = "predict"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if model.isdigit():
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if not url or not endpoint:
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")

View file

@ -96,7 +96,7 @@ class VertexEmbedding(VertexBase):
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = (
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
input=input, optional_params=optional_params
input=input, optional_params=optional_params, model=model
)
)
@ -188,7 +188,7 @@ class VertexEmbedding(VertexBase):
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = (
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
input=input, optional_params=optional_params
input=input, optional_params=optional_params, model=model
)
)

View file

@ -101,11 +101,16 @@ class VertexAITextEmbeddingConfig(BaseModel):
return optional_params
def transform_openai_request_to_vertex_embedding_request(
self, input: Union[list, str], optional_params: dict
self, input: Union[list, str], optional_params: dict, model: str
) -> VertexEmbeddingRequest:
"""
Transforms an openai request to a vertex embedding request.
"""
if model.isdigit():
return self._transform_openai_request_to_fine_tuned_embedding_request(
input, optional_params, model
)
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingInput] = []
task_type: Optional[TaskType] = optional_params.get("task_type")
@ -125,6 +130,47 @@ class VertexAITextEmbeddingConfig(BaseModel):
return vertex_request
def _transform_openai_request_to_fine_tuned_embedding_request(
self, input: Union[list, str], optional_params: dict, model: str
) -> VertexEmbeddingRequest:
"""
Transforms an openai request to a vertex fine-tuned embedding request.
Vertex Doc: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
Sample Request:
```json
{
"instances" : [
{
"inputs": "How would the Future of AI in 10 Years look?",
"parameters": {
"max_new_tokens": 128,
"temperature": 1.0,
"top_p": 0.9,
"top_k": 10
}
}
]
}
```
"""
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingFineTunedInput] = []
if isinstance(input, str):
input = [input] # Convert single string to list for uniform processing
for text in input:
embedding_input = TextEmbeddingFineTunedInput(inputs=text)
vertex_text_embedding_input_list.append(embedding_input)
vertex_request["instances"] = vertex_text_embedding_input_list
vertex_request["parameters"] = TextEmbeddingFineTunedParameters(
**optional_params
)
return vertex_request
def create_embedding_input(
self,
content: str,
@ -157,6 +203,11 @@ class VertexAITextEmbeddingConfig(BaseModel):
"""
Transforms a vertex embedding response to an openai response.
"""
if model.isdigit():
return self._transform_vertex_response_to_openai_for_fine_tuned_models(
response, model, model_response
)
_predictions = response["predictions"]
embedding_response = []
@ -181,3 +232,35 @@ class VertexAITextEmbeddingConfig(BaseModel):
)
setattr(model_response, "usage", usage)
return model_response
def _transform_vertex_response_to_openai_for_fine_tuned_models(
self, response: dict, model: str, model_response: litellm.EmbeddingResponse
) -> litellm.EmbeddingResponse:
"""
Transforms a vertex fine-tuned model embedding response to an openai response format.
"""
_predictions = response["predictions"]
embedding_response = []
# For fine-tuned models, we don't get token counts in the response
input_tokens = 0
for idx, embedding_values in enumerate(_predictions):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding_values[
0
], # The embedding values are nested one level deeper
}
)
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response

View file

@ -23,14 +23,27 @@ class TextEmbeddingInput(TypedDict, total=False):
title: Optional[str]
# Fine-tuned models require a different input format
# Ref: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
class TextEmbeddingFineTunedInput(TypedDict, total=False):
inputs: str
class TextEmbeddingFineTunedParameters(TypedDict, total=False):
max_new_tokens: Optional[int]
temperature: Optional[float]
top_p: Optional[float]
top_k: Optional[int]
class EmbeddingParameters(TypedDict, total=False):
auto_truncate: Optional[bool]
output_dimensionality: Optional[int]
class VertexEmbeddingRequest(TypedDict, total=False):
instances: List[TextEmbeddingInput]
parameters: Optional[EmbeddingParameters]
instances: Union[List[TextEmbeddingInput], List[TextEmbeddingFineTunedInput]]
parameters: Optional[Union[EmbeddingParameters, TextEmbeddingFineTunedParameters]]
# Example usage:

View file

@ -16,6 +16,7 @@ import pytest
import litellm
from litellm import get_optional_params
from litellm.llms.custom_httpx.http_handler import HTTPHandler
import httpx
def test_completion_pydantic_obj_2():
@ -1317,3 +1318,39 @@ def test_image_completion_request(image_url):
mock_post.assert_called_once()
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
assert mock_post.call_args.kwargs["json"] == expected_request_body
@pytest.mark.parametrize(
"model, expected_url",
[
(
"textembedding-gecko@001",
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict",
),
(
"123456789",
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/endpoints/123456789:predict",
),
],
)
def test_vertex_embedding_url(model, expected_url):
"""
Test URL generation for embedding models, including numeric model IDs (fine-tuned models
Relevant issue: https://github.com/BerriAI/litellm/issues/6482
When a fine-tuned embedding model is used, the URL is different from the standard one.
"""
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import _get_vertex_url
url, endpoint = _get_vertex_url(
mode="embedding",
model=model,
stream=False,
vertex_project="project-id",
vertex_location="us-central1",
vertex_api_version="v1",
)
assert url == expected_url
assert endpoint == "predict"

View file

@ -18,6 +18,8 @@ import json
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
from respx import MockRouter
import httpx
import pytest
@ -973,6 +975,7 @@ async def test_partner_models_httpx(model, sync_mode):
data = {
"model": model,
"messages": messages,
"timeout": 10,
}
if sync_mode:
response = litellm.completion(**data)
@ -986,6 +989,8 @@ async def test_partner_models_httpx(model, sync_mode):
assert isinstance(response._hidden_params["response_cost"], float)
except litellm.RateLimitError as e:
pass
except litellm.Timeout as e:
pass
except litellm.InternalServerError as e:
pass
except Exception as e:
@ -3051,3 +3056,70 @@ def test_custom_api_base(api_base):
assert url == api_base + ":"
else:
assert url == test_endpoint
@pytest.mark.asyncio
@pytest.mark.respx
async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
"""
Tests that:
- Request URL and body are correctly formatted for Vertex AI embeddings
- Response is properly parsed into litellm's embedding response format
"""
load_vertex_ai_credentials()
litellm.set_verbose = True
# Test input
input_text = ["good morning from litellm", "this is another item"]
# Expected request/response
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict"
expected_request = {
"instances": [
{"inputs": "good morning from litellm"},
{"inputs": "this is another item"},
],
"parameters": {},
}
mock_response = {
"predictions": [
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
],
"deployedModelId": "2275167734310371328",
"model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876",
"modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876",
"modelVersionId": "1",
}
# Setup mock request
mock_request = respx_mock.post(expected_url).mock(
return_value=httpx.Response(200, json=mock_response)
)
# Make request
response = await litellm.aembedding(
vertex_project="633608382793",
model="vertex_ai/1004708436694269952",
input=input_text,
)
# Assert request was made correctly
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
print("\n\nrequest_body", request_body)
print("\n\nexpected_request", expected_request)
assert request_body == expected_request
# Assert response structure
assert response is not None
assert hasattr(response, "data")
assert len(response.data) == len(input_text)
# Assert embedding structure
for embedding in response.data:
assert "embedding" in embedding
assert isinstance(embedding["embedding"], list)
assert len(embedding["embedding"]) > 0
assert all(isinstance(x, float) for x in embedding["embedding"])