(feat) Vertex AI - add support for fine tuned embedding models (#6749)

* fix use fine tuned vertex embedding models

* test_vertex_embedding_url

* add _transform_openai_request_to_fine_tuned_embedding_request

* add _transform_openai_request_to_fine_tuned_embedding_request

* add transform_openai_request_to_vertex_embedding_request

* add _transform_vertex_response_to_openai_for_fine_tuned_models

* test_vertexai_embedding for ft models

* fix test_vertexai_embedding_finetuned

* doc fine tuned / custom embedding models

* fix test test_partner_models_httpx
This commit is contained in:
Ishaan Jaff 2024-11-14 20:37:55 -08:00 committed by GitHub
parent c03351328f
commit c119bad5f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 261 additions and 5 deletions

View file

@ -18,6 +18,8 @@ import json
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
from respx import MockRouter
import httpx
import pytest
@ -973,6 +975,7 @@ async def test_partner_models_httpx(model, sync_mode):
data = {
"model": model,
"messages": messages,
"timeout": 10,
}
if sync_mode:
response = litellm.completion(**data)
@ -986,6 +989,8 @@ async def test_partner_models_httpx(model, sync_mode):
assert isinstance(response._hidden_params["response_cost"], float)
except litellm.RateLimitError as e:
pass
except litellm.Timeout as e:
pass
except litellm.InternalServerError as e:
pass
except Exception as e:
@ -3051,3 +3056,70 @@ def test_custom_api_base(api_base):
assert url == api_base + ":"
else:
assert url == test_endpoint
@pytest.mark.asyncio
@pytest.mark.respx
async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
"""
Tests that:
- Request URL and body are correctly formatted for Vertex AI embeddings
- Response is properly parsed into litellm's embedding response format
"""
load_vertex_ai_credentials()
litellm.set_verbose = True
# Test input
input_text = ["good morning from litellm", "this is another item"]
# Expected request/response
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict"
expected_request = {
"instances": [
{"inputs": "good morning from litellm"},
{"inputs": "this is another item"},
],
"parameters": {},
}
mock_response = {
"predictions": [
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
],
"deployedModelId": "2275167734310371328",
"model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876",
"modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876",
"modelVersionId": "1",
}
# Setup mock request
mock_request = respx_mock.post(expected_url).mock(
return_value=httpx.Response(200, json=mock_response)
)
# Make request
response = await litellm.aembedding(
vertex_project="633608382793",
model="vertex_ai/1004708436694269952",
input=input_text,
)
# Assert request was made correctly
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
print("\n\nrequest_body", request_body)
print("\n\nexpected_request", expected_request)
assert request_body == expected_request
# Assert response structure
assert response is not None
assert hasattr(response, "data")
assert len(response.data) == len(input_text)
# Assert embedding structure
for embedding in response.data:
assert "embedding" in embedding
assert isinstance(embedding["embedding"], list)
assert len(embedding["embedding"]) > 0
assert all(isinstance(x, float) for x in embedding["embedding"])