forked from phoenix-oss/llama-stack-mirror
feat: New OpenAI compat embeddings API (#2314)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 4s
Integration Tests / test-matrix (http, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, scoring) (push) Failing after 10s
Test Llama Stack Build / generate-matrix (push) Successful in 6s
Integration Tests / test-matrix (library, providers) (push) Failing after 7s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 7s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 9s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 7s
Test Llama Stack Build / build (push) Failing after 5s
Unit Tests / unit-tests (3.10) (push) Failing after 7s
Update ReadTheDocs / update-readthedocs (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test External Providers / test-external-providers (venv) (push) Failing after 26s
Pre-commit / pre-commit (push) Successful in 1m11s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 4s
Integration Tests / test-matrix (http, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, scoring) (push) Failing after 10s
Test Llama Stack Build / generate-matrix (push) Successful in 6s
Integration Tests / test-matrix (library, providers) (push) Failing after 7s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 7s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 9s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 7s
Test Llama Stack Build / build (push) Failing after 5s
Unit Tests / unit-tests (3.10) (push) Failing after 7s
Update ReadTheDocs / update-readthedocs (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test External Providers / test-external-providers (venv) (push) Failing after 26s
Pre-commit / pre-commit (push) Successful in 1m11s
# What does this PR do? Adds a new endpoint that is compatible with OpenAI for embeddings api. `/openai/v1/embeddings` Added providers for OpenAI, LiteLLM and SentenceTransformer. ## Test Plan ``` LLAMA_STACK_CONFIG=http://localhost:8321 pytest -sv tests/integration/inference/test_openai_embeddings.py --embedding-model all-MiniLM-L6-v2,text-embedding-3-small,gemini/text-embedding-004 ```
This commit is contained in:
parent
277f8690ef
commit
b21050935e
21 changed files with 981 additions and 0 deletions
|
@ -22,6 +22,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -197,3 +198,13 @@ class BedrockInferenceAdapter(
|
|||
response_body = json.loads(response.get("body").read())
|
||||
embeddings.append(response_body.get("embedding"))
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -194,3 +195,13 @@ class CerebrasInferenceAdapter(
|
|||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -152,3 +153,13 @@ class DatabricksInferenceAdapter(
|
|||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -37,6 +37,7 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
|
@ -286,6 +287,16 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -238,6 +239,16 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
#
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -370,6 +371,16 @@ class OllamaInferenceAdapter(
|
|||
|
||||
return model
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -14,6 +14,9 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
|
@ -38,6 +41,7 @@ logger = logging.getLogger(__name__)
|
|||
# | batch_chat_completion | LiteLLMOpenAIMixin |
|
||||
# | openai_completion | AsyncOpenAI |
|
||||
# | openai_chat_completion | AsyncOpenAI |
|
||||
# | openai_embeddings | AsyncOpenAI |
|
||||
#
|
||||
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: OpenAIConfig) -> None:
|
||||
|
@ -171,3 +175,51 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
user=user,
|
||||
)
|
||||
return await self._openai_client.chat.completions.create(**params)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||
if model_id.startswith("openai/"):
|
||||
model_id = model_id[len("openai/") :]
|
||||
|
||||
# Prepare parameters for OpenAI embeddings API
|
||||
params = {
|
||||
"model": model_id,
|
||||
"input": input,
|
||||
}
|
||||
|
||||
if encoding_format is not None:
|
||||
params["encoding_format"] = encoding_format
|
||||
if dimensions is not None:
|
||||
params["dimensions"] = dimensions
|
||||
if user is not None:
|
||||
params["user"] = user
|
||||
|
||||
# Call OpenAI embeddings API
|
||||
response = await self._openai_client.embeddings.create(**params)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -210,6 +211,16 @@ class PassthroughInferenceAdapter(Inference):
|
|||
task_type=task_type,
|
||||
)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator
|
|||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingsResponse
|
||||
|
||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
@ -134,3 +135,13 @@ class RunpodInferenceAdapter(
|
|||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
|
@ -291,6 +292,16 @@ class _HfAdapter(
|
|||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
|
@ -267,6 +268,16 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
embeddings = [item.embedding for item in r.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -507,6 +508,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -260,6 +261,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError("embedding is not supported for watsonx")
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue