Merge branch 'main' into sambanova-inferene

This commit is contained in:
snova-edwardm 2025-01-14 10:04:52 -08:00 committed by GitHub
commit 89ab2be302
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
385 changed files with 39001 additions and 9280 deletions

View file

@ -18,6 +18,12 @@ def pytest_addoption(parser):
default=None,
help="Specify the inference model to use for testing",
)
parser.addoption(
"--embedding-model",
action="store",
default=None,
help="Specify the embedding model to use for testing",
)
def pytest_configure(config):

View file

@ -9,16 +9,18 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
)
from llama_stack.providers.inline.inference.vllm import VLLMConfig
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
@ -48,6 +50,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
# If embedding dimension is set, use the 8B model for testing
if os.getenv("EMBEDDING_DIMENSION"):
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
return ProviderFixture(
providers=[
@ -86,7 +91,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
if "Llama3.1-8B-Instruct" in inference_model:
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
return ProviderFixture(
@ -102,6 +107,26 @@ def inference_ollama(inference_model) -> ProviderFixture:
)
@pytest_asyncio.fixture(scope="session")
def inference_vllm(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
return ProviderFixture(
providers=[
Provider(
provider_id=f"vllm-{i}",
provider_type="inline::vllm",
config=VLLMConfig(
model=m,
enforce_eager=True, # Make test run faster
).model_dump(),
)
for i, m in enumerate(inference_model)
]
)
@pytest.fixture(scope="session")
def inference_vllm_remote() -> ProviderFixture:
return ProviderFixture(
@ -111,6 +136,7 @@ def inference_vllm_remote() -> ProviderFixture:
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig(
url=get_env_or_fail("VLLM_URL"),
max_tokens=int(os.getenv("VLLM_MAX_TOKENS", 2048)),
).model_dump(),
)
],
@ -148,6 +174,22 @@ def inference_together() -> ProviderFixture:
)
@pytest.fixture(scope="session")
def inference_groq() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="groq",
provider_type="remote::groq",
config=GroqConfig().model_dump(),
)
],
provider_data=dict(
groq_api_key=get_env_or_fail("GROQ_API_KEY"),
),
)
@pytest.fixture(scope="session")
def inference_bedrock() -> ProviderFixture:
return ProviderFixture(
@ -208,6 +250,18 @@ def inference_sambanova() -> ProviderFixture:
)
def inference_sentence_transformers() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="sentence_transformers",
provider_type="inline::sentence-transformers",
config={},
)
]
)
def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier.
@ -238,6 +292,8 @@ INFERENCE_FIXTURES = [
"ollama",
"fireworks",
"together",
"vllm",
"groq",
"vllm_remote",
"remote",
"bedrock",
@ -252,11 +308,27 @@ INFERENCE_FIXTURES = [
async def inference_stack(request, inference_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
model_type = ModelType.llm
metadata = {}
if os.getenv("EMBEDDING_DIMENSION"):
model_type = ModelType.embedding
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
test_stack = await construct_stack_for_test(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[ModelInput(model_id=inference_model)],
models=[
ModelInput(
model_id=inference_model,
model_type=model_type,
metadata=metadata,
)
],
)
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
# Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended
yield test_stack.impls[Api.inference], test_stack.impls[Api.models]
# Cleanup code that runs after test case completion
await test_stack.impls[Api.inference].shutdown()

View file

@ -0,0 +1,518 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pytest
from groq.types.chat.chat_completion import ChatCompletion, Choice
from groq.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
Choice as StreamChoice,
ChoiceDelta,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from groq.types.chat.chat_completion_message import ChatCompletionMessage
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
CompletionMessage,
StopReason,
SystemMessage,
ToolCall,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.providers.remote.inference.groq.groq_utils import (
convert_chat_completion_request,
convert_chat_completion_response,
convert_chat_completion_response_stream,
)
class TestConvertChatCompletionRequest:
def test_sets_model(self):
request = self._dummy_chat_completion_request()
request.model = "Llama-3.2-3B"
converted = convert_chat_completion_request(request)
assert converted["model"] == "Llama-3.2-3B"
def test_converts_user_message(self):
request = self._dummy_chat_completion_request()
request.messages = [UserMessage(content="Hello World")]
converted = convert_chat_completion_request(request)
assert converted["messages"] == [
{"role": "user", "content": "Hello World"},
]
def test_converts_system_message(self):
request = self._dummy_chat_completion_request()
request.messages = [SystemMessage(content="You are a helpful assistant.")]
converted = convert_chat_completion_request(request)
assert converted["messages"] == [
{"role": "system", "content": "You are a helpful assistant."},
]
def test_converts_completion_message(self):
request = self._dummy_chat_completion_request()
request.messages = [
UserMessage(content="Hello World"),
CompletionMessage(
content="Hello World! How can I help you today?",
stop_reason=StopReason.end_of_message,
),
]
converted = convert_chat_completion_request(request)
assert converted["messages"] == [
{"role": "user", "content": "Hello World"},
{"role": "assistant", "content": "Hello World! How can I help you today?"},
]
def test_does_not_include_logprobs(self):
request = self._dummy_chat_completion_request()
request.logprobs = True
with pytest.warns(Warning) as warnings:
converted = convert_chat_completion_request(request)
assert "logprobs are not supported yet" in warnings[0].message.args[0]
assert converted.get("logprobs") is None
def test_does_not_include_response_format(self):
request = self._dummy_chat_completion_request()
request.response_format = {
"type": "json_object",
"json_schema": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"},
},
},
}
with pytest.warns(Warning) as warnings:
converted = convert_chat_completion_request(request)
assert "response_format is not supported yet" in warnings[0].message.args[0]
assert converted.get("response_format") is None
def test_does_not_include_repetition_penalty(self):
request = self._dummy_chat_completion_request()
request.sampling_params.repetition_penalty = 1.5
with pytest.warns(Warning) as warnings:
converted = convert_chat_completion_request(request)
assert "repetition_penalty is not supported" in warnings[0].message.args[0]
assert converted.get("repetition_penalty") is None
assert converted.get("frequency_penalty") is None
def test_includes_stream(self):
request = self._dummy_chat_completion_request()
request.stream = True
converted = convert_chat_completion_request(request)
assert converted["stream"] is True
def test_if_max_tokens_is_0_then_it_is_not_included(self):
request = self._dummy_chat_completion_request()
# 0 is the default value for max_tokens
# So we assume that if it's 0, the user didn't set it
request.sampling_params.max_tokens = 0
converted = convert_chat_completion_request(request)
assert converted.get("max_tokens") is None
def test_includes_max_tokens_if_set(self):
request = self._dummy_chat_completion_request()
request.sampling_params.max_tokens = 100
converted = convert_chat_completion_request(request)
assert converted["max_tokens"] == 100
def test_includes_temperature(self):
request = self._dummy_chat_completion_request()
request.sampling_params.temperature = 0.5
converted = convert_chat_completion_request(request)
assert converted["temperature"] == 0.5
def test_includes_top_p(self):
request = self._dummy_chat_completion_request()
request.sampling_params.top_p = 0.95
converted = convert_chat_completion_request(request)
assert converted["top_p"] == 0.95
def test_includes_tool_choice(self):
request = self._dummy_chat_completion_request()
request.tool_choice = ToolChoice.required
converted = convert_chat_completion_request(request)
assert converted["tool_choice"] == "required"
def test_includes_tools(self):
request = self._dummy_chat_completion_request()
request.tools = [
ToolDefinition(
tool_name="get_flight_info",
description="Get fight information between two destinations.",
parameters={
"origin": ToolParamDefinition(
param_type="string",
description="The origin airport code. E.g., AU",
required=True,
),
"destination": ToolParamDefinition(
param_type="string",
description="The destination airport code. E.g., 'LAX'",
required=True,
),
"passengers": ToolParamDefinition(
param_type="array",
description="The passengers",
required=False,
),
},
),
ToolDefinition(
tool_name="log",
description="Calulate the logarithm of a number",
parameters={
"number": ToolParamDefinition(
param_type="float",
description="The number to calculate the logarithm of",
required=True,
),
"base": ToolParamDefinition(
param_type="integer",
description="The base of the logarithm",
required=False,
default=10,
),
},
),
]
converted = convert_chat_completion_request(request)
assert converted["tools"] == [
{
"type": "function",
"function": FunctionDefinition(
name="get_flight_info",
description="Get fight information between two destinations.",
parameters={
"origin": {
"type": "string",
"description": "The origin airport code. E.g., AU",
"required": True,
},
"destination": {
"type": "string",
"description": "The destination airport code. E.g., 'LAX'",
"required": True,
},
"passengers": {
"type": "array",
"description": "The passengers",
"required": False,
},
},
),
},
{
"type": "function",
"function": FunctionDefinition(
name="log",
description="Calulate the logarithm of a number",
parameters={
"number": {
"type": "float",
"description": "The number to calculate the logarithm of",
"required": True,
},
"base": {
"type": "integer",
"description": "The base of the logarithm",
"required": False,
"default": 10,
},
},
),
},
]
def _dummy_chat_completion_request(self):
return ChatCompletionRequest(
model="Llama-3.2-3B",
messages=[UserMessage(content="Hello World")],
)
class TestConvertNonStreamChatCompletionResponse:
def test_returns_response(self):
response = self._dummy_chat_completion_response()
response.choices[0].message.content = "Hello World"
converted = convert_chat_completion_response(response)
assert converted.completion_message.content == "Hello World"
def test_maps_stop_to_end_of_message(self):
response = self._dummy_chat_completion_response()
response.choices[0].finish_reason = "stop"
converted = convert_chat_completion_response(response)
assert converted.completion_message.stop_reason == StopReason.end_of_turn
def test_maps_length_to_end_of_message(self):
response = self._dummy_chat_completion_response()
response.choices[0].finish_reason = "length"
converted = convert_chat_completion_response(response)
assert converted.completion_message.stop_reason == StopReason.out_of_tokens
def test_maps_tool_call_to_end_of_message(self):
response = self._dummy_chat_completion_response_with_tool_call()
converted = convert_chat_completion_response(response)
assert converted.completion_message.stop_reason == StopReason.end_of_message
def test_converts_multiple_tool_calls(self):
response = self._dummy_chat_completion_response_with_tool_call()
response.choices[0].message.tool_calls = [
ChatCompletionMessageToolCall(
id="tool_call_id",
type="function",
function=Function(
name="get_flight_info",
arguments='{"origin": "AU", "destination": "LAX"}',
),
),
ChatCompletionMessageToolCall(
id="tool_call_id_2",
type="function",
function=Function(
name="log",
arguments='{"number": 10, "base": 2}',
),
),
]
converted = convert_chat_completion_response(response)
assert converted.completion_message.tool_calls == [
ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},
),
ToolCall(
call_id="tool_call_id_2",
tool_name="log",
arguments={"number": 10, "base": 2},
),
]
def _dummy_chat_completion_response(self):
return ChatCompletion(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
role="assistant", content="Hello World"
),
finish_reason="stop",
)
],
created=1729382400,
object="chat.completion",
)
def _dummy_chat_completion_response_with_tool_call(self):
return ChatCompletion(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
role="assistant",
tool_calls=[
ChatCompletionMessageToolCall(
id="tool_call_id",
type="function",
function=Function(
name="get_flight_info",
arguments='{"origin": "AU", "destination": "LAX"}',
),
)
],
),
finish_reason="tool_calls",
)
],
created=1729382400,
object="chat.completion",
)
class TestConvertStreamChatCompletionResponse:
@pytest.mark.asyncio
async def test_returns_stream(self):
def chat_completion_stream():
messages = ["Hello ", "World ", " !"]
for i, message in enumerate(messages):
chunk = self._dummy_chat_completion_chunk()
chunk.choices[0].delta.content = message
yield chunk
chunk = self._dummy_chat_completion_chunk()
chunk.choices[0].delta.content = None
chunk.choices[0].finish_reason = "stop"
yield chunk
stream = chat_completion_stream()
converted = convert_chat_completion_response_stream(stream)
iter = converted.__aiter__()
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert chunk.event.delta == "Hello "
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta == "World "
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta == " !"
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
assert chunk.event.delta == ""
assert chunk.event.stop_reason == StopReason.end_of_turn
with pytest.raises(StopAsyncIteration):
await iter.__anext__()
@pytest.mark.asyncio
async def test_returns_tool_calls_stream(self):
def tool_call_stream():
tool_calls = [
ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},
),
ToolCall(
call_id="tool_call_id_2",
tool_name="log",
arguments={"number": 10, "base": 2},
),
]
for i, tool_call in enumerate(tool_calls):
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.tool_calls = [
ChoiceDeltaToolCall(
index=0,
type="function",
id=tool_call.call_id,
function=ChoiceDeltaToolCallFunction(
name=tool_call.tool_name,
arguments=json.dumps(tool_call.arguments),
),
),
]
yield chunk
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.content = None
chunk.choices[0].finish_reason = "stop"
yield chunk
stream = tool_call_stream()
converted = convert_chat_completion_response_stream(stream)
iter = converted.__aiter__()
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert chunk.event.delta.content == ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},
)
def _dummy_chat_completion_chunk(self):
return ChatCompletionChunk(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
StreamChoice(
index=0,
delta=ChoiceDelta(role="assistant", content="Hello World"),
)
],
created=1729382400,
object="chat.completion.chunk",
x_groq=None,
)
def _dummy_chat_completion_chunk_with_tool_call(self):
return ChatCompletionChunk(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
StreamChoice(
index=0,
delta=ChoiceDelta(
role="assistant",
content="Hello World",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
type="function",
function=ChoiceDeltaToolCallFunction(
name="get_flight_info",
arguments='{"origin": "AU", "destination": "LAX"}',
),
)
],
),
)
],
created=1729382400,
object="chat.completion.chunk",
x_groq=None,
)

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.inference import Inference
from llama_stack.providers.remote.inference.groq import get_adapter_impl
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
class TestGroqInit:
@pytest.mark.asyncio
async def test_raises_runtime_error_if_config_is_not_groq_config(self):
config = OllamaImplConfig(model="llama3.1-8b-8192")
with pytest.raises(RuntimeError):
await get_adapter_impl(config, None)
@pytest.mark.asyncio
async def test_returns_groq_adapter(self):
config = GroqConfig()
adapter = await get_adapter_impl(config, None)
assert type(adapter) is GroqInferenceAdapter
assert isinstance(adapter, Inference)

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
# How to run this test:
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
class TestEmbeddings:
@pytest.mark.asyncio
async def test_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding:
pytest.skip("This test is only applicable for embedding models")
response = await inference_impl.embeddings(
model_id=inference_model,
contents=["Hello, world!"],
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) > 0
assert all(isinstance(embedding, list) for embedding in response.embeddings)
assert all(
isinstance(value, float)
for embedding in response.embeddings
for value in embedding
)
@pytest.mark.asyncio
async def test_batch_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding:
pytest.skip("This test is only applicable for embedding models")
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
response = await inference_impl.embeddings(
model_id=inference_model,
contents=texts,
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == len(texts)
assert all(isinstance(embedding, list) for embedding in response.embeddings)
assert all(
isinstance(value, float)
for embedding in response.embeddings
for value in embedding
)
embedding_dim = len(response.embeddings[0])
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)

