From c119bad5f93a5c4054e22c759f1b5f1f702ad3c2 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 14 Nov 2024 20:37:55 -0800 Subject: [PATCH] (feat) Vertex AI - add support for fine tuned embedding models (#6749) * fix use fine tuned vertex embedding models * test_vertex_embedding_url * add _transform_openai_request_to_fine_tuned_embedding_request * add _transform_openai_request_to_fine_tuned_embedding_request * add transform_openai_request_to_vertex_embedding_request * add _transform_vertex_response_to_openai_for_fine_tuned_models * test_vertexai_embedding for ft models * fix test_vertexai_embedding_finetuned * doc fine tuned / custom embedding models * fix test test_partner_models_httpx --- docs/my-website/docs/providers/vertex.md | 48 +++++++++++ .../common_utils.py | 3 + .../vertex_embeddings/embedding_handler.py | 4 +- .../vertex_embeddings/transformation.py | 85 ++++++++++++++++++- .../vertex_embeddings/types.py | 17 +++- tests/llm_translation/test_vertex.py | 37 ++++++++ .../test_amazing_vertex_completion.py | 72 ++++++++++++++++ 7 files changed, 261 insertions(+), 5 deletions(-) diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index b69e8ee56..921db9e73 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -1562,6 +1562,10 @@ curl http://0.0.0.0:4000/v1/chat/completions \ ## **Embedding Models** #### Usage - Embedding + + + + ```python import litellm from litellm import embedding @@ -1574,6 +1578,49 @@ response = embedding( ) print(response) ``` + + + + + +1. Add model to config.yaml +```yaml +model_list: + - model_name: snowflake-arctic-embed-m-long-1731622468876 + litellm_params: + model: vertex_ai/ + vertex_project: "adroit-crow-413218" + vertex_location: "us-central1" + vertex_credentials: adroit-crow-413218-a956eef1a2a8.json + +litellm_settings: + drop_params: True +``` + +2. Start Proxy + +``` +$ litellm --config /path/to/config.yaml +``` + +3. Make Request using OpenAI Python SDK, Langchain Python SDK + +```python +import openai + +client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000") + +response = client.embeddings.create( + model="snowflake-arctic-embed-m-long-1731622468876", + input = ["good morning from litellm", "this is another item"], +) + +print(response) +``` + + + + #### Supported Embedding Models All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported @@ -1589,6 +1636,7 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02 | textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` | | text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` | | text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` | +| Fine-tuned OR Custom Embedding models | `embedding(model="vertex_ai/", input)` | ### Supported OpenAI (Unified) Params diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py index 0f95b222c..74bab0b26 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py @@ -89,6 +89,9 @@ def _get_vertex_url( elif mode == "embedding": endpoint = "predict" url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" + if model.isdigit(): + # https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict + url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" if not url or not endpoint: raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}") diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py index 0cde5c3b5..26741ff4f 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py @@ -96,7 +96,7 @@ class VertexEmbedding(VertexBase): headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) vertex_request: VertexEmbeddingRequest = ( litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( - input=input, optional_params=optional_params + input=input, optional_params=optional_params, model=model ) ) @@ -188,7 +188,7 @@ class VertexEmbedding(VertexBase): headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) vertex_request: VertexEmbeddingRequest = ( litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( - input=input, optional_params=optional_params + input=input, optional_params=optional_params, model=model ) ) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py index 1ca405392..6f4b25cef 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py @@ -101,11 +101,16 @@ class VertexAITextEmbeddingConfig(BaseModel): return optional_params def transform_openai_request_to_vertex_embedding_request( - self, input: Union[list, str], optional_params: dict + self, input: Union[list, str], optional_params: dict, model: str ) -> VertexEmbeddingRequest: """ Transforms an openai request to a vertex embedding request. """ + if model.isdigit(): + return self._transform_openai_request_to_fine_tuned_embedding_request( + input, optional_params, model + ) + vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest() vertex_text_embedding_input_list: List[TextEmbeddingInput] = [] task_type: Optional[TaskType] = optional_params.get("task_type") @@ -125,6 +130,47 @@ class VertexAITextEmbeddingConfig(BaseModel): return vertex_request + def _transform_openai_request_to_fine_tuned_embedding_request( + self, input: Union[list, str], optional_params: dict, model: str + ) -> VertexEmbeddingRequest: + """ + Transforms an openai request to a vertex fine-tuned embedding request. + + Vertex Doc: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22)) + Sample Request: + + ```json + { + "instances" : [ + { + "inputs": "How would the Future of AI in 10 Years look?", + "parameters": { + "max_new_tokens": 128, + "temperature": 1.0, + "top_p": 0.9, + "top_k": 10 + } + } + ] + } + ``` + """ + vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest() + vertex_text_embedding_input_list: List[TextEmbeddingFineTunedInput] = [] + if isinstance(input, str): + input = [input] # Convert single string to list for uniform processing + + for text in input: + embedding_input = TextEmbeddingFineTunedInput(inputs=text) + vertex_text_embedding_input_list.append(embedding_input) + + vertex_request["instances"] = vertex_text_embedding_input_list + vertex_request["parameters"] = TextEmbeddingFineTunedParameters( + **optional_params + ) + + return vertex_request + def create_embedding_input( self, content: str, @@ -157,6 +203,11 @@ class VertexAITextEmbeddingConfig(BaseModel): """ Transforms a vertex embedding response to an openai response. """ + if model.isdigit(): + return self._transform_vertex_response_to_openai_for_fine_tuned_models( + response, model, model_response + ) + _predictions = response["predictions"] embedding_response = [] @@ -181,3 +232,35 @@ class VertexAITextEmbeddingConfig(BaseModel): ) setattr(model_response, "usage", usage) return model_response + + def _transform_vertex_response_to_openai_for_fine_tuned_models( + self, response: dict, model: str, model_response: litellm.EmbeddingResponse + ) -> litellm.EmbeddingResponse: + """ + Transforms a vertex fine-tuned model embedding response to an openai response format. + """ + _predictions = response["predictions"] + + embedding_response = [] + # For fine-tuned models, we don't get token counts in the response + input_tokens = 0 + + for idx, embedding_values in enumerate(_predictions): + embedding_response.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding_values[ + 0 + ], # The embedding values are nested one level deeper + } + ) + + model_response.object = "list" + model_response.data = embedding_response + model_response.model = model + usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ) + setattr(model_response, "usage", usage) + return model_response diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py index 311809c82..433305516 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py @@ -23,14 +23,27 @@ class TextEmbeddingInput(TypedDict, total=False): title: Optional[str] +# Fine-tuned models require a different input format +# Ref: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22)) +class TextEmbeddingFineTunedInput(TypedDict, total=False): + inputs: str + + +class TextEmbeddingFineTunedParameters(TypedDict, total=False): + max_new_tokens: Optional[int] + temperature: Optional[float] + top_p: Optional[float] + top_k: Optional[int] + + class EmbeddingParameters(TypedDict, total=False): auto_truncate: Optional[bool] output_dimensionality: Optional[int] class VertexEmbeddingRequest(TypedDict, total=False): - instances: List[TextEmbeddingInput] - parameters: Optional[EmbeddingParameters] + instances: Union[List[TextEmbeddingInput], List[TextEmbeddingFineTunedInput]] + parameters: Optional[Union[EmbeddingParameters, TextEmbeddingFineTunedParameters]] # Example usage: diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py index a06179a49..73960020d 100644 --- a/tests/llm_translation/test_vertex.py +++ b/tests/llm_translation/test_vertex.py @@ -16,6 +16,7 @@ import pytest import litellm from litellm import get_optional_params from litellm.llms.custom_httpx.http_handler import HTTPHandler +import httpx def test_completion_pydantic_obj_2(): @@ -1317,3 +1318,39 @@ def test_image_completion_request(image_url): mock_post.assert_called_once() print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"]) assert mock_post.call_args.kwargs["json"] == expected_request_body + + +@pytest.mark.parametrize( + "model, expected_url", + [ + ( + "textembedding-gecko@001", + "https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict", + ), + ( + "123456789", + "https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/endpoints/123456789:predict", + ), + ], +) +def test_vertex_embedding_url(model, expected_url): + """ + Test URL generation for embedding models, including numeric model IDs (fine-tuned models + + Relevant issue: https://github.com/BerriAI/litellm/issues/6482 + + When a fine-tuned embedding model is used, the URL is different from the standard one. + """ + from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import _get_vertex_url + + url, endpoint = _get_vertex_url( + mode="embedding", + model=model, + stream=False, + vertex_project="project-id", + vertex_location="us-central1", + vertex_api_version="v1", + ) + + assert url == expected_url + assert endpoint == "predict" diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 2de53696f..5a07d17b7 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -18,6 +18,8 @@ import json import os import tempfile from unittest.mock import AsyncMock, MagicMock, patch +from respx import MockRouter +import httpx import pytest @@ -973,6 +975,7 @@ async def test_partner_models_httpx(model, sync_mode): data = { "model": model, "messages": messages, + "timeout": 10, } if sync_mode: response = litellm.completion(**data) @@ -986,6 +989,8 @@ async def test_partner_models_httpx(model, sync_mode): assert isinstance(response._hidden_params["response_cost"], float) except litellm.RateLimitError as e: pass + except litellm.Timeout as e: + pass except litellm.InternalServerError as e: pass except Exception as e: @@ -3051,3 +3056,70 @@ def test_custom_api_base(api_base): assert url == api_base + ":" else: assert url == test_endpoint + + +@pytest.mark.asyncio +@pytest.mark.respx +async def test_vertexai_embedding_finetuned(respx_mock: MockRouter): + """ + Tests that: + - Request URL and body are correctly formatted for Vertex AI embeddings + - Response is properly parsed into litellm's embedding response format + """ + load_vertex_ai_credentials() + litellm.set_verbose = True + + # Test input + input_text = ["good morning from litellm", "this is another item"] + + # Expected request/response + expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict" + expected_request = { + "instances": [ + {"inputs": "good morning from litellm"}, + {"inputs": "this is another item"}, + ], + "parameters": {}, + } + + mock_response = { + "predictions": [ + [[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector + [[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector + ], + "deployedModelId": "2275167734310371328", + "model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876", + "modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876", + "modelVersionId": "1", + } + + # Setup mock request + mock_request = respx_mock.post(expected_url).mock( + return_value=httpx.Response(200, json=mock_response) + ) + + # Make request + response = await litellm.aembedding( + vertex_project="633608382793", + model="vertex_ai/1004708436694269952", + input=input_text, + ) + + # Assert request was made correctly + assert mock_request.called + request_body = json.loads(mock_request.calls[0].request.content) + print("\n\nrequest_body", request_body) + print("\n\nexpected_request", expected_request) + assert request_body == expected_request + + # Assert response structure + assert response is not None + assert hasattr(response, "data") + assert len(response.data) == len(input_text) + + # Assert embedding structure + for embedding in response.data: + assert "embedding" in embedding + assert isinstance(embedding["embedding"], list) + assert len(embedding["embedding"]) > 0 + assert all(isinstance(x, float) for x in embedding["embedding"])