diff --git a/llama_stack/providers/remote/inference/gemini/gemini.py b/llama_stack/providers/remote/inference/gemini/gemini.py index bb66b94d5..27fea8b32 100644 --- a/llama_stack/providers/remote/inference/gemini/gemini.py +++ b/llama_stack/providers/remote/inference/gemini/gemini.py @@ -4,6 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from openai import NOT_GIVEN + +from llama_stack.apis.inference import ( + OpenAIEmbeddingData, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, +) from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import GeminiConfig @@ -14,8 +22,61 @@ class GeminiInferenceAdapter(OpenAIMixin): provider_data_api_key_field: str = "gemini_api_key" embedding_model_metadata: dict[str, dict[str, int]] = { - "text-embedding-004": {"embedding_dimension": 768, "context_length": 2048}, + "models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048}, + "models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048}, } def get_base_url(self): return "https://generativelanguage.googleapis.com/v1beta/openai/" + + async def openai_embeddings( + self, + params: OpenAIEmbeddingsRequestWithExtraBody, + ) -> OpenAIEmbeddingsResponse: + """ + Override embeddings method to handle Gemini's missing usage statistics. + Gemini's embedding API doesn't return usage information, so we provide default values. + """ + # Prepare request parameters + request_params = { + "model": await self._get_provider_model_id(params.model), + "input": params.input, + "encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN, + "dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN, + "user": params.user if params.user is not None else NOT_GIVEN, + } + + # Add extra_body if present + extra_body = params.model_extra + if extra_body: + request_params["extra_body"] = extra_body + + # Call OpenAI embeddings API with properly typed parameters + response = await self.client.embeddings.create(**request_params) + + data = [] + for i, embedding_data in enumerate(response.data): + data.append( + OpenAIEmbeddingData( + embedding=embedding_data.embedding, + index=i, + ) + ) + + # Gemini doesn't return usage statistics - use default values + if hasattr(response, "usage") and response.usage: + usage = OpenAIEmbeddingUsage( + prompt_tokens=response.usage.prompt_tokens, + total_tokens=response.usage.total_tokens, + ) + else: + usage = OpenAIEmbeddingUsage( + prompt_tokens=0, + total_tokens=0, + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=params.model, + usage=usage, + )