(feat) Vertex AI - add support for fine tuned embedding models (#6749)

* fix use fine tuned vertex embedding models

* test_vertex_embedding_url

* add _transform_openai_request_to_fine_tuned_embedding_request

* add _transform_openai_request_to_fine_tuned_embedding_request

* add transform_openai_request_to_vertex_embedding_request

* add _transform_vertex_response_to_openai_for_fine_tuned_models

* test_vertexai_embedding for ft models

* fix test_vertexai_embedding_finetuned

* doc fine tuned / custom embedding models

* fix test test_partner_models_httpx
This commit is contained in:
Ishaan Jaff 2024-11-14 20:37:55 -08:00 committed by GitHub
parent cd1409c1b7
commit 0dc1ac8872
7 changed files with 261 additions and 5 deletions

View file

@ -16,6 +16,7 @@ import pytest
import litellm
from litellm import get_optional_params
from litellm.llms.custom_httpx.http_handler import HTTPHandler
import httpx
def test_completion_pydantic_obj_2():
@ -1317,3 +1318,39 @@ def test_image_completion_request(image_url):
mock_post.assert_called_once()
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
assert mock_post.call_args.kwargs["json"] == expected_request_body
@pytest.mark.parametrize(
"model, expected_url",
[
(
"textembedding-gecko@001",
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict",
),
(
"123456789",
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/endpoints/123456789:predict",
),
],
)
def test_vertex_embedding_url(model, expected_url):
"""
Test URL generation for embedding models, including numeric model IDs (fine-tuned models
Relevant issue: https://github.com/BerriAI/litellm/issues/6482
When a fine-tuned embedding model is used, the URL is different from the standard one.
"""
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import _get_vertex_url
url, endpoint = _get_vertex_url(
mode="embedding",
model=model,
stream=False,
vertex_project="project-id",
vertex_location="us-central1",
vertex_api_version="v1",
)
assert url == expected_url
assert endpoint == "predict"