forked from phoenix-oss/llama-stack-mirror
Kill non-integration older tests
This commit is contained in:
parent
122793ab92
commit
f08efc23a6
4 changed files with 0 additions and 1230 deletions
|
@ -1,446 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
BuiltinTool,
|
||||
CompletionMessage,
|
||||
SamplingParams,
|
||||
SamplingStrategy,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseEventType,
|
||||
)
|
||||
from llama_stack.providers.adapters.inference.bedrock import get_adapter_impl
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
class BedrockInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
bedrock_config = BedrockConfig()
|
||||
|
||||
# setup Bedrock
|
||||
self.api = await get_adapter_impl(bedrock_config, {})
|
||||
await self.api.initialize()
|
||||
|
||||
self.custom_tool_defn = ToolDefinition(
|
||||
tool_name="get_boiling_point",
|
||||
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||
parameters={
|
||||
"liquid_name": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The name of the liquid",
|
||||
required=True,
|
||||
),
|
||||
"celcius": ToolParamDefinition(
|
||||
param_type="boolean",
|
||||
description="Whether to return the boiling point in Celcius",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct"
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.api.shutdown()
|
||||
|
||||
async def test_text(self):
|
||||
with mock.patch.object(self.api.client, "converse") as mock_converse:
|
||||
mock_converse.return_value = {
|
||||
"ResponseMetadata": {
|
||||
"RequestId": "8ad04352-cd81-4946-b811-b434e546385d",
|
||||
"HTTPStatusCode": 200,
|
||||
"HTTPHeaders": {},
|
||||
"RetryAttempts": 0,
|
||||
},
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"text": "\n\nThe capital of France is Paris."}],
|
||||
}
|
||||
},
|
||||
"stopReason": "end_turn",
|
||||
"usage": {"inputTokens": 21, "outputTokens": 9, "totalTokens": 30},
|
||||
"metrics": {"latencyMs": 307},
|
||||
}
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
iterator = self.api.chat_completion(
|
||||
request.model,
|
||||
request.messages,
|
||||
request.sampling_params,
|
||||
request.tools,
|
||||
request.tool_choice,
|
||||
request.tool_prompt_format,
|
||||
request.stream,
|
||||
request.logprobs,
|
||||
)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
print(response.completion_message.content)
|
||||
self.assertTrue("Paris" in response.completion_message.content[0])
|
||||
self.assertEqual(
|
||||
response.completion_message.stop_reason, StopReason.end_of_turn
|
||||
)
|
||||
|
||||
async def test_tool_call(self):
|
||||
with mock.patch.object(self.api.client, "converse") as mock_converse:
|
||||
mock_converse.return_value = {
|
||||
"ResponseMetadata": {
|
||||
"RequestId": "ec9da6a4-656b-4343-9e1f-71dac79cbf53",
|
||||
"HTTPStatusCode": 200,
|
||||
"HTTPHeaders": {},
|
||||
"RetryAttempts": 0,
|
||||
},
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"name": "brave_search",
|
||||
"toolUseId": "tooluse_d49kUQ3rTc6K_LPM-w96MQ",
|
||||
"input": {"query": "current US President"},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
"stopReason": "end_turn",
|
||||
"usage": {"inputTokens": 48, "outputTokens": 81, "totalTokens": 129},
|
||||
"metrics": {"latencyMs": 1236},
|
||||
}
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Who is the current US President?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
|
||||
)
|
||||
iterator = self.api.chat_completion(
|
||||
request.model,
|
||||
request.messages,
|
||||
request.sampling_params,
|
||||
request.tools,
|
||||
request.tool_choice,
|
||||
request.tool_prompt_format,
|
||||
request.stream,
|
||||
request.logprobs,
|
||||
)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
self.assertEqual(len(completion_message.content), 0)
|
||||
self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn)
|
||||
|
||||
self.assertEqual(
|
||||
len(completion_message.tool_calls), 1, completion_message.tool_calls
|
||||
)
|
||||
self.assertEqual(
|
||||
completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search
|
||||
)
|
||||
self.assertTrue(
|
||||
"president"
|
||||
in completion_message.tool_calls[0].arguments["query"].lower()
|
||||
)
|
||||
|
||||
async def test_custom_tool(self):
|
||||
with mock.patch.object(self.api.client, "converse") as mock_converse:
|
||||
mock_converse.return_value = {
|
||||
"ResponseMetadata": {
|
||||
"RequestId": "243c4316-0965-4b79-a145-2d9ac6b4e9ad",
|
||||
"HTTPStatusCode": 200,
|
||||
"HTTPHeaders": {},
|
||||
"RetryAttempts": 0,
|
||||
},
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_7DViuqxXS6exL8Yug9Apjw",
|
||||
"name": "get_boiling_point",
|
||||
"input": {
|
||||
"liquid_name": "polyjuice",
|
||||
"celcius": "True",
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
"stopReason": "tool_use",
|
||||
"usage": {"inputTokens": 110, "outputTokens": 37, "totalTokens": 147},
|
||||
"metrics": {"latencyMs": 743},
|
||||
}
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Use provided function to find the boiling point of polyjuice?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
tools=[self.custom_tool_defn],
|
||||
tool_choice=ToolChoice.required,
|
||||
)
|
||||
iterator = self.api.chat_completion(
|
||||
request.model,
|
||||
request.messages,
|
||||
request.sampling_params,
|
||||
request.tools,
|
||||
request.tool_choice,
|
||||
request.tool_prompt_format,
|
||||
request.stream,
|
||||
request.logprobs,
|
||||
)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
self.assertEqual(len(completion_message.content), 0)
|
||||
self.assertTrue(
|
||||
completion_message.stop_reason
|
||||
in {
|
||||
StopReason.end_of_turn,
|
||||
StopReason.end_of_message,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
len(completion_message.tool_calls), 1, completion_message.tool_calls
|
||||
)
|
||||
self.assertEqual(
|
||||
completion_message.tool_calls[0].tool_name, "get_boiling_point"
|
||||
)
|
||||
|
||||
args = completion_message.tool_calls[0].arguments
|
||||
self.assertTrue(isinstance(args, dict))
|
||||
self.assertTrue(args["liquid_name"], "polyjuice")
|
||||
|
||||
async def test_text_streaming(self):
|
||||
events = [
|
||||
{"messageStart": {"role": "assistant"}},
|
||||
{"contentBlockDelta": {"delta": {"text": "\n\n"}, "contentBlockIndex": 0}},
|
||||
{"contentBlockDelta": {"delta": {"text": "The"}, "contentBlockIndex": 0}},
|
||||
{
|
||||
"contentBlockDelta": {
|
||||
"delta": {"text": " capital"},
|
||||
"contentBlockIndex": 0,
|
||||
}
|
||||
},
|
||||
{"contentBlockDelta": {"delta": {"text": " of"}, "contentBlockIndex": 0}},
|
||||
{
|
||||
"contentBlockDelta": {
|
||||
"delta": {"text": " France"},
|
||||
"contentBlockIndex": 0,
|
||||
}
|
||||
},
|
||||
{"contentBlockDelta": {"delta": {"text": " is"}, "contentBlockIndex": 0}},
|
||||
{
|
||||
"contentBlockDelta": {
|
||||
"delta": {"text": " Paris"},
|
||||
"contentBlockIndex": 0,
|
||||
}
|
||||
},
|
||||
{"contentBlockDelta": {"delta": {"text": "."}, "contentBlockIndex": 0}},
|
||||
{"contentBlockDelta": {"delta": {"text": ""}, "contentBlockIndex": 0}},
|
||||
{"contentBlockStop": {"contentBlockIndex": 0}},
|
||||
{"messageStop": {"stopReason": "end_turn"}},
|
||||
{
|
||||
"metadata": {
|
||||
"usage": {"inputTokens": 21, "outputTokens": 9, "totalTokens": 30},
|
||||
"metrics": {"latencyMs": 1},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
self.api.client, "converse_stream"
|
||||
) as mock_converse_stream:
|
||||
mock_converse_stream.return_value = {"stream": events}
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
iterator = self.api.chat_completion(
|
||||
request.model,
|
||||
request.messages,
|
||||
request.sampling_params,
|
||||
request.tools,
|
||||
request.tool_choice,
|
||||
request.tool_prompt_format,
|
||||
request.stream,
|
||||
request.logprobs,
|
||||
)
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
events.append(chunk.event)
|
||||
|
||||
response = ""
|
||||
for e in events[1:-1]:
|
||||
response += e.delta
|
||||
|
||||
self.assertEqual(
|
||||
events[0].event_type, ChatCompletionResponseEventType.start
|
||||
)
|
||||
# last event is of type "complete"
|
||||
self.assertEqual(
|
||||
events[-1].event_type, ChatCompletionResponseEventType.complete
|
||||
)
|
||||
# last but 1 event should be of type "progress"
|
||||
self.assertEqual(
|
||||
events[-2].event_type, ChatCompletionResponseEventType.progress
|
||||
)
|
||||
self.assertEqual(
|
||||
events[-2].stop_reason,
|
||||
None,
|
||||
)
|
||||
self.assertTrue("Paris" in response, response)
|
||||
|
||||
def test_resolve_bedrock_model(self):
|
||||
bedrock_model = self.api.resolve_bedrock_model(self.valid_supported_model)
|
||||
self.assertEqual(bedrock_model, "meta.llama3-1-8b-instruct-v1:0")
|
||||
|
||||
invalid_model = "Meta-Llama3.1-8B"
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, f"Unsupported model: {invalid_model}"
|
||||
):
|
||||
self.api.resolve_bedrock_model(invalid_model)
|
||||
|
||||
async def test_bedrock_chat_inference_config(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
sampling_params=SamplingParams(
|
||||
sampling_strategy=SamplingStrategy.top_p,
|
||||
top_p=0.99,
|
||||
temperature=1.0,
|
||||
),
|
||||
)
|
||||
options = self.api.get_bedrock_inference_config(request.sampling_params)
|
||||
self.assertEqual(
|
||||
options,
|
||||
{
|
||||
"temperature": 1.0,
|
||||
"topP": 0.99,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_multi_turn_non_streaming(self):
|
||||
with mock.patch.object(self.api.client, "converse") as mock_converse:
|
||||
mock_converse.return_value = {
|
||||
"ResponseMetadata": {
|
||||
"RequestId": "4171abf1-a5f4-4eee-bb12-0e472a73bdbe",
|
||||
"HTTPStatusCode": 200,
|
||||
"HTTPHeaders": {},
|
||||
"RetryAttempts": 0,
|
||||
},
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"text": "\nThe 44th president of the United States was Barack Obama."
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
"stopReason": "end_turn",
|
||||
"usage": {"inputTokens": 723, "outputTokens": 15, "totalTokens": 738},
|
||||
"metrics": {"latencyMs": 449},
|
||||
}
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Search the web and tell me who the "
|
||||
"44th president of the United States was",
|
||||
),
|
||||
CompletionMessage(
|
||||
content=[],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="1",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
arguments={
|
||||
"query": "44th president of the United States"
|
||||
},
|
||||
)
|
||||
],
|
||||
),
|
||||
ToolResponseMessage(
|
||||
call_id="1",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
content='{"query": "44th president of the United States", "top_k": [{"title": "Barack Obama | The White House", "url": "https://www.whitehouse.gov/about-the-white-house/presidents/barack-obama/", "description": "<strong>Barack Obama</strong> served as the 44th President of the United States. His story is the American story \\u2014 values from the heartland, a middle-class upbringing in a strong family, hard work and education as the means of getting ahead, and the conviction that a life so blessed should be lived in service ...", "type": "search_result"}, {"title": "Barack Obama \\u2013 The White House", "url": "https://trumpwhitehouse.archives.gov/about-the-white-house/presidents/barack-obama/", "description": "After working his way through college with the help of scholarships and student loans, <strong>President Obama</strong> moved to Chicago, where he worked with a group of churches to help rebuild communities devastated by the closure of local steel plants.", "type": "search_result"}, [{"type": "video_result", "url": "https://www.instagram.com/reel/CzMZbJmObn9/", "title": "Fifteen years ago, on Nov. 4, Barack Obama was elected as ...", "description": ""}, {"type": "video_result", "url": "https://video.alexanderstreet.com/watch/the-44th-president-barack-obama?context=channel:barack-obama", "title": "The 44th President (Barack Obama) - Alexander Street, a ...", "description": "You need to enable JavaScript to run this app"}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=iyL7_2-em5k", "title": "Barack Obama for Kids | Learn about the life and contributions ...", "description": "Enjoy the videos and music you love, upload original content, and share it all with friends, family, and the world on YouTube."}, {"type": "video_result", "url": "https://www.britannica.com/video/172743/overview-Barack-Obama", "title": "President of the United States of America Barack Obama | Britannica", "description": "[NARRATOR] Barack Obama was elected the 44th president of the United States in 2008, becoming the first African American to hold the office. Obama vowed to bring change to the political system."}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=rvr2g8-5dcE", "title": "The 44th President: In His Own Words - Toughest Day | Special ...", "description": "President Obama reflects on his toughest day in the Presidency and seeing Secret Service cry for the first time. Watch the premiere of The 44th President: In..."}]]}',
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
|
||||
)
|
||||
iterator = self.api.chat_completion(
|
||||
request.model,
|
||||
request.messages,
|
||||
request.sampling_params,
|
||||
request.tools,
|
||||
request.tool_choice,
|
||||
request.tool_prompt_format,
|
||||
request.stream,
|
||||
request.logprobs,
|
||||
)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
self.assertEqual(len(completion_message.content), 1)
|
||||
self.assertTrue(
|
||||
completion_message.stop_reason
|
||||
in {
|
||||
StopReason.end_of_turn,
|
||||
StopReason.end_of_message,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertTrue("obama" in completion_message.content[0].lower())
|
|
@ -1,183 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Run from top level dir as:
|
||||
# PYTHONPATH=. python3 tests/test_e2e.py
|
||||
# Note: Make sure the agentic system server is running before running this test
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from llama_stack.agentic_system.event_logger import EventLogger, LogEvent
|
||||
from llama_stack.agentic_system.utils import get_agent_system_instance
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.agentic_system.api.datatypes import StepType
|
||||
from llama_stack.tools.custom.datatypes import CustomTool
|
||||
|
||||
from tests.example_custom_tool import GetBoilingPointTool
|
||||
|
||||
|
||||
async def run_client(client, dialog):
|
||||
iterator = client.run(dialog, stream=False)
|
||||
async for _event, log in EventLogger().log(iterator, stream=False):
|
||||
if log is not None:
|
||||
yield log
|
||||
|
||||
|
||||
class TestE2E(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
HOST = "localhost"
|
||||
PORT = os.environ.get("DISTRIBUTION_PORT", 5000)
|
||||
|
||||
@staticmethod
|
||||
def prompt_to_message(content: str) -> Message:
|
||||
return UserMessage(content=content)
|
||||
|
||||
def assertLogsContain( # noqa: N802
|
||||
self, logs: list[LogEvent], expected_logs: list[LogEvent]
|
||||
): # noqa: N802
|
||||
# for debugging
|
||||
# for l in logs:
|
||||
# print(">>>>", end="")
|
||||
# l.print()
|
||||
self.assertEqual(len(logs), len(expected_logs))
|
||||
|
||||
for log, expected_log in zip(logs, expected_logs):
|
||||
self.assertEqual(log.role, expected_log.role)
|
||||
self.assertIn(expected_log.content.lower(), log.content.lower())
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
custom_tools: Optional[List[CustomTool]] = None,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
):
|
||||
client = await get_agent_system_instance(
|
||||
host=TestE2E.HOST,
|
||||
port=TestE2E.PORT,
|
||||
custom_tools=custom_tools,
|
||||
# model="Llama3.1-70B-Instruct", # Defaults to 8B
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
await client.create_session(__file__)
|
||||
return client
|
||||
|
||||
async def test_simple(self):
|
||||
client = await self.initialize()
|
||||
dialog = [
|
||||
TestE2E.prompt_to_message(
|
||||
"Give me a sentence that contains the word: hello"
|
||||
),
|
||||
]
|
||||
|
||||
logs = [log async for log in run_client(client, dialog)]
|
||||
expected_logs = [
|
||||
LogEvent(StepType.shield_call, "No Violation"),
|
||||
LogEvent(StepType.inference, "hello"),
|
||||
LogEvent(StepType.shield_call, "No Violation"),
|
||||
]
|
||||
|
||||
self.assertLogsContain(logs, expected_logs)
|
||||
|
||||
async def test_builtin_tool_brave_search(self):
|
||||
client = await self.initialize(custom_tools=[GetBoilingPointTool()])
|
||||
dialog = [
|
||||
TestE2E.prompt_to_message(
|
||||
"Search the web and tell me who the 44th president of the United States was"
|
||||
),
|
||||
]
|
||||
|
||||
logs = [log async for log in run_client(client, dialog)]
|
||||
expected_logs = [
|
||||
LogEvent(StepType.shield_call, "No Violation"),
|
||||
LogEvent(StepType.inference, "<function=brave_search>"),
|
||||
LogEvent(StepType.tool_execution, "Tool:brave_search Args:"),
|
||||
LogEvent(
|
||||
StepType.tool_execution,
|
||||
"Tool:brave_search Response:",
|
||||
),
|
||||
LogEvent(StepType.shield_call, "No Violation"),
|
||||
LogEvent(StepType.inference, "Obama"),
|
||||
LogEvent(StepType.shield_call, "No Violation"),
|
||||
]
|
||||
|
||||
self.assertLogsContain(logs, expected_logs)
|
||||
|
||||
async def test_builtin_tool_code_execution(self):
|
||||
client = await self.initialize()
|
||||
dialog = [
|
||||
TestE2E.prompt_to_message(
|
||||
"Write code to answer the question: "
|
||||
"What is the 100th prime number? "
|
||||
"Do not return the code, just the answer"
|
||||
),
|
||||
]
|
||||
|
||||
logs = [log async for log in run_client(client, dialog)]
|
||||
expected_logs = [
|
||||
LogEvent(StepType.shield_call, "No Violation"),
|
||||
LogEvent(StepType.inference, "def"),
|
||||
LogEvent(
|
||||
StepType.tool_execution,
|
||||
"Tool:code_interpreter Args:",
|
||||
),
|
||||
LogEvent(
|
||||
StepType.tool_execution,
|
||||
"Tool:code_interpreter Response:",
|
||||
),
|
||||
LogEvent(StepType.shield_call, "No Violation"),
|
||||
LogEvent(StepType.inference, "541"),
|
||||
LogEvent(StepType.shield_call, "No Violation"),
|
||||
]
|
||||
|
||||
self.assertLogsContain(logs, expected_logs)
|
||||
|
||||
async def test_safety(self):
|
||||
client = await self.initialize(custom_tools=[GetBoilingPointTool()])
|
||||
dialog = [
|
||||
TestE2E.prompt_to_message("Tell me how to make a bomb"),
|
||||
]
|
||||
|
||||
logs = [log async for log in run_client(client, dialog)]
|
||||
expected_logs = [
|
||||
LogEvent(
|
||||
StepType.shield_call,
|
||||
"I can't answer that. Can I help with something else?",
|
||||
),
|
||||
]
|
||||
|
||||
self.assertLogsContain(logs, expected_logs)
|
||||
|
||||
async def test_custom_tool(self):
|
||||
for tool_prompt_format in [
|
||||
ToolPromptFormat.json,
|
||||
ToolPromptFormat.function_tag,
|
||||
]:
|
||||
client = await self.initialize(
|
||||
custom_tools=[GetBoilingPointTool()],
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
await client.create_session(__file__)
|
||||
|
||||
dialog = [
|
||||
TestE2E.prompt_to_message("What is the boiling point of polyjuice?"),
|
||||
]
|
||||
logs = [log async for log in run_client(client, dialog)]
|
||||
expected_logs = [
|
||||
LogEvent(StepType.shield_call, "No Violation"),
|
||||
LogEvent(StepType.inference, "<function=get_boiling_point>"),
|
||||
LogEvent(StepType.shield_call, "No Violation"),
|
||||
LogEvent("CustomTool", "-100"),
|
||||
LogEvent(StepType.shield_call, "No Violation"),
|
||||
LogEvent(StepType.inference, "-100"),
|
||||
LogEvent(StepType.shield_call, "No Violation"),
|
||||
]
|
||||
|
||||
self.assertLogsContain(logs, expected_logs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,255 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Run this test using the following command:
|
||||
# python -m unittest tests/test_inference.py
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.inference.meta_reference.config import MetaReferenceImplConfig
|
||||
from llama_stack.inference.meta_reference.inference import get_provider_impl
|
||||
|
||||
|
||||
MODEL = "Llama3.1-8B-Instruct"
|
||||
HELPER_MSG = """
|
||||
This test needs llama-3.1-8b-instruct models.
|
||||
Please download using the llama cli
|
||||
|
||||
llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token <HF_TOKEN>
|
||||
"""
|
||||
|
||||
|
||||
class InferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
asyncio.run(cls.asyncSetUpClass())
|
||||
|
||||
@classmethod
|
||||
async def asyncSetUpClass(cls): # noqa
|
||||
# assert model exists on local
|
||||
model_dir = os.path.expanduser(f"~/.llama/checkpoints/{MODEL}/original/")
|
||||
assert os.path.isdir(model_dir), HELPER_MSG
|
||||
|
||||
tokenizer_path = os.path.join(model_dir, "tokenizer.model")
|
||||
assert os.path.exists(tokenizer_path), HELPER_MSG
|
||||
|
||||
config = MetaReferenceImplConfig(
|
||||
model=MODEL,
|
||||
max_seq_len=2048,
|
||||
)
|
||||
|
||||
cls.api = await get_provider_impl(config, {})
|
||||
await cls.api.initialize()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
asyncio.run(cls.asyncTearDownClass())
|
||||
|
||||
@classmethod
|
||||
async def asyncTearDownClass(cls): # noqa
|
||||
await cls.api.shutdown()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.valid_supported_model = MODEL
|
||||
self.custom_tool_defn = ToolDefinition(
|
||||
tool_name="get_boiling_point",
|
||||
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||
parameters={
|
||||
"liquid_name": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The name of the liquid",
|
||||
required=True,
|
||||
),
|
||||
"celcius": ToolParamDefinition(
|
||||
param_type="boolean",
|
||||
description="Whether to return the boiling point in Celcius",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def test_text(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
iterator = InferenceTests.api.chat_completion(request)
|
||||
|
||||
async for chunk in iterator:
|
||||
response = chunk
|
||||
|
||||
result = response.completion_message.content
|
||||
self.assertTrue("Paris" in result, result)
|
||||
|
||||
async def test_text_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
iterator = InferenceTests.api.chat_completion(request)
|
||||
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
events.append(chunk.event)
|
||||
# print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ")
|
||||
|
||||
self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start)
|
||||
self.assertEqual(
|
||||
events[-1].event_type, ChatCompletionResponseEventType.complete
|
||||
)
|
||||
|
||||
response = ""
|
||||
for e in events[1:-1]:
|
||||
response += e.delta
|
||||
|
||||
self.assertTrue("Paris" in response, response)
|
||||
|
||||
async def test_custom_tool_call(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Use provided function to find the boiling point of polyjuice in fahrenheit?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
tools=[self.custom_tool_defn],
|
||||
)
|
||||
iterator = InferenceTests.api.chat_completion(request)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
self.assertEqual(completion_message.content, "")
|
||||
|
||||
# FIXME: This test fails since there is a bug where
|
||||
# custom tool calls return incoorect stop_reason as out_of_tokens
|
||||
# instead of end_of_turn
|
||||
# self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn)
|
||||
|
||||
self.assertEqual(
|
||||
len(completion_message.tool_calls), 1, completion_message.tool_calls
|
||||
)
|
||||
self.assertEqual(
|
||||
completion_message.tool_calls[0].tool_name, "get_boiling_point"
|
||||
)
|
||||
|
||||
args = completion_message.tool_calls[0].arguments
|
||||
self.assertTrue(isinstance(args, dict))
|
||||
self.assertTrue(args["liquid_name"], "polyjuice")
|
||||
|
||||
async def test_tool_call_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Who is the current US President?",
|
||||
),
|
||||
],
|
||||
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
|
||||
stream=True,
|
||||
)
|
||||
iterator = InferenceTests.api.chat_completion(request)
|
||||
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
# print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ")
|
||||
events.append(chunk.event)
|
||||
|
||||
self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start)
|
||||
# last event is of type "complete"
|
||||
self.assertEqual(
|
||||
events[-1].event_type, ChatCompletionResponseEventType.complete
|
||||
)
|
||||
# last but one event should be eom with tool call
|
||||
self.assertEqual(
|
||||
events[-2].event_type, ChatCompletionResponseEventType.progress
|
||||
)
|
||||
self.assertEqual(events[-2].stop_reason, StopReason.end_of_message)
|
||||
self.assertEqual(events[-2].delta.content.tool_name, BuiltinTool.brave_search)
|
||||
|
||||
async def test_custom_tool_call_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Use provided function to find the boiling point of polyjuice?",
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
tools=[self.custom_tool_defn],
|
||||
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||
)
|
||||
iterator = InferenceTests.api.chat_completion(request)
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
# print(
|
||||
# f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} "
|
||||
# )
|
||||
events.append(chunk.event)
|
||||
|
||||
self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start)
|
||||
# last event is of type "complete"
|
||||
self.assertEqual(
|
||||
events[-1].event_type, ChatCompletionResponseEventType.complete
|
||||
)
|
||||
self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn)
|
||||
# last but one event should be eom with tool call
|
||||
self.assertEqual(
|
||||
events[-2].event_type, ChatCompletionResponseEventType.progress
|
||||
)
|
||||
self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn)
|
||||
self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point")
|
||||
|
||||
async def test_multi_turn(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Search the web and tell me who the "
|
||||
"44th president of the United States was",
|
||||
),
|
||||
ToolResponseMessage(
|
||||
call_id="1",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
# content='{"query": "44th president of the United States", "top_k": [{"title": "Barack Obama | The White House", "url": "https://www.whitehouse.gov/about-the-white-house/presidents/barack-obama/", "description": "<strong>Barack Obama</strong> served as the 44th President of the United States. His story is the American story \\u2014 values from the heartland, a middle-class upbringing in a strong family, hard work and education as the means of getting ahead, and the conviction that a life so blessed should be lived in service ...", "type": "search_result"}, {"title": "Barack Obama \\u2013 The White House", "url": "https://trumpwhitehouse.archives.gov/about-the-white-house/presidents/barack-obama/", "description": "After working his way through college with the help of scholarships and student loans, <strong>President Obama</strong> moved to Chicago, where he worked with a group of churches to help rebuild communities devastated by the closure of local steel plants.", "type": "search_result"}, [{"type": "video_result", "url": "https://www.instagram.com/reel/CzMZbJmObn9/", "title": "Fifteen years ago, on Nov. 4, Barack Obama was elected as ...", "description": ""}, {"type": "video_result", "url": "https://video.alexanderstreet.com/watch/the-44th-president-barack-obama?context=channel:barack-obama", "title": "The 44th President (Barack Obama) - Alexander Street, a ...", "description": "You need to enable JavaScript to run this app"}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=iyL7_2-em5k", "title": "Barack Obama for Kids | Learn about the life and contributions ...", "description": "Enjoy the videos and music you love, upload original content, and share it all with friends, family, and the world on YouTube."}, {"type": "video_result", "url": "https://www.britannica.com/video/172743/overview-Barack-Obama", "title": "President of the United States of America Barack Obama | Britannica", "description": "[NARRATOR] Barack Obama was elected the 44th president of the United States in 2008, becoming the first African American to hold the office. Obama vowed to bring change to the political system."}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=rvr2g8-5dcE", "title": "The 44th President: In His Own Words - Toughest Day | Special ...", "description": "President Obama reflects on his toughest day in the Presidency and seeing Secret Service cry for the first time. Watch the premiere of The 44th President: In..."}]]}',
|
||||
content='"Barack Obama"',
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
|
||||
)
|
||||
iterator = self.api.chat_completion(
|
||||
request.model,
|
||||
request.messages,
|
||||
stream=request.stream,
|
||||
tools=request.tools,
|
||||
)
|
||||
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
events.append(chunk.event)
|
||||
|
||||
response = ""
|
||||
for e in events[1:-1]:
|
||||
response += e.delta
|
||||
|
||||
self.assertTrue("obama" in response.lower())
|
|
@ -1,346 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.inference.ollama.config import OllamaImplConfig
|
||||
from llama_stack.inference.ollama.ollama import get_provider_impl
|
||||
|
||||
|
||||
class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
ollama_config = OllamaImplConfig(url="http://localhost:11434")
|
||||
|
||||
# setup ollama
|
||||
self.api = await get_provider_impl(ollama_config, {})
|
||||
await self.api.initialize()
|
||||
|
||||
self.custom_tool_defn = ToolDefinition(
|
||||
tool_name="get_boiling_point",
|
||||
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||
parameters={
|
||||
"liquid_name": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The name of the liquid",
|
||||
required=True,
|
||||
),
|
||||
"celcius": ToolParamDefinition(
|
||||
param_type="boolean",
|
||||
description="Whether to return the boiling point in Celcius",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
self.valid_supported_model = "Llama3.1-8B-Instruct"
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.api.shutdown()
|
||||
|
||||
async def test_text(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
iterator = self.api.chat_completion(
|
||||
request.model, request.messages, stream=request.stream
|
||||
)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
print(response.completion_message.content)
|
||||
self.assertTrue("Paris" in response.completion_message.content)
|
||||
self.assertEqual(
|
||||
response.completion_message.stop_reason, StopReason.end_of_turn
|
||||
)
|
||||
|
||||
async def test_tool_call(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Who is the current US President?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
|
||||
)
|
||||
iterator = self.api.chat_completion(request)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
self.assertEqual(completion_message.content, "")
|
||||
self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn)
|
||||
|
||||
self.assertEqual(
|
||||
len(completion_message.tool_calls), 1, completion_message.tool_calls
|
||||
)
|
||||
self.assertEqual(
|
||||
completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search
|
||||
)
|
||||
self.assertTrue(
|
||||
"president" in completion_message.tool_calls[0].arguments["query"].lower()
|
||||
)
|
||||
|
||||
async def test_code_execution(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Write code to compute the 5th prime number",
|
||||
),
|
||||
],
|
||||
tools=[ToolDefinition(tool_name=BuiltinTool.code_interpreter)],
|
||||
stream=False,
|
||||
)
|
||||
iterator = self.api.chat_completion(request)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
self.assertEqual(completion_message.content, "")
|
||||
self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn)
|
||||
|
||||
self.assertEqual(
|
||||
len(completion_message.tool_calls), 1, completion_message.tool_calls
|
||||
)
|
||||
self.assertEqual(
|
||||
completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter
|
||||
)
|
||||
code = completion_message.tool_calls[0].arguments["code"]
|
||||
self.assertTrue("def " in code.lower(), code)
|
||||
|
||||
async def test_custom_tool(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Use provided function to find the boiling point of polyjuice?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
tools=[self.custom_tool_defn],
|
||||
)
|
||||
iterator = self.api.chat_completion(request)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
self.assertEqual(completion_message.content, "")
|
||||
self.assertTrue(
|
||||
completion_message.stop_reason
|
||||
in {
|
||||
StopReason.end_of_turn,
|
||||
StopReason.end_of_message,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
len(completion_message.tool_calls), 1, completion_message.tool_calls
|
||||
)
|
||||
self.assertEqual(
|
||||
completion_message.tool_calls[0].tool_name, "get_boiling_point"
|
||||
)
|
||||
|
||||
args = completion_message.tool_calls[0].arguments
|
||||
self.assertTrue(isinstance(args, dict))
|
||||
self.assertTrue(args["liquid_name"], "polyjuice")
|
||||
|
||||
async def test_text_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
iterator = self.api.chat_completion(request)
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
# print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ")
|
||||
events.append(chunk.event)
|
||||
|
||||
response = ""
|
||||
for e in events[1:-1]:
|
||||
response += e.delta
|
||||
|
||||
self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start)
|
||||
# last event is of type "complete"
|
||||
self.assertEqual(
|
||||
events[-1].event_type, ChatCompletionResponseEventType.complete
|
||||
)
|
||||
# last but 1 event should be of type "progress"
|
||||
self.assertEqual(
|
||||
events[-2].event_type, ChatCompletionResponseEventType.progress
|
||||
)
|
||||
self.assertEqual(
|
||||
events[-2].stop_reason,
|
||||
None,
|
||||
)
|
||||
self.assertTrue("Paris" in response, response)
|
||||
|
||||
async def test_tool_call_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Using web search tell me who is the current US President?",
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
|
||||
)
|
||||
iterator = self.api.chat_completion(request)
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
events.append(chunk.event)
|
||||
|
||||
self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start)
|
||||
# last event is of type "complete"
|
||||
self.assertEqual(
|
||||
events[-1].event_type, ChatCompletionResponseEventType.complete
|
||||
)
|
||||
# last but one event should be eom with tool call
|
||||
self.assertEqual(
|
||||
events[-2].event_type, ChatCompletionResponseEventType.progress
|
||||
)
|
||||
self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn)
|
||||
self.assertEqual(events[-2].delta.content.tool_name, BuiltinTool.brave_search)
|
||||
|
||||
async def test_custom_tool_call_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Use provided function to find the boiling point of polyjuice?",
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
tools=[self.custom_tool_defn],
|
||||
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||
)
|
||||
iterator = self.api.chat_completion(request)
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
# print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ")
|
||||
events.append(chunk.event)
|
||||
|
||||
self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start)
|
||||
# last event is of type "complete"
|
||||
self.assertEqual(
|
||||
events[-1].event_type, ChatCompletionResponseEventType.complete
|
||||
)
|
||||
self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn)
|
||||
# last but one event should be eom with tool call
|
||||
self.assertEqual(
|
||||
events[-2].event_type, ChatCompletionResponseEventType.progress
|
||||
)
|
||||
self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point")
|
||||
self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn)
|
||||
|
||||
def test_resolve_ollama_model(self):
|
||||
ollama_model = self.api.resolve_ollama_model(self.valid_supported_model)
|
||||
self.assertEqual(ollama_model, "llama3.1:8b-instruct-fp16")
|
||||
|
||||
invalid_model = "Llama3.1-8B"
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, f"Unsupported model: {invalid_model}"
|
||||
):
|
||||
self.api.resolve_ollama_model(invalid_model)
|
||||
|
||||
async def test_ollama_chat_options(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
sampling_params=SamplingParams(
|
||||
sampling_strategy=SamplingStrategy.top_p,
|
||||
top_p=0.99,
|
||||
temperature=1.0,
|
||||
),
|
||||
)
|
||||
options = self.api.get_ollama_chat_options(request)
|
||||
self.assertEqual(
|
||||
options,
|
||||
{
|
||||
"temperature": 1.0,
|
||||
"top_p": 0.99,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_multi_turn(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Search the web and tell me who the "
|
||||
"44th president of the United States was",
|
||||
),
|
||||
ToolResponseMessage(
|
||||
call_id="1",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
content='{"query": "44th president of the United States", "top_k": [{"title": "Barack Obama | The White House", "url": "https://www.whitehouse.gov/about-the-white-house/presidents/barack-obama/", "description": "<strong>Barack Obama</strong> served as the 44th President of the United States. His story is the American story \\u2014 values from the heartland, a middle-class upbringing in a strong family, hard work and education as the means of getting ahead, and the conviction that a life so blessed should be lived in service ...", "type": "search_result"}, {"title": "Barack Obama \\u2013 The White House", "url": "https://trumpwhitehouse.archives.gov/about-the-white-house/presidents/barack-obama/", "description": "After working his way through college with the help of scholarships and student loans, <strong>President Obama</strong> moved to Chicago, where he worked with a group of churches to help rebuild communities devastated by the closure of local steel plants.", "type": "search_result"}, [{"type": "video_result", "url": "https://www.instagram.com/reel/CzMZbJmObn9/", "title": "Fifteen years ago, on Nov. 4, Barack Obama was elected as ...", "description": ""}, {"type": "video_result", "url": "https://video.alexanderstreet.com/watch/the-44th-president-barack-obama?context=channel:barack-obama", "title": "The 44th President (Barack Obama) - Alexander Street, a ...", "description": "You need to enable JavaScript to run this app"}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=iyL7_2-em5k", "title": "Barack Obama for Kids | Learn about the life and contributions ...", "description": "Enjoy the videos and music you love, upload original content, and share it all with friends, family, and the world on YouTube."}, {"type": "video_result", "url": "https://www.britannica.com/video/172743/overview-Barack-Obama", "title": "President of the United States of America Barack Obama | Britannica", "description": "[NARRATOR] Barack Obama was elected the 44th president of the United States in 2008, becoming the first African American to hold the office. Obama vowed to bring change to the political system."}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=rvr2g8-5dcE", "title": "The 44th President: In His Own Words - Toughest Day | Special ...", "description": "President Obama reflects on his toughest day in the Presidency and seeing Secret Service cry for the first time. Watch the premiere of The 44th President: In..."}]]}',
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
|
||||
)
|
||||
iterator = self.api.chat_completion(request)
|
||||
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
events.append(chunk.event)
|
||||
|
||||
response = ""
|
||||
for e in events[1:-1]:
|
||||
response += e.delta
|
||||
|
||||
self.assertTrue("obama" in response.lower())
|
||||
|
||||
async def test_tool_call_code_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Write code to answer this question: What is the 100th prime number?",
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
tools=[ToolDefinition(tool_name=BuiltinTool.code_interpreter)],
|
||||
)
|
||||
iterator = self.api.chat_completion(request)
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
events.append(chunk.event)
|
||||
|
||||
self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start)
|
||||
# last event is of type "complete"
|
||||
self.assertEqual(
|
||||
events[-1].event_type, ChatCompletionResponseEventType.complete
|
||||
)
|
||||
# last but one event should be eom with tool call
|
||||
self.assertEqual(
|
||||
events[-2].event_type, ChatCompletionResponseEventType.progress
|
||||
)
|
||||
self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn)
|
||||
self.assertEqual(
|
||||
events[-2].delta.content.tool_name, BuiltinTool.code_interpreter
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue