mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 16:32:37 +00:00
feat: Implement openai_embeddings method in GeminiInferenceAdapter
This method addresses the absence of usage statistics in Gemini's embedding API by providing default values.
This commit is contained in:
parent
a0c0e5933a
commit
50fd998a49
1 changed files with 60 additions and 0 deletions
|
|
@ -4,6 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from openai import NOT_GIVEN
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIEmbeddingData,
|
||||||
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
|
OpenAIEmbeddingUsage,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .config import GeminiConfig
|
from .config import GeminiConfig
|
||||||
|
|
@ -20,3 +28,55 @@ class GeminiInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self):
|
||||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
"""
|
||||||
|
Override embeddings method to handle Gemini's missing usage statistics.
|
||||||
|
Gemini's embedding API doesn't return usage information, so we provide default values.
|
||||||
|
"""
|
||||||
|
# Prepare request parameters
|
||||||
|
request_params = {
|
||||||
|
"model": await self._get_provider_model_id(params.model),
|
||||||
|
"input": params.input,
|
||||||
|
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
|
||||||
|
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
|
||||||
|
"user": params.user if params.user is not None else NOT_GIVEN,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add extra_body if present
|
||||||
|
extra_body = params.model_extra
|
||||||
|
if extra_body:
|
||||||
|
request_params["extra_body"] = extra_body
|
||||||
|
|
||||||
|
# Call OpenAI embeddings API with properly typed parameters
|
||||||
|
response = await self.client.embeddings.create(**request_params)
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for i, embedding_data in enumerate(response.data):
|
||||||
|
data.append(
|
||||||
|
OpenAIEmbeddingData(
|
||||||
|
embedding=embedding_data.embedding,
|
||||||
|
index=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gemini doesn't return usage statistics - use default values
|
||||||
|
if hasattr(response, "usage") and response.usage:
|
||||||
|
usage = OpenAIEmbeddingUsage(
|
||||||
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
|
total_tokens=response.usage.total_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
usage = OpenAIEmbeddingUsage(
|
||||||
|
prompt_tokens=0,
|
||||||
|
total_tokens=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIEmbeddingsResponse(
|
||||||
|
data=data,
|
||||||
|
model=params.model,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue