From 3ae2b712e84c31eb0da76bba172e58865fcbf36b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 7 Oct 2024 15:46:16 -0700 Subject: [PATCH] Add inference test Run it as: ``` PROVIDER_ID=test-remote \ PROVIDER_CONFIG=$PWD/llama_stack/providers/tests/inference/provider_config_example.yaml \ pytest -s llama_stack/providers/tests/inference/test_inference.py \ --tb=auto \ --disable-warnings ``` --- llama_stack/apis/inference/client.py | 56 ++-- llama_stack/distribution/datatypes.py | 1 + .../adapters/inference/ollama/ollama.py | 45 ++- .../providers/adapters/inference/tgi/tgi.py | 5 +- llama_stack/providers/tests/__init__.py | 5 + .../providers/tests/inference/__init__.py | 5 + .../inference/provider_config_example.yaml | 15 + .../tests/inference/test_inference.py | 278 ++++++++++++++++++ 8 files changed, 356 insertions(+), 54 deletions(-) create mode 100644 llama_stack/providers/tests/__init__.py create mode 100644 llama_stack/providers/tests/inference/__init__.py create mode 100644 llama_stack/providers/tests/inference/provider_config_example.yaml create mode 100644 llama_stack/providers/tests/inference/test_inference.py diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index fffcf4692..8b822058f 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -67,25 +67,26 @@ class InferenceClient(Inference): logprobs=logprobs, ) async with httpx.AsyncClient() as client: - async with client.stream( - "POST", - f"{self.base_url}/inference/chat_completion", - json=encodable_dict(request), - headers={"Content-Type": "application/json"}, - timeout=20, - ) as response: - if response.status_code != 200: - content = await response.aread() - cprint( - f"Error: HTTP {response.status_code} {content.decode()}", "red" - ) - return + if stream: + async with client.stream( + "POST", + f"{self.base_url}/inference/chat_completion", + json=encodable_dict(request), + headers={"Content-Type": "application/json"}, + timeout=20, + ) as response: + if response.status_code != 200: + content = await response.aread() + cprint( + f"Error: HTTP {response.status_code} {content.decode()}", + "red", + ) + return - async for line in response.aiter_lines(): - if line.startswith("data:"): - data = line[len("data: ") :] - try: - if request.stream: + async for line in response.aiter_lines(): + if line.startswith("data:"): + data = line[len("data: ") :] + try: if "error" in data: cprint(data, "red") continue @@ -93,11 +94,20 @@ class InferenceClient(Inference): yield ChatCompletionResponseStreamChunk( **json.loads(data) ) - else: - yield ChatCompletionResponse(**json.loads(data)) - except Exception as e: - print(data) - print(f"Error with parsing or validation: {e}") + except Exception as e: + print(data) + print(f"Error with parsing or validation: {e}") + else: + response = await client.post( + f"{self.base_url}/inference/chat_completion", + json=encodable_dict(request), + headers={"Content-Type": "application/json"}, + timeout=20, + ) + + response.raise_for_status() + j = response.json() + yield ChatCompletionResponse(**j) async def run_main( diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index c987d4c87..e09a6939c 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -109,6 +109,7 @@ this could be just a hash description="Reference to the conda environment if this package refers to a conda environment", ) apis: List[str] = Field( + default_factory=list, description=""" The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", ) diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index 09af46b11..40a3f5977 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -36,8 +36,8 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS ) self.url = url - tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(tokenizer) + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) @property def client(self) -> AsyncClient: @@ -65,17 +65,6 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): ) -> AsyncGenerator: raise NotImplementedError() - def _messages_to_ollama_messages(self, messages: list[Message]) -> list: - ollama_messages = [] - for message in messages: - if message.role == "ipython": - role = "tool" - else: - role = message.role - ollama_messages.append({"role": role, "content": message.content}) - - return ollama_messages - def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict: options = {} if request.sampling_params is not None: @@ -113,6 +102,9 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): ) messages = augment_messages_for_tools(request) + model_input = self.formatter.encode_dialog_prompt(messages) + prompt = self.tokenizer.decode(model_input.tokens) + # accumulate sampling params and other options to pass to ollama options = self.get_ollama_chat_options(request) ollama_model = self.map_to_provider_model(request.model) @@ -131,13 +123,16 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): status["status"] == "success" ), f"Failed to pull model {self.model} in ollama" + common_params = { + "model": ollama_model, + "prompt": prompt, + "options": options, + "raw": True, + "stream": request.stream, + } + if not request.stream: - r = await self.client.chat( - model=ollama_model, - messages=self._messages_to_ollama_messages(messages), - stream=False, - options=options, - ) + r = await self.client.generate(**common_params) stop_reason = None if r["done"]: if r["done_reason"] == "stop": @@ -146,7 +141,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): stop_reason = StopReason.out_of_tokens completion_message = self.formatter.decode_assistant_message_from_content( - r["message"]["content"], stop_reason + r["response"], stop_reason ) yield ChatCompletionResponse( completion_message=completion_message, @@ -159,12 +154,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): delta="", ) ) - stream = await self.client.chat( - model=ollama_model, - messages=self._messages_to_ollama_messages(messages), - stream=True, - options=options, - ) + stream = await self.client.generate(**common_params) buffer = "" ipython = False @@ -178,8 +168,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): stop_reason = StopReason.out_of_tokens break - text = chunk["message"]["content"] - + text = chunk["response"] # check if its a tool call ( aka starts with <|python_tag|> ) if not ipython and text.startswith("<|python_tag|>"): ipython = True diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 24b664068..0ad20edd6 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -100,8 +100,6 @@ class _HfAdapter(Inference): self.max_tokens - input_tokens - 1, ) - print(f"Calculated max_new_tokens: {max_new_tokens}") - options = self.get_chat_options(request) if not request.stream: response = await self.client.text_generation( @@ -119,8 +117,9 @@ class _HfAdapter(Inference): elif response.details.finish_reason == "length": stop_reason = StopReason.out_of_tokens + generated_text = "".join(t.text for t in response.details.tokens) completion_message = self.formatter.decode_assistant_message_from_content( - response.generated_text, + generated_text, stop_reason, ) yield ChatCompletionResponse( diff --git a/llama_stack/providers/tests/__init__.py b/llama_stack/providers/tests/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/tests/inference/__init__.py b/llama_stack/providers/tests/inference/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/tests/inference/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/tests/inference/provider_config_example.yaml b/llama_stack/providers/tests/inference/provider_config_example.yaml new file mode 100644 index 000000000..014ce84d4 --- /dev/null +++ b/llama_stack/providers/tests/inference/provider_config_example.yaml @@ -0,0 +1,15 @@ +providers: + - provider_id: test-ollama + provider_type: remote::ollama + config: + host: localhost + port: 11434 + - provider_id: test-tgi + provider_type: remote::tgi + config: + url: http://localhost:7001 + - provider_id: test-remote + provider_type: remote + config: + host: localhost + port: 7002 diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py new file mode 100644 index 000000000..61989b691 --- /dev/null +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import itertools +import os +from datetime import datetime + +import pytest +import pytest_asyncio +import yaml + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 + +from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.configure import parse_and_maybe_upgrade_config +from llama_stack.distribution.resolver import resolve_impls_with_routing + + +def group_chunks(response): + return { + event_type: list(group) + for event_type, group in itertools.groupby( + response, key=lambda chunk: chunk.event.event_type + ) + } + + +Llama_8B = "Llama3.1-8B-Instruct" +Llama_3B = "Llama3.2-3B-Instruct" + + +def get_expected_stop_reason(model: str): + return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn + + +async def stack_impls(model): + if "PROVIDER_CONFIG" not in os.environ: + raise ValueError( + "You must set PROVIDER_CONFIG to a YAML file containing provider config" + ) + + with open(os.environ["PROVIDER_CONFIG"], "r") as f: + config_dict = yaml.safe_load(f) + + if "providers" not in config_dict: + raise ValueError("Config file should contain a `providers` key") + + providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]} + if len(providers_by_id) == 0: + raise ValueError("No providers found in config file") + + if "PROVIDER_ID" in os.environ: + provider_id = os.environ["PROVIDER_ID"] + if provider_id not in providers_by_id: + raise ValueError(f"Provider ID {provider_id} not found in config file") + provider = providers_by_id[provider_id] + else: + provider = list(providers_by_id.values())[0] + print(f"No provider ID specified, picking first {provider['provider_id']}") + + config_dict = dict( + built_at=datetime.now(), + image_name="test-fixture", + apis=[ + Api.inference, + Api.models, + ], + providers=dict( + inference=[ + Provider(**provider), + ] + ), + models=[ + ModelDef( + identifier=model, + llama_model=model, + provider_id=provider["provider_id"], + ) + ], + shields=[], + memory_banks=[], + ) + run_config = parse_and_maybe_upgrade_config(config_dict) + impls = await resolve_impls_with_routing(run_config) + return impls + + +# This is going to create multiple Stack impls without tearing down the previous one +# Fix that! +@pytest_asyncio.fixture( + scope="session", + params=[ + {"model": Llama_8B}, + {"model": Llama_3B}, + ], +) +async def inference_settings(request): + model = request.param["model"] + impls = await stack_impls(model) + return { + "impl": impls[Api.inference], + "common_params": { + "model": model, + "tool_choice": ToolChoice.auto, + "tool_prompt_format": ( + ToolPromptFormat.json + if "Llama3.1" in model + else ToolPromptFormat.python_list + ), + }, + } + + +@pytest.fixture +def sample_messages(): + return [ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="What's the weather like today?"), + ] + + +@pytest.fixture +def sample_tool_definition(): + return ToolDefinition( + tool_name="get_weather", + description="Get the current weather", + parameters={ + "location": ToolParamDefinition( + param_type="string", + description="The city and state, e.g. San Francisco, CA", + ), + }, + ) + + +@pytest.mark.asyncio +async def test_chat_completion_non_streaming(inference_settings, sample_messages): + inference_impl = inference_settings["impl"] + response = [ + r + async for r in inference_impl.chat_completion( + messages=sample_messages, + stream=False, + **inference_settings["common_params"], + ) + ] + + assert len(response) == 1 + assert isinstance(response[0], ChatCompletionResponse) + assert response[0].completion_message.role == "assistant" + assert isinstance(response[0].completion_message.content, str) + assert len(response[0].completion_message.content) > 0 + + +@pytest.mark.asyncio +async def test_chat_completion_streaming(inference_settings, sample_messages): + inference_impl = inference_settings["impl"] + response = [ + r + async for r in inference_impl.chat_completion( + messages=sample_messages, + stream=True, + **inference_settings["common_params"], + ) + ] + + assert len(response) > 0 + assert all( + isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response + ) + grouped = group_chunks(response) + assert len(grouped[ChatCompletionResponseEventType.start]) == 1 + assert len(grouped[ChatCompletionResponseEventType.progress]) > 0 + assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 + + end = grouped[ChatCompletionResponseEventType.complete][0] + assert end.event.stop_reason == StopReason.end_of_turn + + +@pytest.mark.asyncio +async def test_chat_completion_with_tool_calling( + inference_settings, + sample_messages, + sample_tool_definition, +): + inference_impl = inference_settings["impl"] + messages = sample_messages + [ + UserMessage( + content="What's the weather like in San Francisco?", + ) + ] + + response = [ + r + async for r in inference_impl.chat_completion( + messages=messages, + tools=[sample_tool_definition], + stream=False, + **inference_settings["common_params"], + ) + ] + + assert len(response) == 1 + assert isinstance(response[0], ChatCompletionResponse) + + message = response[0].completion_message + + stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"]) + assert message.stop_reason == stop_reason + assert message.tool_calls is not None + assert len(message.tool_calls) > 0 + + call = message.tool_calls[0] + assert call.tool_name == "get_weather" + assert "location" in call.arguments + assert "San Francisco" in call.arguments["location"] + + +@pytest.mark.asyncio +async def test_chat_completion_with_tool_calling_streaming( + inference_settings, + sample_messages, + sample_tool_definition, +): + inference_impl = inference_settings["impl"] + messages = sample_messages + [ + UserMessage( + content="What's the weather like in San Francisco?", + ) + ] + + response = [ + r + async for r in inference_impl.chat_completion( + messages=messages, + tools=[sample_tool_definition], + stream=True, + **inference_settings["common_params"], + ) + ] + + assert len(response) > 0 + assert all( + isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response + ) + grouped = group_chunks(response) + assert len(grouped[ChatCompletionResponseEventType.start]) == 1 + assert len(grouped[ChatCompletionResponseEventType.progress]) > 0 + assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 + + end = grouped[ChatCompletionResponseEventType.complete][0] + expected_stop_reason = get_expected_stop_reason( + inference_settings["common_params"]["model"] + ) + assert end.event.stop_reason == expected_stop_reason + + model = inference_settings["common_params"]["model"] + if "Llama3.1" in model: + assert all( + isinstance(chunk.event.delta, ToolCallDelta) + for chunk in grouped[ChatCompletionResponseEventType.progress] + ) + first = grouped[ChatCompletionResponseEventType.progress][0] + assert first.event.delta.parse_status == ToolCallParseStatus.started + + last = grouped[ChatCompletionResponseEventType.progress][-1] + assert last.event.stop_reason == expected_stop_reason + assert last.event.delta.parse_status == ToolCallParseStatus.success + assert isinstance(last.event.delta.content, ToolCall) + + call = last.event.delta.content + assert call.tool_name == "get_weather" + assert "location" in call.arguments + assert "San Francisco" in call.arguments["location"]