mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 14:57:20 +00:00
feat(gemini): Support gemini-embedding-001 and fix models/ prefix in metadata keys (#3813)
# Add support for Google Gemini `gemini-embedding-001` embedding model and correctly registers model type MR message created with the assistance of Claude-4.5-sonnet This resolves https://github.com/llamastack/llama-stack/issues/3755 ## What does this PR do? This PR adds support for the `gemini-embedding-001` Google embedding model to the llama-stack Gemini provider. This model provides high-dimensional embeddings (3072 dimensions) compared to the existing `text-embedding-004` model (768 dimensions). Old embeddings models (such as text-embedding-004) will be deprecated soon according to Google ([Link](https://developers.googleblog.com/en/gemini-embedding-available-gemini-api/)) ## Problem The Gemini provider only supported the `text-embedding-004` embedding model. The newer `gemini-embedding-001` model, which provides higher-dimensional embeddings for improved semantic representation, was not available through llama-stack. ## Solution This PR consists of three commits that implement, fix the model registration, and enable embedding generation: ### Commit 1: Initial addition of gemini-embedding-001 Added metadata for `gemini-embedding-001` to the `embedding_model_metadata` dictionary: ```python embedding_model_metadata: dict[str, dict[str, int]] = { "text-embedding-004": {"embedding_dimension": 768, "context_length": 2048}, "gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048}, # NEW } ``` **Issue discovered:** The model was not being registered correctly because the dictionary keys didn't match the model IDs returned by Gemini's API. ### Commit 2: Fix model ID matching with `models/` prefix Updated both dictionary keys to include the `models/` prefix to match Gemini's OpenAI-compatible API response format: ```python embedding_model_metadata: dict[str, dict[str, int]] = { "models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048}, # UPDATED "models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048}, # UPDATED } ``` **Root cause:** Gemini's OpenAI-compatible API returns model IDs with the `models/` prefix (e.g., `models/text-embedding-004`). The `OpenAIMixin.list_models()` method directly matches these IDs against the `embedding_model_metadata` dictionary keys. Without the prefix, the models were being registered as LLMs instead of embedding models. ### Commit 3: Fix embedding generation for providers without usage stats Fixed a bug in `OpenAIMixin.openai_embeddings()` that prevented embedding generation for providers (like Gemini) that don't return usage statistics: ```python # Before (Line 351-354): usage = OpenAIEmbeddingUsage( prompt_tokens=response.usage.prompt_tokens, # ← Crashed with AttributeError total_tokens=response.usage.total_tokens, ) # After (Lines 351-362): if response.usage: usage = OpenAIEmbeddingUsage( prompt_tokens=response.usage.prompt_tokens, total_tokens=response.usage.total_tokens, ) else: usage = OpenAIEmbeddingUsage( prompt_tokens=0, # Default when not provided total_tokens=0, # Default when not provided ) ``` **Impact:** This fix enables embedding generation for **all** Gemini embedding models, not just the newly added one. ## Changes ### Modified Files **`llama_stack/providers/remote/inference/gemini/gemini.py`** - Line 17: Updated `text-embedding-004` key to `models/text-embedding-004` - Line 18: Added `models/gemini-embedding-001` with correct metadata **`llama_stack/providers/utils/inference/openai_mixin.py`** - Lines 351-362: Added null check for `response.usage` to handle providers without usage statistics ## Key Technical Details ### Model ID Matching Flow 1. `list_provider_model_ids()` calls Gemini's `/v1/models` endpoint 2. API returns model IDs like: `models/text-embedding-004`, `models/gemini-embedding-001` 3. `OpenAIMixin.list_models()` (line 410) checks: `if metadata := self.embedding_model_metadata.get(provider_model_id)` 4. If matched, registers as `model_type: "embedding"` with metadata; otherwise registers as `model_type: "llm"` ### Why Both Keys Needed the Prefix The `text-embedding-004` model was already working because there was likely separate configuration or manual registration handling it. For auto-discovery to work correctly for **both** models, both keys must match the API's model ID format exactly. ## How to test this PR Verified the changes by: 1. **Model Auto-Discovery**: Started llama-stack server and confirmed models are auto-discovered from Gemini API 2. **Model Registration**: Confirmed both embedding models are correctly registered and visible ```bash curl http://localhost:8325/v1/models | jq '.data[] | select(.provider_id == "gemini" and .model_type == "embedding")' ``` **Results:** - ✅ `gemini/models/text-embedding-004` - 768 dimensions - `model_type: "embedding"` - ✅ `gemini/models/gemini-embedding-001` - 3072 dimensions - `model_type: "embedding"` 3. **Before Fix (Commit 1)**: Models appeared as `model_type: "llm"` without embedding metadata 4. **After Fix (Commit 2)**: Models correctly identified as `model_type: "embedding"` with proper metadata 5. **Generate Embeddings**: Verified embedding generation works ```bash curl -X POST http://localhost:8325/v1/embeddings \ -H "Content-Type: application/json" \ -d '{"model": "gemini/models/gemini-embedding-001", "input": "test"}' | \ jq '.data[0].embedding | length' ```
This commit is contained in:
parent
ce8ea2f505
commit
add8cd801b
1 changed files with 62 additions and 1 deletions
|
@ -4,6 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
@ -14,8 +22,61 @@ class GeminiInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "gemini_api_key"
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
"models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
"models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048},
|
||||
}
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
Override embeddings method to handle Gemini's missing usage statistics.
|
||||
Gemini's embedding API doesn't return usage information, so we provide default values.
|
||||
"""
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"model": await self._get_provider_model_id(params.model),
|
||||
"input": params.input,
|
||||
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
|
||||
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
|
||||
"user": params.user if params.user is not None else NOT_GIVEN,
|
||||
}
|
||||
|
||||
# Add extra_body if present
|
||||
extra_body = params.model_extra
|
||||
if extra_body:
|
||||
request_params["extra_body"] = extra_body
|
||||
|
||||
# Call OpenAI embeddings API with properly typed parameters
|
||||
response = await self.client.embeddings.create(**request_params)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
# Gemini doesn't return usage statistics - use default values
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
else:
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=params.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue