forked from phoenix/litellm-mirror
(feat) Vertex AI - add support for fine tuned embedding models (#6749)
* fix use fine tuned vertex embedding models * test_vertex_embedding_url * add _transform_openai_request_to_fine_tuned_embedding_request * add _transform_openai_request_to_fine_tuned_embedding_request * add transform_openai_request_to_vertex_embedding_request * add _transform_vertex_response_to_openai_for_fine_tuned_models * test_vertexai_embedding for ft models * fix test_vertexai_embedding_finetuned * doc fine tuned / custom embedding models * fix test test_partner_models_httpx
This commit is contained in:
parent
c03351328f
commit
c119bad5f9
7 changed files with 261 additions and 5 deletions
|
@ -1562,6 +1562,10 @@ curl http://0.0.0.0:4000/v1/chat/completions \
|
|||
## **Embedding Models**
|
||||
|
||||
#### Usage - Embedding
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm import embedding
|
||||
|
@ -1574,6 +1578,49 @@ response = embedding(
|
|||
)
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="LiteLLM PROXY">
|
||||
|
||||
|
||||
1. Add model to config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: snowflake-arctic-embed-m-long-1731622468876
|
||||
litellm_params:
|
||||
model: vertex_ai/<your-model-id>
|
||||
vertex_project: "adroit-crow-413218"
|
||||
vertex_location: "us-central1"
|
||||
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
|
||||
|
||||
litellm_settings:
|
||||
drop_params: True
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
|
||||
```
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Make Request using OpenAI Python SDK, Langchain Python SDK
|
||||
|
||||
```python
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.embeddings.create(
|
||||
model="snowflake-arctic-embed-m-long-1731622468876",
|
||||
input = ["good morning from litellm", "this is another item"],
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
#### Supported Embedding Models
|
||||
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
|
||||
|
@ -1589,6 +1636,7 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
|
|||
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
|
||||
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||
| Fine-tuned OR Custom Embedding models | `embedding(model="vertex_ai/<your-model-id>", input)` |
|
||||
|
||||
### Supported OpenAI (Unified) Params
|
||||
|
||||
|
|
|
@ -89,6 +89,9 @@ def _get_vertex_url(
|
|||
elif mode == "embedding":
|
||||
endpoint = "predict"
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
if model.isdigit():
|
||||
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||
|
||||
if not url or not endpoint:
|
||||
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
|
||||
|
|
|
@ -96,7 +96,7 @@ class VertexEmbedding(VertexBase):
|
|||
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
vertex_request: VertexEmbeddingRequest = (
|
||||
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
|
||||
input=input, optional_params=optional_params
|
||||
input=input, optional_params=optional_params, model=model
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -188,7 +188,7 @@ class VertexEmbedding(VertexBase):
|
|||
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
vertex_request: VertexEmbeddingRequest = (
|
||||
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
|
||||
input=input, optional_params=optional_params
|
||||
input=input, optional_params=optional_params, model=model
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -101,11 +101,16 @@ class VertexAITextEmbeddingConfig(BaseModel):
|
|||
return optional_params
|
||||
|
||||
def transform_openai_request_to_vertex_embedding_request(
|
||||
self, input: Union[list, str], optional_params: dict
|
||||
self, input: Union[list, str], optional_params: dict, model: str
|
||||
) -> VertexEmbeddingRequest:
|
||||
"""
|
||||
Transforms an openai request to a vertex embedding request.
|
||||
"""
|
||||
if model.isdigit():
|
||||
return self._transform_openai_request_to_fine_tuned_embedding_request(
|
||||
input, optional_params, model
|
||||
)
|
||||
|
||||
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
|
||||
vertex_text_embedding_input_list: List[TextEmbeddingInput] = []
|
||||
task_type: Optional[TaskType] = optional_params.get("task_type")
|
||||
|
@ -125,6 +130,47 @@ class VertexAITextEmbeddingConfig(BaseModel):
|
|||
|
||||
return vertex_request
|
||||
|
||||
def _transform_openai_request_to_fine_tuned_embedding_request(
|
||||
self, input: Union[list, str], optional_params: dict, model: str
|
||||
) -> VertexEmbeddingRequest:
|
||||
"""
|
||||
Transforms an openai request to a vertex fine-tuned embedding request.
|
||||
|
||||
Vertex Doc: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
|
||||
Sample Request:
|
||||
|
||||
```json
|
||||
{
|
||||
"instances" : [
|
||||
{
|
||||
"inputs": "How would the Future of AI in 10 Years look?",
|
||||
"parameters": {
|
||||
"max_new_tokens": 128,
|
||||
"temperature": 1.0,
|
||||
"top_p": 0.9,
|
||||
"top_k": 10
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
"""
|
||||
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
|
||||
vertex_text_embedding_input_list: List[TextEmbeddingFineTunedInput] = []
|
||||
if isinstance(input, str):
|
||||
input = [input] # Convert single string to list for uniform processing
|
||||
|
||||
for text in input:
|
||||
embedding_input = TextEmbeddingFineTunedInput(inputs=text)
|
||||
vertex_text_embedding_input_list.append(embedding_input)
|
||||
|
||||
vertex_request["instances"] = vertex_text_embedding_input_list
|
||||
vertex_request["parameters"] = TextEmbeddingFineTunedParameters(
|
||||
**optional_params
|
||||
)
|
||||
|
||||
return vertex_request
|
||||
|
||||
def create_embedding_input(
|
||||
self,
|
||||
content: str,
|
||||
|
@ -157,6 +203,11 @@ class VertexAITextEmbeddingConfig(BaseModel):
|
|||
"""
|
||||
Transforms a vertex embedding response to an openai response.
|
||||
"""
|
||||
if model.isdigit():
|
||||
return self._transform_vertex_response_to_openai_for_fine_tuned_models(
|
||||
response, model, model_response
|
||||
)
|
||||
|
||||
_predictions = response["predictions"]
|
||||
|
||||
embedding_response = []
|
||||
|
@ -181,3 +232,35 @@ class VertexAITextEmbeddingConfig(BaseModel):
|
|||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def _transform_vertex_response_to_openai_for_fine_tuned_models(
|
||||
self, response: dict, model: str, model_response: litellm.EmbeddingResponse
|
||||
) -> litellm.EmbeddingResponse:
|
||||
"""
|
||||
Transforms a vertex fine-tuned model embedding response to an openai response format.
|
||||
"""
|
||||
_predictions = response["predictions"]
|
||||
|
||||
embedding_response = []
|
||||
# For fine-tuned models, we don't get token counts in the response
|
||||
input_tokens = 0
|
||||
|
||||
for idx, embedding_values in enumerate(_predictions):
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding_values[
|
||||
0
|
||||
], # The embedding values are nested one level deeper
|
||||
}
|
||||
)
|
||||
|
||||
model_response.object = "list"
|
||||
model_response.data = embedding_response
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
|
|
@ -23,14 +23,27 @@ class TextEmbeddingInput(TypedDict, total=False):
|
|||
title: Optional[str]
|
||||
|
||||
|
||||
# Fine-tuned models require a different input format
|
||||
# Ref: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
|
||||
class TextEmbeddingFineTunedInput(TypedDict, total=False):
|
||||
inputs: str
|
||||
|
||||
|
||||
class TextEmbeddingFineTunedParameters(TypedDict, total=False):
|
||||
max_new_tokens: Optional[int]
|
||||
temperature: Optional[float]
|
||||
top_p: Optional[float]
|
||||
top_k: Optional[int]
|
||||
|
||||
|
||||
class EmbeddingParameters(TypedDict, total=False):
|
||||
auto_truncate: Optional[bool]
|
||||
output_dimensionality: Optional[int]
|
||||
|
||||
|
||||
class VertexEmbeddingRequest(TypedDict, total=False):
|
||||
instances: List[TextEmbeddingInput]
|
||||
parameters: Optional[EmbeddingParameters]
|
||||
instances: Union[List[TextEmbeddingInput], List[TextEmbeddingFineTunedInput]]
|
||||
parameters: Optional[Union[EmbeddingParameters, TextEmbeddingFineTunedParameters]]
|
||||
|
||||
|
||||
# Example usage:
|
||||
|
|
|
@ -16,6 +16,7 @@ import pytest
|
|||
import litellm
|
||||
from litellm import get_optional_params
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
import httpx
|
||||
|
||||
|
||||
def test_completion_pydantic_obj_2():
|
||||
|
@ -1317,3 +1318,39 @@ def test_image_completion_request(image_url):
|
|||
mock_post.assert_called_once()
|
||||
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
|
||||
assert mock_post.call_args.kwargs["json"] == expected_request_body
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expected_url",
|
||||
[
|
||||
(
|
||||
"textembedding-gecko@001",
|
||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict",
|
||||
),
|
||||
(
|
||||
"123456789",
|
||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/endpoints/123456789:predict",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_vertex_embedding_url(model, expected_url):
|
||||
"""
|
||||
Test URL generation for embedding models, including numeric model IDs (fine-tuned models
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/6482
|
||||
|
||||
When a fine-tuned embedding model is used, the URL is different from the standard one.
|
||||
"""
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import _get_vertex_url
|
||||
|
||||
url, endpoint = _get_vertex_url(
|
||||
mode="embedding",
|
||||
model=model,
|
||||
stream=False,
|
||||
vertex_project="project-id",
|
||||
vertex_location="us-central1",
|
||||
vertex_api_version="v1",
|
||||
)
|
||||
|
||||
assert url == expected_url
|
||||
assert endpoint == "predict"
|
||||
|
|
|
@ -18,6 +18,8 @@ import json
|
|||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from respx import MockRouter
|
||||
import httpx
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -973,6 +975,7 @@ async def test_partner_models_httpx(model, sync_mode):
|
|||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"timeout": 10,
|
||||
}
|
||||
if sync_mode:
|
||||
response = litellm.completion(**data)
|
||||
|
@ -986,6 +989,8 @@ async def test_partner_models_httpx(model, sync_mode):
|
|||
assert isinstance(response._hidden_params["response_cost"], float)
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except litellm.InternalServerError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
@ -3051,3 +3056,70 @@ def test_custom_api_base(api_base):
|
|||
assert url == api_base + ":"
|
||||
else:
|
||||
assert url == test_endpoint
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.respx
|
||||
async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
|
||||
"""
|
||||
Tests that:
|
||||
- Request URL and body are correctly formatted for Vertex AI embeddings
|
||||
- Response is properly parsed into litellm's embedding response format
|
||||
"""
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
|
||||
# Test input
|
||||
input_text = ["good morning from litellm", "this is another item"]
|
||||
|
||||
# Expected request/response
|
||||
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict"
|
||||
expected_request = {
|
||||
"instances": [
|
||||
{"inputs": "good morning from litellm"},
|
||||
{"inputs": "this is another item"},
|
||||
],
|
||||
"parameters": {},
|
||||
}
|
||||
|
||||
mock_response = {
|
||||
"predictions": [
|
||||
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
|
||||
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
|
||||
],
|
||||
"deployedModelId": "2275167734310371328",
|
||||
"model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876",
|
||||
"modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876",
|
||||
"modelVersionId": "1",
|
||||
}
|
||||
|
||||
# Setup mock request
|
||||
mock_request = respx_mock.post(expected_url).mock(
|
||||
return_value=httpx.Response(200, json=mock_response)
|
||||
)
|
||||
|
||||
# Make request
|
||||
response = await litellm.aembedding(
|
||||
vertex_project="633608382793",
|
||||
model="vertex_ai/1004708436694269952",
|
||||
input=input_text,
|
||||
)
|
||||
|
||||
# Assert request was made correctly
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
print("\n\nrequest_body", request_body)
|
||||
print("\n\nexpected_request", expected_request)
|
||||
assert request_body == expected_request
|
||||
|
||||
# Assert response structure
|
||||
assert response is not None
|
||||
assert hasattr(response, "data")
|
||||
assert len(response.data) == len(input_text)
|
||||
|
||||
# Assert embedding structure
|
||||
for embedding in response.data:
|
||||
assert "embedding" in embedding
|
||||
assert isinstance(embedding["embedding"], list)
|
||||
assert len(embedding["embedding"]) > 0
|
||||
assert all(isinstance(x, float) for x in embedding["embedding"])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue