feat: Add /v1/embeddings endpoint to batches API

This PR extends the Llama Stack Batches API to support the /v1/embeddings endpoint, enabling efficient batch processing of embedding requests alongside the existing /v1/chat/completions and /v1/completions support.

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-09-08 16:55:17 -07:00
parent aab22dc759
commit 531b1451dc
3 changed files with 122 additions and 6 deletions

View file

@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
@ -425,18 +425,23 @@ class ReferenceBatchesImpl(Batches):
valid = False
if batch.endpoint == "/v1/chat/completions":
required_params = [
required_params: list[tuple[str, Any, str]] = [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]
else: # /v1/completions
elif batch.endpoint == "/v1/completions":
required_params = [
("model", str, "a string"),
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
]
else: # /v1/embeddings
required_params = [
("model", str, "a string"),
("input", (str, list), "a string or array of strings"),
]
for param, expected_type, type_string in required_params:
if param not in body:
@ -614,7 +619,7 @@ class ReferenceBatchesImpl(Batches):
"body": chat_response.model_dump_json(),
},
}
else: # /v1/completions
elif request.url == "/v1/completions":
completion_response = await self.inference_api.openai_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type
@ -630,6 +635,20 @@ class ReferenceBatchesImpl(Batches):
"body": completion_response.model_dump_json(),
},
}
else: # /v1/embeddings
embeddings_response = await self.inference_api.openai_embeddings(**request.body)
assert hasattr(embeddings_response, "model_dump_json"), (
"Embeddings response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": embeddings_response.model_dump_json(),
},
}
except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return {