mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: Add /v1/embeddings endpoint to batches API
This PR extends the Llama Stack Batches API to support the /v1/embeddings endpoint, enabling efficient batch processing of embedding requests alongside the existing /v1/chat/completions and /v1/completions support. Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
aab22dc759
commit
531b1451dc
3 changed files with 122 additions and 6 deletions
|
@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
|
@ -425,18 +425,23 @@ class ReferenceBatchesImpl(Batches):
|
|||
valid = False
|
||||
|
||||
if batch.endpoint == "/v1/chat/completions":
|
||||
required_params = [
|
||||
required_params: list[tuple[str, Any, str]] = [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]
|
||||
else: # /v1/completions
|
||||
elif batch.endpoint == "/v1/completions":
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||
]
|
||||
else: # /v1/embeddings
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("input", (str, list), "a string or array of strings"),
|
||||
]
|
||||
|
||||
for param, expected_type, type_string in required_params:
|
||||
if param not in body:
|
||||
|
@ -614,7 +619,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/completions
|
||||
elif request.url == "/v1/completions":
|
||||
completion_response = await self.inference_api.openai_completion(**request.body)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
|
@ -630,6 +635,20 @@ class ReferenceBatchesImpl(Batches):
|
|||
"body": completion_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/embeddings
|
||||
embeddings_response = await self.inference_api.openai_embeddings(**request.body)
|
||||
assert hasattr(embeddings_response, "model_dump_json"), (
|
||||
"Embeddings response must have model_dump_json method"
|
||||
)
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": embeddings_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue