feat: implement OpenAI chat completion for meta_reference provider

- Add chat_completion() method to LlamaGenerator supporting OpenAI request format
- Implement openai_chat_completion() in MetaReferenceInferenceImpl
- Fix ModelRunner task dispatch to handle chat_completion tasks
- Add convert_openai_message_to_raw_message() utility for message conversion
- Add unit tests for message conversion and model-parallel dispatch
- Remove unused CompletionRequestWithRawContent references

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-11-08 14:24:13 -05:00
parent 54754cb2a7
commit 7574f147b6
8 changed files with 240 additions and 17 deletions

View file

@ -13,11 +13,12 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.apis.inference import (
GreedySamplingStrategy,
JsonSchemaResponseFormat,
OpenAIChatCompletionRequestWithExtraBody,
ResponseFormat,
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.models.llama.datatypes import QuantizationMode
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
@ -142,3 +143,53 @@ class LlamaGenerator:
self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def chat_completion(
self,
request: OpenAIChatCompletionRequestWithExtraBody,
raw_messages: list,
):
"""Generate chat completion using OpenAI request format.
Args:
request: OpenAI chat completion request
raw_messages: Pre-converted list of RawMessage objects
"""
# Determine tool prompt format
tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json
# Prepare sampling params
sampling_params = SamplingParams()
if request.temperature is not None or request.top_p is not None:
sampling_params.strategy = TopPSamplingStrategy(
temperature=request.temperature or 1.0, top_p=request.top_p or 1.0
)
if request.max_tokens:
sampling_params.max_tokens = request.max_tokens
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
# Get logits processor for response format
logits_processor = None
if request.response_format:
if isinstance(request.response_format, dict) and request.response_format.get("type") == "json_schema":
json_schema_format = JsonSchemaResponseFormat(
type="json_schema", json_schema=request.response_format.get("json_schema", {})
)
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
# Generate
yield from self.inner_generator.generate(
llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=False,
echo=False,
logits_processor=logits_processor,
)

View file

@ -5,11 +5,16 @@
# the root directory of this source tree.
import asyncio
import time
import uuid
from collections.abc import AsyncIterator
from llama_stack.apis.inference import (
InferenceProvider,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionUsage,
OpenAIChoice,
OpenAICompletionRequestWithExtraBody,
)
from llama_stack.apis.inference.inference import (
@ -136,10 +141,14 @@ class MetaReferenceInferenceImpl(
self.llama_model = llama_model
log.info("Warming up...")
from llama_stack.apis.inference import OpenAIUserMessageParam
await self.openai_chat_completion(
model=model_id,
messages=[{"role": "user", "content": "Hi how are you?"}],
max_tokens=20,
params=OpenAIChatCompletionRequestWithExtraBody(
model=model_id,
messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")],
max_tokens=20,
)
)
log.info("Warmed up!")
@ -155,4 +164,50 @@ class MetaReferenceInferenceImpl(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
self.check_model(params)
# Convert OpenAI messages to RawMessages
from llama_stack.providers.utils.inference.prompt_adapter import convert_openai_message_to_raw_message
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
if isinstance(self.generator, LlamaGenerator):
generator = self.generator.chat_completion(params, raw_messages)
else:
# Model parallel: submit task to process group
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
# Collect all generated text
generated_text = ""
for result_batch in generator:
for result in result_batch:
if not result.ignore_token and result.source == "output":
generated_text += result.text
# Create OpenAI response
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
return OpenAIChatCompletion(
id=response_id,
object="chat.completion",
created=created,
model=params.model,
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant",
content=generated_text,
),
finish_reason="stop",
logprobs=None,
)
],
usage=OpenAIChatCompletionUsage(
prompt_tokens=0, # TODO: calculate properly
completion_tokens=0, # TODO: calculate properly
total_tokens=0, # TODO: calculate properly
),
)

View file

@ -19,7 +19,13 @@ class ModelRunner:
self.llama = llama
def __call__(self, task: Any):
raise ValueError(f"Unexpected task type {task[0]}")
task_type = task[0]
if task_type == "chat_completion":
# task[1] is [params, raw_messages]
params, raw_messages = task[1]
return self.llama.chat_completion(params, raw_messages)
else:
raise ValueError(f"Unexpected task type {task_type}")
def init_model_cb(

View file

@ -33,9 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
CompletionRequestWithRawContent,
)
log = get_logger(name=__name__, category="inference")
@ -68,10 +65,7 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: tuple[
str,
list[CompletionRequestWithRawContent],
]
task: tuple[str, list]
class TaskResponse(BaseModel):
@ -327,10 +321,7 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: tuple[
str,
list[CompletionRequestWithRawContent],
],
req: tuple[str, list],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -22,9 +22,14 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.inference import (
CompletionRequest,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIFile,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
ResponseFormat,
ResponseFormatType,
ToolChoice,
@ -37,6 +42,7 @@ from llama_stack.models.llama.datatypes import (
RawMessage,
RawTextItem,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
@ -128,6 +134,36 @@ async def interleaved_content_convert_to_raw(
return await _localize_single(content)
async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage:
"""Convert OpenAI message format to RawMessage format used by Llama formatters."""
if isinstance(message, OpenAIUserMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="user", content=content)
elif isinstance(message, OpenAISystemMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="system", content=content)
elif isinstance(message, OpenAIAssistantMessageParam):
content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type]
tool_calls = []
if message.tool_calls:
for tc in message.tool_calls:
if tc.function:
tool_calls.append(
ToolCall(
call_id=tc.id or "",
tool_name=tc.function.name or "",
arguments=tc.function.arguments or "{}",
)
)
return RawMessage(role="assistant", content=content, tool_calls=tool_calls)
elif isinstance(message, OpenAIToolMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="tool", content=content)
else:
# Handle OpenAIDeveloperMessageParam if needed
raise ValueError(f"Unsupported message type: {type(message)}")
def content_has_media(content: InterleavedContent):
def _has_media_content(c):
return isinstance(c, ImageContentItem)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import Mock
import pytest
from llama_stack.providers.inline.inference.meta_reference.model_parallel import (
ModelRunner,
)
class TestModelRunner:
"""Test ModelRunner task dispatching for model-parallel inference."""
def test_chat_completion_task_dispatch(self):
"""Verify ModelRunner correctly dispatches chat_completion tasks."""
# Create a mock generator
mock_generator = Mock()
mock_generator.chat_completion = Mock(return_value=iter([]))
runner = ModelRunner(mock_generator)
# Create a chat_completion task
fake_params = {"model": "test"}
fake_messages = [{"role": "user", "content": "test"}]
task = ("chat_completion", [fake_params, fake_messages])
# Execute task
runner(task)
# Verify chat_completion was called with correct arguments
mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages)
def test_invalid_task_type_raises_error(self):
"""Verify ModelRunner rejects invalid task types."""
mock_generator = Mock()
runner = ModelRunner(mock_generator)
with pytest.raises(ValueError, match="Unexpected task type"):
runner(("invalid_task", []))

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.models.llama.datatypes import RawTextItem
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_openai_message_to_raw_message,
)
class TestConvertOpenAIMessageToRawMessage:
"""Test conversion of OpenAI message types to RawMessage format."""
async def test_user_message_conversion(self):
msg = OpenAIUserMessageParam(role="user", content="Hello world")
raw_msg = await convert_openai_message_to_raw_message(msg)
assert raw_msg.role == "user"
assert isinstance(raw_msg.content, RawTextItem)
assert raw_msg.content.text == "Hello world"
async def test_assistant_message_conversion(self):
msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!")
raw_msg = await convert_openai_message_to_raw_message(msg)
assert raw_msg.role == "assistant"
assert isinstance(raw_msg.content, RawTextItem)
assert raw_msg.content.text == "Hi there!"
assert raw_msg.tool_calls == []