View file

@ -4,13 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import pytest
# How to run this test:
#
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
# -m "meta_reference"
# torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct"
# ./llama_stack/providers/tests/inference/test_model_registration.py
class TestModelRegistration:
@ -51,16 +53,37 @@ class TestModelRegistration:
_ = await models_impl.register_model(
model_id="custom-model",
metadata={"llama_model": "meta-llama/Llama-2-7b"},
metadata={
"llama_model": "meta-llama/Llama-2-7b",
"skip_load": True,
},
)
with pytest.raises(ValueError) as exc_info:
with pytest.raises(AssertionError) as exc_info:
await models_impl.register_model(
model_id="custom-model-2",
metadata={"llama_model": "meta-llama/Llama-2-7b"},
metadata={
"llama_model": "meta-llama/Llama-2-7b",
},
provider_model_id="custom-model",
)
@pytest.mark.asyncio
async def test_initialize_model_during_registering(self, inference_stack):
_, models_impl = inference_stack
with patch(
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model",
new_callable=AsyncMock,
) as mock_load_model:
_ = await models_impl.register_model(
model_id="Llama3.1-8B-Instruct",
metadata={
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
},
)
mock_load_model.assert_called_once()
@pytest.mark.asyncio
async def test_register_with_invalid_llama_model(self, inference_stack):
_, models_impl = inference_stack

