forked from phoenix-oss/llama-stack-mirror
## What does this PR do? This is a long-pending change and particularly important to get done now. Specifically: - we cannot "localize" (aka download) any URLs from media attachments anywhere near our modeling code. it must be done within llama-stack. - `PIL.Image` is infesting all our APIs via `ImageMedia -> InterleavedTextMedia` and that cannot be right at all. Anything in the API surface must be "naturally serializable". We need a standard `{ type: "image", image_url: "<...>" }` which is more extensible - `UserMessage`, `SystemMessage`, etc. are moved completely to llama-stack from the llama-models repository. See https://github.com/meta-llama/llama-models/pull/244 for the corresponding PR in llama-models. ## Test Plan ```bash cd llama_stack/providers/tests pytest -s -v -k "fireworks or ollama or together" inference/test_vision_inference.py pytest -s -v -k "(fireworks or ollama or together) and llama_3b" inference/test_text_inference.py pytest -s -v -k chroma memory/test_memory.py \ --env EMBEDDING_DIMENSION=384 --env CHROMA_DB_PATH=/tmp/foobar pytest -s -v -k fireworks agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct ``` Updated the client sdk (see PR ...), installed the SDK in the same environment and then ran the SDK tests: ```bash cd tests/client-sdk LLAMA_STACK_CONFIG=together pytest -s -v agents/test_agents.py LLAMA_STACK_CONFIG=ollama pytest -s -v memory/test_memory.py # this one needed a bit of hacking in the run.yaml to ensure I could register the vision model correctly INFERENCE_MODEL=llama3.2-vision:latest LLAMA_STACK_CONFIG=ollama pytest -s -v inference/test_inference.py ```
475 lines
16 KiB
Python
475 lines
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import * # noqa: F403
|
|
import json
|
|
import uuid
|
|
|
|
from botocore.client import BaseClient
|
|
from llama_models.datatypes import CoreModelId
|
|
from llama_models.llama3.api.chat_format import ChatFormat
|
|
|
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
build_model_alias,
|
|
ModelRegistryHelper,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
content_has_media,
|
|
interleaved_content_as_str,
|
|
)
|
|
|
|
from llama_stack.apis.inference import * # noqa: F403
|
|
|
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
|
|
|
|
|
MODEL_ALIASES = [
|
|
build_model_alias(
|
|
"meta.llama3-1-8b-instruct-v1:0",
|
|
CoreModelId.llama3_1_8b_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"meta.llama3-1-70b-instruct-v1:0",
|
|
CoreModelId.llama3_1_70b_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"meta.llama3-1-405b-instruct-v1:0",
|
|
CoreModelId.llama3_1_405b_instruct.value,
|
|
),
|
|
]
|
|
|
|
|
|
# NOTE: this is not quite tested after the recent refactors
|
|
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|
def __init__(self, config: BedrockConfig) -> None:
|
|
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
|
self._config = config
|
|
|
|
self._client = create_bedrock_client(config)
|
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
|
|
@property
|
|
def client(self) -> BaseClient:
|
|
return self._client
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
self.client.close()
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> AsyncGenerator:
|
|
raise NotImplementedError()
|
|
|
|
@staticmethod
|
|
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
|
if bedrock_stop_reason == "max_tokens":
|
|
return StopReason.out_of_tokens
|
|
return StopReason.end_of_turn
|
|
|
|
@staticmethod
|
|
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
|
for builtin_tool in BuiltinTool:
|
|
if builtin_tool.value == tool_name_str:
|
|
return builtin_tool
|
|
else:
|
|
return tool_name_str
|
|
|
|
@staticmethod
|
|
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
|
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
|
converse_api_res["stopReason"]
|
|
)
|
|
|
|
bedrock_message = converse_api_res["output"]["message"]
|
|
|
|
role = bedrock_message["role"]
|
|
contents = bedrock_message["content"]
|
|
|
|
tool_calls = []
|
|
text_content = ""
|
|
for content in contents:
|
|
if "toolUse" in content:
|
|
tool_use = content["toolUse"]
|
|
tool_calls.append(
|
|
ToolCall(
|
|
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
|
tool_use["name"]
|
|
),
|
|
arguments=tool_use["input"] if "input" in tool_use else None,
|
|
call_id=tool_use["toolUseId"],
|
|
)
|
|
)
|
|
elif "text" in content:
|
|
text_content += content["text"]
|
|
|
|
return CompletionMessage(
|
|
role=role,
|
|
content=text_content,
|
|
stop_reason=stop_reason,
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
@staticmethod
|
|
def _messages_to_bedrock_messages(
|
|
messages: List[Message],
|
|
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
|
bedrock_messages = []
|
|
system_bedrock_messages = []
|
|
|
|
user_contents = []
|
|
assistant_contents = None
|
|
for message in messages:
|
|
role = message.role
|
|
content_list = (
|
|
message.content
|
|
if isinstance(message.content, list)
|
|
else [message.content]
|
|
)
|
|
if role == "ipython" or role == "user":
|
|
if not user_contents:
|
|
user_contents = []
|
|
|
|
if role == "ipython":
|
|
user_contents.extend(
|
|
[
|
|
{
|
|
"toolResult": {
|
|
"toolUseId": message.call_id or str(uuid.uuid4()),
|
|
"content": [
|
|
{"text": content} for content in content_list
|
|
],
|
|
}
|
|
}
|
|
]
|
|
)
|
|
else:
|
|
user_contents.extend(
|
|
[{"text": content} for content in content_list]
|
|
)
|
|
|
|
if assistant_contents:
|
|
bedrock_messages.append(
|
|
{"role": "assistant", "content": assistant_contents}
|
|
)
|
|
assistant_contents = None
|
|
elif role == "system":
|
|
system_bedrock_messages.extend(
|
|
[{"text": content} for content in content_list]
|
|
)
|
|
elif role == "assistant":
|
|
if not assistant_contents:
|
|
assistant_contents = []
|
|
|
|
assistant_contents.extend(
|
|
[
|
|
{
|
|
"text": content,
|
|
}
|
|
for content in content_list
|
|
]
|
|
+ [
|
|
{
|
|
"toolUse": {
|
|
"input": tool_call.arguments,
|
|
"name": (
|
|
tool_call.tool_name
|
|
if isinstance(tool_call.tool_name, str)
|
|
else tool_call.tool_name.value
|
|
),
|
|
"toolUseId": tool_call.call_id,
|
|
}
|
|
}
|
|
for tool_call in message.tool_calls
|
|
]
|
|
)
|
|
|
|
if user_contents:
|
|
bedrock_messages.append({"role": "user", "content": user_contents})
|
|
user_contents = None
|
|
else:
|
|
# Unknown role
|
|
pass
|
|
|
|
if user_contents:
|
|
bedrock_messages.append({"role": "user", "content": user_contents})
|
|
if assistant_contents:
|
|
bedrock_messages.append(
|
|
{"role": "assistant", "content": assistant_contents}
|
|
)
|
|
|
|
if system_bedrock_messages:
|
|
return bedrock_messages, system_bedrock_messages
|
|
|
|
return bedrock_messages, None
|
|
|
|
@staticmethod
|
|
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
|
inference_config = {}
|
|
if sampling_params:
|
|
param_mapping = {
|
|
"max_tokens": "maxTokens",
|
|
"temperature": "temperature",
|
|
"top_p": "topP",
|
|
}
|
|
|
|
for k, v in param_mapping.items():
|
|
if getattr(sampling_params, k):
|
|
inference_config[v] = getattr(sampling_params, k)
|
|
|
|
return inference_config
|
|
|
|
@staticmethod
|
|
def _tool_parameters_to_input_schema(
|
|
tool_parameters: Optional[Dict[str, ToolParamDefinition]],
|
|
) -> Dict:
|
|
input_schema = {"type": "object"}
|
|
if not tool_parameters:
|
|
return input_schema
|
|
|
|
json_properties = {}
|
|
required = []
|
|
for name, param in tool_parameters.items():
|
|
json_property = {
|
|
"type": param.param_type,
|
|
}
|
|
|
|
if param.description:
|
|
json_property["description"] = param.description
|
|
if param.required:
|
|
required.append(name)
|
|
json_properties[name] = json_property
|
|
|
|
input_schema["properties"] = json_properties
|
|
if required:
|
|
input_schema["required"] = required
|
|
return input_schema
|
|
|
|
@staticmethod
|
|
def _tools_to_tool_config(
|
|
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
|
) -> Optional[Dict]:
|
|
if not tools:
|
|
return None
|
|
|
|
bedrock_tools = []
|
|
for tool in tools:
|
|
tool_name = (
|
|
tool.tool_name
|
|
if isinstance(tool.tool_name, str)
|
|
else tool.tool_name.value
|
|
)
|
|
|
|
tool_spec = {
|
|
"toolSpec": {
|
|
"name": tool_name,
|
|
"inputSchema": {
|
|
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
|
tool.parameters
|
|
),
|
|
},
|
|
}
|
|
}
|
|
|
|
if tool.description:
|
|
tool_spec["toolSpec"]["description"] = tool.description
|
|
|
|
bedrock_tools.append(tool_spec)
|
|
tool_config = {
|
|
"tools": bedrock_tools,
|
|
}
|
|
|
|
if tool_choice:
|
|
tool_config["toolChoice"] = (
|
|
{"any": {}}
|
|
if tool_choice.value == ToolChoice.required
|
|
else {"auto": {}}
|
|
)
|
|
return tool_config
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[
|
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
]:
|
|
model = await self.model_store.get_model(model_id)
|
|
request = ChatCompletionRequest(
|
|
model=model.provider_resource_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
tool_choice=tool_choice,
|
|
tool_prompt_format=tool_prompt_format,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
|
|
if stream:
|
|
return self._stream_chat_completion(request)
|
|
else:
|
|
return await self._nonstream_chat_completion(request)
|
|
|
|
async def _nonstream_chat_completion(
|
|
self, request: ChatCompletionRequest
|
|
) -> ChatCompletionResponse:
|
|
params = self._get_params_for_chat_completion(request)
|
|
converse_api_res = self.client.converse(**params)
|
|
|
|
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
|
converse_api_res
|
|
)
|
|
|
|
return ChatCompletionResponse(
|
|
completion_message=output_message,
|
|
logprobs=None,
|
|
)
|
|
|
|
async def _stream_chat_completion(
|
|
self, request: ChatCompletionRequest
|
|
) -> AsyncGenerator:
|
|
params = self._get_params_for_chat_completion(request)
|
|
converse_stream_api_res = self.client.converse_stream(**params)
|
|
event_stream = converse_stream_api_res["stream"]
|
|
|
|
for chunk in event_stream:
|
|
if "messageStart" in chunk:
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.start,
|
|
delta="",
|
|
)
|
|
)
|
|
elif "contentBlockStart" in chunk:
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=ToolCallDelta(
|
|
content=ToolCall(
|
|
tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
|
|
call_id=chunk["contentBlockStart"]["toolUse"][
|
|
"toolUseId"
|
|
],
|
|
),
|
|
parse_status=ToolCallParseStatus.started,
|
|
),
|
|
)
|
|
)
|
|
elif "contentBlockDelta" in chunk:
|
|
if "text" in chunk["contentBlockDelta"]["delta"]:
|
|
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
|
else:
|
|
delta = ToolCallDelta(
|
|
content=ToolCall(
|
|
arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
|
|
"input"
|
|
]
|
|
),
|
|
parse_status=ToolCallParseStatus.success,
|
|
)
|
|
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=delta,
|
|
)
|
|
)
|
|
elif "contentBlockStop" in chunk:
|
|
# Ignored
|
|
pass
|
|
elif "messageStop" in chunk:
|
|
stop_reason = (
|
|
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
|
chunk["messageStop"]["stopReason"]
|
|
)
|
|
)
|
|
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.complete,
|
|
delta="",
|
|
stop_reason=stop_reason,
|
|
)
|
|
)
|
|
elif "metadata" in chunk:
|
|
# Ignored
|
|
pass
|
|
else:
|
|
# Ignored
|
|
pass
|
|
|
|
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
|
bedrock_model = request.model
|
|
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
|
request.sampling_params
|
|
)
|
|
|
|
tool_config = BedrockInferenceAdapter._tools_to_tool_config(
|
|
request.tools, request.tool_choice
|
|
)
|
|
bedrock_messages, system_bedrock_messages = (
|
|
BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
|
|
)
|
|
|
|
converse_api_params = {
|
|
"modelId": bedrock_model,
|
|
"messages": bedrock_messages,
|
|
}
|
|
if inference_config:
|
|
converse_api_params["inferenceConfig"] = inference_config
|
|
|
|
# Tool use is not supported in streaming mode
|
|
if tool_config and not request.stream:
|
|
converse_api_params["toolConfig"] = tool_config
|
|
if system_bedrock_messages:
|
|
converse_api_params["system"] = system_bedrock_messages
|
|
|
|
return converse_api_params
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[InterleavedContent],
|
|
) -> EmbeddingsResponse:
|
|
model = await self.model_store.get_model(model_id)
|
|
embeddings = []
|
|
for content in contents:
|
|
assert not content_has_media(
|
|
content
|
|
), "Bedrock does not support media for embeddings"
|
|
input_text = interleaved_content_as_str(content)
|
|
input_body = {"inputText": input_text}
|
|
body = json.dumps(input_body)
|
|
response = self.client.invoke_model(
|
|
body=body,
|
|
modelId=model.provider_resource_id,
|
|
accept="application/json",
|
|
contentType="application/json",
|
|
)
|
|
response_body = json.loads(response.get("body").read())
|
|
embeddings.append(response_body.get("embedding"))
|
|
return EmbeddingsResponse(embeddings=embeddings)
|