diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index 984ef5a90..1ff554e70 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -18,7 +18,15 @@ from pydantic import BaseModel from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError from llama_stack.apis.files import Files, OpenAIFilePurpose -from llama_stack.apis.inference import Inference +from llama_stack.apis.inference import ( + Inference, + OpenAIAssistantMessageParam, + OpenAIDeveloperMessageParam, + OpenAIMessageParam, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, +) from llama_stack.apis.models import Models from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore @@ -65,6 +73,24 @@ class BatchRequest(BaseModel): body: dict[str, Any] +def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam: + """Convert a message dictionary to OpenAIMessageParam based on role.""" + role = msg.get("role") + + if role == "user": + return OpenAIUserMessageParam(**msg) + elif role == "system": + return OpenAISystemMessageParam(**msg) + elif role == "assistant": + return OpenAIAssistantMessageParam(**msg) + elif role == "tool": + return OpenAIToolMessageParam(**msg) + elif role == "developer": + return OpenAIDeveloperMessageParam(**msg) + else: + raise ValueError(f"Unknown message role: {role}") + + class ReferenceBatchesImpl(Batches): """Reference implementation of the Batches API. @@ -517,6 +543,7 @@ class ReferenceBatchesImpl(Batches): try: # TODO(SECURITY): review body for security issues + request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] chat_response = await self.inference_api.openai_chat_completion(**request.body) # this is for mypy, we don't allow streaming so we'll get the right type