llama-stack-mirror/src/llama_stack/providers/remote/inference/gemini/gemini.py
Ashwin Bharambe 3c81b23fbe chore: remove unused NOT_GIVEN imports
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-28 10:37:48 -07:00

81 lines
2.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.inference import (
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import GeminiConfig
class GeminiInferenceAdapter(OpenAIMixin):
config: GeminiConfig
provider_data_api_key_field: str = "gemini_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
"models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048},
}
def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/"
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
Override embeddings method to handle Gemini's missing usage statistics.
Gemini's embedding API doesn't return usage information, so we provide default values.
"""
# Build kwargs conditionally to avoid NotGiven/Omit type mismatch
kwargs: dict[str, Any] = {
"model": await self._get_provider_model_id(params.model),
"input": params.input,
}
if params.encoding_format is not None:
kwargs["encoding_format"] = params.encoding_format
if params.dimensions is not None:
kwargs["dimensions"] = params.dimensions
if params.user is not None:
kwargs["user"] = params.user
if params.model_extra:
kwargs["extra_body"] = params.model_extra
response = await self.client.embeddings.create(**kwargs)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
# Gemini doesn't return usage statistics - use default values
if hasattr(response, "usage") and response.usage:
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
else:
usage = OpenAIEmbeddingUsage(
prompt_tokens=0,
total_tokens=0,
)
return OpenAIEmbeddingsResponse(
data=data,
model=params.model,
usage=usage,
)