mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
Add inference test
Run it as: ``` PROVIDER_ID=test-remote \ PROVIDER_CONFIG=$PWD/llama_stack/providers/tests/inference/provider_config_example.yaml \ pytest -s llama_stack/providers/tests/inference/test_inference.py \ --tb=auto \ --disable-warnings ```
This commit is contained in:
parent
4fa467731e
commit
3ae2b712e8
8 changed files with 356 additions and 54 deletions
|
@ -67,25 +67,26 @@ class InferenceClient(Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/inference/chat_completion",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
content = await response.aread()
|
||||
cprint(
|
||||
f"Error: HTTP {response.status_code} {content.decode()}", "red"
|
||||
)
|
||||
return
|
||||
if stream:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/inference/chat_completion",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
content = await response.aread()
|
||||
cprint(
|
||||
f"Error: HTTP {response.status_code} {content.decode()}",
|
||||
"red",
|
||||
)
|
||||
return
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
if request.stream:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
@ -93,11 +94,20 @@ class InferenceClient(Inference):
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
**json.loads(data)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponse(**json.loads(data))
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
else:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/inference/chat_completion",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
j = response.json()
|
||||
yield ChatCompletionResponse(**j)
|
||||
|
||||
|
||||
async def run_main(
|
||||
|
|
|
@ -109,6 +109,7 @@ this could be just a hash
|
|||
description="Reference to the conda environment if this package refers to a conda environment",
|
||||
)
|
||||
apis: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="""
|
||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||
)
|
||||
|
|
|
@ -36,8 +36,8 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS
|
||||
)
|
||||
self.url = url
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
|
@ -65,17 +65,6 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
|
||||
ollama_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
ollama_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return ollama_messages
|
||||
|
||||
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
|
@ -113,6 +102,9 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||
prompt = self.tokenizer.decode(model_input.tokens)
|
||||
|
||||
# accumulate sampling params and other options to pass to ollama
|
||||
options = self.get_ollama_chat_options(request)
|
||||
ollama_model = self.map_to_provider_model(request.model)
|
||||
|
@ -131,13 +123,16 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
status["status"] == "success"
|
||||
), f"Failed to pull model {self.model} in ollama"
|
||||
|
||||
common_params = {
|
||||
"model": ollama_model,
|
||||
"prompt": prompt,
|
||||
"options": options,
|
||||
"raw": True,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.chat(
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(messages),
|
||||
stream=False,
|
||||
options=options,
|
||||
)
|
||||
r = await self.client.generate(**common_params)
|
||||
stop_reason = None
|
||||
if r["done"]:
|
||||
if r["done_reason"] == "stop":
|
||||
|
@ -146,7 +141,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r["message"]["content"], stop_reason
|
||||
r["response"], stop_reason
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
|
@ -159,12 +154,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
delta="",
|
||||
)
|
||||
)
|
||||
stream = await self.client.chat(
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(messages),
|
||||
stream=True,
|
||||
options=options,
|
||||
)
|
||||
stream = await self.client.generate(**common_params)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
|
@ -178,8 +168,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = chunk["message"]["content"]
|
||||
|
||||
text = chunk["response"]
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
|
|
|
@ -100,8 +100,6 @@ class _HfAdapter(Inference):
|
|||
self.max_tokens - input_tokens - 1,
|
||||
)
|
||||
|
||||
print(f"Calculated max_new_tokens: {max_new_tokens}")
|
||||
|
||||
options = self.get_chat_options(request)
|
||||
if not request.stream:
|
||||
response = await self.client.text_generation(
|
||||
|
@ -119,8 +117,9 @@ class _HfAdapter(Inference):
|
|||
elif response.details.finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
generated_text = "".join(t.text for t in response.details.tokens)
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
response.generated_text,
|
||||
generated_text,
|
||||
stop_reason,
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
|
|
5
llama_stack/providers/tests/__init__.py
Normal file
5
llama_stack/providers/tests/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
5
llama_stack/providers/tests/inference/__init__.py
Normal file
5
llama_stack/providers/tests/inference/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,15 @@
|
|||
providers:
|
||||
- provider_id: test-ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 11434
|
||||
- provider_id: test-tgi
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: http://localhost:7001
|
||||
- provider_id: test-remote
|
||||
provider_type: remote
|
||||
config:
|
||||
host: localhost
|
||||
port: 7002
|
278
llama_stack/providers/tests/inference/test_inference.py
Normal file
278
llama_stack/providers/tests/inference/test_inference.py
Normal file
|
@ -0,0 +1,278 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import yaml
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||
|
||||
|
||||
def group_chunks(response):
|
||||
return {
|
||||
event_type: list(group)
|
||||
for event_type, group in itertools.groupby(
|
||||
response, key=lambda chunk: chunk.event.event_type
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
Llama_8B = "Llama3.1-8B-Instruct"
|
||||
Llama_3B = "Llama3.2-3B-Instruct"
|
||||
|
||||
|
||||
def get_expected_stop_reason(model: str):
|
||||
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
||||
|
||||
|
||||
async def stack_impls(model):
|
||||
if "PROVIDER_CONFIG" not in os.environ:
|
||||
raise ValueError(
|
||||
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
|
||||
)
|
||||
|
||||
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
if "providers" not in config_dict:
|
||||
raise ValueError("Config file should contain a `providers` key")
|
||||
|
||||
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
|
||||
if len(providers_by_id) == 0:
|
||||
raise ValueError("No providers found in config file")
|
||||
|
||||
if "PROVIDER_ID" in os.environ:
|
||||
provider_id = os.environ["PROVIDER_ID"]
|
||||
if provider_id not in providers_by_id:
|
||||
raise ValueError(f"Provider ID {provider_id} not found in config file")
|
||||
provider = providers_by_id[provider_id]
|
||||
else:
|
||||
provider = list(providers_by_id.values())[0]
|
||||
print(f"No provider ID specified, picking first {provider['provider_id']}")
|
||||
|
||||
config_dict = dict(
|
||||
built_at=datetime.now(),
|
||||
image_name="test-fixture",
|
||||
apis=[
|
||||
Api.inference,
|
||||
Api.models,
|
||||
],
|
||||
providers=dict(
|
||||
inference=[
|
||||
Provider(**provider),
|
||||
]
|
||||
),
|
||||
models=[
|
||||
ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
provider_id=provider["provider_id"],
|
||||
)
|
||||
],
|
||||
shields=[],
|
||||
memory_banks=[],
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(config_dict)
|
||||
impls = await resolve_impls_with_routing(run_config)
|
||||
return impls
|
||||
|
||||
|
||||
# This is going to create multiple Stack impls without tearing down the previous one
|
||||
# Fix that!
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session",
|
||||
params=[
|
||||
{"model": Llama_8B},
|
||||
{"model": Llama_3B},
|
||||
],
|
||||
)
|
||||
async def inference_settings(request):
|
||||
model = request.param["model"]
|
||||
impls = await stack_impls(model)
|
||||
return {
|
||||
"impl": impls[Api.inference],
|
||||
"common_params": {
|
||||
"model": model,
|
||||
"tool_choice": ToolChoice.auto,
|
||||
"tool_prompt_format": (
|
||||
ToolPromptFormat.json
|
||||
if "Llama3.1" in model
|
||||
else ToolPromptFormat.python_list
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages():
|
||||
return [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="What's the weather like today?"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tool_definition():
|
||||
return ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get the current weather",
|
||||
parameters={
|
||||
"location": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The city and state, e.g. San Francisco, CA",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
response = [
|
||||
r
|
||||
async for r in inference_impl.chat_completion(
|
||||
messages=sample_messages,
|
||||
stream=False,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) == 1
|
||||
assert isinstance(response[0], ChatCompletionResponse)
|
||||
assert response[0].completion_message.role == "assistant"
|
||||
assert isinstance(response[0].completion_message.content, str)
|
||||
assert len(response[0].completion_message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
response = [
|
||||
r
|
||||
async for r in inference_impl.chat_completion(
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
|
||||
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
assert end.event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_calling(
|
||||
inference_settings,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl = inference_settings["impl"]
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
)
|
||||
]
|
||||
|
||||
response = [
|
||||
r
|
||||
async for r in inference_impl.chat_completion(
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=False,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) == 1
|
||||
assert isinstance(response[0], ChatCompletionResponse)
|
||||
|
||||
message = response[0].completion_message
|
||||
|
||||
stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
|
||||
assert message.stop_reason == stop_reason
|
||||
assert message.tool_calls is not None
|
||||
assert len(message.tool_calls) > 0
|
||||
|
||||
call = message.tool_calls[0]
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_calling_streaming(
|
||||
inference_settings,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl = inference_settings["impl"]
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
)
|
||||
]
|
||||
|
||||
response = [
|
||||
r
|
||||
async for r in inference_impl.chat_completion(
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=True,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
|
||||
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
expected_stop_reason = get_expected_stop_reason(
|
||||
inference_settings["common_params"]["model"]
|
||||
)
|
||||
assert end.event.stop_reason == expected_stop_reason
|
||||
|
||||
model = inference_settings["common_params"]["model"]
|
||||
if "Llama3.1" in model:
|
||||
assert all(
|
||||
isinstance(chunk.event.delta, ToolCallDelta)
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||
|
||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
assert last.event.stop_reason == expected_stop_reason
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
||||
assert isinstance(last.event.delta.content, ToolCall)
|
||||
|
||||
call = last.event.delta.content
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
Loading…
Add table
Add a link
Reference in a new issue