mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
add streaming support for ollama inference with tests
This commit is contained in:
parent
0e75e73fa7
commit
0e985648f5
4 changed files with 491 additions and 61 deletions
|
@ -1,6 +1,6 @@
|
|||
import textwrap
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import (
|
||||
BuiltinTool,
|
||||
|
@ -9,7 +9,9 @@ from llama_models.llama3_1.api.datatypes import (
|
|||
StopReason,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
from llama_toolchain.inference.api.datatypes import (
|
||||
ChatCompletionResponseEventType,
|
||||
)
|
||||
from llama_toolchain.inference.api.endpoints import (
|
||||
ChatCompletionRequest
|
||||
)
|
||||
|
@ -29,9 +31,9 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
url="http://localhost:11434",
|
||||
)
|
||||
|
||||
# setup ollama
|
||||
self.inference = OllamaInference(ollama_config)
|
||||
await self.inference.initialize()
|
||||
# setup ollama
|
||||
self.api = OllamaInference(ollama_config)
|
||||
await self.api.initialize()
|
||||
|
||||
current_date = datetime.now()
|
||||
formatted_date = current_date.strftime("%d %B %Y")
|
||||
|
@ -78,7 +80,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
)
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.inference.shutdown()
|
||||
await self.api.shutdown()
|
||||
|
||||
async def test_text(self):
|
||||
request = ChatCompletionRequest(
|
||||
|
@ -90,12 +92,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
],
|
||||
stream=False,
|
||||
)
|
||||
iterator = self.inference.chat_completion(request)
|
||||
iterator = self.api.chat_completion(request)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
self.assertTrue("Paris" in response.completion_message.content)
|
||||
self.assertEquals(response.completion_message.stop_reason, StopReason.end_of_turn)
|
||||
self.assertEqual(response.completion_message.stop_reason, StopReason.end_of_turn)
|
||||
|
||||
async def test_tool_call(self):
|
||||
request = ChatCompletionRequest(
|
||||
|
@ -108,21 +110,21 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
],
|
||||
stream=False,
|
||||
)
|
||||
iterator = self.inference.chat_completion(request)
|
||||
iterator = self.api.chat_completion(request)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
self.assertEquals(completion_message.content, "")
|
||||
self.assertEquals(completion_message.stop_reason, StopReason.end_of_message)
|
||||
|
||||
self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls)
|
||||
self.assertEquals(completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search)
|
||||
|
||||
self.assertEqual(completion_message.content, "")
|
||||
self.assertEqual(completion_message.stop_reason, StopReason.end_of_message)
|
||||
|
||||
self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls)
|
||||
self.assertEqual(completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search)
|
||||
self.assertTrue(
|
||||
"president" in completion_message.tool_calls[0].arguments["query"].lower()
|
||||
)
|
||||
|
||||
|
||||
async def test_code_execution(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
|
@ -134,17 +136,17 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
],
|
||||
stream=False,
|
||||
)
|
||||
iterator = self.inference.chat_completion(request)
|
||||
iterator = self.api.chat_completion(request)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
self.assertEquals(completion_message.content, "")
|
||||
self.assertEquals(completion_message.stop_reason, StopReason.end_of_message)
|
||||
|
||||
self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls)
|
||||
self.assertEquals(completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter)
|
||||
self.assertEqual(completion_message.content, "")
|
||||
self.assertEqual(completion_message.stop_reason, StopReason.end_of_message)
|
||||
|
||||
self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls)
|
||||
self.assertEqual(completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter)
|
||||
code = completion_message.tool_calls[0].arguments["code"]
|
||||
self.assertTrue("def " in code.lower(), code)
|
||||
|
||||
|
@ -154,23 +156,135 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
messages=[
|
||||
self.system_prompt_with_custom_tool,
|
||||
UserMessage(
|
||||
content="Use provided function to find the boiling point of polyjuice in fahrenheit?",
|
||||
content="Use provided function to find the boiling point of polyjuice?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
iterator = self.inference.chat_completion(request)
|
||||
iterator = self.api.chat_completion(request)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
|
||||
self.assertEqual(completion_message.content, "")
|
||||
self.assertEquals(completion_message.stop_reason, StopReason.end_of_turn)
|
||||
|
||||
self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls)
|
||||
self.assertEquals(completion_message.tool_calls[0].tool_name, "get_boiling_point")
|
||||
self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn)
|
||||
|
||||
self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls)
|
||||
self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point")
|
||||
|
||||
args = completion_message.tool_calls[0].arguments
|
||||
self.assertTrue(isinstance(args, dict))
|
||||
self.assertTrue(args["liquid_name"], "polyjuice")
|
||||
|
||||
|
||||
async def test_text_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
iterator = self.api.chat_completion(request)
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
# print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ")
|
||||
events.append(chunk.event)
|
||||
|
||||
response = ""
|
||||
for e in events[1:-1]:
|
||||
response += e.delta
|
||||
|
||||
self.assertEqual(
|
||||
events[0].event_type,
|
||||
ChatCompletionResponseEventType.start
|
||||
)
|
||||
# last event is of type "complete"
|
||||
self.assertEqual(
|
||||
events[-1].event_type,
|
||||
ChatCompletionResponseEventType.complete
|
||||
)
|
||||
# last but 1 event should be of type "progress"
|
||||
self.assertEqual(
|
||||
events[-2].event_type,
|
||||
ChatCompletionResponseEventType.progress
|
||||
)
|
||||
self.assertEqual(
|
||||
events[-2].stop_reason,
|
||||
None,
|
||||
)
|
||||
self.assertTrue("Paris" in response, response)
|
||||
|
||||
async def test_tool_call_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
messages=[
|
||||
self.system_prompt,
|
||||
UserMessage(
|
||||
content="Who is the current US President?",
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
iterator = self.api.chat_completion(request)
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
# print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ")
|
||||
events.append(chunk.event)
|
||||
|
||||
self.assertEqual(
|
||||
events[0].event_type,
|
||||
ChatCompletionResponseEventType.start
|
||||
)
|
||||
# last event is of type "complete"
|
||||
self.assertEqual(
|
||||
events[-1].event_type,
|
||||
ChatCompletionResponseEventType.complete
|
||||
)
|
||||
|
||||
async def test_custom_tool_call_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
messages=[
|
||||
self.system_prompt_with_custom_tool,
|
||||
UserMessage(
|
||||
content="Use provided function to find the boiling point of polyjuice?",
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
iterator = self.api.chat_completion(request)
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
# print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ")
|
||||
events.append(chunk.event)
|
||||
|
||||
self.assertEqual(
|
||||
events[0].event_type,
|
||||
ChatCompletionResponseEventType.start
|
||||
)
|
||||
# last event is of type "complete"
|
||||
self.assertEqual(
|
||||
events[-1].event_type,
|
||||
ChatCompletionResponseEventType.complete
|
||||
)
|
||||
self.assertEqual(
|
||||
events[-1].stop_reason,
|
||||
StopReason.end_of_turn
|
||||
)
|
||||
# last but one event should be eom with tool call
|
||||
self.assertEqual(
|
||||
events[-2].event_type,
|
||||
ChatCompletionResponseEventType.progress
|
||||
)
|
||||
self.assertEqual(
|
||||
events[-2].delta.content.tool_name,
|
||||
"get_boiling_point"
|
||||
)
|
||||
self.assertEqual(
|
||||
events[-2].stop_reason,
|
||||
StopReason.end_of_turn
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue