Bump version to 0.0.24 (#94)

Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
poegej 2024-09-25 09:31:12 -07:00 committed by GitHub
parent ed8d10775a
commit 95abbf576b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 998 additions and 0 deletions

2
.gitignore vendored
View file

@ -5,4 +5,6 @@ dist
dev_requirements.txt
build
.DS_Store
.idea
*.iml
llama_stack/configs/*

View file

@ -0,0 +1,10 @@
name: local-bedrock-conda-example
distribution_spec:
description: Use Amazon Bedrock APIs.
providers:
inference: remote::bedrock
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .bedrock import BedrockInferenceAdapter
from .config import BedrockConfig
async def get_adapter_impl(config: BedrockConfig, _deps):
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
impl = BedrockInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,457 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import * # noqa: F403
import boto3
from botocore.client import BaseClient
from botocore.config import Config
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
# mapping of Model SKUs to ollama models
BEDROCK_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
"Meta-Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
"Meta-Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
}
class BedrockInferenceAdapter(Inference):
@staticmethod
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
).items()
if v is not None
}
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
).items()
if v is not None
}
boto3_config = Config(**config_args)
session_args = {
k: v
for k, v in dict(
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
aws_session_token=config.aws_session_token,
region_name=config.region_name,
profile_name=config.profile_name,
).items()
if v is not None
}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config)
def __init__(self, config: BedrockConfig) -> None:
self._config = config
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> BaseClient:
return self._client
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
self.client.close()
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
@staticmethod
def resolve_bedrock_model(model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in BEDROCK_SUPPORTED_MODELS
), (
f"Unsupported model: {model_name}, use one of the supported models: "
f"{','.join(BEDROCK_SUPPORTED_MODELS.keys())}"
)
return BEDROCK_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)
@staticmethod
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
if bedrock_stop_reason == "max_tokens":
return StopReason.out_of_tokens
return StopReason.end_of_turn
@staticmethod
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
for builtin_tool in BuiltinTool:
if builtin_tool.value == tool_name_str:
return builtin_tool
else:
return tool_name_str
@staticmethod
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
converse_api_res["stopReason"]
)
bedrock_message = converse_api_res["output"]["message"]
role = bedrock_message["role"]
contents = bedrock_message["content"]
tool_calls = []
text_content = []
for content in contents:
if "toolUse" in content:
tool_use = content["toolUse"]
tool_calls.append(
ToolCall(
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
tool_use["name"]
),
arguments=tool_use["input"] if "input" in tool_use else None,
call_id=tool_use["toolUseId"],
)
)
elif "text" in content:
text_content.append(content["text"])
return CompletionMessage(
role=role,
content=text_content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
@staticmethod
def _messages_to_bedrock_messages(
messages: List[Message],
) -> Tuple[List[Dict], Optional[List[Dict]]]:
bedrock_messages = []
system_bedrock_messages = []
user_contents = []
assistant_contents = None
for message in messages:
role = message.role
content_list = (
message.content
if isinstance(message.content, list)
else [message.content]
)
if role == "ipython" or role == "user":
if not user_contents:
user_contents = []
if role == "ipython":
user_contents.extend(
[
{
"toolResult": {
"toolUseId": message.call_id,
"content": [
{"text": content} for content in content_list
],
}
}
]
)
else:
user_contents.extend(
[{"text": content} for content in content_list]
)
if assistant_contents:
bedrock_messages.append(
{"role": "assistant", "content": assistant_contents}
)
assistant_contents = None
elif role == "system":
system_bedrock_messages.extend(
[{"text": content} for content in content_list]
)
elif role == "assistant":
if not assistant_contents:
assistant_contents = []
assistant_contents.extend(
[
{
"text": content,
}
for content in content_list
]
+ [
{
"toolUse": {
"input": tool_call.arguments,
"name": (
tool_call.tool_name
if isinstance(tool_call.tool_name, str)
else tool_call.tool_name.value
),
"toolUseId": tool_call.call_id,
}
}
for tool_call in message.tool_calls
]
)
if user_contents:
bedrock_messages.append({"role": "user", "content": user_contents})
user_contents = None
else:
# Unknown role
pass
if user_contents:
bedrock_messages.append({"role": "user", "content": user_contents})
if assistant_contents:
bedrock_messages.append(
{"role": "assistant", "content": assistant_contents}
)
if system_bedrock_messages:
return bedrock_messages, system_bedrock_messages
return bedrock_messages, None
@staticmethod
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
inference_config = {}
if sampling_params:
param_mapping = {
"max_tokens": "maxTokens",
"temperature": "temperature",
"top_p": "topP",
}
for k, v in param_mapping.items():
if getattr(sampling_params, k):
inference_config[v] = getattr(sampling_params, k)
return inference_config
@staticmethod
def _tool_parameters_to_input_schema(
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
) -> Dict:
input_schema = {"type": "object"}
if not tool_parameters:
return input_schema
json_properties = {}
required = []
for name, param in tool_parameters.items():
json_property = {
"type": param.param_type,
}
if param.description:
json_property["description"] = param.description
if param.required:
required.append(name)
json_properties[name] = json_property
input_schema["properties"] = json_properties
if required:
input_schema["required"] = required
return input_schema
@staticmethod
def _tools_to_tool_config(
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
) -> Optional[Dict]:
if not tools:
return None
bedrock_tools = []
for tool in tools:
tool_name = (
tool.tool_name
if isinstance(tool.tool_name, str)
else tool.tool_name.value
)
tool_spec = {
"toolSpec": {
"name": tool_name,
"inputSchema": {
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
tool.parameters
),
},
}
}
if tool.description:
tool_spec["toolSpec"]["description"] = tool.description
bedrock_tools.append(tool_spec)
tool_config = {
"tools": bedrock_tools,
}
if tool_choice:
tool_config["toolChoice"] = (
{"any": {}}
if tool_choice.value == ToolChoice.required
else {"auto": {}}
)
return tool_config
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> (
AsyncGenerator
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
bedrock_model = BedrockInferenceAdapter.resolve_bedrock_model(model)
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
sampling_params
)
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
bedrock_messages, system_bedrock_messages = (
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
)
converse_api_params = {
"modelId": bedrock_model,
"messages": bedrock_messages,
}
if inference_config:
converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode
if tool_config and not stream:
converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages
if not stream:
converse_api_res = self.client.converse(**converse_api_params)
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
converse_api_res
)
yield ChatCompletionResponse(
completion_message=output_message,
logprobs=None,
)
else:
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
event_stream = converse_stream_api_res["stream"]
for chunk in event_stream:
if "messageStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
elif "contentBlockStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=ToolCall(
tool_name=chunk["contentBlockStart"]["toolUse"][
"name"
],
call_id=chunk["contentBlockStart"]["toolUse"][
"toolUseId"
],
),
parse_status=ToolCallParseStatus.started,
),
)
)
elif "contentBlockDelta" in chunk:
if "text" in chunk["contentBlockDelta"]["delta"]:
delta = chunk["contentBlockDelta"]["delta"]["text"]
else:
delta = ToolCallDelta(
content=ToolCall(
arguments=chunk["contentBlockDelta"]["delta"][
"toolUse"
]["input"]
),
parse_status=ToolCallParseStatus.success,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
)
)
elif "contentBlockStop" in chunk:
# Ignored
pass
elif "messageStop" in chunk:
stop_reason = (
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
chunk["messageStop"]["stopReason"]
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
elif "metadata" in chunk:
# Ignored
pass
else:
# Ignored
pass

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import * # noqa: F403
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class BedrockConfig(BaseModel):
aws_access_key_id: Optional[str] = Field(
default=None,
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
)
aws_secret_access_key: Optional[str] = Field(
default=None,
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
)
aws_session_token: Optional[str] = Field(
default=None,
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
)
region_name: Optional[str] = Field(
default=None,
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
"Default use environment variable: AWS_DEFAULT_REGION",
)
profile_name: Optional[str] = Field(
default=None,
description="The profile name that contains credentials to use."
"Default use environment variable: AWS_PROFILE",
)
total_max_attempts: Optional[int] = Field(
default=None,
description="An integer representing the maximum number of attempts that will be made for a single request, "
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
)
retry_mode: Optional[str] = Field(
default=None,
description="A string representing the type of retries Boto3 will perform."
"Default use environment variable: AWS_RETRY_MODE",
)
connect_timeout: Optional[float] = Field(
default=60,
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
"The default is 60 seconds.",
)
read_timeout: Optional[float] = Field(
default=60,
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
"The default is 60 seconds.",
)

View file

@ -75,4 +75,15 @@ def available_providers() -> List[ProviderSpec]:
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="bedrock",
pip_packages=[
"boto3",
],
module="llama_stack.providers.adapters.inference.bedrock",
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
),
),
]

View file

@ -0,0 +1,446 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import unittest
from unittest import mock
from llama_models.llama3.api.datatypes import (
BuiltinTool,
CompletionMessage,
SamplingParams,
SamplingStrategy,
StopReason,
ToolCall,
ToolChoice,
ToolDefinition,
ToolParamDefinition,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.inference.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
)
from llama_stack.providers.adapters.inference.bedrock import get_adapter_impl
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
class BedrockInferenceTests(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
bedrock_config = BedrockConfig()
# setup Bedrock
self.api = await get_adapter_impl(bedrock_config, {})
await self.api.initialize()
self.custom_tool_defn = ToolDefinition(
tool_name="get_boiling_point",
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
parameters={
"liquid_name": ToolParamDefinition(
param_type="str",
description="The name of the liquid",
required=True,
),
"celcius": ToolParamDefinition(
param_type="boolean",
description="Whether to return the boiling point in Celcius",
required=False,
),
},
)
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct"
async def asyncTearDown(self):
await self.api.shutdown()
async def test_text(self):
with mock.patch.object(self.api.client, "converse") as mock_converse:
mock_converse.return_value = {
"ResponseMetadata": {
"RequestId": "8ad04352-cd81-4946-b811-b434e546385d",
"HTTPStatusCode": 200,
"HTTPHeaders": {},
"RetryAttempts": 0,
},
"output": {
"message": {
"role": "assistant",
"content": [{"text": "\n\nThe capital of France is Paris."}],
}
},
"stopReason": "end_turn",
"usage": {"inputTokens": 21, "outputTokens": 9, "totalTokens": 30},
"metrics": {"latencyMs": 307},
}
request = ChatCompletionRequest(
model=self.valid_supported_model,
messages=[
UserMessage(
content="What is the capital of France?",
),
],
stream=False,
)
iterator = self.api.chat_completion(
request.model,
request.messages,
request.sampling_params,
request.tools,
request.tool_choice,
request.tool_prompt_format,
request.stream,
request.logprobs,
)
async for r in iterator:
response = r
print(response.completion_message.content)
self.assertTrue("Paris" in response.completion_message.content[0])
self.assertEqual(
response.completion_message.stop_reason, StopReason.end_of_turn
)
async def test_tool_call(self):
with mock.patch.object(self.api.client, "converse") as mock_converse:
mock_converse.return_value = {
"ResponseMetadata": {
"RequestId": "ec9da6a4-656b-4343-9e1f-71dac79cbf53",
"HTTPStatusCode": 200,
"HTTPHeaders": {},
"RetryAttempts": 0,
},
"output": {
"message": {
"role": "assistant",
"content": [
{
"toolUse": {
"name": "brave_search",
"toolUseId": "tooluse_d49kUQ3rTc6K_LPM-w96MQ",
"input": {"query": "current US President"},
}
}
],
}
},
"stopReason": "end_turn",
"usage": {"inputTokens": 48, "outputTokens": 81, "totalTokens": 129},
"metrics": {"latencyMs": 1236},
}
request = ChatCompletionRequest(
model=self.valid_supported_model,
messages=[
UserMessage(
content="Who is the current US President?",
),
],
stream=False,
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
)
iterator = self.api.chat_completion(
request.model,
request.messages,
request.sampling_params,
request.tools,
request.tool_choice,
request.tool_prompt_format,
request.stream,
request.logprobs,
)
async for r in iterator:
response = r
completion_message = response.completion_message
self.assertEqual(len(completion_message.content), 0)
self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn)
self.assertEqual(
len(completion_message.tool_calls), 1, completion_message.tool_calls
)
self.assertEqual(
completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search
)
self.assertTrue(
"president"
in completion_message.tool_calls[0].arguments["query"].lower()
)
async def test_custom_tool(self):
with mock.patch.object(self.api.client, "converse") as mock_converse:
mock_converse.return_value = {
"ResponseMetadata": {
"RequestId": "243c4316-0965-4b79-a145-2d9ac6b4e9ad",
"HTTPStatusCode": 200,
"HTTPHeaders": {},
"RetryAttempts": 0,
},
"output": {
"message": {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tooluse_7DViuqxXS6exL8Yug9Apjw",
"name": "get_boiling_point",
"input": {
"liquid_name": "polyjuice",
"celcius": "True",
},
}
}
],
}
},
"stopReason": "tool_use",
"usage": {"inputTokens": 110, "outputTokens": 37, "totalTokens": 147},
"metrics": {"latencyMs": 743},
}
request = ChatCompletionRequest(
model=self.valid_supported_model,
messages=[
UserMessage(
content="Use provided function to find the boiling point of polyjuice?",
),
],
stream=False,
tools=[self.custom_tool_defn],
tool_choice=ToolChoice.required,
)
iterator = self.api.chat_completion(
request.model,
request.messages,
request.sampling_params,
request.tools,
request.tool_choice,
request.tool_prompt_format,
request.stream,
request.logprobs,
)
async for r in iterator:
response = r
completion_message = response.completion_message
self.assertEqual(len(completion_message.content), 0)
self.assertTrue(
completion_message.stop_reason
in {
StopReason.end_of_turn,
StopReason.end_of_message,
}
)
self.assertEqual(
len(completion_message.tool_calls), 1, completion_message.tool_calls
)
self.assertEqual(
completion_message.tool_calls[0].tool_name, "get_boiling_point"
)
args = completion_message.tool_calls[0].arguments
self.assertTrue(isinstance(args, dict))
self.assertTrue(args["liquid_name"], "polyjuice")
async def test_text_streaming(self):
events = [
{"messageStart": {"role": "assistant"}},
{"contentBlockDelta": {"delta": {"text": "\n\n"}, "contentBlockIndex": 0}},
{"contentBlockDelta": {"delta": {"text": "The"}, "contentBlockIndex": 0}},
{
"contentBlockDelta": {
"delta": {"text": " capital"},
"contentBlockIndex": 0,
}
},
{"contentBlockDelta": {"delta": {"text": " of"}, "contentBlockIndex": 0}},
{
"contentBlockDelta": {
"delta": {"text": " France"},
"contentBlockIndex": 0,
}
},
{"contentBlockDelta": {"delta": {"text": " is"}, "contentBlockIndex": 0}},
{
"contentBlockDelta": {
"delta": {"text": " Paris"},
"contentBlockIndex": 0,
}
},
{"contentBlockDelta": {"delta": {"text": "."}, "contentBlockIndex": 0}},
{"contentBlockDelta": {"delta": {"text": ""}, "contentBlockIndex": 0}},
{"contentBlockStop": {"contentBlockIndex": 0}},
{"messageStop": {"stopReason": "end_turn"}},
{
"metadata": {
"usage": {"inputTokens": 21, "outputTokens": 9, "totalTokens": 30},
"metrics": {"latencyMs": 1},
}
},
]
with mock.patch.object(
self.api.client, "converse_stream"
) as mock_converse_stream:
mock_converse_stream.return_value = {"stream": events}
request = ChatCompletionRequest(
model=self.valid_supported_model,
messages=[
UserMessage(
content="What is the capital of France?",
),
],
stream=True,
)
iterator = self.api.chat_completion(
request.model,
request.messages,
request.sampling_params,
request.tools,
request.tool_choice,
request.tool_prompt_format,
request.stream,
request.logprobs,
)
events = []
async for chunk in iterator:
events.append(chunk.event)
response = ""
for e in events[1:-1]:
response += e.delta
self.assertEqual(
events[0].event_type, ChatCompletionResponseEventType.start
)
# last event is of type "complete"
self.assertEqual(
events[-1].event_type, ChatCompletionResponseEventType.complete
)
# last but 1 event should be of type "progress"
self.assertEqual(
events[-2].event_type, ChatCompletionResponseEventType.progress
)
self.assertEqual(
events[-2].stop_reason,
None,
)
self.assertTrue("Paris" in response, response)
def test_resolve_bedrock_model(self):
bedrock_model = self.api.resolve_bedrock_model(self.valid_supported_model)
self.assertEqual(bedrock_model, "meta.llama3-1-8b-instruct-v1:0")
invalid_model = "Meta-Llama3.1-8B"
with self.assertRaisesRegex(
AssertionError, f"Unsupported model: {invalid_model}"
):
self.api.resolve_bedrock_model(invalid_model)
async def test_bedrock_chat_inference_config(self):
request = ChatCompletionRequest(
model=self.valid_supported_model,
messages=[
UserMessage(
content="What is the capital of France?",
),
],
stream=False,
sampling_params=SamplingParams(
sampling_strategy=SamplingStrategy.top_p,
top_p=0.99,
temperature=1.0,
),
)
options = self.api.get_bedrock_inference_config(request.sampling_params)
self.assertEqual(
options,
{
"temperature": 1.0,
"topP": 0.99,
},
)
async def test_multi_turn_non_streaming(self):
with mock.patch.object(self.api.client, "converse") as mock_converse:
mock_converse.return_value = {
"ResponseMetadata": {
"RequestId": "4171abf1-a5f4-4eee-bb12-0e472a73bdbe",
"HTTPStatusCode": 200,
"HTTPHeaders": {},
"RetryAttempts": 0,
},
"output": {
"message": {
"role": "assistant",
"content": [
{
"text": "\nThe 44th president of the United States was Barack Obama."
}
],
}
},
"stopReason": "end_turn",
"usage": {"inputTokens": 723, "outputTokens": 15, "totalTokens": 738},
"metrics": {"latencyMs": 449},
}
request = ChatCompletionRequest(
model=self.valid_supported_model,
messages=[
UserMessage(
content="Search the web and tell me who the "
"44th president of the United States was",
),
CompletionMessage(
content=[],
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
call_id="1",
tool_name=BuiltinTool.brave_search,
arguments={
"query": "44th president of the United States"
},
)
],
),
ToolResponseMessage(
call_id="1",
tool_name=BuiltinTool.brave_search,
content='{"query": "44th president of the United States", "top_k": [{"title": "Barack Obama | The White House", "url": "https://www.whitehouse.gov/about-the-white-house/presidents/barack-obama/", "description": "<strong>Barack Obama</strong> served as the 44th President of the United States. His story is the American story \\u2014 values from the heartland, a middle-class upbringing in a strong family, hard work and education as the means of getting ahead, and the conviction that a life so blessed should be lived in service ...", "type": "search_result"}, {"title": "Barack Obama \\u2013 The White House", "url": "https://trumpwhitehouse.archives.gov/about-the-white-house/presidents/barack-obama/", "description": "After working his way through college with the help of scholarships and student loans, <strong>President Obama</strong> moved to Chicago, where he worked with a group of churches to help rebuild communities devastated by the closure of local steel plants.", "type": "search_result"}, [{"type": "video_result", "url": "https://www.instagram.com/reel/CzMZbJmObn9/", "title": "Fifteen years ago, on Nov. 4, Barack Obama was elected as ...", "description": ""}, {"type": "video_result", "url": "https://video.alexanderstreet.com/watch/the-44th-president-barack-obama?context=channel:barack-obama", "title": "The 44th President (Barack Obama) - Alexander Street, a ...", "description": "You need to enable JavaScript to run this app"}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=iyL7_2-em5k", "title": "Barack Obama for Kids | Learn about the life and contributions ...", "description": "Enjoy the videos and music you love, upload original content, and share it all with friends, family, and the world on YouTube."}, {"type": "video_result", "url": "https://www.britannica.com/video/172743/overview-Barack-Obama", "title": "President of the United States of America Barack Obama | Britannica", "description": "[NARRATOR] Barack Obama was elected the 44th president of the United States in 2008, becoming the first African American to hold the office. Obama vowed to bring change to the political system."}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=rvr2g8-5dcE", "title": "The 44th President: In His Own Words - Toughest Day | Special ...", "description": "President Obama reflects on his toughest day in the Presidency and seeing Secret Service cry for the first time. Watch the premiere of The 44th President: In..."}]]}',
),
],
stream=False,
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
)
iterator = self.api.chat_completion(
request.model,
request.messages,
request.sampling_params,
request.tools,
request.tool_choice,
request.tool_prompt_format,
request.stream,
request.logprobs,
)
async for r in iterator:
response = r
completion_message = response.completion_message
self.assertEqual(len(completion_message.content), 1)
self.assertTrue(
completion_message.stop_reason
in {
StopReason.end_of_turn,
StopReason.end_of_message,
}
)
self.assertTrue("obama" in completion_message.content[0].lower())