mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
feat: implement OpenAI chat completion for meta_reference provider
- Add chat_completion() method to LlamaGenerator supporting OpenAI request format - Implement openai_chat_completion() in MetaReferenceInferenceImpl - Fix ModelRunner task dispatch to handle chat_completion tasks - Add convert_openai_message_to_raw_message() utility for message conversion - Add unit tests for message conversion and model-parallel dispatch - Remove unused CompletionRequestWithRawContent references Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
54754cb2a7
commit
7574f147b6
8 changed files with 240 additions and 17 deletions
5
tests/unit/providers/inline/inference/__init__.py
Normal file
5
tests/unit/providers/inline/inference/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.inline.inference.meta_reference.model_parallel import (
|
||||
ModelRunner,
|
||||
)
|
||||
|
||||
|
||||
class TestModelRunner:
|
||||
"""Test ModelRunner task dispatching for model-parallel inference."""
|
||||
|
||||
def test_chat_completion_task_dispatch(self):
|
||||
"""Verify ModelRunner correctly dispatches chat_completion tasks."""
|
||||
# Create a mock generator
|
||||
mock_generator = Mock()
|
||||
mock_generator.chat_completion = Mock(return_value=iter([]))
|
||||
|
||||
runner = ModelRunner(mock_generator)
|
||||
|
||||
# Create a chat_completion task
|
||||
fake_params = {"model": "test"}
|
||||
fake_messages = [{"role": "user", "content": "test"}]
|
||||
task = ("chat_completion", [fake_params, fake_messages])
|
||||
|
||||
# Execute task
|
||||
runner(task)
|
||||
|
||||
# Verify chat_completion was called with correct arguments
|
||||
mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages)
|
||||
|
||||
def test_invalid_task_type_raises_error(self):
|
||||
"""Verify ModelRunner rejects invalid task types."""
|
||||
mock_generator = Mock()
|
||||
runner = ModelRunner(mock_generator)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected task type"):
|
||||
runner(("invalid_task", []))
|
||||
35
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
35
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import RawTextItem
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_openai_message_to_raw_message,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertOpenAIMessageToRawMessage:
|
||||
"""Test conversion of OpenAI message types to RawMessage format."""
|
||||
|
||||
async def test_user_message_conversion(self):
|
||||
msg = OpenAIUserMessageParam(role="user", content="Hello world")
|
||||
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||
|
||||
assert raw_msg.role == "user"
|
||||
assert isinstance(raw_msg.content, RawTextItem)
|
||||
assert raw_msg.content.text == "Hello world"
|
||||
|
||||
async def test_assistant_message_conversion(self):
|
||||
msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!")
|
||||
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||
|
||||
assert raw_msg.role == "assistant"
|
||||
assert isinstance(raw_msg.content, RawTextItem)
|
||||
assert raw_msg.content.text == "Hi there!"
|
||||
assert raw_msg.tool_calls == []
|
||||
Loading…
Add table
Add a link
Reference in a new issue