mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
feat: implement OpenAI chat completion for meta_reference provider
- Add chat_completion() method to LlamaGenerator supporting OpenAI request format - Implement openai_chat_completion() in MetaReferenceInferenceImpl - Fix ModelRunner task dispatch to handle chat_completion tasks - Add convert_openai_message_to_raw_message() utility for message conversion - Add unit tests for message conversion and model-parallel dispatch - Remove unused CompletionRequestWithRawContent references Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
54754cb2a7
commit
7574f147b6
8 changed files with 240 additions and 17 deletions
|
|
@ -13,11 +13,12 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
GreedySamplingStrategy,
|
GreedySamplingStrategy,
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.datatypes import QuantizationMode
|
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
|
||||||
from llama_stack.models.llama.llama3.generation import Llama3
|
from llama_stack.models.llama.llama3.generation import Llama3
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||||
from llama_stack.models.llama.llama4.generation import Llama4
|
from llama_stack.models.llama.llama4.generation import Llama4
|
||||||
|
|
@ -142,3 +143,53 @@ class LlamaGenerator:
|
||||||
self.tokenizer = self.inner_generator.tokenizer
|
self.tokenizer = self.inner_generator.tokenizer
|
||||||
self.args = self.inner_generator.args
|
self.args = self.inner_generator.args
|
||||||
self.formatter = self.inner_generator.formatter
|
self.formatter = self.inner_generator.formatter
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
|
self,
|
||||||
|
request: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
raw_messages: list,
|
||||||
|
):
|
||||||
|
"""Generate chat completion using OpenAI request format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: OpenAI chat completion request
|
||||||
|
raw_messages: Pre-converted list of RawMessage objects
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Determine tool prompt format
|
||||||
|
tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json
|
||||||
|
|
||||||
|
# Prepare sampling params
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
if request.temperature is not None or request.top_p is not None:
|
||||||
|
sampling_params.strategy = TopPSamplingStrategy(
|
||||||
|
temperature=request.temperature or 1.0, top_p=request.top_p or 1.0
|
||||||
|
)
|
||||||
|
if request.max_tokens:
|
||||||
|
sampling_params.max_tokens = request.max_tokens
|
||||||
|
|
||||||
|
max_gen_len = sampling_params.max_tokens
|
||||||
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||||
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
|
|
||||||
|
# Get logits processor for response format
|
||||||
|
logits_processor = None
|
||||||
|
if request.response_format:
|
||||||
|
if isinstance(request.response_format, dict) and request.response_format.get("type") == "json_schema":
|
||||||
|
json_schema_format = JsonSchemaResponseFormat(
|
||||||
|
type="json_schema", json_schema=request.response_format.get("json_schema", {})
|
||||||
|
)
|
||||||
|
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
yield from self.inner_generator.generate(
|
||||||
|
llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)],
|
||||||
|
max_gen_len=max_gen_len,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
logprobs=False,
|
||||||
|
echo=False,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,16 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
OpenAIChatCompletionRequestWithExtraBody,
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
OpenAIChatCompletionUsage,
|
||||||
|
OpenAIChoice,
|
||||||
OpenAICompletionRequestWithExtraBody,
|
OpenAICompletionRequestWithExtraBody,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
|
|
@ -136,11 +141,15 @@ class MetaReferenceInferenceImpl(
|
||||||
self.llama_model = llama_model
|
self.llama_model = llama_model
|
||||||
|
|
||||||
log.info("Warming up...")
|
log.info("Warming up...")
|
||||||
|
from llama_stack.apis.inference import OpenAIUserMessageParam
|
||||||
|
|
||||||
await self.openai_chat_completion(
|
await self.openai_chat_completion(
|
||||||
|
params=OpenAIChatCompletionRequestWithExtraBody(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
messages=[{"role": "user", "content": "Hi how are you?"}],
|
messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")],
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
log.info("Warmed up!")
|
log.info("Warmed up!")
|
||||||
|
|
||||||
def check_model(self, request) -> None:
|
def check_model(self, request) -> None:
|
||||||
|
|
@ -155,4 +164,50 @@ class MetaReferenceInferenceImpl(
|
||||||
self,
|
self,
|
||||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
|
self.check_model(params)
|
||||||
|
|
||||||
|
# Convert OpenAI messages to RawMessages
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import convert_openai_message_to_raw_message
|
||||||
|
|
||||||
|
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
|
||||||
|
|
||||||
|
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
|
||||||
|
if isinstance(self.generator, LlamaGenerator):
|
||||||
|
generator = self.generator.chat_completion(params, raw_messages)
|
||||||
|
else:
|
||||||
|
# Model parallel: submit task to process group
|
||||||
|
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
|
||||||
|
|
||||||
|
# Collect all generated text
|
||||||
|
generated_text = ""
|
||||||
|
for result_batch in generator:
|
||||||
|
for result in result_batch:
|
||||||
|
if not result.ignore_token and result.source == "output":
|
||||||
|
generated_text += result.text
|
||||||
|
|
||||||
|
# Create OpenAI response
|
||||||
|
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||||
|
created = int(time.time())
|
||||||
|
|
||||||
|
return OpenAIChatCompletion(
|
||||||
|
id=response_id,
|
||||||
|
object="chat.completion",
|
||||||
|
created=created,
|
||||||
|
model=params.model,
|
||||||
|
choices=[
|
||||||
|
OpenAIChoice(
|
||||||
|
index=0,
|
||||||
|
message=OpenAIAssistantMessageParam(
|
||||||
|
role="assistant",
|
||||||
|
content=generated_text,
|
||||||
|
),
|
||||||
|
finish_reason="stop",
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=OpenAIChatCompletionUsage(
|
||||||
|
prompt_tokens=0, # TODO: calculate properly
|
||||||
|
completion_tokens=0, # TODO: calculate properly
|
||||||
|
total_tokens=0, # TODO: calculate properly
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,13 @@ class ModelRunner:
|
||||||
self.llama = llama
|
self.llama = llama
|
||||||
|
|
||||||
def __call__(self, task: Any):
|
def __call__(self, task: Any):
|
||||||
raise ValueError(f"Unexpected task type {task[0]}")
|
task_type = task[0]
|
||||||
|
if task_type == "chat_completion":
|
||||||
|
# task[1] is [params, raw_messages]
|
||||||
|
params, raw_messages = task[1]
|
||||||
|
return self.llama.chat_completion(params, raw_messages)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected task type {task_type}")
|
||||||
|
|
||||||
|
|
||||||
def init_model_cb(
|
def init_model_cb(
|
||||||
|
|
|
||||||
|
|
@ -33,9 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import GenerationResult
|
from llama_stack.models.llama.datatypes import GenerationResult
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
CompletionRequestWithRawContent,
|
|
||||||
)
|
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
@ -68,10 +65,7 @@ class CancelSentinel(BaseModel):
|
||||||
|
|
||||||
class TaskRequest(BaseModel):
|
class TaskRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||||
task: tuple[
|
task: tuple[str, list]
|
||||||
str,
|
|
||||||
list[CompletionRequestWithRawContent],
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
|
|
@ -327,10 +321,7 @@ class ModelParallelProcessGroup:
|
||||||
|
|
||||||
def run_inference(
|
def run_inference(
|
||||||
self,
|
self,
|
||||||
req: tuple[
|
req: tuple[str, list],
|
||||||
str,
|
|
||||||
list[CompletionRequestWithRawContent],
|
|
||||||
],
|
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,14 @@ from llama_stack.apis.common.content_types import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
OpenAIFile,
|
OpenAIFile,
|
||||||
|
OpenAIMessageParam,
|
||||||
|
OpenAISystemMessageParam,
|
||||||
|
OpenAIToolMessageParam,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
|
@ -37,6 +42,7 @@ from llama_stack.models.llama.datatypes import (
|
||||||
RawMessage,
|
RawMessage,
|
||||||
RawTextItem,
|
RawTextItem,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
@ -128,6 +134,36 @@ async def interleaved_content_convert_to_raw(
|
||||||
return await _localize_single(content)
|
return await _localize_single(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage:
|
||||||
|
"""Convert OpenAI message format to RawMessage format used by Llama formatters."""
|
||||||
|
if isinstance(message, OpenAIUserMessageParam):
|
||||||
|
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
|
||||||
|
return RawMessage(role="user", content=content)
|
||||||
|
elif isinstance(message, OpenAISystemMessageParam):
|
||||||
|
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
|
||||||
|
return RawMessage(role="system", content=content)
|
||||||
|
elif isinstance(message, OpenAIAssistantMessageParam):
|
||||||
|
content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type]
|
||||||
|
tool_calls = []
|
||||||
|
if message.tool_calls:
|
||||||
|
for tc in message.tool_calls:
|
||||||
|
if tc.function:
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
call_id=tc.id or "",
|
||||||
|
tool_name=tc.function.name or "",
|
||||||
|
arguments=tc.function.arguments or "{}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return RawMessage(role="assistant", content=content, tool_calls=tool_calls)
|
||||||
|
elif isinstance(message, OpenAIToolMessageParam):
|
||||||
|
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
|
||||||
|
return RawMessage(role="tool", content=content)
|
||||||
|
else:
|
||||||
|
# Handle OpenAIDeveloperMessageParam if needed
|
||||||
|
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
|
|
||||||
def content_has_media(content: InterleavedContent):
|
def content_has_media(content: InterleavedContent):
|
||||||
def _has_media_content(c):
|
def _has_media_content(c):
|
||||||
return isinstance(c, ImageContentItem)
|
return isinstance(c, ImageContentItem)
|
||||||
|
|
|
||||||
5
tests/unit/providers/inline/inference/__init__.py
Normal file
5
tests/unit/providers/inline/inference/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.inference.meta_reference.model_parallel import (
|
||||||
|
ModelRunner,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelRunner:
|
||||||
|
"""Test ModelRunner task dispatching for model-parallel inference."""
|
||||||
|
|
||||||
|
def test_chat_completion_task_dispatch(self):
|
||||||
|
"""Verify ModelRunner correctly dispatches chat_completion tasks."""
|
||||||
|
# Create a mock generator
|
||||||
|
mock_generator = Mock()
|
||||||
|
mock_generator.chat_completion = Mock(return_value=iter([]))
|
||||||
|
|
||||||
|
runner = ModelRunner(mock_generator)
|
||||||
|
|
||||||
|
# Create a chat_completion task
|
||||||
|
fake_params = {"model": "test"}
|
||||||
|
fake_messages = [{"role": "user", "content": "test"}]
|
||||||
|
task = ("chat_completion", [fake_params, fake_messages])
|
||||||
|
|
||||||
|
# Execute task
|
||||||
|
runner(task)
|
||||||
|
|
||||||
|
# Verify chat_completion was called with correct arguments
|
||||||
|
mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages)
|
||||||
|
|
||||||
|
def test_invalid_task_type_raises_error(self):
|
||||||
|
"""Verify ModelRunner rejects invalid task types."""
|
||||||
|
mock_generator = Mock()
|
||||||
|
runner = ModelRunner(mock_generator)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unexpected task type"):
|
||||||
|
runner(("invalid_task", []))
|
||||||
35
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
35
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
)
|
||||||
|
from llama_stack.models.llama.datatypes import RawTextItem
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
convert_openai_message_to_raw_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertOpenAIMessageToRawMessage:
|
||||||
|
"""Test conversion of OpenAI message types to RawMessage format."""
|
||||||
|
|
||||||
|
async def test_user_message_conversion(self):
|
||||||
|
msg = OpenAIUserMessageParam(role="user", content="Hello world")
|
||||||
|
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||||
|
|
||||||
|
assert raw_msg.role == "user"
|
||||||
|
assert isinstance(raw_msg.content, RawTextItem)
|
||||||
|
assert raw_msg.content.text == "Hello world"
|
||||||
|
|
||||||
|
async def test_assistant_message_conversion(self):
|
||||||
|
msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!")
|
||||||
|
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||||
|
|
||||||
|
assert raw_msg.role == "assistant"
|
||||||
|
assert isinstance(raw_msg.content, RawTextItem)
|
||||||
|
assert raw_msg.content.text == "Hi there!"
|
||||||
|
assert raw_msg.tool_calls == []
|
||||||
Loading…
Add table
Add a link
Reference in a new issue