(feat) Vertex AI - add support for fine tuned embedding models (#6749)

* fix use fine tuned vertex embedding models

* test_vertex_embedding_url

* add _transform_openai_request_to_fine_tuned_embedding_request

* add _transform_openai_request_to_fine_tuned_embedding_request

* add transform_openai_request_to_vertex_embedding_request

* add _transform_vertex_response_to_openai_for_fine_tuned_models

* test_vertexai_embedding for ft models

* fix test_vertexai_embedding_finetuned

* doc fine tuned / custom embedding models

* fix test test_partner_models_httpx
This commit is contained in:
Ishaan Jaff 2024-11-14 20:37:55 -08:00 committed by GitHub
parent c03351328f
commit c119bad5f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 261 additions and 5 deletions

View file

@ -1562,6 +1562,10 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## **Embedding Models** ## **Embedding Models**
#### Usage - Embedding #### Usage - Embedding
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import litellm import litellm
from litellm import embedding from litellm import embedding
@ -1574,6 +1578,49 @@ response = embedding(
) )
print(response) print(response)
``` ```
</TabItem>
<TabItem value="proxy" label="LiteLLM PROXY">
1. Add model to config.yaml
```yaml
model_list:
- model_name: snowflake-arctic-embed-m-long-1731622468876
litellm_params:
model: vertex_ai/<your-model-id>
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
litellm_settings:
drop_params: True
```
2. Start Proxy
```
$ litellm --config /path/to/config.yaml
```
3. Make Request using OpenAI Python SDK, Langchain Python SDK
```python
import openai
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
response = client.embeddings.create(
model="snowflake-arctic-embed-m-long-1731622468876",
input = ["good morning from litellm", "this is another item"],
)
print(response)
```
</TabItem>
</Tabs>
#### Supported Embedding Models #### Supported Embedding Models
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
@ -1589,6 +1636,7 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` | | textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` | | text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` | | text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
| Fine-tuned OR Custom Embedding models | `embedding(model="vertex_ai/<your-model-id>", input)` |
### Supported OpenAI (Unified) Params ### Supported OpenAI (Unified) Params

View file

@ -89,6 +89,9 @@ def _get_vertex_url(
elif mode == "embedding": elif mode == "embedding":
endpoint = "predict" endpoint = "predict"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if model.isdigit():
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if not url or not endpoint: if not url or not endpoint:
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}") raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")

View file

@ -96,7 +96,7 @@ class VertexEmbedding(VertexBase):
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = ( vertex_request: VertexEmbeddingRequest = (
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
input=input, optional_params=optional_params input=input, optional_params=optional_params, model=model
) )
) )
@ -188,7 +188,7 @@ class VertexEmbedding(VertexBase):
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = ( vertex_request: VertexEmbeddingRequest = (
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
input=input, optional_params=optional_params input=input, optional_params=optional_params, model=model
) )
) )

View file

@ -101,11 +101,16 @@ class VertexAITextEmbeddingConfig(BaseModel):
return optional_params return optional_params
def transform_openai_request_to_vertex_embedding_request( def transform_openai_request_to_vertex_embedding_request(
self, input: Union[list, str], optional_params: dict self, input: Union[list, str], optional_params: dict, model: str
) -> VertexEmbeddingRequest: ) -> VertexEmbeddingRequest:
""" """
Transforms an openai request to a vertex embedding request. Transforms an openai request to a vertex embedding request.
""" """
if model.isdigit():
return self._transform_openai_request_to_fine_tuned_embedding_request(
input, optional_params, model
)
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest() vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingInput] = [] vertex_text_embedding_input_list: List[TextEmbeddingInput] = []
task_type: Optional[TaskType] = optional_params.get("task_type") task_type: Optional[TaskType] = optional_params.get("task_type")
@ -125,6 +130,47 @@ class VertexAITextEmbeddingConfig(BaseModel):
return vertex_request return vertex_request
def _transform_openai_request_to_fine_tuned_embedding_request(
self, input: Union[list, str], optional_params: dict, model: str
) -> VertexEmbeddingRequest:
"""
Transforms an openai request to a vertex fine-tuned embedding request.
Vertex Doc: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
Sample Request:
```json
{
"instances" : [
{
"inputs": "How would the Future of AI in 10 Years look?",
"parameters": {
"max_new_tokens": 128,
"temperature": 1.0,
"top_p": 0.9,
"top_k": 10
}
}
]
}
```
"""
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingFineTunedInput] = []
if isinstance(input, str):
input = [input] # Convert single string to list for uniform processing
for text in input:
embedding_input = TextEmbeddingFineTunedInput(inputs=text)
vertex_text_embedding_input_list.append(embedding_input)
vertex_request["instances"] = vertex_text_embedding_input_list
vertex_request["parameters"] = TextEmbeddingFineTunedParameters(
**optional_params
)
return vertex_request
def create_embedding_input( def create_embedding_input(
self, self,
content: str, content: str,
@ -157,6 +203,11 @@ class VertexAITextEmbeddingConfig(BaseModel):
""" """
Transforms a vertex embedding response to an openai response. Transforms a vertex embedding response to an openai response.
""" """
if model.isdigit():
return self._transform_vertex_response_to_openai_for_fine_tuned_models(
response, model, model_response
)
_predictions = response["predictions"] _predictions = response["predictions"]
embedding_response = [] embedding_response = []
@ -181,3 +232,35 @@ class VertexAITextEmbeddingConfig(BaseModel):
) )
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
return model_response return model_response
def _transform_vertex_response_to_openai_for_fine_tuned_models(
self, response: dict, model: str, model_response: litellm.EmbeddingResponse
) -> litellm.EmbeddingResponse:
"""
Transforms a vertex fine-tuned model embedding response to an openai response format.
"""
_predictions = response["predictions"]
embedding_response = []
# For fine-tuned models, we don't get token counts in the response
input_tokens = 0
for idx, embedding_values in enumerate(_predictions):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding_values[
0
], # The embedding values are nested one level deeper
}
)
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response

View file

@ -23,14 +23,27 @@ class TextEmbeddingInput(TypedDict, total=False):
title: Optional[str] title: Optional[str]
# Fine-tuned models require a different input format
# Ref: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
class TextEmbeddingFineTunedInput(TypedDict, total=False):
inputs: str
class TextEmbeddingFineTunedParameters(TypedDict, total=False):
max_new_tokens: Optional[int]
temperature: Optional[float]
top_p: Optional[float]
top_k: Optional[int]
class EmbeddingParameters(TypedDict, total=False): class EmbeddingParameters(TypedDict, total=False):
auto_truncate: Optional[bool] auto_truncate: Optional[bool]
output_dimensionality: Optional[int] output_dimensionality: Optional[int]
class VertexEmbeddingRequest(TypedDict, total=False): class VertexEmbeddingRequest(TypedDict, total=False):
instances: List[TextEmbeddingInput] instances: Union[List[TextEmbeddingInput], List[TextEmbeddingFineTunedInput]]
parameters: Optional[EmbeddingParameters] parameters: Optional[Union[EmbeddingParameters, TextEmbeddingFineTunedParameters]]
# Example usage: # Example usage:

View file

@ -16,6 +16,7 @@ import pytest
import litellm import litellm
from litellm import get_optional_params from litellm import get_optional_params
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
import httpx
def test_completion_pydantic_obj_2(): def test_completion_pydantic_obj_2():
@ -1317,3 +1318,39 @@ def test_image_completion_request(image_url):
mock_post.assert_called_once() mock_post.assert_called_once()
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"]) print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
assert mock_post.call_args.kwargs["json"] == expected_request_body assert mock_post.call_args.kwargs["json"] == expected_request_body
@pytest.mark.parametrize(
"model, expected_url",
[
(
"textembedding-gecko@001",
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict",
),
(
"123456789",
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/endpoints/123456789:predict",
),
],
)
def test_vertex_embedding_url(model, expected_url):
"""
Test URL generation for embedding models, including numeric model IDs (fine-tuned models
Relevant issue: https://github.com/BerriAI/litellm/issues/6482
When a fine-tuned embedding model is used, the URL is different from the standard one.
"""
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import _get_vertex_url
url, endpoint = _get_vertex_url(
mode="embedding",
model=model,
stream=False,
vertex_project="project-id",
vertex_location="us-central1",
vertex_api_version="v1",
)
assert url == expected_url
assert endpoint == "predict"

View file

@ -18,6 +18,8 @@ import json
import os import os
import tempfile import tempfile
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from respx import MockRouter
import httpx
import pytest import pytest
@ -973,6 +975,7 @@ async def test_partner_models_httpx(model, sync_mode):
data = { data = {
"model": model, "model": model,
"messages": messages, "messages": messages,
"timeout": 10,
} }
if sync_mode: if sync_mode:
response = litellm.completion(**data) response = litellm.completion(**data)
@ -986,6 +989,8 @@ async def test_partner_models_httpx(model, sync_mode):
assert isinstance(response._hidden_params["response_cost"], float) assert isinstance(response._hidden_params["response_cost"], float)
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
pass pass
except litellm.Timeout as e:
pass
except litellm.InternalServerError as e: except litellm.InternalServerError as e:
pass pass
except Exception as e: except Exception as e:
@ -3051,3 +3056,70 @@ def test_custom_api_base(api_base):
assert url == api_base + ":" assert url == api_base + ":"
else: else:
assert url == test_endpoint assert url == test_endpoint
@pytest.mark.asyncio
@pytest.mark.respx
async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
"""
Tests that:
- Request URL and body are correctly formatted for Vertex AI embeddings
- Response is properly parsed into litellm's embedding response format
"""
load_vertex_ai_credentials()
litellm.set_verbose = True
# Test input
input_text = ["good morning from litellm", "this is another item"]
# Expected request/response
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict"
expected_request = {
"instances": [
{"inputs": "good morning from litellm"},
{"inputs": "this is another item"},
],
"parameters": {},
}
mock_response = {
"predictions": [
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
],
"deployedModelId": "2275167734310371328",
"model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876",
"modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876",
"modelVersionId": "1",
}
# Setup mock request
mock_request = respx_mock.post(expected_url).mock(
return_value=httpx.Response(200, json=mock_response)
)
# Make request
response = await litellm.aembedding(
vertex_project="633608382793",
model="vertex_ai/1004708436694269952",
input=input_text,
)
# Assert request was made correctly
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
print("\n\nrequest_body", request_body)
print("\n\nexpected_request", expected_request)
assert request_body == expected_request
# Assert response structure
assert response is not None
assert hasattr(response, "data")
assert len(response.data) == len(input_text)
# Assert embedding structure
for embedding in response.data:
assert "embedding" in embedding
assert isinstance(embedding["embedding"], list)
assert len(embedding["embedding"]) > 0
assert all(isinstance(x, float) for x in embedding["embedding"])