ensure messages is of the correct type

This commit is contained in:
Matthew Farrellee 2025-08-15 11:56:00 -04:00 committed by Ashwin Bharambe
parent 829c0e25ab
commit 2b599aa9b4

View file

@ -18,7 +18,15 @@ from pydantic import BaseModel
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.models import Models
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
@ -65,6 +73,24 @@ class BatchRequest(BaseModel):
body: dict[str, Any]
def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam:
"""Convert a message dictionary to OpenAIMessageParam based on role."""
role = msg.get("role")
if role == "user":
return OpenAIUserMessageParam(**msg)
elif role == "system":
return OpenAISystemMessageParam(**msg)
elif role == "assistant":
return OpenAIAssistantMessageParam(**msg)
elif role == "tool":
return OpenAIToolMessageParam(**msg)
elif role == "developer":
return OpenAIDeveloperMessageParam(**msg)
else:
raise ValueError(f"Unknown message role: {role}")
class ReferenceBatchesImpl(Batches):
"""Reference implementation of the Batches API.
@ -517,6 +543,7 @@ class ReferenceBatchesImpl(Batches):
try:
# TODO(SECURITY): review body for security issues
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type