Merge branch 'main' into fix/nvidia-launch-customization

This commit is contained in:
Jash Gulabrai 2025-04-25 16:01:22 -04:00
commit 6659ed995a
53 changed files with 2203 additions and 217 deletions

View file

@ -75,19 +75,24 @@ def openai_client(client_with_models):
return OpenAI(base_url=base_url, api_key="bar")
@pytest.fixture(params=["openai_client", "llama_stack_client"])
def compat_client(request):
return request.getfixturevalue(request.param)
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:sanity",
],
)
def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
def test_openai_completion_non_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
# ollama needs more verbose prompting for some reason here...
prompt = "Respond to this question and explain your answer. " + tc["content"]
response = openai_client.completions.create(
response = llama_stack_client.completions.create(
model=text_model_id,
prompt=prompt,
stream=False,
@ -103,13 +108,13 @@ def test_openai_completion_non_streaming(openai_client, client_with_models, text
"inference:completion:sanity",
],
)
def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
def test_openai_completion_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
# ollama needs more verbose prompting for some reason here...
prompt = "Respond to this question and explain your answer. " + tc["content"]
response = openai_client.completions.create(
response = llama_stack_client.completions.create(
model=text_model_id,
prompt=prompt,
stream=True,
@ -127,11 +132,11 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod
0,
],
)
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs):
def test_openai_completion_prompt_logprobs(llama_stack_client, client_with_models, text_model_id, prompt_logprobs):
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
prompt = "Hello, world!"
response = openai_client.completions.create(
response = llama_stack_client.completions.create(
model=text_model_id,
prompt=prompt,
stream=False,
@ -144,11 +149,11 @@ def test_openai_completion_prompt_logprobs(openai_client, client_with_models, te
assert len(choice.prompt_logprobs) > 0
def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id):
def test_openai_completion_guided_choice(llama_stack_client, client_with_models, text_model_id):
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
prompt = "I am feeling really sad today."
response = openai_client.completions.create(
response = llama_stack_client.completions.create(
model=text_model_id,
prompt=prompt,
stream=False,
@ -161,6 +166,9 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
assert choice.text in ["joy", "sadness"]
# Run the chat-completion tests with both the OpenAI client and the LlamaStack client
@pytest.mark.parametrize(
"test_case",
[
@ -168,13 +176,13 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
"inference:chat_completion:non_streaming_02",
],
)
def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
def test_openai_chat_completion_non_streaming(compat_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = openai_client.chat.completions.create(
response = compat_client.chat.completions.create(
model=text_model_id,
messages=[
{
@ -196,13 +204,13 @@ def test_openai_chat_completion_non_streaming(openai_client, client_with_models,
"inference:chat_completion:streaming_02",
],
)
def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
def test_openai_chat_completion_streaming(compat_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = openai_client.chat.completions.create(
response = compat_client.chat.completions.create(
model=text_model_id,
messages=[{"role": "user", "content": question}],
stream=True,

View file

@ -28,12 +28,15 @@ from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionMessage,
SystemMessage,
ToolChoice,
ToolConfig,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import (
VLLMInferenceAdapter,
@ -135,6 +138,49 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
assert request.tool_config.tool_choice == ToolChoice.none
@pytest.mark.asyncio
async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format."""
# Patch the call to vllm so we can inspect the arguments sent were correct
with patch.object(
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
) as mock_nonstream_completion:
messages = [
SystemMessage(content="You are a helpful assistant"),
UserMessage(content="How many?"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
call_id="foo",
tool_name="knowledge_search",
arguments={"query": "How many?"},
arguments_json='{"query": "How many?"}',
)
],
),
ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."),
]
await vllm_inference_adapter.chat_completion(
"mock-model",
messages,
stream=False,
tools=[],
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
{
"id": "foo",
"type": "function",
"function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'},
}
]
@pytest.mark.asyncio
async def test_tool_call_delta_empty_tool_call_buf():
"""

View file

@ -0,0 +1,201 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
from unittest.mock import MagicMock, patch
import pytest
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
MOCK_DATASET_ID = "default/test-dataset"
MOCK_BENCHMARK_ID = "test-benchmark"
class TestNVIDIAEvalImpl(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
# Create mock APIs
self.datasetio_api = MagicMock()
self.datasets_api = MagicMock()
self.scoring_api = MagicMock()
self.inference_api = MagicMock()
self.agents_api = MagicMock()
self.config = NVIDIAEvalConfig(
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
)
self.eval_impl = NVIDIAEvalImpl(
config=self.config,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
scoring_api=self.scoring_api,
inference_api=self.inference_api,
agents_api=self.agents_api,
)
# Mock the HTTP request methods
self.evaluator_get_patcher = patch(
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get"
)
self.evaluator_post_patcher = patch(
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
)
self.mock_evaluator_get = self.evaluator_get_patcher.start()
self.mock_evaluator_post = self.evaluator_post_patcher.start()
def tearDown(self):
"""Clean up after each test."""
self.evaluator_get_patcher.stop()
self.evaluator_post_patcher.stop()
def _assert_request_body(self, expected_json):
"""Helper method to verify request body in Evaluator POST request is correct"""
call_args = self.mock_evaluator_post.call_args
actual_json = call_args[0][1]
# Check that all expected keys contain the expected values in the actual JSON
for key, value in expected_json.items():
assert key in actual_json, f"Key '{key}' missing in actual JSON"
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
else:
assert actual_json[key] == value, f"Value mismatch for '{key}'"
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def test_register_benchmark(self):
eval_config = {
"type": "custom",
"params": {"parallelism": 8},
"tasks": {
"qa": {
"type": "completion",
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
}
},
}
benchmark = Benchmark(
provider_id="nvidia",
type="benchmark",
identifier=MOCK_BENCHMARK_ID,
dataset_id=MOCK_DATASET_ID,
scoring_functions=["basic::equality"],
metadata=eval_config,
)
# Mock Evaluator API response
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
self.mock_evaluator_post.return_value = mock_evaluator_response
# Register the benchmark
self.run_async(self.eval_impl.register_benchmark(benchmark))
# Verify the Evaluator API was called correctly
self.mock_evaluator_post.assert_called_once()
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
def test_run_eval(self):
benchmark_config = BenchmarkConfig(
eval_candidate=ModelCandidate(
type="model",
model=CoreModelId.llama3_1_8b_instruct.value,
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
)
)
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "created"}
self.mock_evaluator_post.return_value = mock_evaluator_response
# Run the Evaluation job
result = self.run_async(
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
)
# Verify the Evaluator API was called correctly
self.mock_evaluator_post.assert_called_once()
self._assert_request_body(
{
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
}
)
# Verify the result
assert isinstance(result, Job)
assert result.job_id == "job-123"
assert result.status == JobStatus.in_progress
def test_job_status(self):
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "completed"}
self.mock_evaluator_get.return_value = mock_evaluator_response
# Get the Evaluation job
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
# Verify the result
assert isinstance(result, Job)
assert result.job_id == "job-123"
assert result.status == JobStatus.completed
# Verify the API was called correctly
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
def test_job_cancel(self):
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
self.mock_evaluator_post.return_value = mock_evaluator_response
# Cancel the Evaluation job
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
# Verify the API was called correctly
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
def test_job_result(self):
# Mock Evaluator API responses
mock_job_status_response = {"id": "job-123", "status": "completed"}
mock_job_results_response = {
"id": "job-123",
"status": "completed",
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
}
self.mock_evaluator_get.side_effect = [
mock_job_status_response, # First call to retrieve job
mock_job_results_response, # Second call to retrieve job results
]
# Get the Evaluation job results
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
# Verify the result
assert isinstance(result, EvaluateResponse)
assert MOCK_BENCHMARK_ID in result.scores
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
# Verify the API was called correctly
assert self.mock_evaluator_get.call_count == 2
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")

View file

@ -11,6 +11,7 @@ from unittest.mock import patch
import pytest
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.post_training.post_training import (
DataConfig,
DatasetFormat,
@ -21,6 +22,7 @@ from llama_stack.apis.post_training.post_training import (
TrainingConfig,
)
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
from llama_stack.providers.remote.post_training.nvidia.post_training import (
ListNvidiaPostTrainingJobs,
NvidiaPostTrainingAdapter,
@ -44,8 +46,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
)
self.mock_make_request = self.make_request_patcher.start()
# Mock the inference client
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
self.mock_client = unittest.mock.MagicMock()
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
self.inference_mock_make_request = self.mock_client.chat.completions.create
self.inference_make_request_patcher = patch(
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
return_value=self.mock_client,
)
self.inference_make_request_patcher.start()
def tearDown(self):
self.make_request_patcher.stop()
self.inference_make_request_patcher.stop()
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
@ -316,6 +332,31 @@ class TestNvidiaPostTraining(unittest.TestCase):
expected_params={"job_id": job_id},
)
def test_inference_register_model(self):
model_id = "default/job-1234"
model_type = ModelType.llm
model = Model(
identifier=model_id,
provider_id="nvidia",
provider_model_id=model_id,
provider_resource_id=model_id,
model_type=model_type,
)
result = self.run_async(self.inference_adapter.register_model(model))
assert result == model
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
self.run_async(
self.inference_adapter.chat_completion(
model_id=model_id,
messages=[{"role": "user", "content": "Hello, model"}],
)
)
mock_chat_completion.assert_called()
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference.inference import CompletionMessage, UserMessage
from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict():
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
assert await convert_message_to_openai_dict(message) == {
"role": "user",
"content": [{"type": "text", "text": "Hello, world!"}],
}
# Test convert_message_to_openai_dict with a tool call
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict_with_tool_call():
message = CompletionMessage(
content="",
tool_calls=[
ToolCall(call_id="123", tool_name="test_tool", arguments_json='{"foo": "bar"}', arguments={"foo": "bar"})
],
stop_reason=StopReason.end_of_turn,
)
openai_dict = await convert_message_to_openai_dict(message)
assert openai_dict == {
"role": "assistant",
"content": [{"type": "text", "text": ""}],
"tool_calls": [
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
],
}

View file

@ -47,9 +47,45 @@ async def test_sse_generator_client_disconnected():
sse_gen = sse_generator(async_event_gen())
assert sse_gen is not None
# Start reading the events, ensuring this doesn't raise an exception
seen_events = []
async for event in sse_gen:
seen_events.append(event)
# We should see 1 event before the client disconnected
assert len(seen_events) == 1
assert seen_events[0] == create_sse_event("Test event 1")
@pytest.mark.asyncio
async def test_sse_generator_client_disconnected_before_response_starts():
# Disconnect before the response starts
async def async_event_gen():
raise asyncio.CancelledError()
sse_gen = sse_generator(async_event_gen())
assert sse_gen is not None
seen_events = []
async for event in sse_gen:
seen_events.append(event)
# No events should be seen since the client disconnected immediately
assert len(seen_events) == 0
@pytest.mark.asyncio
async def test_sse_generator_error_before_response_starts():
# Raise an error before the response starts
async def async_event_gen():
raise Exception("Test error")
sse_gen = sse_generator(async_event_gen())
assert sse_gen is not None
seen_events = []
async for event in sse_gen:
seen_events.append(event)
# We should have 1 error event
assert len(seen_events) == 1
assert 'data: {"error":' in seen_events[0]