diff --git a/src/llama_stack/providers/inline/inference/meta_reference/generators.py b/src/llama_stack/providers/inline/inference/meta_reference/generators.py index bed9020d5..02494120b 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -13,11 +13,12 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken from llama_stack.apis.inference import ( GreedySamplingStrategy, JsonSchemaResponseFormat, + OpenAIChatCompletionRequestWithExtraBody, ResponseFormat, SamplingParams, TopPSamplingStrategy, ) -from llama_stack.models.llama.datatypes import QuantizationMode +from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat from llama_stack.models.llama.llama3.generation import Llama3 from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.generation import Llama4 @@ -142,3 +143,53 @@ class LlamaGenerator: self.tokenizer = self.inner_generator.tokenizer self.args = self.inner_generator.args self.formatter = self.inner_generator.formatter + + def chat_completion( + self, + request: OpenAIChatCompletionRequestWithExtraBody, + raw_messages: list, + ): + """Generate chat completion using OpenAI request format. + + Args: + request: OpenAI chat completion request + raw_messages: Pre-converted list of RawMessage objects + """ + + # Determine tool prompt format + tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json + + # Prepare sampling params + sampling_params = SamplingParams() + if request.temperature is not None or request.top_p is not None: + sampling_params.strategy = TopPSamplingStrategy( + temperature=request.temperature or 1.0, top_p=request.top_p or 1.0 + ) + if request.max_tokens: + sampling_params.max_tokens = request.max_tokens + + max_gen_len = sampling_params.max_tokens + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: + max_gen_len = self.args.max_seq_len - 1 + + temperature, top_p = _infer_sampling_params(sampling_params) + + # Get logits processor for response format + logits_processor = None + if request.response_format: + if isinstance(request.response_format, dict) and request.response_format.get("type") == "json_schema": + json_schema_format = JsonSchemaResponseFormat( + type="json_schema", json_schema=request.response_format.get("json_schema", {}) + ) + logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format) + + # Generate + yield from self.inner_generator.generate( + llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)], + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=False, + echo=False, + logits_processor=logits_processor, + ) diff --git a/src/llama_stack/providers/inline/inference/meta_reference/inference.py b/src/llama_stack/providers/inline/inference/meta_reference/inference.py index 76d3fdd50..5acf0e26d 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -5,11 +5,16 @@ # the root directory of this source tree. import asyncio +import time +import uuid from collections.abc import AsyncIterator from llama_stack.apis.inference import ( InferenceProvider, + OpenAIAssistantMessageParam, OpenAIChatCompletionRequestWithExtraBody, + OpenAIChatCompletionUsage, + OpenAIChoice, OpenAICompletionRequestWithExtraBody, ) from llama_stack.apis.inference.inference import ( @@ -136,10 +141,14 @@ class MetaReferenceInferenceImpl( self.llama_model = llama_model log.info("Warming up...") + from llama_stack.apis.inference import OpenAIUserMessageParam + await self.openai_chat_completion( - model=model_id, - messages=[{"role": "user", "content": "Hi how are you?"}], - max_tokens=20, + params=OpenAIChatCompletionRequestWithExtraBody( + model=model_id, + messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")], + max_tokens=20, + ) ) log.info("Warmed up!") @@ -155,4 +164,50 @@ class MetaReferenceInferenceImpl( self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider") + self.check_model(params) + + # Convert OpenAI messages to RawMessages + from llama_stack.providers.utils.inference.prompt_adapter import convert_openai_message_to_raw_message + + raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages] + + # Call generator's chat_completion method (works for both single-GPU and model-parallel) + if isinstance(self.generator, LlamaGenerator): + generator = self.generator.chat_completion(params, raw_messages) + else: + # Model parallel: submit task to process group + generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages])) + + # Collect all generated text + generated_text = "" + for result_batch in generator: + for result in result_batch: + if not result.ignore_token and result.source == "output": + generated_text += result.text + + # Create OpenAI response + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + + return OpenAIChatCompletion( + id=response_id, + object="chat.completion", + created=created, + model=params.model, + choices=[ + OpenAIChoice( + index=0, + message=OpenAIAssistantMessageParam( + role="assistant", + content=generated_text, + ), + finish_reason="stop", + logprobs=None, + ) + ], + usage=OpenAIChatCompletionUsage( + prompt_tokens=0, # TODO: calculate properly + completion_tokens=0, # TODO: calculate properly + total_tokens=0, # TODO: calculate properly + ), + ) diff --git a/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index 4d3d1e078..f50b41f34 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -19,7 +19,13 @@ class ModelRunner: self.llama = llama def __call__(self, task: Any): - raise ValueError(f"Unexpected task type {task[0]}") + task_type = task[0] + if task_type == "chat_completion": + # task[1] is [params, raw_messages] + params, raw_messages = task[1] + return self.llama.chat_completion(params, raw_messages) + else: + raise ValueError(f"Unexpected task type {task_type}") def init_model_cb( diff --git a/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 450594dd8..663e4793b 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -33,9 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import GenerationResult -from llama_stack.providers.utils.inference.prompt_adapter import ( - CompletionRequestWithRawContent, -) log = get_logger(name=__name__, category="inference") @@ -68,10 +65,7 @@ class CancelSentinel(BaseModel): class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request - task: tuple[ - str, - list[CompletionRequestWithRawContent], - ] + task: tuple[str, list] class TaskResponse(BaseModel): @@ -327,10 +321,7 @@ class ModelParallelProcessGroup: def run_inference( self, - req: tuple[ - str, - list[CompletionRequestWithRawContent], - ], + req: tuple[str, list], ) -> Generator: assert not self.running, "inference already running" diff --git a/src/llama_stack/providers/utils/inference/prompt_adapter.py b/src/llama_stack/providers/utils/inference/prompt_adapter.py index 0aa06af1d..35a7b3484 100644 --- a/src/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/src/llama_stack/providers/utils/inference/prompt_adapter.py @@ -22,9 +22,14 @@ from llama_stack.apis.common.content_types import ( ) from llama_stack.apis.inference import ( CompletionRequest, + OpenAIAssistantMessageParam, OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam, OpenAIFile, + OpenAIMessageParam, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, ResponseFormat, ResponseFormatType, ToolChoice, @@ -37,6 +42,7 @@ from llama_stack.models.llama.datatypes import ( RawMessage, RawTextItem, StopReason, + ToolCall, ToolDefinition, ToolPromptFormat, ) @@ -128,6 +134,36 @@ async def interleaved_content_convert_to_raw( return await _localize_single(content) +async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage: + """Convert OpenAI message format to RawMessage format used by Llama formatters.""" + if isinstance(message, OpenAIUserMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="user", content=content) + elif isinstance(message, OpenAISystemMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="system", content=content) + elif isinstance(message, OpenAIAssistantMessageParam): + content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type] + tool_calls = [] + if message.tool_calls: + for tc in message.tool_calls: + if tc.function: + tool_calls.append( + ToolCall( + call_id=tc.id or "", + tool_name=tc.function.name or "", + arguments=tc.function.arguments or "{}", + ) + ) + return RawMessage(role="assistant", content=content, tool_calls=tool_calls) + elif isinstance(message, OpenAIToolMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="tool", content=content) + else: + # Handle OpenAIDeveloperMessageParam if needed + raise ValueError(f"Unsupported message type: {type(message)}") + + def content_has_media(content: InterleavedContent): def _has_media_content(c): return isinstance(c, ImageContentItem) diff --git a/tests/unit/providers/inline/inference/__init__.py b/tests/unit/providers/inline/inference/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/inline/inference/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/inline/inference/test_meta_reference.py b/tests/unit/providers/inline/inference/test_meta_reference.py new file mode 100644 index 000000000..381836397 --- /dev/null +++ b/tests/unit/providers/inline/inference/test_meta_reference.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import Mock + +import pytest + +from llama_stack.providers.inline.inference.meta_reference.model_parallel import ( + ModelRunner, +) + + +class TestModelRunner: + """Test ModelRunner task dispatching for model-parallel inference.""" + + def test_chat_completion_task_dispatch(self): + """Verify ModelRunner correctly dispatches chat_completion tasks.""" + # Create a mock generator + mock_generator = Mock() + mock_generator.chat_completion = Mock(return_value=iter([])) + + runner = ModelRunner(mock_generator) + + # Create a chat_completion task + fake_params = {"model": "test"} + fake_messages = [{"role": "user", "content": "test"}] + task = ("chat_completion", [fake_params, fake_messages]) + + # Execute task + runner(task) + + # Verify chat_completion was called with correct arguments + mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages) + + def test_invalid_task_type_raises_error(self): + """Verify ModelRunner rejects invalid task types.""" + mock_generator = Mock() + runner = ModelRunner(mock_generator) + + with pytest.raises(ValueError, match="Unexpected task type"): + runner(("invalid_task", [])) diff --git a/tests/unit/providers/utils/inference/test_prompt_adapter.py b/tests/unit/providers/utils/inference/test_prompt_adapter.py new file mode 100644 index 000000000..62c8db74d --- /dev/null +++ b/tests/unit/providers/utils/inference/test_prompt_adapter.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIUserMessageParam, +) +from llama_stack.models.llama.datatypes import RawTextItem +from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_openai_message_to_raw_message, +) + + +class TestConvertOpenAIMessageToRawMessage: + """Test conversion of OpenAI message types to RawMessage format.""" + + async def test_user_message_conversion(self): + msg = OpenAIUserMessageParam(role="user", content="Hello world") + raw_msg = await convert_openai_message_to_raw_message(msg) + + assert raw_msg.role == "user" + assert isinstance(raw_msg.content, RawTextItem) + assert raw_msg.content.text == "Hello world" + + async def test_assistant_message_conversion(self): + msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!") + raw_msg = await convert_openai_message_to_raw_message(msg) + + assert raw_msg.role == "assistant" + assert isinstance(raw_msg.content, RawTextItem) + assert raw_msg.content.text == "Hi there!" + assert raw_msg.tool_calls == []