mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-16 14:38:00 +00:00
ensure messages is of the correct type
This commit is contained in:
parent
829c0e25ab
commit
2b599aa9b4
1 changed files with 28 additions and 1 deletions
|
@ -18,7 +18,15 @@ from pydantic import BaseModel
|
||||||
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
|
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
|
||||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import (
|
||||||
|
Inference,
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIDeveloperMessageParam,
|
||||||
|
OpenAIMessageParam,
|
||||||
|
OpenAISystemMessageParam,
|
||||||
|
OpenAIToolMessageParam,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
)
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
@ -65,6 +73,24 @@ class BatchRequest(BaseModel):
|
||||||
body: dict[str, Any]
|
body: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam:
|
||||||
|
"""Convert a message dictionary to OpenAIMessageParam based on role."""
|
||||||
|
role = msg.get("role")
|
||||||
|
|
||||||
|
if role == "user":
|
||||||
|
return OpenAIUserMessageParam(**msg)
|
||||||
|
elif role == "system":
|
||||||
|
return OpenAISystemMessageParam(**msg)
|
||||||
|
elif role == "assistant":
|
||||||
|
return OpenAIAssistantMessageParam(**msg)
|
||||||
|
elif role == "tool":
|
||||||
|
return OpenAIToolMessageParam(**msg)
|
||||||
|
elif role == "developer":
|
||||||
|
return OpenAIDeveloperMessageParam(**msg)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown message role: {role}")
|
||||||
|
|
||||||
|
|
||||||
class ReferenceBatchesImpl(Batches):
|
class ReferenceBatchesImpl(Batches):
|
||||||
"""Reference implementation of the Batches API.
|
"""Reference implementation of the Batches API.
|
||||||
|
|
||||||
|
@ -517,6 +543,7 @@ class ReferenceBatchesImpl(Batches):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO(SECURITY): review body for security issues
|
# TODO(SECURITY): review body for security issues
|
||||||
|
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||||
|
|
||||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue