mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-05 10:23:44 +00:00
Migrates package structure to src/ layout following Python packaging best practices. All code moved from `llama_stack/` to `src/llama_stack/`. Public API unchanged - imports remain `import llama_stack.*`. Updated build configs, pre-commit hooks, scripts, and GitHub workflows accordingly. All hooks pass, package builds cleanly. **Developer note**: Reinstall after pulling: `pip install -e .`
82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from openai import NOT_GIVEN
|
|
|
|
from llama_stack.apis.inference import (
|
|
OpenAIEmbeddingData,
|
|
OpenAIEmbeddingsRequestWithExtraBody,
|
|
OpenAIEmbeddingsResponse,
|
|
OpenAIEmbeddingUsage,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|
|
|
from .config import GeminiConfig
|
|
|
|
|
|
class GeminiInferenceAdapter(OpenAIMixin):
|
|
config: GeminiConfig
|
|
|
|
provider_data_api_key_field: str = "gemini_api_key"
|
|
embedding_model_metadata: dict[str, dict[str, int]] = {
|
|
"models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
|
"models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048},
|
|
}
|
|
|
|
def get_base_url(self):
|
|
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
|
|
|
async def openai_embeddings(
|
|
self,
|
|
params: OpenAIEmbeddingsRequestWithExtraBody,
|
|
) -> OpenAIEmbeddingsResponse:
|
|
"""
|
|
Override embeddings method to handle Gemini's missing usage statistics.
|
|
Gemini's embedding API doesn't return usage information, so we provide default values.
|
|
"""
|
|
# Prepare request parameters
|
|
request_params = {
|
|
"model": await self._get_provider_model_id(params.model),
|
|
"input": params.input,
|
|
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
|
|
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
|
|
"user": params.user if params.user is not None else NOT_GIVEN,
|
|
}
|
|
|
|
# Add extra_body if present
|
|
extra_body = params.model_extra
|
|
if extra_body:
|
|
request_params["extra_body"] = extra_body
|
|
|
|
# Call OpenAI embeddings API with properly typed parameters
|
|
response = await self.client.embeddings.create(**request_params)
|
|
|
|
data = []
|
|
for i, embedding_data in enumerate(response.data):
|
|
data.append(
|
|
OpenAIEmbeddingData(
|
|
embedding=embedding_data.embedding,
|
|
index=i,
|
|
)
|
|
)
|
|
|
|
# Gemini doesn't return usage statistics - use default values
|
|
if hasattr(response, "usage") and response.usage:
|
|
usage = OpenAIEmbeddingUsage(
|
|
prompt_tokens=response.usage.prompt_tokens,
|
|
total_tokens=response.usage.total_tokens,
|
|
)
|
|
else:
|
|
usage = OpenAIEmbeddingUsage(
|
|
prompt_tokens=0,
|
|
total_tokens=0,
|
|
)
|
|
|
|
return OpenAIEmbeddingsResponse(
|
|
data=data,
|
|
model=params.model,
|
|
usage=usage,
|
|
)
|