forked from phoenix-oss/llama-stack-mirror
feat: New OpenAI compat embeddings API (#2314)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 4s
Integration Tests / test-matrix (http, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, scoring) (push) Failing after 10s
Test Llama Stack Build / generate-matrix (push) Successful in 6s
Integration Tests / test-matrix (library, providers) (push) Failing after 7s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 7s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 9s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 7s
Test Llama Stack Build / build (push) Failing after 5s
Unit Tests / unit-tests (3.10) (push) Failing after 7s
Update ReadTheDocs / update-readthedocs (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test External Providers / test-external-providers (venv) (push) Failing after 26s
Pre-commit / pre-commit (push) Successful in 1m11s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 4s
Integration Tests / test-matrix (http, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, scoring) (push) Failing after 10s
Test Llama Stack Build / generate-matrix (push) Successful in 6s
Integration Tests / test-matrix (library, providers) (push) Failing after 7s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 7s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 9s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 7s
Test Llama Stack Build / build (push) Failing after 5s
Unit Tests / unit-tests (3.10) (push) Failing after 7s
Update ReadTheDocs / update-readthedocs (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test External Providers / test-external-providers (venv) (push) Failing after 26s
Pre-commit / pre-commit (push) Successful in 1m11s
# What does this PR do? Adds a new endpoint that is compatible with OpenAI for embeddings api. `/openai/v1/embeddings` Added providers for OpenAI, LiteLLM and SentenceTransformer. ## Test Plan ``` LLAMA_STACK_CONFIG=http://localhost:8321 pytest -sv tests/integration/inference/test_openai_embeddings.py --embedding-model all-MiniLM-L6-v2,text-embedding-3-small,gemini/text-embedding-004 ```
This commit is contained in:
parent
277f8690ef
commit
b21050935e
21 changed files with 981 additions and 0 deletions
176
docs/_static/llama-stack-spec.html
vendored
176
docs/_static/llama-stack-spec.html
vendored
|
@ -3607,6 +3607,49 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/openai/v1/embeddings": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "An OpenAIEmbeddingsResponse containing the embeddings.",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIEmbeddingsResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Inference"
|
||||||
|
],
|
||||||
|
"description": "Generate OpenAI-compatible embeddings for the given input using the specified model.",
|
||||||
|
"parameters": [],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/OpenaiEmbeddingsRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/openai/v1/models": {
|
"/v1/openai/v1/models": {
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -11777,6 +11820,139 @@
|
||||||
"title": "OpenAICompletionChoice",
|
"title": "OpenAICompletionChoice",
|
||||||
"description": "A choice from an OpenAI-compatible completion response."
|
"description": "A choice from an OpenAI-compatible completion response."
|
||||||
},
|
},
|
||||||
|
"OpenaiEmbeddingsRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
|
||||||
|
},
|
||||||
|
"input": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings."
|
||||||
|
},
|
||||||
|
"encoding_format": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(Optional) The format to return the embeddings in. Can be either \"float\" or \"base64\". Defaults to \"float\"."
|
||||||
|
},
|
||||||
|
"dimensions": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "(Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models."
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"model",
|
||||||
|
"input"
|
||||||
|
],
|
||||||
|
"title": "OpenaiEmbeddingsRequest"
|
||||||
|
},
|
||||||
|
"OpenAIEmbeddingData": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"object": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "embedding",
|
||||||
|
"default": "embedding",
|
||||||
|
"description": "The object type, which will be \"embedding\""
|
||||||
|
},
|
||||||
|
"embedding": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "number"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "The embedding vector as a list of floats (when encoding_format=\"float\") or as a base64-encoded string (when encoding_format=\"base64\")"
|
||||||
|
},
|
||||||
|
"index": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The index of the embedding in the input list"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"object",
|
||||||
|
"embedding",
|
||||||
|
"index"
|
||||||
|
],
|
||||||
|
"title": "OpenAIEmbeddingData",
|
||||||
|
"description": "A single embedding data object from an OpenAI-compatible embeddings response."
|
||||||
|
},
|
||||||
|
"OpenAIEmbeddingUsage": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The number of tokens in the input"
|
||||||
|
},
|
||||||
|
"total_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The total number of tokens used"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"prompt_tokens",
|
||||||
|
"total_tokens"
|
||||||
|
],
|
||||||
|
"title": "OpenAIEmbeddingUsage",
|
||||||
|
"description": "Usage information for an OpenAI-compatible embeddings response."
|
||||||
|
},
|
||||||
|
"OpenAIEmbeddingsResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"object": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "list",
|
||||||
|
"default": "list",
|
||||||
|
"description": "The object type, which will be \"list\""
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIEmbeddingData"
|
||||||
|
},
|
||||||
|
"description": "List of embedding data objects"
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The model that was used to generate the embeddings"
|
||||||
|
},
|
||||||
|
"usage": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIEmbeddingUsage",
|
||||||
|
"description": "Usage information"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"object",
|
||||||
|
"data",
|
||||||
|
"model",
|
||||||
|
"usage"
|
||||||
|
],
|
||||||
|
"title": "OpenAIEmbeddingsResponse",
|
||||||
|
"description": "Response from an OpenAI-compatible embeddings request."
|
||||||
|
},
|
||||||
"OpenAIModel": {
|
"OpenAIModel": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
144
docs/_static/llama-stack-spec.yaml
vendored
144
docs/_static/llama-stack-spec.yaml
vendored
|
@ -2520,6 +2520,38 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/OpenaiCompletionRequest'
|
$ref: '#/components/schemas/OpenaiCompletionRequest'
|
||||||
required: true
|
required: true
|
||||||
|
/v1/openai/v1/embeddings:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: >-
|
||||||
|
An OpenAIEmbeddingsResponse containing the embeddings.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIEmbeddingsResponse'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Inference
|
||||||
|
description: >-
|
||||||
|
Generate OpenAI-compatible embeddings for the given input using the specified
|
||||||
|
model.
|
||||||
|
parameters: []
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenaiEmbeddingsRequest'
|
||||||
|
required: true
|
||||||
/v1/openai/v1/models:
|
/v1/openai/v1/models:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
@ -8197,6 +8229,118 @@ components:
|
||||||
title: OpenAICompletionChoice
|
title: OpenAICompletionChoice
|
||||||
description: >-
|
description: >-
|
||||||
A choice from an OpenAI-compatible completion response.
|
A choice from an OpenAI-compatible completion response.
|
||||||
|
OpenaiEmbeddingsRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The identifier of the model to use. The model must be an embedding model
|
||||||
|
registered with Llama Stack and available via the /models endpoint.
|
||||||
|
input:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
Input text to embed, encoded as a string or array of strings. To embed
|
||||||
|
multiple inputs in a single request, pass an array of strings.
|
||||||
|
encoding_format:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
(Optional) The format to return the embeddings in. Can be either "float"
|
||||||
|
or "base64". Defaults to "float".
|
||||||
|
dimensions:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) The number of dimensions the resulting output embeddings should
|
||||||
|
have. Only supported in text-embedding-3 and later models.
|
||||||
|
user:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
(Optional) A unique identifier representing your end-user, which can help
|
||||||
|
OpenAI to monitor and detect abuse.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- model
|
||||||
|
- input
|
||||||
|
title: OpenaiEmbeddingsRequest
|
||||||
|
OpenAIEmbeddingData:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: embedding
|
||||||
|
default: embedding
|
||||||
|
description: >-
|
||||||
|
The object type, which will be "embedding"
|
||||||
|
embedding:
|
||||||
|
oneOf:
|
||||||
|
- type: array
|
||||||
|
items:
|
||||||
|
type: number
|
||||||
|
- type: string
|
||||||
|
description: >-
|
||||||
|
The embedding vector as a list of floats (when encoding_format="float")
|
||||||
|
or as a base64-encoded string (when encoding_format="base64")
|
||||||
|
index:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
The index of the embedding in the input list
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- object
|
||||||
|
- embedding
|
||||||
|
- index
|
||||||
|
title: OpenAIEmbeddingData
|
||||||
|
description: >-
|
||||||
|
A single embedding data object from an OpenAI-compatible embeddings response.
|
||||||
|
OpenAIEmbeddingUsage:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
prompt_tokens:
|
||||||
|
type: integer
|
||||||
|
description: The number of tokens in the input
|
||||||
|
total_tokens:
|
||||||
|
type: integer
|
||||||
|
description: The total number of tokens used
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- prompt_tokens
|
||||||
|
- total_tokens
|
||||||
|
title: OpenAIEmbeddingUsage
|
||||||
|
description: >-
|
||||||
|
Usage information for an OpenAI-compatible embeddings response.
|
||||||
|
OpenAIEmbeddingsResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: list
|
||||||
|
default: list
|
||||||
|
description: The object type, which will be "list"
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/OpenAIEmbeddingData'
|
||||||
|
description: List of embedding data objects
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The model that was used to generate the embeddings
|
||||||
|
usage:
|
||||||
|
$ref: '#/components/schemas/OpenAIEmbeddingUsage'
|
||||||
|
description: Usage information
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- object
|
||||||
|
- data
|
||||||
|
- model
|
||||||
|
- usage
|
||||||
|
title: OpenAIEmbeddingsResponse
|
||||||
|
description: >-
|
||||||
|
Response from an OpenAI-compatible embeddings request.
|
||||||
OpenAIModel:
|
OpenAIModel:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
@ -783,6 +783,48 @@ class OpenAICompletion(BaseModel):
|
||||||
object: Literal["text_completion"] = "text_completion"
|
object: Literal["text_completion"] = "text_completion"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIEmbeddingData(BaseModel):
|
||||||
|
"""A single embedding data object from an OpenAI-compatible embeddings response.
|
||||||
|
|
||||||
|
:param object: The object type, which will be "embedding"
|
||||||
|
:param embedding: The embedding vector as a list of floats (when encoding_format="float") or as a base64-encoded string (when encoding_format="base64")
|
||||||
|
:param index: The index of the embedding in the input list
|
||||||
|
"""
|
||||||
|
|
||||||
|
object: Literal["embedding"] = "embedding"
|
||||||
|
embedding: list[float] | str
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIEmbeddingUsage(BaseModel):
|
||||||
|
"""Usage information for an OpenAI-compatible embeddings response.
|
||||||
|
|
||||||
|
:param prompt_tokens: The number of tokens in the input
|
||||||
|
:param total_tokens: The total number of tokens used
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIEmbeddingsResponse(BaseModel):
|
||||||
|
"""Response from an OpenAI-compatible embeddings request.
|
||||||
|
|
||||||
|
:param object: The object type, which will be "list"
|
||||||
|
:param data: List of embedding data objects
|
||||||
|
:param model: The model that was used to generate the embeddings
|
||||||
|
:param usage: Usage information
|
||||||
|
"""
|
||||||
|
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
data: list[OpenAIEmbeddingData]
|
||||||
|
model: str
|
||||||
|
usage: OpenAIEmbeddingUsage
|
||||||
|
|
||||||
|
|
||||||
class ModelStore(Protocol):
|
class ModelStore(Protocol):
|
||||||
async def get_model(self, identifier: str) -> Model: ...
|
async def get_model(self, identifier: str) -> Model: ...
|
||||||
|
|
||||||
|
@ -1076,6 +1118,26 @@ class InferenceProvider(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/embeddings", method="POST")
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
"""Generate OpenAI-compatible embeddings for the given input using the specified model.
|
||||||
|
|
||||||
|
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||||
|
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
|
||||||
|
:param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float".
|
||||||
|
:param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
|
||||||
|
:param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
|
||||||
|
:returns: An OpenAIEmbeddingsResponse containing the embeddings.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class Inference(InferenceProvider):
|
class Inference(InferenceProvider):
|
||||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||||
|
|
|
@ -45,6 +45,7 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
|
@ -546,6 +547,34 @@ class InferenceRouter(Inference):
|
||||||
await self.store.store_chat_completion(response, messages)
|
await self.store.store_chat_completion(response, messages)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
logger.debug(
|
||||||
|
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
||||||
|
)
|
||||||
|
model_obj = await self.routing_table.get_model(model)
|
||||||
|
if model_obj is None:
|
||||||
|
raise ValueError(f"Model '{model}' not found")
|
||||||
|
if model_obj.model_type != ModelType.embedding:
|
||||||
|
raise ValueError(f"Model '{model}' is not an embedding model")
|
||||||
|
|
||||||
|
params = dict(
|
||||||
|
model=model_obj.identifier,
|
||||||
|
input=input,
|
||||||
|
encoding_format=encoding_format,
|
||||||
|
dimensions=dimensions,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
|
return await provider.openai_embeddings(**params)
|
||||||
|
|
||||||
async def list_chat_completions(
|
async def list_chat_completions(
|
||||||
self,
|
self,
|
||||||
after: str | None = None,
|
after: str | None = None,
|
||||||
|
|
|
@ -40,6 +40,7 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -410,6 +411,16 @@ class VLLMInferenceImpl(
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -197,3 +198,13 @@ class BedrockInferenceAdapter(
|
||||||
response_body = json.loads(response.get("body").read())
|
response_body = json.loads(response.get("body").read())
|
||||||
embeddings.append(response_body.get("embedding"))
|
embeddings.append(response_body.get("embedding"))
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -194,3 +195,13 @@ class CerebrasInferenceAdapter(
|
||||||
task_type: EmbeddingTaskType | None = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -152,3 +153,13 @@ class DatabricksInferenceAdapter(
|
||||||
task_type: EmbeddingTaskType | None = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
@ -37,6 +37,7 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
|
@ -286,6 +287,16 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
embeddings = [data.embedding for data in response.data]
|
embeddings = [data.embedding for data in response.data]
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -238,6 +239,16 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
#
|
#
|
||||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -370,6 +371,16 @@ class OllamaInferenceAdapter(
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -14,6 +14,9 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
|
OpenAIEmbeddingData,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
|
@ -38,6 +41,7 @@ logger = logging.getLogger(__name__)
|
||||||
# | batch_chat_completion | LiteLLMOpenAIMixin |
|
# | batch_chat_completion | LiteLLMOpenAIMixin |
|
||||||
# | openai_completion | AsyncOpenAI |
|
# | openai_completion | AsyncOpenAI |
|
||||||
# | openai_chat_completion | AsyncOpenAI |
|
# | openai_chat_completion | AsyncOpenAI |
|
||||||
|
# | openai_embeddings | AsyncOpenAI |
|
||||||
#
|
#
|
||||||
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
def __init__(self, config: OpenAIConfig) -> None:
|
def __init__(self, config: OpenAIConfig) -> None:
|
||||||
|
@ -171,3 +175,51 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
return await self._openai_client.chat.completions.create(**params)
|
return await self._openai_client.chat.completions.create(**params)
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||||
|
if model_id.startswith("openai/"):
|
||||||
|
model_id = model_id[len("openai/") :]
|
||||||
|
|
||||||
|
# Prepare parameters for OpenAI embeddings API
|
||||||
|
params = {
|
||||||
|
"model": model_id,
|
||||||
|
"input": input,
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding_format is not None:
|
||||||
|
params["encoding_format"] = encoding_format
|
||||||
|
if dimensions is not None:
|
||||||
|
params["dimensions"] = dimensions
|
||||||
|
if user is not None:
|
||||||
|
params["user"] = user
|
||||||
|
|
||||||
|
# Call OpenAI embeddings API
|
||||||
|
response = await self._openai_client.embeddings.create(**params)
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for i, embedding_data in enumerate(response.data):
|
||||||
|
data.append(
|
||||||
|
OpenAIEmbeddingData(
|
||||||
|
embedding=embedding_data.embedding,
|
||||||
|
index=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = OpenAIEmbeddingUsage(
|
||||||
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
|
total_tokens=response.usage.total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIEmbeddingsResponse(
|
||||||
|
data=data,
|
||||||
|
model=response.model,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -210,6 +211,16 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
task_type=task_type,
|
task_type=task_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference.inference import OpenAIEmbeddingsResponse
|
||||||
|
|
||||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
@ -134,3 +135,13 @@ class RunpodInferenceAdapter(
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
@ -291,6 +292,16 @@ class _HfAdapter(
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class TGIAdapter(_HfAdapter):
|
class TGIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: TGIImplConfig) -> None:
|
async def initialize(self, config: TGIImplConfig) -> None:
|
||||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
@ -267,6 +268,16 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
embeddings = [item.embedding for item in r.data]
|
embeddings = [item.embedding for item in r.data]
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -507,6 +508,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
embeddings = [data.embedding for data in response.data]
|
embeddings = [data.embedding for data in response.data]
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -260,6 +261,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError("embedding is not supported for watsonx")
|
raise NotImplementedError("embedding is not supported for watsonx")
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -4,7 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
import logging
|
import logging
|
||||||
|
import struct
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -15,6 +17,9 @@ from llama_stack.apis.inference import (
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
ModelStore,
|
ModelStore,
|
||||||
|
OpenAIEmbeddingData,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
|
OpenAIEmbeddingUsage,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
|
@ -43,6 +48,50 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
)
|
)
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
# Convert input to list format if it's a single string
|
||||||
|
input_list = [input] if isinstance(input, str) else input
|
||||||
|
if not input_list:
|
||||||
|
raise ValueError("Empty list not supported")
|
||||||
|
|
||||||
|
# Get the model and generate embeddings
|
||||||
|
model_obj = await self.model_store.get_model(model)
|
||||||
|
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||||
|
embeddings = embedding_model.encode(input_list, show_progress_bar=False)
|
||||||
|
|
||||||
|
# Convert embeddings to the requested format
|
||||||
|
data = []
|
||||||
|
for i, embedding in enumerate(embeddings):
|
||||||
|
if encoding_format == "base64":
|
||||||
|
# Convert float array to base64 string
|
||||||
|
float_bytes = struct.pack(f"{len(embedding)}f", *embedding)
|
||||||
|
embedding_value = base64.b64encode(float_bytes).decode("ascii")
|
||||||
|
else:
|
||||||
|
# Default to float format
|
||||||
|
embedding_value = embedding.tolist()
|
||||||
|
|
||||||
|
data.append(
|
||||||
|
OpenAIEmbeddingData(
|
||||||
|
embedding=embedding_value,
|
||||||
|
index=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Not returning actual token usage
|
||||||
|
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
||||||
|
return OpenAIEmbeddingsResponse(
|
||||||
|
data=data,
|
||||||
|
model=model_obj.provider_resource_id,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||||
global EMBEDDING_MODELS
|
global EMBEDDING_MODELS
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import struct
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -35,6 +37,9 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
|
OpenAIEmbeddingData,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
|
@ -264,6 +269,52 @@ class LiteLLMOpenAIMixin(
|
||||||
embeddings = [data["embedding"] for data in response["data"]]
|
embeddings = [data["embedding"] for data in response["data"]]
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
model_obj = await self.model_store.get_model(model)
|
||||||
|
|
||||||
|
# Convert input to list if it's a string
|
||||||
|
input_list = [input] if isinstance(input, str) else input
|
||||||
|
|
||||||
|
# Call litellm embedding function
|
||||||
|
# litellm.drop_params = True
|
||||||
|
response = litellm.embedding(
|
||||||
|
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||||
|
input=input_list,
|
||||||
|
api_key=self.get_api_key(),
|
||||||
|
api_base=self.api_base,
|
||||||
|
dimensions=dimensions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert response to OpenAI format
|
||||||
|
data = []
|
||||||
|
for i, embedding_data in enumerate(response["data"]):
|
||||||
|
# we encode to base64 if the encoding format is base64 in the request
|
||||||
|
if encoding_format == "base64":
|
||||||
|
byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"])
|
||||||
|
embedding = base64.b64encode(byte_data).decode("utf-8")
|
||||||
|
else:
|
||||||
|
embedding = embedding_data["embedding"]
|
||||||
|
|
||||||
|
data.append(OpenAIEmbeddingData(embedding=embedding, index=i))
|
||||||
|
|
||||||
|
usage = OpenAIEmbeddingUsage(
|
||||||
|
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||||
|
total_tokens=response["usage"]["total_tokens"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIEmbeddingsResponse(
|
||||||
|
data=data,
|
||||||
|
model=model_obj.provider_resource_id,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
275
tests/integration/inference/test_openai_embeddings.py
Normal file
275
tests/integration/inference/test_openai_embeddings.py
Normal file
|
@ -0,0 +1,275 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import struct
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
|
||||||
|
def decode_base64_to_floats(base64_string: str) -> list[float]:
|
||||||
|
"""Helper function to decode base64 string to list of float32 values."""
|
||||||
|
embedding_bytes = base64.b64decode(base64_string)
|
||||||
|
float_count = len(embedding_bytes) // 4 # 4 bytes per float32
|
||||||
|
embedding_floats = struct.unpack(f"{float_count}f", embedding_bytes)
|
||||||
|
return list(embedding_floats)
|
||||||
|
|
||||||
|
|
||||||
|
def provider_from_model(client_with_models, model_id):
|
||||||
|
models = {m.identifier: m for m in client_with_models.models.list()}
|
||||||
|
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
|
||||||
|
provider_id = models[model_id].provider_id
|
||||||
|
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||||
|
return providers[provider_id]
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_model_doesnt_support_variable_dimensions(model_id):
|
||||||
|
if "text-embedding-3" not in model_id:
|
||||||
|
pytest.skip("{model_id} does not support variable output embedding dimensions")
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_model_doesnt_support_openai_embeddings(client_with_models, model_id):
|
||||||
|
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||||
|
pytest.skip("OpenAI embeddings are not supported when testing with library client yet.")
|
||||||
|
|
||||||
|
provider = provider_from_model(client_with_models, model_id)
|
||||||
|
if provider.provider_type in (
|
||||||
|
"inline::meta-reference",
|
||||||
|
"remote::bedrock",
|
||||||
|
"remote::cerebras",
|
||||||
|
"remote::databricks",
|
||||||
|
"remote::runpod",
|
||||||
|
"remote::sambanova",
|
||||||
|
"remote::tgi",
|
||||||
|
"remote::ollama",
|
||||||
|
):
|
||||||
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def openai_client(client_with_models):
|
||||||
|
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||||
|
return OpenAI(base_url=base_url, api_key="fake")
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embeddings_single_string(openai_client, client_with_models, embedding_model_id):
|
||||||
|
"""Test OpenAI embeddings endpoint with a single string input."""
|
||||||
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
|
input_text = "Hello, world!"
|
||||||
|
|
||||||
|
response = openai_client.embeddings.create(
|
||||||
|
model=embedding_model_id,
|
||||||
|
input=input_text,
|
||||||
|
encoding_format="float",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.object == "list"
|
||||||
|
assert response.model == embedding_model_id
|
||||||
|
assert len(response.data) == 1
|
||||||
|
assert response.data[0].object == "embedding"
|
||||||
|
assert response.data[0].index == 0
|
||||||
|
assert isinstance(response.data[0].embedding, list)
|
||||||
|
assert len(response.data[0].embedding) > 0
|
||||||
|
assert all(isinstance(x, float) for x in response.data[0].embedding)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embeddings_multiple_strings(openai_client, client_with_models, embedding_model_id):
|
||||||
|
"""Test OpenAI embeddings endpoint with multiple string inputs."""
|
||||||
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
|
input_texts = ["Hello, world!", "How are you today?", "This is a test."]
|
||||||
|
|
||||||
|
response = openai_client.embeddings.create(
|
||||||
|
model=embedding_model_id,
|
||||||
|
input=input_texts,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.object == "list"
|
||||||
|
assert response.model == embedding_model_id
|
||||||
|
assert len(response.data) == len(input_texts)
|
||||||
|
|
||||||
|
for i, embedding_data in enumerate(response.data):
|
||||||
|
assert embedding_data.object == "embedding"
|
||||||
|
assert embedding_data.index == i
|
||||||
|
assert isinstance(embedding_data.embedding, list)
|
||||||
|
assert len(embedding_data.embedding) > 0
|
||||||
|
assert all(isinstance(x, float) for x in embedding_data.embedding)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embeddings_with_encoding_format_float(openai_client, client_with_models, embedding_model_id):
|
||||||
|
"""Test OpenAI embeddings endpoint with float encoding format."""
|
||||||
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
|
input_text = "Test encoding format"
|
||||||
|
|
||||||
|
response = openai_client.embeddings.create(
|
||||||
|
model=embedding_model_id,
|
||||||
|
input=input_text,
|
||||||
|
encoding_format="float",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.object == "list"
|
||||||
|
assert len(response.data) == 1
|
||||||
|
assert isinstance(response.data[0].embedding, list)
|
||||||
|
assert all(isinstance(x, float) for x in response.data[0].embedding)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embeddings_with_dimensions(openai_client, client_with_models, embedding_model_id):
|
||||||
|
"""Test OpenAI embeddings endpoint with custom dimensions parameter."""
|
||||||
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
skip_if_model_doesnt_support_variable_dimensions(embedding_model_id)
|
||||||
|
|
||||||
|
input_text = "Test dimensions parameter"
|
||||||
|
dimensions = 16
|
||||||
|
|
||||||
|
response = openai_client.embeddings.create(
|
||||||
|
model=embedding_model_id,
|
||||||
|
input=input_text,
|
||||||
|
dimensions=dimensions,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.object == "list"
|
||||||
|
assert len(response.data) == 1
|
||||||
|
# Note: Not all models support custom dimensions, so we don't assert the exact dimension
|
||||||
|
assert isinstance(response.data[0].embedding, list)
|
||||||
|
assert len(response.data[0].embedding) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embeddings_with_user_parameter(openai_client, client_with_models, embedding_model_id):
|
||||||
|
"""Test OpenAI embeddings endpoint with user parameter."""
|
||||||
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
|
input_text = "Test user parameter"
|
||||||
|
user_id = "test-user-123"
|
||||||
|
|
||||||
|
response = openai_client.embeddings.create(
|
||||||
|
model=embedding_model_id,
|
||||||
|
input=input_text,
|
||||||
|
user=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.object == "list"
|
||||||
|
assert len(response.data) == 1
|
||||||
|
assert isinstance(response.data[0].embedding, list)
|
||||||
|
assert len(response.data[0].embedding) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embeddings_empty_list_error(openai_client, client_with_models, embedding_model_id):
|
||||||
|
"""Test that empty list input raises an appropriate error."""
|
||||||
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
|
with pytest.raises(Exception): # noqa: B017
|
||||||
|
openai_client.embeddings.create(
|
||||||
|
model=embedding_model_id,
|
||||||
|
input=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embeddings_invalid_model_error(openai_client, client_with_models, embedding_model_id):
|
||||||
|
"""Test that invalid model ID raises an appropriate error."""
|
||||||
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
|
with pytest.raises(Exception): # noqa: B017
|
||||||
|
openai_client.embeddings.create(
|
||||||
|
model="invalid-model-id",
|
||||||
|
input="Test text",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embeddings_different_inputs_different_outputs(openai_client, client_with_models, embedding_model_id):
|
||||||
|
"""Test that different inputs produce different embeddings."""
|
||||||
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
|
input_text1 = "This is the first text"
|
||||||
|
input_text2 = "This is completely different content"
|
||||||
|
|
||||||
|
response1 = openai_client.embeddings.create(
|
||||||
|
model=embedding_model_id,
|
||||||
|
input=input_text1,
|
||||||
|
)
|
||||||
|
|
||||||
|
response2 = openai_client.embeddings.create(
|
||||||
|
model=embedding_model_id,
|
||||||
|
input=input_text2,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding1 = response1.data[0].embedding
|
||||||
|
embedding2 = response2.data[0].embedding
|
||||||
|
|
||||||
|
assert len(embedding1) == len(embedding2)
|
||||||
|
# Embeddings should be different for different inputs
|
||||||
|
assert embedding1 != embedding2
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embeddings_with_encoding_format_base64(openai_client, client_with_models, embedding_model_id):
|
||||||
|
"""Test OpenAI embeddings endpoint with base64 encoding format."""
|
||||||
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
skip_if_model_doesnt_support_variable_dimensions(embedding_model_id)
|
||||||
|
|
||||||
|
input_text = "Test base64 encoding format"
|
||||||
|
dimensions = 12
|
||||||
|
|
||||||
|
response = openai_client.embeddings.create(
|
||||||
|
model=embedding_model_id,
|
||||||
|
input=input_text,
|
||||||
|
encoding_format="base64",
|
||||||
|
dimensions=dimensions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate response structure
|
||||||
|
assert response.object == "list"
|
||||||
|
assert len(response.data) == 1
|
||||||
|
|
||||||
|
# With base64 encoding, embedding should be a string, not a list
|
||||||
|
embedding_data = response.data[0]
|
||||||
|
assert embedding_data.object == "embedding"
|
||||||
|
assert embedding_data.index == 0
|
||||||
|
assert isinstance(embedding_data.embedding, str)
|
||||||
|
|
||||||
|
# Verify it's valid base64 and decode to floats
|
||||||
|
embedding_floats = decode_base64_to_floats(embedding_data.embedding)
|
||||||
|
|
||||||
|
# Verify we got valid floats
|
||||||
|
assert len(embedding_floats) == dimensions, f"Got embedding length {len(embedding_floats)}, expected {dimensions}"
|
||||||
|
assert all(isinstance(x, float) for x in embedding_floats)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embeddings_base64_batch_processing(openai_client, client_with_models, embedding_model_id):
|
||||||
|
"""Test OpenAI embeddings endpoint with base64 encoding for batch processing."""
|
||||||
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
|
input_texts = ["First text for base64", "Second text for base64", "Third text for base64"]
|
||||||
|
|
||||||
|
response = openai_client.embeddings.create(
|
||||||
|
model=embedding_model_id,
|
||||||
|
input=input_texts,
|
||||||
|
encoding_format="base64",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate response structure
|
||||||
|
assert response.object == "list"
|
||||||
|
assert response.model == embedding_model_id
|
||||||
|
assert len(response.data) == len(input_texts)
|
||||||
|
|
||||||
|
# Validate each embedding in the batch
|
||||||
|
embedding_dimensions = []
|
||||||
|
for i, embedding_data in enumerate(response.data):
|
||||||
|
assert embedding_data.object == "embedding"
|
||||||
|
assert embedding_data.index == i
|
||||||
|
|
||||||
|
# With base64 encoding, embedding should be a string, not a list
|
||||||
|
assert isinstance(embedding_data.embedding, str)
|
||||||
|
embedding_floats = decode_base64_to_floats(embedding_data.embedding)
|
||||||
|
assert len(embedding_floats) > 0
|
||||||
|
assert all(isinstance(x, float) for x in embedding_floats)
|
||||||
|
embedding_dimensions.append(len(embedding_floats))
|
||||||
|
|
||||||
|
# All embeddings should have the same dimensionality
|
||||||
|
assert all(dim == embedding_dimensions[0] for dim in embedding_dimensions)
|
Loading…
Add table
Add a link
Reference in a new issue