forked from phoenix/litellm-mirror
(feat) Vertex AI - add support for fine tuned embedding models (#6749)
* fix use fine tuned vertex embedding models * test_vertex_embedding_url * add _transform_openai_request_to_fine_tuned_embedding_request * add _transform_openai_request_to_fine_tuned_embedding_request * add transform_openai_request_to_vertex_embedding_request * add _transform_vertex_response_to_openai_for_fine_tuned_models * test_vertexai_embedding for ft models * fix test_vertexai_embedding_finetuned * doc fine tuned / custom embedding models * fix test test_partner_models_httpx
This commit is contained in:
parent
c03351328f
commit
c119bad5f9
7 changed files with 261 additions and 5 deletions
|
@ -18,6 +18,8 @@ import json
|
|||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from respx import MockRouter
|
||||
import httpx
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -973,6 +975,7 @@ async def test_partner_models_httpx(model, sync_mode):
|
|||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"timeout": 10,
|
||||
}
|
||||
if sync_mode:
|
||||
response = litellm.completion(**data)
|
||||
|
@ -986,6 +989,8 @@ async def test_partner_models_httpx(model, sync_mode):
|
|||
assert isinstance(response._hidden_params["response_cost"], float)
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except litellm.InternalServerError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
@ -3051,3 +3056,70 @@ def test_custom_api_base(api_base):
|
|||
assert url == api_base + ":"
|
||||
else:
|
||||
assert url == test_endpoint
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.respx
|
||||
async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
|
||||
"""
|
||||
Tests that:
|
||||
- Request URL and body are correctly formatted for Vertex AI embeddings
|
||||
- Response is properly parsed into litellm's embedding response format
|
||||
"""
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
|
||||
# Test input
|
||||
input_text = ["good morning from litellm", "this is another item"]
|
||||
|
||||
# Expected request/response
|
||||
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict"
|
||||
expected_request = {
|
||||
"instances": [
|
||||
{"inputs": "good morning from litellm"},
|
||||
{"inputs": "this is another item"},
|
||||
],
|
||||
"parameters": {},
|
||||
}
|
||||
|
||||
mock_response = {
|
||||
"predictions": [
|
||||
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
|
||||
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
|
||||
],
|
||||
"deployedModelId": "2275167734310371328",
|
||||
"model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876",
|
||||
"modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876",
|
||||
"modelVersionId": "1",
|
||||
}
|
||||
|
||||
# Setup mock request
|
||||
mock_request = respx_mock.post(expected_url).mock(
|
||||
return_value=httpx.Response(200, json=mock_response)
|
||||
)
|
||||
|
||||
# Make request
|
||||
response = await litellm.aembedding(
|
||||
vertex_project="633608382793",
|
||||
model="vertex_ai/1004708436694269952",
|
||||
input=input_text,
|
||||
)
|
||||
|
||||
# Assert request was made correctly
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
print("\n\nrequest_body", request_body)
|
||||
print("\n\nexpected_request", expected_request)
|
||||
assert request_body == expected_request
|
||||
|
||||
# Assert response structure
|
||||
assert response is not None
|
||||
assert hasattr(response, "data")
|
||||
assert len(response.data) == len(input_text)
|
||||
|
||||
# Assert embedding structure
|
||||
for embedding in response.data:
|
||||
assert "embedding" in embedding
|
||||
assert isinstance(embedding["embedding"], list)
|
||||
assert len(embedding["embedding"]) > 0
|
||||
assert all(isinstance(x, float) for x in embedding["embedding"])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue