From 88cc3c8fdc33bc925369c8a247d8f94f86a1fb91 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 14 Nov 2024 16:02:27 -0800 Subject: [PATCH] add _transform_vertex_response_to_openai_for_fine_tuned_models --- .../vertex_embeddings/transformation.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py index d59a916a5..6f4b25cef 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py @@ -203,6 +203,11 @@ class VertexAITextEmbeddingConfig(BaseModel): """ Transforms a vertex embedding response to an openai response. """ + if model.isdigit(): + return self._transform_vertex_response_to_openai_for_fine_tuned_models( + response, model, model_response + ) + _predictions = response["predictions"] embedding_response = [] @@ -227,3 +232,35 @@ class VertexAITextEmbeddingConfig(BaseModel): ) setattr(model_response, "usage", usage) return model_response + + def _transform_vertex_response_to_openai_for_fine_tuned_models( + self, response: dict, model: str, model_response: litellm.EmbeddingResponse + ) -> litellm.EmbeddingResponse: + """ + Transforms a vertex fine-tuned model embedding response to an openai response format. + """ + _predictions = response["predictions"] + + embedding_response = [] + # For fine-tuned models, we don't get token counts in the response + input_tokens = 0 + + for idx, embedding_values in enumerate(_predictions): + embedding_response.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding_values[ + 0 + ], # The embedding values are nested one level deeper + } + ) + + model_response.object = "list" + model_response.data = embedding_response + model_response.model = model + usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ) + setattr(model_response, "usage", usage) + return model_response