OpenAI compat embeddings API

This commit is contained in:
Hardik Shah 2025-05-29 15:27:59 -07:00
parent 2603f10f95
commit f2c2a05f58
20 changed files with 706 additions and 0 deletions

View file

@ -4,7 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import logging
import struct
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@ -15,6 +17,9 @@ from llama_stack.apis.inference import (
EmbeddingTaskType,
InterleavedContentItem,
ModelStore,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
TextTruncation,
)
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
@ -43,6 +48,50 @@ class SentenceTransformerEmbeddingMixin:
)
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
# Convert input to list format if it's a single string
input_list = [input] if isinstance(input, str) else input
if not input_list:
raise ValueError("Empty list not supported")
# Get the model and generate embeddings
model_obj = await self.model_store.get_model(model)
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
embeddings = embedding_model.encode(input_list, show_progress_bar=False)
# Convert embeddings to the requested format
data = []
for i, embedding in enumerate(embeddings):
if encoding_format == "base64":
# Convert float array to base64 string
float_bytes = struct.pack(f"{len(embedding)}f", *embedding)
embedding_value = base64.b64encode(float_bytes).decode("ascii")
else:
# Default to float format
embedding_value = embedding.tolist()
data.append(
OpenAIEmbeddingData(
embedding=embedding_value,
index=i,
)
)
# Not returning actual token usage
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
usage=usage,
)
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS