Add inference test

Run it as:

```
PROVIDER_ID=test-remote \
 PROVIDER_CONFIG=$PWD/llama_stack/providers/tests/inference/provider_config_example.yaml \
 pytest -s llama_stack/providers/tests/inference/test_inference.py \
 --tb=auto \
 --disable-warnings
```
This commit is contained in:
Ashwin Bharambe 2024-10-07 15:46:16 -07:00 committed by Ashwin Bharambe
parent 4fa467731e
commit 3ae2b712e8
8 changed files with 356 additions and 54 deletions

View file

@ -67,6 +67,7 @@ class InferenceClient(Inference):
logprobs=logprobs, logprobs=logprobs,
) )
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
if stream:
async with client.stream( async with client.stream(
"POST", "POST",
f"{self.base_url}/inference/chat_completion", f"{self.base_url}/inference/chat_completion",
@ -77,7 +78,8 @@ class InferenceClient(Inference):
if response.status_code != 200: if response.status_code != 200:
content = await response.aread() content = await response.aread()
cprint( cprint(
f"Error: HTTP {response.status_code} {content.decode()}", "red" f"Error: HTTP {response.status_code} {content.decode()}",
"red",
) )
return return
@ -85,7 +87,6 @@ class InferenceClient(Inference):
if line.startswith("data:"): if line.startswith("data:"):
data = line[len("data: ") :] data = line[len("data: ") :]
try: try:
if request.stream:
if "error" in data: if "error" in data:
cprint(data, "red") cprint(data, "red")
continue continue
@ -93,11 +94,20 @@ class InferenceClient(Inference):
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
**json.loads(data) **json.loads(data)
) )
else:
yield ChatCompletionResponse(**json.loads(data))
except Exception as e: except Exception as e:
print(data) print(data)
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
else:
response = await client.post(
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
)
response.raise_for_status()
j = response.json()
yield ChatCompletionResponse(**j)
async def run_main( async def run_main(

View file

@ -109,6 +109,7 @@ this could be just a hash
description="Reference to the conda environment if this package refers to a conda environment", description="Reference to the conda environment if this package refers to a conda environment",
) )
apis: List[str] = Field( apis: List[str] = Field(
default_factory=list,
description=""" description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
) )

View file

@ -36,8 +36,8 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS
) )
self.url = url self.url = url
tokenizer = Tokenizer.get_instance() self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer) self.formatter = ChatFormat(self.tokenizer)
@property @property
def client(self) -> AsyncClient: def client(self) -> AsyncClient:
@ -65,17 +65,6 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
ollama_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
ollama_messages.append({"role": role, "content": message.content})
return ollama_messages
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict: def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {} options = {}
if request.sampling_params is not None: if request.sampling_params is not None:
@ -113,6 +102,9 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
) )
messages = augment_messages_for_tools(request) messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
# accumulate sampling params and other options to pass to ollama # accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request) options = self.get_ollama_chat_options(request)
ollama_model = self.map_to_provider_model(request.model) ollama_model = self.map_to_provider_model(request.model)
@ -131,13 +123,16 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
status["status"] == "success" status["status"] == "success"
), f"Failed to pull model {self.model} in ollama" ), f"Failed to pull model {self.model} in ollama"
common_params = {
"model": ollama_model,
"prompt": prompt,
"options": options,
"raw": True,
"stream": request.stream,
}
if not request.stream: if not request.stream:
r = await self.client.chat( r = await self.client.generate(**common_params)
model=ollama_model,
messages=self._messages_to_ollama_messages(messages),
stream=False,
options=options,
)
stop_reason = None stop_reason = None
if r["done"]: if r["done"]:
if r["done_reason"] == "stop": if r["done_reason"] == "stop":
@ -146,7 +141,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content( completion_message = self.formatter.decode_assistant_message_from_content(
r["message"]["content"], stop_reason r["response"], stop_reason
) )
yield ChatCompletionResponse( yield ChatCompletionResponse(
completion_message=completion_message, completion_message=completion_message,
@ -159,12 +154,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
delta="", delta="",
) )
) )
stream = await self.client.chat( stream = await self.client.generate(**common_params)
model=ollama_model,
messages=self._messages_to_ollama_messages(messages),
stream=True,
options=options,
)
buffer = "" buffer = ""
ipython = False ipython = False
@ -178,8 +168,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
break break
text = chunk["message"]["content"] text = chunk["response"]
# check if its a tool call ( aka starts with <|python_tag|> ) # check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"): if not ipython and text.startswith("<|python_tag|>"):
ipython = True ipython = True

View file

@ -100,8 +100,6 @@ class _HfAdapter(Inference):
self.max_tokens - input_tokens - 1, self.max_tokens - input_tokens - 1,
) )
print(f"Calculated max_new_tokens: {max_new_tokens}")
options = self.get_chat_options(request) options = self.get_chat_options(request)
if not request.stream: if not request.stream:
response = await self.client.text_generation( response = await self.client.text_generation(
@ -119,8 +117,9 @@ class _HfAdapter(Inference):
elif response.details.finish_reason == "length": elif response.details.finish_reason == "length":
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
generated_text = "".join(t.text for t in response.details.tokens)
completion_message = self.formatter.decode_assistant_message_from_content( completion_message = self.formatter.decode_assistant_message_from_content(
response.generated_text, generated_text,
stop_reason, stop_reason,
) )
yield ChatCompletionResponse( yield ChatCompletionResponse(

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,15 @@
providers:
- provider_id: test-ollama
provider_type: remote::ollama
config:
host: localhost
port: 11434
- provider_id: test-tgi
provider_type: remote::tgi
config:
url: http://localhost:7001
- provider_id: test-remote
provider_type: remote
config:
host: localhost
port: 7002

View file

@ -0,0 +1,278 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import itertools
import os
from datetime import datetime
import pytest
import pytest_asyncio
import yaml
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.resolver import resolve_impls_with_routing
def group_chunks(response):
return {
event_type: list(group)
for event_type, group in itertools.groupby(
response, key=lambda chunk: chunk.event.event_type
)
}
Llama_8B = "Llama3.1-8B-Instruct"
Llama_3B = "Llama3.2-3B-Instruct"
def get_expected_stop_reason(model: str):
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
async def stack_impls(model):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
)
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
config_dict = yaml.safe_load(f)
if "providers" not in config_dict:
raise ValueError("Config file should contain a `providers` key")
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
if len(providers_by_id) == 0:
raise ValueError("No providers found in config file")
if "PROVIDER_ID" in os.environ:
provider_id = os.environ["PROVIDER_ID"]
if provider_id not in providers_by_id:
raise ValueError(f"Provider ID {provider_id} not found in config file")
provider = providers_by_id[provider_id]
else:
provider = list(providers_by_id.values())[0]
print(f"No provider ID specified, picking first {provider['provider_id']}")
config_dict = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[
Api.inference,
Api.models,
],
providers=dict(
inference=[
Provider(**provider),
]
),
models=[
ModelDef(
identifier=model,
llama_model=model,
provider_id=provider["provider_id"],
)
],
shields=[],
memory_banks=[],
)
run_config = parse_and_maybe_upgrade_config(config_dict)
impls = await resolve_impls_with_routing(run_config)
return impls
# This is going to create multiple Stack impls without tearing down the previous one
# Fix that!
@pytest_asyncio.fixture(
scope="session",
params=[
{"model": Llama_8B},
{"model": Llama_3B},
],
)
async def inference_settings(request):
model = request.param["model"]
impls = await stack_impls(model)
return {
"impl": impls[Api.inference],
"common_params": {
"model": model,
"tool_choice": ToolChoice.auto,
"tool_prompt_format": (
ToolPromptFormat.json
if "Llama3.1" in model
else ToolPromptFormat.python_list
),
},
}
@pytest.fixture
def sample_messages():
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def sample_tool_definition():
return ToolDefinition(
tool_name="get_weather",
description="Get the current weather",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
),
},
)
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = [
r
async for r in inference_impl.chat_completion(
messages=sample_messages,
stream=False,
**inference_settings["common_params"],
)
]
assert len(response) == 1
assert isinstance(response[0], ChatCompletionResponse)
assert response[0].completion_message.role == "assistant"
assert isinstance(response[0].completion_message.content, str)
assert len(response[0].completion_message.content) > 0
@pytest.mark.asyncio
async def test_chat_completion_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = [
r
async for r in inference_impl.chat_completion(
messages=sample_messages,
stream=True,
**inference_settings["common_params"],
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling(
inference_settings,
sample_messages,
sample_tool_definition,
):
inference_impl = inference_settings["impl"]
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = [
r
async for r in inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=False,
**inference_settings["common_params"],
)
]
assert len(response) == 1
assert isinstance(response[0], ChatCompletionResponse)
message = response[0].completion_message
stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
assert message.stop_reason == stop_reason
assert message.tool_calls is not None
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling_streaming(
inference_settings,
sample_messages,
sample_tool_definition,
):
inference_impl = inference_settings["impl"]
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = [
r
async for r in inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=True,
**inference_settings["common_params"],
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
expected_stop_reason = get_expected_stop_reason(
inference_settings["common_params"]["model"]
)
assert end.event.stop_reason == expected_stop_reason
model = inference_settings["common_params"]["model"]
if "Llama3.1" in model:
assert all(
isinstance(chunk.event.delta, ToolCallDelta)
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)
call = last.event.delta.content
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]