forked from phoenix/litellm-mirror
fix test_vertexai_embedding_finetuned
This commit is contained in:
parent
d7eefe0e0e
commit
0fc6d8c8d3
2 changed files with 69 additions and 67 deletions
|
@ -17,7 +17,6 @@ import litellm
|
||||||
from litellm import get_optional_params
|
from litellm import get_optional_params
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
import httpx
|
import httpx
|
||||||
from respx import MockRouter
|
|
||||||
|
|
||||||
|
|
||||||
def test_completion_pydantic_obj_2():
|
def test_completion_pydantic_obj_2():
|
||||||
|
@ -1355,69 +1354,3 @@ def test_vertex_embedding_url(model, expected_url):
|
||||||
|
|
||||||
assert url == expected_url
|
assert url == expected_url
|
||||||
assert endpoint == "predict"
|
assert endpoint == "predict"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.respx
|
|
||||||
async def test_vertexai_embedding(respx_mock: MockRouter):
|
|
||||||
"""
|
|
||||||
Tests that:
|
|
||||||
- Request URL and body are correctly formatted for Vertex AI embeddings
|
|
||||||
- Response is properly parsed into litellm's embedding response format
|
|
||||||
"""
|
|
||||||
litellm.set_verbose = True
|
|
||||||
|
|
||||||
# Test input
|
|
||||||
input_text = ["good morning from litellm", "this is another item"]
|
|
||||||
|
|
||||||
# Expected request/response
|
|
||||||
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict"
|
|
||||||
expected_request = {
|
|
||||||
"instances": [
|
|
||||||
{"inputs": "good morning from litellm"},
|
|
||||||
{"inputs": "this is another item"},
|
|
||||||
],
|
|
||||||
"parameters": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_response = {
|
|
||||||
"predictions": [
|
|
||||||
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
|
|
||||||
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
|
|
||||||
],
|
|
||||||
"deployedModelId": "2275167734310371328",
|
|
||||||
"model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876",
|
|
||||||
"modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876",
|
|
||||||
"modelVersionId": "1",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Setup mock request
|
|
||||||
mock_request = respx_mock.post(expected_url).mock(
|
|
||||||
return_value=httpx.Response(200, json=mock_response)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make request
|
|
||||||
response = await litellm.aembedding(
|
|
||||||
vertex_project="633608382793",
|
|
||||||
model="vertex_ai/1004708436694269952",
|
|
||||||
input=input_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert request was made correctly
|
|
||||||
assert mock_request.called
|
|
||||||
request_body = json.loads(mock_request.calls[0].request.content)
|
|
||||||
print("\n\nrequest_body", request_body)
|
|
||||||
print("\n\nexpected_request", expected_request)
|
|
||||||
assert request_body == expected_request
|
|
||||||
|
|
||||||
# Assert response structure
|
|
||||||
assert response is not None
|
|
||||||
assert hasattr(response, "data")
|
|
||||||
assert len(response.data) == len(input_text)
|
|
||||||
|
|
||||||
# Assert embedding structure
|
|
||||||
for embedding in response.data:
|
|
||||||
assert "embedding" in embedding
|
|
||||||
assert isinstance(embedding["embedding"], list)
|
|
||||||
assert len(embedding["embedding"]) > 0
|
|
||||||
assert all(isinstance(x, float) for x in embedding["embedding"])
|
|
||||||
|
|
|
@ -18,6 +18,8 @@ import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from respx import MockRouter
|
||||||
|
import httpx
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -3051,3 +3053,70 @@ def test_custom_api_base(api_base):
|
||||||
assert url == api_base + ":"
|
assert url == api_base + ":"
|
||||||
else:
|
else:
|
||||||
assert url == test_endpoint
|
assert url == test_endpoint
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.respx
|
||||||
|
async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
|
||||||
|
"""
|
||||||
|
Tests that:
|
||||||
|
- Request URL and body are correctly formatted for Vertex AI embeddings
|
||||||
|
- Response is properly parsed into litellm's embedding response format
|
||||||
|
"""
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
# Test input
|
||||||
|
input_text = ["good morning from litellm", "this is another item"]
|
||||||
|
|
||||||
|
# Expected request/response
|
||||||
|
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict"
|
||||||
|
expected_request = {
|
||||||
|
"instances": [
|
||||||
|
{"inputs": "good morning from litellm"},
|
||||||
|
{"inputs": "this is another item"},
|
||||||
|
],
|
||||||
|
"parameters": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response = {
|
||||||
|
"predictions": [
|
||||||
|
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
|
||||||
|
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
|
||||||
|
],
|
||||||
|
"deployedModelId": "2275167734310371328",
|
||||||
|
"model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876",
|
||||||
|
"modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876",
|
||||||
|
"modelVersionId": "1",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = respx_mock.post(expected_url).mock(
|
||||||
|
return_value=httpx.Response(200, json=mock_response)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make request
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
vertex_project="633608382793",
|
||||||
|
model="vertex_ai/1004708436694269952",
|
||||||
|
input=input_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert request was made correctly
|
||||||
|
assert mock_request.called
|
||||||
|
request_body = json.loads(mock_request.calls[0].request.content)
|
||||||
|
print("\n\nrequest_body", request_body)
|
||||||
|
print("\n\nexpected_request", expected_request)
|
||||||
|
assert request_body == expected_request
|
||||||
|
|
||||||
|
# Assert response structure
|
||||||
|
assert response is not None
|
||||||
|
assert hasattr(response, "data")
|
||||||
|
assert len(response.data) == len(input_text)
|
||||||
|
|
||||||
|
# Assert embedding structure
|
||||||
|
for embedding in response.data:
|
||||||
|
assert "embedding" in embedding
|
||||||
|
assert isinstance(embedding["embedding"], list)
|
||||||
|
assert len(embedding["embedding"]) > 0
|
||||||
|
assert all(isinstance(x, float) for x in embedding["embedding"])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue