mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 18:24:30 +00:00
Merge branch 'main' into sambanova-inferene
This commit is contained in:
commit
89ab2be302
385 changed files with 39001 additions and 9280 deletions
|
|
@ -18,6 +18,12 @@ def pytest_addoption(parser):
|
|||
default=None,
|
||||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--embedding-model",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the embedding model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
|
|
|
|||
|
|
@ -9,16 +9,18 @@ import os
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelInput
|
||||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.inline.inference.meta_reference import (
|
||||
MetaReferenceInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.vllm import VLLMConfig
|
||||
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||
|
||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||
|
|
@ -48,6 +50,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
|
|||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
# If embedding dimension is set, use the 8B model for testing
|
||||
if os.getenv("EMBEDDING_DIMENSION"):
|
||||
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
|
||||
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
|
|
@ -86,7 +91,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
|||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
if "Llama3.1-8B-Instruct" in inference_model:
|
||||
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
|
||||
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
||||
|
||||
return ProviderFixture(
|
||||
|
|
@ -102,6 +107,26 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
def inference_vllm(inference_model) -> ProviderFixture:
|
||||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id=f"vllm-{i}",
|
||||
provider_type="inline::vllm",
|
||||
config=VLLMConfig(
|
||||
model=m,
|
||||
enforce_eager=True, # Make test run faster
|
||||
).model_dump(),
|
||||
)
|
||||
for i, m in enumerate(inference_model)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_vllm_remote() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
|
|
@ -111,6 +136,7 @@ def inference_vllm_remote() -> ProviderFixture:
|
|||
provider_type="remote::vllm",
|
||||
config=VLLMInferenceAdapterConfig(
|
||||
url=get_env_or_fail("VLLM_URL"),
|
||||
max_tokens=int(os.getenv("VLLM_MAX_TOKENS", 2048)),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
|
|
@ -148,6 +174,22 @@ def inference_together() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_groq() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="groq",
|
||||
provider_type="remote::groq",
|
||||
config=GroqConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
provider_data=dict(
|
||||
groq_api_key=get_env_or_fail("GROQ_API_KEY"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_bedrock() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
|
|
@ -208,6 +250,18 @@ def inference_sambanova() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
def inference_sentence_transformers() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="sentence_transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config={},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_model_short_name(model_name: str) -> str:
|
||||
"""Convert model name to a short test identifier.
|
||||
|
||||
|
|
@ -238,6 +292,8 @@ INFERENCE_FIXTURES = [
|
|||
"ollama",
|
||||
"fireworks",
|
||||
"together",
|
||||
"vllm",
|
||||
"groq",
|
||||
"vllm_remote",
|
||||
"remote",
|
||||
"bedrock",
|
||||
|
|
@ -252,11 +308,27 @@ INFERENCE_FIXTURES = [
|
|||
async def inference_stack(request, inference_model):
|
||||
fixture_name = request.param
|
||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||
model_type = ModelType.llm
|
||||
metadata = {}
|
||||
if os.getenv("EMBEDDING_DIMENSION"):
|
||||
model_type = ModelType.embedding
|
||||
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.inference],
|
||||
{"inference": inference_fixture.providers},
|
||||
inference_fixture.provider_data,
|
||||
models=[ModelInput(model_id=inference_model)],
|
||||
models=[
|
||||
ModelInput(
|
||||
model_id=inference_model,
|
||||
model_type=model_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||
# Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended
|
||||
yield test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||
|
||||
# Cleanup code that runs after test case completion
|
||||
await test_stack.impls[Api.inference].shutdown()
|
||||
|
|
|
|||
518
llama_stack/providers/tests/inference/groq/test_groq_utils.py
Normal file
518
llama_stack/providers/tests/inference/groq/test_groq_utils.py
Normal file
|
|
@ -0,0 +1,518 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from groq.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from groq.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk,
|
||||
Choice as StreamChoice,
|
||||
ChoiceDelta,
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from groq.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from groq.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
from groq.types.shared.function_definition import FunctionDefinition
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.groq.groq_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_chat_completion_response,
|
||||
convert_chat_completion_response_stream,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertChatCompletionRequest:
|
||||
def test_sets_model(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.model = "Llama-3.2-3B"
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["model"] == "Llama-3.2-3B"
|
||||
|
||||
def test_converts_user_message(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.messages = [UserMessage(content="Hello World")]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["messages"] == [
|
||||
{"role": "user", "content": "Hello World"},
|
||||
]
|
||||
|
||||
def test_converts_system_message(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.messages = [SystemMessage(content="You are a helpful assistant.")]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["messages"] == [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
def test_converts_completion_message(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.messages = [
|
||||
UserMessage(content="Hello World"),
|
||||
CompletionMessage(
|
||||
content="Hello World! How can I help you today?",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
),
|
||||
]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["messages"] == [
|
||||
{"role": "user", "content": "Hello World"},
|
||||
{"role": "assistant", "content": "Hello World! How can I help you today?"},
|
||||
]
|
||||
|
||||
def test_does_not_include_logprobs(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.logprobs = True
|
||||
|
||||
with pytest.warns(Warning) as warnings:
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert "logprobs are not supported yet" in warnings[0].message.args[0]
|
||||
assert converted.get("logprobs") is None
|
||||
|
||||
def test_does_not_include_response_format(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.response_format = {
|
||||
"type": "json_object",
|
||||
"json_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "number"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.warns(Warning) as warnings:
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert "response_format is not supported yet" in warnings[0].message.args[0]
|
||||
assert converted.get("response_format") is None
|
||||
|
||||
def test_does_not_include_repetition_penalty(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.repetition_penalty = 1.5
|
||||
|
||||
with pytest.warns(Warning) as warnings:
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert "repetition_penalty is not supported" in warnings[0].message.args[0]
|
||||
assert converted.get("repetition_penalty") is None
|
||||
assert converted.get("frequency_penalty") is None
|
||||
|
||||
def test_includes_stream(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.stream = True
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["stream"] is True
|
||||
|
||||
def test_if_max_tokens_is_0_then_it_is_not_included(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
# 0 is the default value for max_tokens
|
||||
# So we assume that if it's 0, the user didn't set it
|
||||
request.sampling_params.max_tokens = 0
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted.get("max_tokens") is None
|
||||
|
||||
def test_includes_max_tokens_if_set(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.max_tokens = 100
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["max_tokens"] == 100
|
||||
|
||||
def test_includes_temperature(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.temperature = 0.5
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["temperature"] == 0.5
|
||||
|
||||
def test_includes_top_p(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.top_p = 0.95
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["top_p"] == 0.95
|
||||
|
||||
def test_includes_tool_choice(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.tool_choice = ToolChoice.required
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["tool_choice"] == "required"
|
||||
|
||||
def test_includes_tools(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.tools = [
|
||||
ToolDefinition(
|
||||
tool_name="get_flight_info",
|
||||
description="Get fight information between two destinations.",
|
||||
parameters={
|
||||
"origin": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The origin airport code. E.g., AU",
|
||||
required=True,
|
||||
),
|
||||
"destination": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The destination airport code. E.g., 'LAX'",
|
||||
required=True,
|
||||
),
|
||||
"passengers": ToolParamDefinition(
|
||||
param_type="array",
|
||||
description="The passengers",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
),
|
||||
ToolDefinition(
|
||||
tool_name="log",
|
||||
description="Calulate the logarithm of a number",
|
||||
parameters={
|
||||
"number": ToolParamDefinition(
|
||||
param_type="float",
|
||||
description="The number to calculate the logarithm of",
|
||||
required=True,
|
||||
),
|
||||
"base": ToolParamDefinition(
|
||||
param_type="integer",
|
||||
description="The base of the logarithm",
|
||||
required=False,
|
||||
default=10,
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["tools"] == [
|
||||
{
|
||||
"type": "function",
|
||||
"function": FunctionDefinition(
|
||||
name="get_flight_info",
|
||||
description="Get fight information between two destinations.",
|
||||
parameters={
|
||||
"origin": {
|
||||
"type": "string",
|
||||
"description": "The origin airport code. E.g., AU",
|
||||
"required": True,
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The destination airport code. E.g., 'LAX'",
|
||||
"required": True,
|
||||
},
|
||||
"passengers": {
|
||||
"type": "array",
|
||||
"description": "The passengers",
|
||||
"required": False,
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": FunctionDefinition(
|
||||
name="log",
|
||||
description="Calulate the logarithm of a number",
|
||||
parameters={
|
||||
"number": {
|
||||
"type": "float",
|
||||
"description": "The number to calculate the logarithm of",
|
||||
"required": True,
|
||||
},
|
||||
"base": {
|
||||
"type": "integer",
|
||||
"description": "The base of the logarithm",
|
||||
"required": False,
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
def _dummy_chat_completion_request(self):
|
||||
return ChatCompletionRequest(
|
||||
model="Llama-3.2-3B",
|
||||
messages=[UserMessage(content="Hello World")],
|
||||
)
|
||||
|
||||
|
||||
class TestConvertNonStreamChatCompletionResponse:
|
||||
def test_returns_response(self):
|
||||
response = self._dummy_chat_completion_response()
|
||||
response.choices[0].message.content = "Hello World"
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.content == "Hello World"
|
||||
|
||||
def test_maps_stop_to_end_of_message(self):
|
||||
response = self._dummy_chat_completion_response()
|
||||
response.choices[0].finish_reason = "stop"
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.stop_reason == StopReason.end_of_turn
|
||||
|
||||
def test_maps_length_to_end_of_message(self):
|
||||
response = self._dummy_chat_completion_response()
|
||||
response.choices[0].finish_reason = "length"
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.stop_reason == StopReason.out_of_tokens
|
||||
|
||||
def test_maps_tool_call_to_end_of_message(self):
|
||||
response = self._dummy_chat_completion_response_with_tool_call()
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.stop_reason == StopReason.end_of_message
|
||||
|
||||
def test_converts_multiple_tool_calls(self):
|
||||
response = self._dummy_chat_completion_response_with_tool_call()
|
||||
response.choices[0].message.tool_calls = [
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tool_call_id",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_flight_info",
|
||||
arguments='{"origin": "AU", "destination": "LAX"}',
|
||||
),
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tool_call_id_2",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="log",
|
||||
arguments='{"number": 10, "base": 2}',
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.tool_calls == [
|
||||
ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name="get_flight_info",
|
||||
arguments={"origin": "AU", "destination": "LAX"},
|
||||
),
|
||||
ToolCall(
|
||||
call_id="tool_call_id_2",
|
||||
tool_name="log",
|
||||
arguments={"number": 10, "base": 2},
|
||||
),
|
||||
]
|
||||
|
||||
def _dummy_chat_completion_response(self):
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content="Hello World"
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
def _dummy_chat_completion_response_with_tool_call(self):
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tool_call_id",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_flight_info",
|
||||
arguments='{"origin": "AU", "destination": "LAX"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
|
||||
class TestConvertStreamChatCompletionResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_stream(self):
|
||||
def chat_completion_stream():
|
||||
messages = ["Hello ", "World ", " !"]
|
||||
for i, message in enumerate(messages):
|
||||
chunk = self._dummy_chat_completion_chunk()
|
||||
chunk.choices[0].delta.content = message
|
||||
yield chunk
|
||||
|
||||
chunk = self._dummy_chat_completion_chunk()
|
||||
chunk.choices[0].delta.content = None
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
yield chunk
|
||||
|
||||
stream = chat_completion_stream()
|
||||
converted = convert_chat_completion_response_stream(stream)
|
||||
|
||||
iter = converted.__aiter__()
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||
assert chunk.event.delta == "Hello "
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta == "World "
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta == " !"
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunk.event.delta == ""
|
||||
assert chunk.event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await iter.__anext__()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_tool_calls_stream(self):
|
||||
def tool_call_stream():
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name="get_flight_info",
|
||||
arguments={"origin": "AU", "destination": "LAX"},
|
||||
),
|
||||
ToolCall(
|
||||
call_id="tool_call_id_2",
|
||||
tool_name="log",
|
||||
arguments={"number": 10, "base": 2},
|
||||
),
|
||||
]
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||
chunk.choices[0].delta.tool_calls = [
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
id=tool_call.call_id,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=tool_call.tool_name,
|
||||
arguments=json.dumps(tool_call.arguments),
|
||||
),
|
||||
),
|
||||
]
|
||||
yield chunk
|
||||
|
||||
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||
chunk.choices[0].delta.content = None
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
yield chunk
|
||||
|
||||
stream = tool_call_stream()
|
||||
converted = convert_chat_completion_response_stream(stream)
|
||||
|
||||
iter = converted.__aiter__()
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||
assert chunk.event.delta.content == ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name="get_flight_info",
|
||||
arguments={"origin": "AU", "destination": "LAX"},
|
||||
)
|
||||
|
||||
def _dummy_chat_completion_chunk(self):
|
||||
return ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(role="assistant", content="Hello World"),
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion.chunk",
|
||||
x_groq=None,
|
||||
)
|
||||
|
||||
def _dummy_chat_completion_chunk_with_tool_call(self):
|
||||
return ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
content="Hello World",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="get_flight_info",
|
||||
arguments='{"origin": "AU", "destination": "LAX"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion.chunk",
|
||||
x_groq=None,
|
||||
)
|
||||
29
llama_stack/providers/tests/inference/groq/test_init.py
Normal file
29
llama_stack/providers/tests/inference/groq/test_init.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.providers.remote.inference.groq import get_adapter_impl
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
||||
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
|
||||
|
||||
class TestGroqInit:
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_runtime_error_if_config_is_not_groq_config(self):
|
||||
config = OllamaImplConfig(model="llama3.1-8b-8192")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await get_adapter_impl(config, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_groq_adapter(self):
|
||||
config = GroqConfig()
|
||||
adapter = await get_adapter_impl(config, None)
|
||||
assert type(adapter) is GroqInferenceAdapter
|
||||
assert isinstance(adapter, Inference)
|
||||
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
|
||||
|
||||
# How to run this test:
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
|
||||
|
||||
|
||||
class TestEmbeddings:
|
||||
@pytest.mark.asyncio
|
||||
async def test_embeddings(self, inference_model, inference_stack):
|
||||
inference_impl, models_impl = inference_stack
|
||||
model = await models_impl.get_model(inference_model)
|
||||
|
||||
if model.model_type != ModelType.embedding:
|
||||
pytest.skip("This test is only applicable for embedding models")
|
||||
|
||||
response = await inference_impl.embeddings(
|
||||
model_id=inference_model,
|
||||
contents=["Hello, world!"],
|
||||
)
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) > 0
|
||||
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||
assert all(
|
||||
isinstance(value, float)
|
||||
for embedding in response.embeddings
|
||||
for value in embedding
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_embeddings(self, inference_model, inference_stack):
|
||||
inference_impl, models_impl = inference_stack
|
||||
model = await models_impl.get_model(inference_model)
|
||||
|
||||
if model.model_type != ModelType.embedding:
|
||||
pytest.skip("This test is only applicable for embedding models")
|
||||
|
||||
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
|
||||
|
||||
response = await inference_impl.embeddings(
|
||||
model_id=inference_model,
|
||||
contents=texts,
|
||||
)
|
||||
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) == len(texts)
|
||||
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||
assert all(
|
||||
isinstance(value, float)
|
||||
for embedding in response.embeddings
|
||||
for value in embedding
|
||||
)
|
||||
|
||||
embedding_dim = len(response.embeddings[0])
|
||||
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)
|
||||
|
|
@ -4,13 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
|
||||
# -m "meta_reference"
|
||||
# torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct"
|
||||
# ./llama_stack/providers/tests/inference/test_model_registration.py
|
||||
|
||||
|
||||
class TestModelRegistration:
|
||||
|
|
@ -51,16 +53,37 @@ class TestModelRegistration:
|
|||
|
||||
_ = await models_impl.register_model(
|
||||
model_id="custom-model",
|
||||
metadata={"llama_model": "meta-llama/Llama-2-7b"},
|
||||
metadata={
|
||||
"llama_model": "meta-llama/Llama-2-7b",
|
||||
"skip_load": True,
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(AssertionError) as exc_info:
|
||||
await models_impl.register_model(
|
||||
model_id="custom-model-2",
|
||||
metadata={"llama_model": "meta-llama/Llama-2-7b"},
|
||||
metadata={
|
||||
"llama_model": "meta-llama/Llama-2-7b",
|
||||
},
|
||||
provider_model_id="custom-model",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_model_during_registering(self, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
|
||||
with patch(
|
||||
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_load_model:
|
||||
_ = await models_impl.register_model(
|
||||
model_id="Llama3.1-8B-Instruct",
|
||||
metadata={
|
||||
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
},
|
||||
)
|
||||
mock_load_model.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_invalid_llama_model(self, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
|
|
|
|||
|
|
@ -6,8 +6,14 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from llama_models.llama3.api import * # noqa: F403
|
||||
from llama_stack.apis.inference.inference import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,13 +7,31 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
SystemMessage,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from .utils import group_chunks
|
||||
|
||||
|
||||
|
|
@ -67,7 +85,9 @@ def sample_tool_definition():
|
|||
|
||||
|
||||
class TestInference:
|
||||
@pytest.mark.asyncio
|
||||
# Session scope for asyncio because the tests in this class all
|
||||
# share the same provider instance.
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_model_list(self, inference_model, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
response = await models_impl.list_models()
|
||||
|
|
@ -83,7 +103,7 @@ class TestInference:
|
|||
|
||||
assert model_def is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_completion(self, inference_model, inference_stack):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
|
|
@ -94,6 +114,7 @@ class TestInference:
|
|||
"remote::tgi",
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
"remote::nvidia",
|
||||
"remote::cerebras",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
|
@ -127,19 +148,77 @@ class TestInference:
|
|||
last = chunks[-1]
|
||||
assert last.stop_reason == StopReason.out_of_tokens
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_completion_logprobs(self, inference_model, inference_stack):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
# "remote::nvidia", -- provider doesn't provide all logprobs
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
response = await inference_impl.completion(
|
||||
content="Micheael Jordan is born in ",
|
||||
stream=False,
|
||||
model_id=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=5,
|
||||
),
|
||||
logprobs=LogProbConfig(
|
||||
top_k=3,
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert 1 <= len(response.logprobs) <= 5
|
||||
assert response.logprobs, "Logprobs should not be empty"
|
||||
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
||||
|
||||
chunks = [
|
||||
r
|
||||
async for r in await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=True,
|
||||
model_id=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=5,
|
||||
),
|
||||
logprobs=LogProbConfig(
|
||||
top_k=3,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
||||
assert (
|
||||
1 <= len(chunks) <= 6
|
||||
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
||||
for chunk in chunks:
|
||||
if (
|
||||
chunk.delta.type == "text" and chunk.delta.text
|
||||
): # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||
)
|
||||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.skip("This test is not quite robust")
|
||||
async def test_completions_structured_output(
|
||||
self, inference_model, inference_stack
|
||||
):
|
||||
async def test_completion_structured_output(self, inference_model, inference_stack):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"inline::meta-reference",
|
||||
"remote::ollama",
|
||||
"remote::tgi",
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
"remote::nvidia",
|
||||
"remote::vllm",
|
||||
"remote::cerebras",
|
||||
):
|
||||
pytest.skip(
|
||||
|
|
@ -171,7 +250,7 @@ class TestInference:
|
|||
assert answer.year_born == "1963"
|
||||
assert answer.year_retired == "2003"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chat_completion_non_streaming(
|
||||
self, inference_model, inference_stack, common_params, sample_messages
|
||||
):
|
||||
|
|
@ -188,7 +267,7 @@ class TestInference:
|
|||
assert isinstance(response.completion_message.content, str)
|
||||
assert len(response.completion_message.content) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_structured_output(
|
||||
self, inference_model, inference_stack, common_params
|
||||
):
|
||||
|
|
@ -197,9 +276,11 @@ class TestInference:
|
|||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"inline::meta-reference",
|
||||
"remote::ollama",
|
||||
"remote::fireworks",
|
||||
"remote::tgi",
|
||||
"remote::together",
|
||||
"remote::vllm",
|
||||
"remote::nvidia",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support structured output yet")
|
||||
|
|
@ -257,7 +338,7 @@ class TestInference:
|
|||
with pytest.raises(ValidationError):
|
||||
AnswerFormat.model_validate_json(response.completion_message.content)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chat_completion_streaming(
|
||||
self, inference_model, inference_stack, common_params, sample_messages
|
||||
):
|
||||
|
|
@ -284,7 +365,7 @@ class TestInference:
|
|||
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
assert end.event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chat_completion_with_tool_calling(
|
||||
self,
|
||||
inference_model,
|
||||
|
|
@ -294,6 +375,14 @@ class TestInference:
|
|||
sample_tool_definition,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if (
|
||||
provider.__provider_spec__.provider_type == "remote::groq"
|
||||
and "Llama-3.2" in inference_model
|
||||
):
|
||||
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
||||
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
||||
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
|
|
@ -323,7 +412,7 @@ class TestInference:
|
|||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chat_completion_with_tool_calling_streaming(
|
||||
self,
|
||||
inference_model,
|
||||
|
|
@ -333,6 +422,14 @@ class TestInference:
|
|||
sample_tool_definition,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if (
|
||||
provider.__provider_spec__.provider_type == "remote::groq"
|
||||
and "Llama-3.2" in inference_model
|
||||
):
|
||||
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
||||
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
||||
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
|
|
@ -349,7 +446,6 @@ class TestInference:
|
|||
**common_params,
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
|
|
@ -368,7 +464,7 @@ class TestInference:
|
|||
|
||||
if "Llama3.1" in inference_model:
|
||||
assert all(
|
||||
isinstance(chunk.event.delta, ToolCallDelta)
|
||||
chunk.event.delta.type == "tool_call"
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
|
|
@ -379,8 +475,8 @@ class TestInference:
|
|||
|
||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
# assert last.event.stop_reason == expected_stop_reason
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
||||
assert isinstance(last.event.delta.content, ToolCall)
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
|
||||
assert last.event.delta.content.type == "tool_call"
|
||||
|
||||
call = last.event.delta.content
|
||||
assert call.tool_name == "get_weather"
|
||||
|
|
|
|||
|
|
@ -7,16 +7,24 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
SamplingParams,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from .utils import group_chunks
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
with open(THIS_DIR / "pasta.jpeg", "rb") as f:
|
||||
PASTA_IMAGE = f.read()
|
||||
|
||||
|
||||
class TestVisionModelInference:
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -24,12 +32,12 @@ class TestVisionModelInference:
|
|||
"image, expected_strings",
|
||||
[
|
||||
(
|
||||
ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")),
|
||||
ImageContentItem(data=PASTA_IMAGE),
|
||||
["spaghetti"],
|
||||
),
|
||||
(
|
||||
ImageMedia(
|
||||
image=URL(
|
||||
ImageContentItem(
|
||||
url=URL(
|
||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
)
|
||||
),
|
||||
|
|
@ -59,7 +67,12 @@ class TestVisionModelInference:
|
|||
model_id=inference_model,
|
||||
messages=[
|
||||
UserMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content=[image, "Describe this image in two sentences."]),
|
||||
UserMessage(
|
||||
content=[
|
||||
image,
|
||||
TextContentItem(text="Describe this image in two sentences."),
|
||||
]
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
sampling_params=SamplingParams(max_tokens=100),
|
||||
|
|
@ -91,8 +104,8 @@ class TestVisionModelInference:
|
|||
)
|
||||
|
||||
images = [
|
||||
ImageMedia(
|
||||
image=URL(
|
||||
ImageContentItem(
|
||||
url=URL(
|
||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
)
|
||||
),
|
||||
|
|
@ -108,7 +121,12 @@ class TestVisionModelInference:
|
|||
messages=[
|
||||
UserMessage(content="You are a helpful assistant."),
|
||||
UserMessage(
|
||||
content=[image, "Describe this image in two sentences."]
|
||||
content=[
|
||||
image,
|
||||
TextContentItem(
|
||||
text="Describe this image in two sentences."
|
||||
),
|
||||
]
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue