diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py index ea9763681..e4f8af6ac 100644 --- a/tests/llm_translation/test_vertex.py +++ b/tests/llm_translation/test_vertex.py @@ -16,6 +16,8 @@ import pytest import litellm from litellm import get_optional_params from litellm.llms.custom_httpx.http_handler import HTTPHandler +import httpx +from respx import MockRouter def test_completion_pydantic_obj_2(): @@ -1353,3 +1355,69 @@ def test_vertex_embedding_url(model, expected_url): assert url == expected_url assert endpoint == "predict" + + +@pytest.mark.asyncio +@pytest.mark.respx +async def test_vertexai_embedding(respx_mock: MockRouter): + """ + Tests that: + - Request URL and body are correctly formatted for Vertex AI embeddings + - Response is properly parsed into litellm's embedding response format + """ + litellm.set_verbose = True + + # Test input + input_text = ["good morning from litellm", "this is another item"] + + # Expected request/response + expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict" + expected_request = { + "instances": [ + {"inputs": "good morning from litellm"}, + {"inputs": "this is another item"}, + ], + "parameters": {}, + } + + mock_response = { + "predictions": [ + [[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector + [[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector + ], + "deployedModelId": "2275167734310371328", + "model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876", + "modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876", + "modelVersionId": "1", + } + + # Setup mock request + mock_request = respx_mock.post(expected_url).mock( + return_value=httpx.Response(200, json=mock_response) + ) + + # Make request + response = await litellm.aembedding( + vertex_project="633608382793", + model="vertex_ai/1004708436694269952", + input=input_text, + ) + + # Assert request was made correctly + assert mock_request.called + request_body = json.loads(mock_request.calls[0].request.content) + print("\n\nrequest_body", request_body) + print("\n\nexpected_request", expected_request) + assert request_body == expected_request + + # Assert response structure + assert response is not None + assert hasattr(response, "data") + assert len(response.data) == len(input_text) + + # Assert embedding structure + for embedding in response.data: + assert "embedding" in embedding + assert isinstance(embedding["embedding"], list) + assert len(embedding["embedding"]) > 0 + assert all(isinstance(x, float) for x in embedding["embedding"])