View file

@ -6,8 +6,14 @@
import unittest
from llama_models.llama3.api import * # noqa: F403
from llama_stack.apis.inference.inference import * # noqa: F403
from llama_models.llama3.api.datatypes import (
BuiltinTool,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)

View file

@ -7,13 +7,31 @@
import pytest
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from pydantic import BaseModel, ValidationError
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
LogProbConfig,
SystemMessage,
ToolChoice,
UserMessage,
)
from llama_stack.apis.models import Model
from .utils import group_chunks
@ -67,7 +85,9 @@ def sample_tool_definition():
class TestInference:
@pytest.mark.asyncio
# Session scope for asyncio because the tests in this class all
# share the same provider instance.
@pytest.mark.asyncio(loop_scope="session")
async def test_model_list(self, inference_model, inference_stack):
_, models_impl = inference_stack
response = await models_impl.list_models()
@ -83,7 +103,7 @@ class TestInference:
assert model_def is not None
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_completion(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
@ -94,6 +114,7 @@ class TestInference:
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::cerebras",
):
pytest.skip("Other inference providers don't support completion() yet")
@ -127,19 +148,77 @@ class TestInference:
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_logprobs(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
# "remote::nvidia", -- provider doesn't provide all logprobs
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=5,
),
logprobs=LogProbConfig(
top_k=3,
),
)
assert isinstance(response, CompletionResponse)
assert 1 <= len(response.logprobs) <= 5
assert response.logprobs, "Logprobs should not be empty"
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=5,
),
logprobs=LogProbConfig(
top_k=3,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert (
1 <= len(chunks) <= 6
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
for chunk in chunks:
if (
chunk.delta.type == "text" and chunk.delta.text
): # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.skip("This test is not quite robust")
async def test_completions_structured_output(
self, inference_model, inference_stack
):
async def test_completion_structured_output(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::vllm",
"remote::cerebras",
):
pytest.skip(
@ -171,7 +250,7 @@ class TestInference:
assert answer.year_born == "1963"
assert answer.year_retired == "2003"
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
@ -188,7 +267,7 @@ class TestInference:
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_structured_output(
self, inference_model, inference_stack, common_params
):
@ -197,9 +276,11 @@ class TestInference:
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::fireworks",
"remote::tgi",
"remote::together",
"remote::vllm",
"remote::nvidia",
):
pytest.skip("Other inference providers don't support structured output yet")
@ -257,7 +338,7 @@ class TestInference:
with pytest.raises(ValidationError):
AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
@ -284,7 +365,7 @@ class TestInference:
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling(
self,
inference_model,
@ -294,6 +375,14 @@ class TestInference:
sample_tool_definition,
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
@ -323,7 +412,7 @@ class TestInference:
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling_streaming(
self,
inference_model,
@ -333,6 +422,14 @@ class TestInference:
sample_tool_definition,
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
@ -349,7 +446,6 @@ class TestInference:
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
@ -368,7 +464,7 @@ class TestInference:
if "Llama3.1" in inference_model:
assert all(
isinstance(chunk.event.delta, ToolCallDelta)
chunk.event.delta.type == "tool_call"
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
@ -379,8 +475,8 @@ class TestInference:
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
assert last.event.delta.content.type == "tool_call"
call = last.event.delta.content
assert call.tool_name == "get_weather"

View file

@ -7,16 +7,24 @@
from pathlib import Path
import pytest
from PIL import Image as PIL_Image
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
SamplingParams,
UserMessage,
)
from .utils import group_chunks
THIS_DIR = Path(__file__).parent
with open(THIS_DIR / "pasta.jpeg", "rb") as f:
PASTA_IMAGE = f.read()
class TestVisionModelInference:
@pytest.mark.asyncio
@ -24,12 +32,12 @@ class TestVisionModelInference:
"image, expected_strings",
[
(
ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")),
ImageContentItem(data=PASTA_IMAGE),
["spaghetti"],
),
(
ImageMedia(
image=URL(
ImageContentItem(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
),
@ -59,7 +67,12 @@ class TestVisionModelInference:
model_id=inference_model,
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(content=[image, "Describe this image in two sentences."]),
UserMessage(
content=[
image,
TextContentItem(text="Describe this image in two sentences."),
]
),
],
stream=False,
sampling_params=SamplingParams(max_tokens=100),
@ -91,8 +104,8 @@ class TestVisionModelInference:
)
images = [
ImageMedia(
image=URL(
ImageContentItem(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
),
@ -108,7 +121,12 @@ class TestVisionModelInference:
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(
content=[image, "Describe this image in two sentences."]
content=[
image,
TextContentItem(
text="Describe this image in two sentences."
),
]
),
],
stream=True,