From 0e985648f530ce3c5c78eee7e60fa1ebfee43295 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Wed, 31 Jul 2024 19:33:36 -0700 Subject: [PATCH] add streaming support for ollama inference with tests --- llama_toolchain/inference/inference.py | 14 +- llama_toolchain/inference/ollama.py | 147 +++++++++++++++-- tests/test_inference.py | 217 +++++++++++++++++++++++-- tests/test_ollama_inference.py | 174 ++++++++++++++++---- 4 files changed, 491 insertions(+), 61 deletions(-) diff --git a/llama_toolchain/inference/inference.py b/llama_toolchain/inference/inference.py index b3fa058fe..d7211ae65 100644 --- a/llama_toolchain/inference/inference.py +++ b/llama_toolchain/inference/inference.py @@ -103,13 +103,15 @@ class InferenceImpl(Inference): ) else: delta = text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, + + if stop_reason is None: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=delta, + stop_reason=stop_reason, + ) ) - ) if stop_reason is None: stop_reason = StopReason.out_of_tokens diff --git a/llama_toolchain/inference/ollama.py b/llama_toolchain/inference/ollama.py index 485e5d558..91727fd62 100644 --- a/llama_toolchain/inference/ollama.py +++ b/llama_toolchain/inference/ollama.py @@ -7,14 +7,20 @@ from ollama import AsyncClient from llama_models.llama3_1.api.datatypes import ( BuiltinTool, - CompletionMessage, - Message, + CompletionMessage, + Message, StopReason, ToolCall, ) from llama_models.llama3_1.api.tool_utils import ToolUtils from .api.config import OllamaImplConfig +from .api.datatypes import ( + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ToolCallDelta, + ToolCallParseStatus, +) from .api.endpoints import ( ChatCompletionResponse, ChatCompletionRequest, @@ -54,28 +60,148 @@ class OllamaInference(Inference): ) return ollama_messages - + async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: if not request.stream: r = await self.client.chat( model=self.model, messages=self._messages_to_ollama_messages(request.messages), - stream=False + stream=False, + #TODO: add support for options like temp, top_p, max_seq_length, etc ) + if r['done']: + if r['done_reason'] == 'stop': + stop_reason = StopReason.end_of_turn + elif r['done_reason'] == 'length': + stop_reason = StopReason.out_of_tokens + completion_message = decode_assistant_message_from_content( - r['message']['content'] + r['message']['content'], + stop_reason, ) - yield ChatCompletionResponse( completion_message=completion_message, logprobs=None, ) else: - raise NotImplementedError() + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) + ) + + stream = await self.client.chat( + model=self.model, + messages=self._messages_to_ollama_messages(request.messages), + stream=True + ) + + buffer = "" + ipython = False + stop_reason = None + + async for chunk in stream: + # check if ollama is done + if chunk['done']: + if chunk['done_reason'] == 'stop': + stop_reason = StopReason.end_of_turn + elif chunk['done_reason'] == 'length': + stop_reason = StopReason.out_of_tokens + break + + text = chunk['message']['content'] + + # check if its a tool call ( aka starts with <|python_tag|> ) + if not ipython and text.startswith("<|python_tag|>"): + ipython = True + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.started, + ), + ) + ) + buffer = buffer[len("<|python_tag|>") :] + continue + + if ipython: + if text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + continue + elif text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + continue + + buffer += text + delta = ToolCallDelta( + content=text, + parse_status=ToolCallParseStatus.in_progress, + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=delta, + stop_reason=stop_reason, + ) + ) + else: + buffer += text + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=text, + stop_reason=stop_reason, + ) + ) + + # parse tool calls and report errors + message = decode_assistant_message_from_content(buffer, stop_reason) + + parsed_tool_calls = len(message.tool_calls) > 0 + if ipython and not parsed_tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.failure, + ), + stop_reason=stop_reason, + ) + ) + + for tool_call in message.tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content=tool_call, + parse_status=ToolCallParseStatus.success, + ), + stop_reason=stop_reason, + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + stop_reason=stop_reason, + ) + ) #TODO: Consolidate this with impl in llama-models -def decode_assistant_message_from_content(content: str) -> CompletionMessage: +def decode_assistant_message_from_content( + content: str, + stop_reason: StopReason, +) -> CompletionMessage: ipython = content.startswith("<|python_tag|>") if ipython: content = content[len("<|python_tag|>") :] @@ -86,11 +212,6 @@ def decode_assistant_message_from_content(content: str) -> CompletionMessage: elif content.endswith("<|eom_id|>"): content = content[: -len("<|eom_id|>")] stop_reason = StopReason.end_of_message - else: - # Ollama does not return <|eot_id|> - # and hence we explicitly set it as the default. - #TODO: Check for StopReason.out_of_tokens - stop_reason = StopReason.end_of_turn tool_name = None tool_arguments = {} diff --git a/tests/test_inference.py b/tests/test_inference.py index 6d11ba415..bbd4a2c1a 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,22 +1,33 @@ # Run this test using the following command: # python -m unittest tests/test_inference.py +import asyncio import os +import textwrap import unittest +from datetime import datetime + from llama_models.llama3_1.api.datatypes import ( + BuiltinTool, InstructModel, - UserMessage + UserMessage, + StopReason, + SystemMessage, ) from llama_toolchain.inference.api.config import ( ImplType, InferenceConfig, InlineImplConfig, + RemoteImplConfig, ModelCheckpointConfig, PytorchCheckpoint, CheckpointQuantizationFormat, ) +from llama_toolchain.inference.api_instance import ( + get_inference_api_instance, +) from llama_toolchain.inference.api.datatypes import ( ChatCompletionResponseEventType, ) @@ -37,7 +48,13 @@ llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token < class InferenceTests(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): + @classmethod + def setUpClass(cls): + # This runs the async setup function + asyncio.run(cls.asyncSetUpClass()) + + @classmethod + async def asyncSetUpClass(cls): # assert model exists on local model_dir = os.path.expanduser("~/.llama/checkpoints/Meta-Llama-3.1-8B-Instruct/original/") assert os.path.isdir(model_dir), HELPER_MSG @@ -45,7 +62,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): tokenizer_path = os.path.join(model_dir, "tokenizer.model") assert os.path.exists(tokenizer_path), HELPER_MSG - inference_config = InlineImplConfig( + inline_config = InlineImplConfig( checkpoint_config=ModelCheckpointConfig( checkpoint=PytorchCheckpoint( checkpoint_dir=model_dir, @@ -56,14 +73,74 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): ), max_seq_len=2048, ) + inference_config = InferenceConfig( + impl_config=inline_config + ) - self.inference = InferenceImpl(inference_config) - await self.inference.initialize() + # -- For faster testing iteration -- + # remote_config = RemoteImplConfig( + # url="http://localhost:5000" + # ) + # inference_config = InferenceConfig( + # impl_config=remote_config + # ) - async def asyncTearDown(self): - await self.inference.shutdown() + cls.api = await get_inference_api_instance(inference_config) + await cls.api.initialize() - async def test_inline_inference_no_streaming(self): + current_date = datetime.now() + formatted_date = current_date.strftime("%d %B %Y") + cls.system_prompt = SystemMessage( + content=textwrap.dedent(f""" + Environment: ipython + Tools: brave_search + + Cutting Knowledge Date: December 2023 + Today Date:{formatted_date} + + """), + ) + cls.system_prompt_with_custom_tool = SystemMessage( + content=textwrap.dedent(""" + Environment: ipython + Tools: brave_search, wolfram_alpha, photogen + + Cutting Knowledge Date: December 2023 + Today Date: 30 July 2024 + + + You have access to the following functions: + + Use the function 'get_boiling_point' to 'Get the boiling point of a imaginary liquids (eg. polyjuice)' + {"name": "get_boiling_point", "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", "parameters": {"liquid_name": {"param_type": "string", "description": "The name of the liquid", "required": true}, "celcius": {"param_type": "boolean", "description": "Whether to return the boiling point in Celcius", "required": false}}} + + + Think very carefully before calling functions. + If you choose to call a function ONLY reply in the following format with no prefix or suffix: + + {"example_name": "example_value"} + + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + + """ + ), + ) + + @classmethod + def tearDownClass(cls): + # This runs the async teardown function + asyncio.run(cls.asyncTearDownClass()) + + @classmethod + async def asyncTearDownClass(cls): + await cls.api.shutdown() + + async def test_text(self): request = ChatCompletionRequest( model=InstructModel.llama3_8b_chat, messages=[ @@ -73,7 +150,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): ], stream=False, ) - iterator = self.inference.chat_completion(request) + iterator = InferenceTests.api.chat_completion(request) async for chunk in iterator: response = chunk @@ -81,7 +158,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): result = response.completion_message.content self.assertTrue("Paris" in result, result) - async def test_inline_inference_streaming(self): + async def test_text_streaming(self): request = ChatCompletionRequest( model=InstructModel.llama3_8b_chat, messages=[ @@ -91,12 +168,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): ], stream=True, ) - iterator = self.inference.chat_completion(request) + iterator = InferenceTests.api.chat_completion(request) events = [] async for chunk in iterator: events.append(chunk.event) - + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") self.assertEqual( events[0].event_type, @@ -112,3 +189,119 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): response += e.delta self.assertTrue("Paris" in response, response) + + async def test_custom_tool_call(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + InferenceTests.system_prompt_with_custom_tool, + UserMessage( + content="Use provided function to find the boiling point of polyjuice in fahrenheit?", + ), + ], + stream=False, + ) + iterator = InferenceTests.api.chat_completion(request) + async for r in iterator: + response = r + + completion_message = response.completion_message + + self.assertEqual(completion_message.content, "") + + # FIXME: This test fails since there is a bug where + # custom tool calls return incoorect stop_reason as out_of_tokens + # instead of end_of_turn + # self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) + + self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) + self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point") + + args = completion_message.tool_calls[0].arguments + self.assertTrue(isinstance(args, dict)) + self.assertTrue(args["liquid_name"], "polyjuice") + + async def test_tool_call_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + self.system_prompt, + UserMessage( + content="Who is the current US President?", + ), + ], + stream=True, + ) + iterator = InferenceTests.api.chat_completion(request) + + events = [] + async for chunk in iterator: + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + events.append(chunk.event) + + self.assertEqual( + events[0].event_type, + ChatCompletionResponseEventType.start + ) + # last event is of type "complete" + self.assertEqual( + events[-1].event_type, + ChatCompletionResponseEventType.complete + ) + # last but one event should be eom with tool call + self.assertEqual( + events[-2].event_type, + ChatCompletionResponseEventType.progress + ) + self.assertEqual( + events[-2].stop_reason, + StopReason.end_of_message + ) + self.assertEqual( + events[-2].delta.content.tool_name, + BuiltinTool.brave_search + ) + + async def test_custom_tool_call_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + InferenceTests.system_prompt_with_custom_tool, + UserMessage( + content="Use provided function to find the boiling point of polyjuice?", + ), + ], + stream=True, + ) + iterator = InferenceTests.api.chat_completion(request) + events = [] + async for chunk in iterator: + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + events.append(chunk.event) + + self.assertEqual( + events[0].event_type, + ChatCompletionResponseEventType.start + ) + # last event is of type "complete" + self.assertEqual( + events[-1].event_type, + ChatCompletionResponseEventType.complete + ) + self.assertEqual( + events[-1].stop_reason, + StopReason.end_of_turn + ) + # last but one event should be eom with tool call + self.assertEqual( + events[-2].event_type, + ChatCompletionResponseEventType.progress + ) + self.assertEqual( + events[-2].stop_reason, + StopReason.end_of_turn + ) + self.assertEqual( + events[-2].delta.content.tool_name, + "get_boiling_point" + ) diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index d37ff26c3..6f6b5d1a8 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -1,6 +1,6 @@ import textwrap -import unittest -from datetime import datetime +import unittest +from datetime import datetime from llama_models.llama3_1.api.datatypes import ( BuiltinTool, @@ -9,7 +9,9 @@ from llama_models.llama3_1.api.datatypes import ( StopReason, SystemMessage, ) - +from llama_toolchain.inference.api.datatypes import ( + ChatCompletionResponseEventType, +) from llama_toolchain.inference.api.endpoints import ( ChatCompletionRequest ) @@ -29,9 +31,9 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): url="http://localhost:11434", ) - # setup ollama - self.inference = OllamaInference(ollama_config) - await self.inference.initialize() + # setup ollama + self.api = OllamaInference(ollama_config) + await self.api.initialize() current_date = datetime.now() formatted_date = current_date.strftime("%d %B %Y") @@ -78,7 +80,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): ) async def asyncTearDown(self): - await self.inference.shutdown() + await self.api.shutdown() async def test_text(self): request = ChatCompletionRequest( @@ -90,12 +92,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): ], stream=False, ) - iterator = self.inference.chat_completion(request) + iterator = self.api.chat_completion(request) async for r in iterator: response = r self.assertTrue("Paris" in response.completion_message.content) - self.assertEquals(response.completion_message.stop_reason, StopReason.end_of_turn) + self.assertEqual(response.completion_message.stop_reason, StopReason.end_of_turn) async def test_tool_call(self): request = ChatCompletionRequest( @@ -108,21 +110,21 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): ], stream=False, ) - iterator = self.inference.chat_completion(request) + iterator = self.api.chat_completion(request) async for r in iterator: response = r completion_message = response.completion_message - - self.assertEquals(completion_message.content, "") - self.assertEquals(completion_message.stop_reason, StopReason.end_of_message) - - self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls) - self.assertEquals(completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search) + + self.assertEqual(completion_message.content, "") + self.assertEqual(completion_message.stop_reason, StopReason.end_of_message) + + self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) + self.assertEqual(completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search) self.assertTrue( "president" in completion_message.tool_calls[0].arguments["query"].lower() ) - + async def test_code_execution(self): request = ChatCompletionRequest( model=InstructModel.llama3_8b_chat, @@ -134,17 +136,17 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): ], stream=False, ) - iterator = self.inference.chat_completion(request) + iterator = self.api.chat_completion(request) async for r in iterator: response = r completion_message = response.completion_message - self.assertEquals(completion_message.content, "") - self.assertEquals(completion_message.stop_reason, StopReason.end_of_message) - - self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls) - self.assertEquals(completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter) + self.assertEqual(completion_message.content, "") + self.assertEqual(completion_message.stop_reason, StopReason.end_of_message) + + self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) + self.assertEqual(completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter) code = completion_message.tool_calls[0].arguments["code"] self.assertTrue("def " in code.lower(), code) @@ -154,23 +156,135 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): messages=[ self.system_prompt_with_custom_tool, UserMessage( - content="Use provided function to find the boiling point of polyjuice in fahrenheit?", + content="Use provided function to find the boiling point of polyjuice?", ), ], stream=False, ) - iterator = self.inference.chat_completion(request) + iterator = self.api.chat_completion(request) async for r in iterator: response = r completion_message = response.completion_message - + self.assertEqual(completion_message.content, "") - self.assertEquals(completion_message.stop_reason, StopReason.end_of_turn) - - self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls) - self.assertEquals(completion_message.tool_calls[0].tool_name, "get_boiling_point") + self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) + + self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) + self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point") args = completion_message.tool_calls[0].arguments self.assertTrue(isinstance(args, dict)) self.assertTrue(args["liquid_name"], "polyjuice") + + + async def test_text_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + UserMessage( + content="What is the capital of France?", + ), + ], + stream=True, + ) + iterator = self.api.chat_completion(request) + events = [] + async for chunk in iterator: + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + events.append(chunk.event) + + response = "" + for e in events[1:-1]: + response += e.delta + + self.assertEqual( + events[0].event_type, + ChatCompletionResponseEventType.start + ) + # last event is of type "complete" + self.assertEqual( + events[-1].event_type, + ChatCompletionResponseEventType.complete + ) + # last but 1 event should be of type "progress" + self.assertEqual( + events[-2].event_type, + ChatCompletionResponseEventType.progress + ) + self.assertEqual( + events[-2].stop_reason, + None, + ) + self.assertTrue("Paris" in response, response) + + async def test_tool_call_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + self.system_prompt, + UserMessage( + content="Who is the current US President?", + ), + ], + stream=True, + ) + iterator = self.api.chat_completion(request) + events = [] + async for chunk in iterator: + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + events.append(chunk.event) + + self.assertEqual( + events[0].event_type, + ChatCompletionResponseEventType.start + ) + # last event is of type "complete" + self.assertEqual( + events[-1].event_type, + ChatCompletionResponseEventType.complete + ) + + async def test_custom_tool_call_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + self.system_prompt_with_custom_tool, + UserMessage( + content="Use provided function to find the boiling point of polyjuice?", + ), + ], + stream=True, + ) + iterator = self.api.chat_completion(request) + events = [] + async for chunk in iterator: + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + events.append(chunk.event) + + self.assertEqual( + events[0].event_type, + ChatCompletionResponseEventType.start + ) + # last event is of type "complete" + self.assertEqual( + events[-1].event_type, + ChatCompletionResponseEventType.complete + ) + self.assertEqual( + events[-1].stop_reason, + StopReason.end_of_turn + ) + # last but one event should be eom with tool call + self.assertEqual( + events[-2].event_type, + ChatCompletionResponseEventType.progress + ) + self.assertEqual( + events[-2].delta.content.tool_name, + "get_boiling_point" + ) + self.assertEqual( + events[-2].stop_reason, + StopReason.end_of_turn + )