mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Bump version to 0.0.24 (#94)
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
ed8d10775a
commit
95abbf576b
7 changed files with 998 additions and 0 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -5,4 +5,6 @@ dist
|
|||
dev_requirements.txt
|
||||
build
|
||||
.DS_Store
|
||||
.idea
|
||||
*.iml
|
||||
llama_stack/configs/*
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
name: local-bedrock-conda-example
|
||||
distribution_spec:
|
||||
description: Use Amazon Bedrock APIs.
|
||||
providers:
|
||||
inference: remote::bedrock
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
17
llama_stack/providers/adapters/inference/bedrock/__init__.py
Normal file
17
llama_stack/providers/adapters/inference/bedrock/__init__.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from .bedrock import BedrockInferenceAdapter
|
||||
from .config import BedrockConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BedrockConfig, _deps):
|
||||
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = BedrockInferenceAdapter(config)
|
||||
|
||||
await impl.initialize()
|
||||
|
||||
return impl
|
457
llama_stack/providers/adapters/inference/bedrock/bedrock.py
Normal file
457
llama_stack/providers/adapters/inference/bedrock/bedrock.py
Normal file
|
@ -0,0 +1,457 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import * # noqa: F403
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
|
||||
# mapping of Model SKUs to ollama models
|
||||
BEDROCK_SUPPORTED_MODELS = {
|
||||
"Meta-Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"Meta-Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"Meta-Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
||||
}
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(Inference):
|
||||
|
||||
@staticmethod
|
||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
|
||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
self._config = config
|
||||
|
||||
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def resolve_bedrock_model(model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert (
|
||||
model is not None
|
||||
and model.descriptor(shorten_default_variant=True)
|
||||
in BEDROCK_SUPPORTED_MODELS
|
||||
), (
|
||||
f"Unsupported model: {model_name}, use one of the supported models: "
|
||||
f"{','.join(BEDROCK_SUPPORTED_MODELS.keys())}"
|
||||
)
|
||||
|
||||
return BEDROCK_SUPPORTED_MODELS.get(
|
||||
model.descriptor(shorten_default_variant=True)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
||||
if bedrock_stop_reason == "max_tokens":
|
||||
return StopReason.out_of_tokens
|
||||
return StopReason.end_of_turn
|
||||
|
||||
@staticmethod
|
||||
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
||||
for builtin_tool in BuiltinTool:
|
||||
if builtin_tool.value == tool_name_str:
|
||||
return builtin_tool
|
||||
else:
|
||||
return tool_name_str
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
||||
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
converse_api_res["stopReason"]
|
||||
)
|
||||
|
||||
bedrock_message = converse_api_res["output"]["message"]
|
||||
|
||||
role = bedrock_message["role"]
|
||||
contents = bedrock_message["content"]
|
||||
|
||||
tool_calls = []
|
||||
text_content = []
|
||||
for content in contents:
|
||||
if "toolUse" in content:
|
||||
tool_use = content["toolUse"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
||||
tool_use["name"]
|
||||
),
|
||||
arguments=tool_use["input"] if "input" in tool_use else None,
|
||||
call_id=tool_use["toolUseId"],
|
||||
)
|
||||
)
|
||||
elif "text" in content:
|
||||
text_content.append(content["text"])
|
||||
|
||||
return CompletionMessage(
|
||||
role=role,
|
||||
content=text_content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _messages_to_bedrock_messages(
|
||||
messages: List[Message],
|
||||
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
||||
bedrock_messages = []
|
||||
system_bedrock_messages = []
|
||||
|
||||
user_contents = []
|
||||
assistant_contents = None
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content_list = (
|
||||
message.content
|
||||
if isinstance(message.content, list)
|
||||
else [message.content]
|
||||
)
|
||||
if role == "ipython" or role == "user":
|
||||
if not user_contents:
|
||||
user_contents = []
|
||||
|
||||
if role == "ipython":
|
||||
user_contents.extend(
|
||||
[
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message.call_id,
|
||||
"content": [
|
||||
{"text": content} for content in content_list
|
||||
],
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
user_contents.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
assistant_contents = None
|
||||
elif role == "system":
|
||||
system_bedrock_messages.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
elif role == "assistant":
|
||||
if not assistant_contents:
|
||||
assistant_contents = []
|
||||
|
||||
assistant_contents.extend(
|
||||
[
|
||||
{
|
||||
"text": content,
|
||||
}
|
||||
for content in content_list
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"toolUse": {
|
||||
"input": tool_call.arguments,
|
||||
"name": (
|
||||
tool_call.tool_name
|
||||
if isinstance(tool_call.tool_name, str)
|
||||
else tool_call.tool_name.value
|
||||
),
|
||||
"toolUseId": tool_call.call_id,
|
||||
}
|
||||
}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
)
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
user_contents = None
|
||||
else:
|
||||
# Unknown role
|
||||
pass
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
|
||||
if system_bedrock_messages:
|
||||
return bedrock_messages, system_bedrock_messages
|
||||
|
||||
return bedrock_messages, None
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
||||
inference_config = {}
|
||||
if sampling_params:
|
||||
param_mapping = {
|
||||
"max_tokens": "maxTokens",
|
||||
"temperature": "temperature",
|
||||
"top_p": "topP",
|
||||
}
|
||||
|
||||
for k, v in param_mapping.items():
|
||||
if getattr(sampling_params, k):
|
||||
inference_config[v] = getattr(sampling_params, k)
|
||||
|
||||
return inference_config
|
||||
|
||||
@staticmethod
|
||||
def _tool_parameters_to_input_schema(
|
||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||
) -> Dict:
|
||||
input_schema = {"type": "object"}
|
||||
if not tool_parameters:
|
||||
return input_schema
|
||||
|
||||
json_properties = {}
|
||||
required = []
|
||||
for name, param in tool_parameters.items():
|
||||
json_property = {
|
||||
"type": param.param_type,
|
||||
}
|
||||
|
||||
if param.description:
|
||||
json_property["description"] = param.description
|
||||
if param.required:
|
||||
required.append(name)
|
||||
json_properties[name] = json_property
|
||||
|
||||
input_schema["properties"] = json_properties
|
||||
if required:
|
||||
input_schema["required"] = required
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def _tools_to_tool_config(
|
||||
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
||||
) -> Optional[Dict]:
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
tool_name = (
|
||||
tool.tool_name
|
||||
if isinstance(tool.tool_name, str)
|
||||
else tool.tool_name.value
|
||||
)
|
||||
|
||||
tool_spec = {
|
||||
"toolSpec": {
|
||||
"name": tool_name,
|
||||
"inputSchema": {
|
||||
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
||||
tool.parameters
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if tool.description:
|
||||
tool_spec["toolSpec"]["description"] = tool.description
|
||||
|
||||
bedrock_tools.append(tool_spec)
|
||||
tool_config = {
|
||||
"tools": bedrock_tools,
|
||||
}
|
||||
|
||||
if tool_choice:
|
||||
tool_config["toolChoice"] = (
|
||||
{"any": {}}
|
||||
if tool_choice.value == ToolChoice.required
|
||||
else {"auto": {}}
|
||||
)
|
||||
return tool_config
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> (
|
||||
AsyncGenerator
|
||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
bedrock_model = BedrockInferenceAdapter.resolve_bedrock_model(model)
|
||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||
sampling_params
|
||||
)
|
||||
|
||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
||||
bedrock_messages, system_bedrock_messages = (
|
||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
||||
)
|
||||
|
||||
converse_api_params = {
|
||||
"modelId": bedrock_model,
|
||||
"messages": bedrock_messages,
|
||||
}
|
||||
if inference_config:
|
||||
converse_api_params["inferenceConfig"] = inference_config
|
||||
|
||||
# Tool use is not supported in streaming mode
|
||||
if tool_config and not stream:
|
||||
converse_api_params["toolConfig"] = tool_config
|
||||
if system_bedrock_messages:
|
||||
converse_api_params["system"] = system_bedrock_messages
|
||||
|
||||
if not stream:
|
||||
converse_api_res = self.client.converse(**converse_api_params)
|
||||
|
||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||
converse_api_res
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=output_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
||||
event_stream = converse_stream_api_res["stream"]
|
||||
|
||||
for chunk in event_stream:
|
||||
if "messageStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
elif "contentBlockStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
||||
"name"
|
||||
],
|
||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||
"toolUseId"
|
||||
],
|
||||
),
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif "contentBlockDelta" in chunk:
|
||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||
else:
|
||||
delta = ToolCallDelta(
|
||||
content=ToolCall(
|
||||
arguments=chunk["contentBlockDelta"]["delta"][
|
||||
"toolUse"
|
||||
]["input"]
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
elif "contentBlockStop" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
elif "messageStop" in chunk:
|
||||
stop_reason = (
|
||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
chunk["messageStop"]["stopReason"]
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
elif "metadata" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
else:
|
||||
# Ignored
|
||||
pass
|
55
llama_stack/providers/adapters/inference/bedrock/config.py
Normal file
55
llama_stack/providers/adapters/inference/bedrock/config.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import * # noqa: F403
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BedrockConfig(BaseModel):
|
||||
aws_access_key_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
)
|
||||
aws_secret_access_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||
)
|
||||
aws_session_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||
)
|
||||
region_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||
)
|
||||
profile_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The profile name that contains credentials to use."
|
||||
"Default use environment variable: AWS_PROFILE",
|
||||
)
|
||||
total_max_attempts: Optional[int] = Field(
|
||||
default=None,
|
||||
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
||||
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
||||
)
|
||||
retry_mode: Optional[str] = Field(
|
||||
default=None,
|
||||
description="A string representing the type of retries Boto3 will perform."
|
||||
"Default use environment variable: AWS_RETRY_MODE",
|
||||
)
|
||||
connect_timeout: Optional[float] = Field(
|
||||
default=60,
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
read_timeout: Optional[float] = Field(
|
||||
default=60,
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||
"The default is 60 seconds.",
|
||||
)
|
|
@ -75,4 +75,15 @@ def available_providers() -> List[ProviderSpec]:
|
|||
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_id="bedrock",
|
||||
pip_packages=[
|
||||
"boto3",
|
||||
],
|
||||
module="llama_stack.providers.adapters.inference.bedrock",
|
||||
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
446
tests/test_bedrock_inference.py
Normal file
446
tests/test_bedrock_inference.py
Normal file
|
@ -0,0 +1,446 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
BuiltinTool,
|
||||
CompletionMessage,
|
||||
SamplingParams,
|
||||
SamplingStrategy,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseEventType,
|
||||
)
|
||||
from llama_stack.providers.adapters.inference.bedrock import get_adapter_impl
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
class BedrockInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
bedrock_config = BedrockConfig()
|
||||
|
||||
# setup Bedrock
|
||||
self.api = await get_adapter_impl(bedrock_config, {})
|
||||
await self.api.initialize()
|
||||
|
||||
self.custom_tool_defn = ToolDefinition(
|
||||
tool_name="get_boiling_point",
|
||||
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||
parameters={
|
||||
"liquid_name": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The name of the liquid",
|
||||
required=True,
|
||||
),
|
||||
"celcius": ToolParamDefinition(
|
||||
param_type="boolean",
|
||||
description="Whether to return the boiling point in Celcius",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct"
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.api.shutdown()
|
||||
|
||||
async def test_text(self):
|
||||
with mock.patch.object(self.api.client, "converse") as mock_converse:
|
||||
mock_converse.return_value = {
|
||||
"ResponseMetadata": {
|
||||
"RequestId": "8ad04352-cd81-4946-b811-b434e546385d",
|
||||
"HTTPStatusCode": 200,
|
||||
"HTTPHeaders": {},
|
||||
"RetryAttempts": 0,
|
||||
},
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"text": "\n\nThe capital of France is Paris."}],
|
||||
}
|
||||
},
|
||||
"stopReason": "end_turn",
|
||||
"usage": {"inputTokens": 21, "outputTokens": 9, "totalTokens": 30},
|
||||
"metrics": {"latencyMs": 307},
|
||||
}
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
iterator = self.api.chat_completion(
|
||||
request.model,
|
||||
request.messages,
|
||||
request.sampling_params,
|
||||
request.tools,
|
||||
request.tool_choice,
|
||||
request.tool_prompt_format,
|
||||
request.stream,
|
||||
request.logprobs,
|
||||
)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
print(response.completion_message.content)
|
||||
self.assertTrue("Paris" in response.completion_message.content[0])
|
||||
self.assertEqual(
|
||||
response.completion_message.stop_reason, StopReason.end_of_turn
|
||||
)
|
||||
|
||||
async def test_tool_call(self):
|
||||
with mock.patch.object(self.api.client, "converse") as mock_converse:
|
||||
mock_converse.return_value = {
|
||||
"ResponseMetadata": {
|
||||
"RequestId": "ec9da6a4-656b-4343-9e1f-71dac79cbf53",
|
||||
"HTTPStatusCode": 200,
|
||||
"HTTPHeaders": {},
|
||||
"RetryAttempts": 0,
|
||||
},
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"name": "brave_search",
|
||||
"toolUseId": "tooluse_d49kUQ3rTc6K_LPM-w96MQ",
|
||||
"input": {"query": "current US President"},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
"stopReason": "end_turn",
|
||||
"usage": {"inputTokens": 48, "outputTokens": 81, "totalTokens": 129},
|
||||
"metrics": {"latencyMs": 1236},
|
||||
}
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Who is the current US President?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
|
||||
)
|
||||
iterator = self.api.chat_completion(
|
||||
request.model,
|
||||
request.messages,
|
||||
request.sampling_params,
|
||||
request.tools,
|
||||
request.tool_choice,
|
||||
request.tool_prompt_format,
|
||||
request.stream,
|
||||
request.logprobs,
|
||||
)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
self.assertEqual(len(completion_message.content), 0)
|
||||
self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn)
|
||||
|
||||
self.assertEqual(
|
||||
len(completion_message.tool_calls), 1, completion_message.tool_calls
|
||||
)
|
||||
self.assertEqual(
|
||||
completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search
|
||||
)
|
||||
self.assertTrue(
|
||||
"president"
|
||||
in completion_message.tool_calls[0].arguments["query"].lower()
|
||||
)
|
||||
|
||||
async def test_custom_tool(self):
|
||||
with mock.patch.object(self.api.client, "converse") as mock_converse:
|
||||
mock_converse.return_value = {
|
||||
"ResponseMetadata": {
|
||||
"RequestId": "243c4316-0965-4b79-a145-2d9ac6b4e9ad",
|
||||
"HTTPStatusCode": 200,
|
||||
"HTTPHeaders": {},
|
||||
"RetryAttempts": 0,
|
||||
},
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_7DViuqxXS6exL8Yug9Apjw",
|
||||
"name": "get_boiling_point",
|
||||
"input": {
|
||||
"liquid_name": "polyjuice",
|
||||
"celcius": "True",
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
"stopReason": "tool_use",
|
||||
"usage": {"inputTokens": 110, "outputTokens": 37, "totalTokens": 147},
|
||||
"metrics": {"latencyMs": 743},
|
||||
}
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Use provided function to find the boiling point of polyjuice?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
tools=[self.custom_tool_defn],
|
||||
tool_choice=ToolChoice.required,
|
||||
)
|
||||
iterator = self.api.chat_completion(
|
||||
request.model,
|
||||
request.messages,
|
||||
request.sampling_params,
|
||||
request.tools,
|
||||
request.tool_choice,
|
||||
request.tool_prompt_format,
|
||||
request.stream,
|
||||
request.logprobs,
|
||||
)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
self.assertEqual(len(completion_message.content), 0)
|
||||
self.assertTrue(
|
||||
completion_message.stop_reason
|
||||
in {
|
||||
StopReason.end_of_turn,
|
||||
StopReason.end_of_message,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
len(completion_message.tool_calls), 1, completion_message.tool_calls
|
||||
)
|
||||
self.assertEqual(
|
||||
completion_message.tool_calls[0].tool_name, "get_boiling_point"
|
||||
)
|
||||
|
||||
args = completion_message.tool_calls[0].arguments
|
||||
self.assertTrue(isinstance(args, dict))
|
||||
self.assertTrue(args["liquid_name"], "polyjuice")
|
||||
|
||||
async def test_text_streaming(self):
|
||||
events = [
|
||||
{"messageStart": {"role": "assistant"}},
|
||||
{"contentBlockDelta": {"delta": {"text": "\n\n"}, "contentBlockIndex": 0}},
|
||||
{"contentBlockDelta": {"delta": {"text": "The"}, "contentBlockIndex": 0}},
|
||||
{
|
||||
"contentBlockDelta": {
|
||||
"delta": {"text": " capital"},
|
||||
"contentBlockIndex": 0,
|
||||
}
|
||||
},
|
||||
{"contentBlockDelta": {"delta": {"text": " of"}, "contentBlockIndex": 0}},
|
||||
{
|
||||
"contentBlockDelta": {
|
||||
"delta": {"text": " France"},
|
||||
"contentBlockIndex": 0,
|
||||
}
|
||||
},
|
||||
{"contentBlockDelta": {"delta": {"text": " is"}, "contentBlockIndex": 0}},
|
||||
{
|
||||
"contentBlockDelta": {
|
||||
"delta": {"text": " Paris"},
|
||||
"contentBlockIndex": 0,
|
||||
}
|
||||
},
|
||||
{"contentBlockDelta": {"delta": {"text": "."}, "contentBlockIndex": 0}},
|
||||
{"contentBlockDelta": {"delta": {"text": ""}, "contentBlockIndex": 0}},
|
||||
{"contentBlockStop": {"contentBlockIndex": 0}},
|
||||
{"messageStop": {"stopReason": "end_turn"}},
|
||||
{
|
||||
"metadata": {
|
||||
"usage": {"inputTokens": 21, "outputTokens": 9, "totalTokens": 30},
|
||||
"metrics": {"latencyMs": 1},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
self.api.client, "converse_stream"
|
||||
) as mock_converse_stream:
|
||||
mock_converse_stream.return_value = {"stream": events}
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
iterator = self.api.chat_completion(
|
||||
request.model,
|
||||
request.messages,
|
||||
request.sampling_params,
|
||||
request.tools,
|
||||
request.tool_choice,
|
||||
request.tool_prompt_format,
|
||||
request.stream,
|
||||
request.logprobs,
|
||||
)
|
||||
events = []
|
||||
async for chunk in iterator:
|
||||
events.append(chunk.event)
|
||||
|
||||
response = ""
|
||||
for e in events[1:-1]:
|
||||
response += e.delta
|
||||
|
||||
self.assertEqual(
|
||||
events[0].event_type, ChatCompletionResponseEventType.start
|
||||
)
|
||||
# last event is of type "complete"
|
||||
self.assertEqual(
|
||||
events[-1].event_type, ChatCompletionResponseEventType.complete
|
||||
)
|
||||
# last but 1 event should be of type "progress"
|
||||
self.assertEqual(
|
||||
events[-2].event_type, ChatCompletionResponseEventType.progress
|
||||
)
|
||||
self.assertEqual(
|
||||
events[-2].stop_reason,
|
||||
None,
|
||||
)
|
||||
self.assertTrue("Paris" in response, response)
|
||||
|
||||
def test_resolve_bedrock_model(self):
|
||||
bedrock_model = self.api.resolve_bedrock_model(self.valid_supported_model)
|
||||
self.assertEqual(bedrock_model, "meta.llama3-1-8b-instruct-v1:0")
|
||||
|
||||
invalid_model = "Meta-Llama3.1-8B"
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, f"Unsupported model: {invalid_model}"
|
||||
):
|
||||
self.api.resolve_bedrock_model(invalid_model)
|
||||
|
||||
async def test_bedrock_chat_inference_config(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
sampling_params=SamplingParams(
|
||||
sampling_strategy=SamplingStrategy.top_p,
|
||||
top_p=0.99,
|
||||
temperature=1.0,
|
||||
),
|
||||
)
|
||||
options = self.api.get_bedrock_inference_config(request.sampling_params)
|
||||
self.assertEqual(
|
||||
options,
|
||||
{
|
||||
"temperature": 1.0,
|
||||
"topP": 0.99,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_multi_turn_non_streaming(self):
|
||||
with mock.patch.object(self.api.client, "converse") as mock_converse:
|
||||
mock_converse.return_value = {
|
||||
"ResponseMetadata": {
|
||||
"RequestId": "4171abf1-a5f4-4eee-bb12-0e472a73bdbe",
|
||||
"HTTPStatusCode": 200,
|
||||
"HTTPHeaders": {},
|
||||
"RetryAttempts": 0,
|
||||
},
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"text": "\nThe 44th president of the United States was Barack Obama."
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
"stopReason": "end_turn",
|
||||
"usage": {"inputTokens": 723, "outputTokens": 15, "totalTokens": 738},
|
||||
"metrics": {"latencyMs": 449},
|
||||
}
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Search the web and tell me who the "
|
||||
"44th president of the United States was",
|
||||
),
|
||||
CompletionMessage(
|
||||
content=[],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="1",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
arguments={
|
||||
"query": "44th president of the United States"
|
||||
},
|
||||
)
|
||||
],
|
||||
),
|
||||
ToolResponseMessage(
|
||||
call_id="1",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
content='{"query": "44th president of the United States", "top_k": [{"title": "Barack Obama | The White House", "url": "https://www.whitehouse.gov/about-the-white-house/presidents/barack-obama/", "description": "<strong>Barack Obama</strong> served as the 44th President of the United States. His story is the American story \\u2014 values from the heartland, a middle-class upbringing in a strong family, hard work and education as the means of getting ahead, and the conviction that a life so blessed should be lived in service ...", "type": "search_result"}, {"title": "Barack Obama \\u2013 The White House", "url": "https://trumpwhitehouse.archives.gov/about-the-white-house/presidents/barack-obama/", "description": "After working his way through college with the help of scholarships and student loans, <strong>President Obama</strong> moved to Chicago, where he worked with a group of churches to help rebuild communities devastated by the closure of local steel plants.", "type": "search_result"}, [{"type": "video_result", "url": "https://www.instagram.com/reel/CzMZbJmObn9/", "title": "Fifteen years ago, on Nov. 4, Barack Obama was elected as ...", "description": ""}, {"type": "video_result", "url": "https://video.alexanderstreet.com/watch/the-44th-president-barack-obama?context=channel:barack-obama", "title": "The 44th President (Barack Obama) - Alexander Street, a ...", "description": "You need to enable JavaScript to run this app"}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=iyL7_2-em5k", "title": "Barack Obama for Kids | Learn about the life and contributions ...", "description": "Enjoy the videos and music you love, upload original content, and share it all with friends, family, and the world on YouTube."}, {"type": "video_result", "url": "https://www.britannica.com/video/172743/overview-Barack-Obama", "title": "President of the United States of America Barack Obama | Britannica", "description": "[NARRATOR] Barack Obama was elected the 44th president of the United States in 2008, becoming the first African American to hold the office. Obama vowed to bring change to the political system."}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=rvr2g8-5dcE", "title": "The 44th President: In His Own Words - Toughest Day | Special ...", "description": "President Obama reflects on his toughest day in the Presidency and seeing Secret Service cry for the first time. Watch the premiere of The 44th President: In..."}]]}',
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
|
||||
)
|
||||
iterator = self.api.chat_completion(
|
||||
request.model,
|
||||
request.messages,
|
||||
request.sampling_params,
|
||||
request.tools,
|
||||
request.tool_choice,
|
||||
request.tool_prompt_format,
|
||||
request.stream,
|
||||
request.logprobs,
|
||||
)
|
||||
async for r in iterator:
|
||||
response = r
|
||||
|
||||
completion_message = response.completion_message
|
||||
|
||||
self.assertEqual(len(completion_message.content), 1)
|
||||
self.assertTrue(
|
||||
completion_message.stop_reason
|
||||
in {
|
||||
StopReason.end_of_turn,
|
||||
StopReason.end_of_message,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertTrue("obama" in completion_message.content[0].lower())
|
Loading…
Add table
Add a link
Reference in a new issue