add streaming support for ollama inference with tests

This commit is contained in:
Hardik Shah 2024-07-31 19:33:36 -07:00
parent 0e75e73fa7
commit 0e985648f5
4 changed files with 491 additions and 61 deletions

View file

@ -1,6 +1,6 @@
import textwrap
import unittest
from datetime import datetime
import unittest
from datetime import datetime
from llama_models.llama3_1.api.datatypes import (
BuiltinTool,
@ -9,7 +9,9 @@ from llama_models.llama3_1.api.datatypes import (
StopReason,
SystemMessage,
)
from llama_toolchain.inference.api.datatypes import (
ChatCompletionResponseEventType,
)
from llama_toolchain.inference.api.endpoints import (
ChatCompletionRequest
)
@ -29,9 +31,9 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
url="http://localhost:11434",
)
# setup ollama
self.inference = OllamaInference(ollama_config)
await self.inference.initialize()
# setup ollama
self.api = OllamaInference(ollama_config)
await self.api.initialize()
current_date = datetime.now()
formatted_date = current_date.strftime("%d %B %Y")
@ -78,7 +80,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
)
async def asyncTearDown(self):
await self.inference.shutdown()
await self.api.shutdown()
async def test_text(self):
request = ChatCompletionRequest(
@ -90,12 +92,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
],
stream=False,
)
iterator = self.inference.chat_completion(request)
iterator = self.api.chat_completion(request)
async for r in iterator:
response = r
self.assertTrue("Paris" in response.completion_message.content)
self.assertEquals(response.completion_message.stop_reason, StopReason.end_of_turn)
self.assertEqual(response.completion_message.stop_reason, StopReason.end_of_turn)
async def test_tool_call(self):
request = ChatCompletionRequest(
@ -108,21 +110,21 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
],
stream=False,
)
iterator = self.inference.chat_completion(request)
iterator = self.api.chat_completion(request)
async for r in iterator:
response = r
completion_message = response.completion_message
self.assertEquals(completion_message.content, "")
self.assertEquals(completion_message.stop_reason, StopReason.end_of_message)
self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls)
self.assertEquals(completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search)
self.assertEqual(completion_message.content, "")
self.assertEqual(completion_message.stop_reason, StopReason.end_of_message)
self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls)
self.assertEqual(completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search)
self.assertTrue(
"president" in completion_message.tool_calls[0].arguments["query"].lower()
)
async def test_code_execution(self):
request = ChatCompletionRequest(
model=InstructModel.llama3_8b_chat,
@ -134,17 +136,17 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
],
stream=False,
)
iterator = self.inference.chat_completion(request)
iterator = self.api.chat_completion(request)
async for r in iterator:
response = r
completion_message = response.completion_message
self.assertEquals(completion_message.content, "")
self.assertEquals(completion_message.stop_reason, StopReason.end_of_message)
self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls)
self.assertEquals(completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter)
self.assertEqual(completion_message.content, "")
self.assertEqual(completion_message.stop_reason, StopReason.end_of_message)
self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls)
self.assertEqual(completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter)
code = completion_message.tool_calls[0].arguments["code"]
self.assertTrue("def " in code.lower(), code)
@ -154,23 +156,135 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
messages=[
self.system_prompt_with_custom_tool,
UserMessage(
content="Use provided function to find the boiling point of polyjuice in fahrenheit?",
content="Use provided function to find the boiling point of polyjuice?",
),
],
stream=False,
)
iterator = self.inference.chat_completion(request)
iterator = self.api.chat_completion(request)
async for r in iterator:
response = r
completion_message = response.completion_message
self.assertEqual(completion_message.content, "")
self.assertEquals(completion_message.stop_reason, StopReason.end_of_turn)
self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls)
self.assertEquals(completion_message.tool_calls[0].tool_name, "get_boiling_point")
self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn)
self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls)
self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point")
args = completion_message.tool_calls[0].arguments
self.assertTrue(isinstance(args, dict))
self.assertTrue(args["liquid_name"], "polyjuice")
async def test_text_streaming(self):
request = ChatCompletionRequest(
model=InstructModel.llama3_8b_chat,
messages=[
UserMessage(
content="What is the capital of France?",
),
],
stream=True,
)
iterator = self.api.chat_completion(request)
events = []
async for chunk in iterator:
# print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ")
events.append(chunk.event)
response = ""
for e in events[1:-1]:
response += e.delta
self.assertEqual(
events[0].event_type,
ChatCompletionResponseEventType.start
)
# last event is of type "complete"
self.assertEqual(
events[-1].event_type,
ChatCompletionResponseEventType.complete
)
# last but 1 event should be of type "progress"
self.assertEqual(
events[-2].event_type,
ChatCompletionResponseEventType.progress
)
self.assertEqual(
events[-2].stop_reason,
None,
)
self.assertTrue("Paris" in response, response)
async def test_tool_call_streaming(self):
request = ChatCompletionRequest(
model=InstructModel.llama3_8b_chat,
messages=[
self.system_prompt,
UserMessage(
content="Who is the current US President?",
),
],
stream=True,
)
iterator = self.api.chat_completion(request)
events = []
async for chunk in iterator:
# print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ")
events.append(chunk.event)
self.assertEqual(
events[0].event_type,
ChatCompletionResponseEventType.start
)
# last event is of type "complete"
self.assertEqual(
events[-1].event_type,
ChatCompletionResponseEventType.complete
)
async def test_custom_tool_call_streaming(self):
request = ChatCompletionRequest(
model=InstructModel.llama3_8b_chat,
messages=[
self.system_prompt_with_custom_tool,
UserMessage(
content="Use provided function to find the boiling point of polyjuice?",
),
],
stream=True,
)
iterator = self.api.chat_completion(request)
events = []
async for chunk in iterator:
# print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ")
events.append(chunk.event)
self.assertEqual(
events[0].event_type,
ChatCompletionResponseEventType.start
)
# last event is of type "complete"
self.assertEqual(
events[-1].event_type,
ChatCompletionResponseEventType.complete
)
self.assertEqual(
events[-1].stop_reason,
StopReason.end_of_turn
)
# last but one event should be eom with tool call
self.assertEqual(
events[-2].event_type,
ChatCompletionResponseEventType.progress
)
self.assertEqual(
events[-2].delta.content.tool_name,
"get_boiling_point"
)
self.assertEqual(
events[-2].stop_reason,
StopReason.end_of_turn
)