diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index 07a851e78..9963fffd8 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -17,14 +17,19 @@ class DistributionInspectConfig(BaseModel): pass -def get_provider_impl(*args, **kwargs): - return DistributionInspectImpl() +async def get_provider_impl(*args, **kwargs): + impl = DistributionInspectImpl() + await impl.initialize() + return impl class DistributionInspectImpl(Inspect): def __init__(self): pass + async def initialize(self) -> None: + pass + async def list_providers(self) -> Dict[str, List[ProviderInfo]]: ret = {} all_providers = get_provider_registry() diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 660d84fc8..2c383587c 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -20,6 +20,7 @@ class ProviderWithSpec(Provider): spec: ProviderSpec +# TODO: this code is not very straightforward to follow and needs one more round of refactoring async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: """ Does two things: @@ -134,7 +135,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An print("") impls = {} - inner_impls_by_provider_id = {f"inner-{x}": {} for x in router_apis} + inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis} for api_str, provider in sorted_providers: deps = {a: impls[a] for a in provider.spec.api_dependencies} diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index c360bcfb0..c56b33f21 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -14,14 +14,13 @@ from llama_stack.apis.safety import * # noqa: F403 class MemoryRouter(Memory): - """Routes to an provider based on the memory bank type""" + """Routes to an provider based on the memory bank identifier""" def __init__( self, routing_table: RoutingTable, ) -> None: self.routing_table = routing_table - self.bank_id_to_type = {} async def initialize(self) -> None: pass @@ -29,32 +28,14 @@ class MemoryRouter(Memory): async def shutdown(self) -> None: pass - def get_provider_from_bank_id(self, bank_id: str) -> Any: - bank_type = self.bank_id_to_type.get(bank_id) - if not bank_type: - raise ValueError(f"Could not find bank type for {bank_id}") + async def list_memory_banks(self) -> List[MemoryBankDef]: + return self.routing_table.list_memory_banks() - provider = self.routing_table.get_provider_impl(bank_type) - if not provider: - raise ValueError(f"Could not find provider for {bank_type}") - return provider + async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: + return self.routing_table.get_memory_bank(identifier) - async def create_memory_bank( - self, - name: str, - config: MemoryBankConfig, - url: Optional[URL] = None, - ) -> MemoryBank: - bank_type = config.type - bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank( - name, config, url - ) - self.bank_id_to_type[bank.bank_id] = bank_type - return bank - - async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: - provider = self.get_provider_from_bank_id(bank_id) - return await provider.get_memory_bank(bank_id) + async def register_memory_bank(self, bank: MemoryBankDef) -> None: + await self.routing_table.register_memory_bank(bank) async def insert_documents( self, @@ -62,7 +43,7 @@ class MemoryRouter(Memory): documents: List[MemoryBankDocument], ttl_seconds: Optional[int] = None, ) -> None: - return await self.get_provider_from_bank_id(bank_id).insert_documents( + return await self.routing_table.get_provider_impl(bank_id).insert_documents( bank_id, documents, ttl_seconds ) @@ -72,7 +53,7 @@ class MemoryRouter(Memory): query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: - return await self.get_provider_from_bank_id(bank_id).query_documents( + return await self.routing_table.get_provider_impl(bank_id).query_documents( bank_id, query, params ) @@ -92,6 +73,15 @@ class InferenceRouter(Inference): async def shutdown(self) -> None: pass + async def list_models(self) -> List[ModelDef]: + return self.routing_table.list_models() + + async def get_model(self, identifier: str) -> Optional[ModelDef]: + return self.routing_table.get_model(identifier) + + async def register_model(self, model: ModelDef) -> None: + await self.routing_table.register_model(model) + async def chat_completion( self, model: str, @@ -159,6 +149,15 @@ class SafetyRouter(Safety): async def shutdown(self) -> None: pass + async def list_shields(self) -> List[ShieldDef]: + return self.routing_table.list_shields() + + async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: + return self.routing_table.get_shield(shield_type) + + async def register_shield(self, shield: ShieldDef) -> None: + await self.routing_table.register_shield(shield) + async def run_shield( self, shield_type: str, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index fbc3eae32..350ab05fe 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -15,6 +15,8 @@ from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +# TODO: this routing table maintains state in memory purely. We need to +# add persistence to it when we add dynamic registration of objects. class CommonRoutingTableImpl(RoutingTable): def __init__( self, @@ -54,7 +56,7 @@ class CommonRoutingTableImpl(RoutingTable): return obj return None - def register_object(self, obj: RoutableObject) -> None: + async def register_object_common(self, obj: RoutableObject) -> None: if obj.identifier in self.routing_key_to_object: raise ValueError(f"Object `{obj.identifier}` already registered") @@ -79,7 +81,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return self.get_object_by_identifier(identifier) async def register_model(self, model: ModelDef) -> None: - await self.register_object(model) + await self.register_object_common(model) class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): @@ -93,7 +95,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): return self.get_object_by_identifier(shield_type) async def register_shield(self, shield: ShieldDef) -> None: - await self.register_object(shield) + await self.register_object_common(shield) class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): @@ -107,4 +109,4 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): return self.get_object_by_identifier(identifier) async def register_memory_bank(self, bank: MemoryBankDef) -> None: - await self.register_object(bank) + await self.register_object_common(bank) diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 9c1db4bdb..7f51894bc 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -1,445 +1,445 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import * # noqa: F403 - -import boto3 -from botocore.client import BaseClient -from botocore.config import Config - -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer - -from llama_stack.providers.utils.inference.routable import RoutableProviderForModels - -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig - - -BEDROCK_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0", - "Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0", - "Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0", -} - - -class BedrockInferenceAdapter(Inference, RoutableProviderForModels): - - @staticmethod - def _create_bedrock_client(config: BedrockConfig) -> BaseClient: - retries_config = { - k: v - for k, v in dict( - total_max_attempts=config.total_max_attempts, - mode=config.retry_mode, - ).items() - if v is not None - } - - config_args = { - k: v - for k, v in dict( - region_name=config.region_name, - retries=retries_config if retries_config else None, - connect_timeout=config.connect_timeout, - read_timeout=config.read_timeout, - ).items() - if v is not None - } - - boto3_config = Config(**config_args) - - session_args = { - k: v - for k, v in dict( - aws_access_key_id=config.aws_access_key_id, - aws_secret_access_key=config.aws_secret_access_key, - aws_session_token=config.aws_session_token, - region_name=config.region_name, - profile_name=config.profile_name, - ).items() - if v is not None - } - - boto3_session = boto3.session.Session(**session_args) - - return boto3_session.client("bedrock-runtime", config=boto3_config) - - def __init__(self, config: BedrockConfig) -> None: - RoutableProviderForModels.__init__( - self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS - ) - self._config = config - - self._client = BedrockInferenceAdapter._create_bedrock_client(config) - tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(tokenizer) - - @property - def client(self) -> BaseClient: - return self._client - - async def initialize(self) -> None: - pass - - async def shutdown(self) -> None: - self.client.close() - - async def completion( - self, - model: str, - content: InterleavedTextMedia, - sampling_params: Optional[SamplingParams] = SamplingParams(), - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: - raise NotImplementedError() - - @staticmethod - def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason: - if bedrock_stop_reason == "max_tokens": - return StopReason.out_of_tokens - return StopReason.end_of_turn - - @staticmethod - def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]: - for builtin_tool in BuiltinTool: - if builtin_tool.value == tool_name_str: - return builtin_tool - else: - return tool_name_str - - @staticmethod - def _bedrock_message_to_message(converse_api_res: Dict) -> Message: - stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( - converse_api_res["stopReason"] - ) - - bedrock_message = converse_api_res["output"]["message"] - - role = bedrock_message["role"] - contents = bedrock_message["content"] - - tool_calls = [] - text_content = [] - for content in contents: - if "toolUse" in content: - tool_use = content["toolUse"] - tool_calls.append( - ToolCall( - tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum( - tool_use["name"] - ), - arguments=tool_use["input"] if "input" in tool_use else None, - call_id=tool_use["toolUseId"], - ) - ) - elif "text" in content: - text_content.append(content["text"]) - - return CompletionMessage( - role=role, - content=text_content, - stop_reason=stop_reason, - tool_calls=tool_calls, - ) - - @staticmethod - def _messages_to_bedrock_messages( - messages: List[Message], - ) -> Tuple[List[Dict], Optional[List[Dict]]]: - bedrock_messages = [] - system_bedrock_messages = [] - - user_contents = [] - assistant_contents = None - for message in messages: - role = message.role - content_list = ( - message.content - if isinstance(message.content, list) - else [message.content] - ) - if role == "ipython" or role == "user": - if not user_contents: - user_contents = [] - - if role == "ipython": - user_contents.extend( - [ - { - "toolResult": { - "toolUseId": message.call_id, - "content": [ - {"text": content} for content in content_list - ], - } - } - ] - ) - else: - user_contents.extend( - [{"text": content} for content in content_list] - ) - - if assistant_contents: - bedrock_messages.append( - {"role": "assistant", "content": assistant_contents} - ) - assistant_contents = None - elif role == "system": - system_bedrock_messages.extend( - [{"text": content} for content in content_list] - ) - elif role == "assistant": - if not assistant_contents: - assistant_contents = [] - - assistant_contents.extend( - [ - { - "text": content, - } - for content in content_list - ] - + [ - { - "toolUse": { - "input": tool_call.arguments, - "name": ( - tool_call.tool_name - if isinstance(tool_call.tool_name, str) - else tool_call.tool_name.value - ), - "toolUseId": tool_call.call_id, - } - } - for tool_call in message.tool_calls - ] - ) - - if user_contents: - bedrock_messages.append({"role": "user", "content": user_contents}) - user_contents = None - else: - # Unknown role - pass - - if user_contents: - bedrock_messages.append({"role": "user", "content": user_contents}) - if assistant_contents: - bedrock_messages.append( - {"role": "assistant", "content": assistant_contents} - ) - - if system_bedrock_messages: - return bedrock_messages, system_bedrock_messages - - return bedrock_messages, None - - @staticmethod - def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict: - inference_config = {} - if sampling_params: - param_mapping = { - "max_tokens": "maxTokens", - "temperature": "temperature", - "top_p": "topP", - } - - for k, v in param_mapping.items(): - if getattr(sampling_params, k): - inference_config[v] = getattr(sampling_params, k) - - return inference_config - - @staticmethod - def _tool_parameters_to_input_schema( - tool_parameters: Optional[Dict[str, ToolParamDefinition]] - ) -> Dict: - input_schema = {"type": "object"} - if not tool_parameters: - return input_schema - - json_properties = {} - required = [] - for name, param in tool_parameters.items(): - json_property = { - "type": param.param_type, - } - - if param.description: - json_property["description"] = param.description - if param.required: - required.append(name) - json_properties[name] = json_property - - input_schema["properties"] = json_properties - if required: - input_schema["required"] = required - return input_schema - - @staticmethod - def _tools_to_tool_config( - tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice] - ) -> Optional[Dict]: - if not tools: - return None - - bedrock_tools = [] - for tool in tools: - tool_name = ( - tool.tool_name - if isinstance(tool.tool_name, str) - else tool.tool_name.value - ) - - tool_spec = { - "toolSpec": { - "name": tool_name, - "inputSchema": { - "json": BedrockInferenceAdapter._tool_parameters_to_input_schema( - tool.parameters - ), - }, - } - } - - if tool.description: - tool_spec["toolSpec"]["description"] = tool.description - - bedrock_tools.append(tool_spec) - tool_config = { - "tools": bedrock_tools, - } - - if tool_choice: - tool_config["toolChoice"] = ( - {"any": {}} - if tool_choice.value == ToolChoice.required - else {"auto": {}} - ) - return tool_config - - async def chat_completion( - self, - model: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), - # zero-shot tool definitions as input to the model - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ) -> ( - AsyncGenerator - ): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: - bedrock_model = self.map_to_provider_model(model) - inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( - sampling_params - ) - - tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice) - bedrock_messages, system_bedrock_messages = ( - BedrockInferenceAdapter._messages_to_bedrock_messages(messages) - ) - - converse_api_params = { - "modelId": bedrock_model, - "messages": bedrock_messages, - } - if inference_config: - converse_api_params["inferenceConfig"] = inference_config - - # Tool use is not supported in streaming mode - if tool_config and not stream: - converse_api_params["toolConfig"] = tool_config - if system_bedrock_messages: - converse_api_params["system"] = system_bedrock_messages - - if not stream: - converse_api_res = self.client.converse(**converse_api_params) - - output_message = BedrockInferenceAdapter._bedrock_message_to_message( - converse_api_res - ) - - yield ChatCompletionResponse( - completion_message=output_message, - logprobs=None, - ) - else: - converse_stream_api_res = self.client.converse_stream(**converse_api_params) - event_stream = converse_stream_api_res["stream"] - - for chunk in event_stream: - if "messageStart" in chunk: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) - ) - elif "contentBlockStart" in chunk: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=ToolCall( - tool_name=chunk["contentBlockStart"]["toolUse"][ - "name" - ], - call_id=chunk["contentBlockStart"]["toolUse"][ - "toolUseId" - ], - ), - parse_status=ToolCallParseStatus.started, - ), - ) - ) - elif "contentBlockDelta" in chunk: - if "text" in chunk["contentBlockDelta"]["delta"]: - delta = chunk["contentBlockDelta"]["delta"]["text"] - else: - delta = ToolCallDelta( - content=ToolCall( - arguments=chunk["contentBlockDelta"]["delta"][ - "toolUse" - ]["input"] - ), - parse_status=ToolCallParseStatus.success, - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - ) - ) - elif "contentBlockStop" in chunk: - # Ignored - pass - elif "messageStop" in chunk: - stop_reason = ( - BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( - chunk["messageStop"]["stopReason"] - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) - elif "metadata" in chunk: - # Ignored - pass - else: - # Ignored - pass +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import * # noqa: F403 + +import boto3 +from botocore.client import BaseClient +from botocore.config import Config + +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.tokenizer import Tokenizer + +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper + +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig + + +BEDROCK_SUPPORTED_MODELS = { + "Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0", + "Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0", + "Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0", +} + + +class BedrockInferenceAdapter(ModelRegistryHelper, Inference): + + @staticmethod + def _create_bedrock_client(config: BedrockConfig) -> BaseClient: + retries_config = { + k: v + for k, v in dict( + total_max_attempts=config.total_max_attempts, + mode=config.retry_mode, + ).items() + if v is not None + } + + config_args = { + k: v + for k, v in dict( + region_name=config.region_name, + retries=retries_config if retries_config else None, + connect_timeout=config.connect_timeout, + read_timeout=config.read_timeout, + ).items() + if v is not None + } + + boto3_config = Config(**config_args) + + session_args = { + k: v + for k, v in dict( + aws_access_key_id=config.aws_access_key_id, + aws_secret_access_key=config.aws_secret_access_key, + aws_session_token=config.aws_session_token, + region_name=config.region_name, + profile_name=config.profile_name, + ).items() + if v is not None + } + + boto3_session = boto3.session.Session(**session_args) + + return boto3_session.client("bedrock-runtime", config=boto3_config) + + def __init__(self, config: BedrockConfig) -> None: + ModelRegistryHelper.__init__( + self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS + ) + self._config = config + + self._client = BedrockInferenceAdapter._create_bedrock_client(config) + tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(tokenizer) + + @property + def client(self) -> BaseClient: + return self._client + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + self.client.close() + + async def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + raise NotImplementedError() + + @staticmethod + def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason: + if bedrock_stop_reason == "max_tokens": + return StopReason.out_of_tokens + return StopReason.end_of_turn + + @staticmethod + def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]: + for builtin_tool in BuiltinTool: + if builtin_tool.value == tool_name_str: + return builtin_tool + else: + return tool_name_str + + @staticmethod + def _bedrock_message_to_message(converse_api_res: Dict) -> Message: + stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( + converse_api_res["stopReason"] + ) + + bedrock_message = converse_api_res["output"]["message"] + + role = bedrock_message["role"] + contents = bedrock_message["content"] + + tool_calls = [] + text_content = [] + for content in contents: + if "toolUse" in content: + tool_use = content["toolUse"] + tool_calls.append( + ToolCall( + tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum( + tool_use["name"] + ), + arguments=tool_use["input"] if "input" in tool_use else None, + call_id=tool_use["toolUseId"], + ) + ) + elif "text" in content: + text_content.append(content["text"]) + + return CompletionMessage( + role=role, + content=text_content, + stop_reason=stop_reason, + tool_calls=tool_calls, + ) + + @staticmethod + def _messages_to_bedrock_messages( + messages: List[Message], + ) -> Tuple[List[Dict], Optional[List[Dict]]]: + bedrock_messages = [] + system_bedrock_messages = [] + + user_contents = [] + assistant_contents = None + for message in messages: + role = message.role + content_list = ( + message.content + if isinstance(message.content, list) + else [message.content] + ) + if role == "ipython" or role == "user": + if not user_contents: + user_contents = [] + + if role == "ipython": + user_contents.extend( + [ + { + "toolResult": { + "toolUseId": message.call_id, + "content": [ + {"text": content} for content in content_list + ], + } + } + ] + ) + else: + user_contents.extend( + [{"text": content} for content in content_list] + ) + + if assistant_contents: + bedrock_messages.append( + {"role": "assistant", "content": assistant_contents} + ) + assistant_contents = None + elif role == "system": + system_bedrock_messages.extend( + [{"text": content} for content in content_list] + ) + elif role == "assistant": + if not assistant_contents: + assistant_contents = [] + + assistant_contents.extend( + [ + { + "text": content, + } + for content in content_list + ] + + [ + { + "toolUse": { + "input": tool_call.arguments, + "name": ( + tool_call.tool_name + if isinstance(tool_call.tool_name, str) + else tool_call.tool_name.value + ), + "toolUseId": tool_call.call_id, + } + } + for tool_call in message.tool_calls + ] + ) + + if user_contents: + bedrock_messages.append({"role": "user", "content": user_contents}) + user_contents = None + else: + # Unknown role + pass + + if user_contents: + bedrock_messages.append({"role": "user", "content": user_contents}) + if assistant_contents: + bedrock_messages.append( + {"role": "assistant", "content": assistant_contents} + ) + + if system_bedrock_messages: + return bedrock_messages, system_bedrock_messages + + return bedrock_messages, None + + @staticmethod + def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict: + inference_config = {} + if sampling_params: + param_mapping = { + "max_tokens": "maxTokens", + "temperature": "temperature", + "top_p": "topP", + } + + for k, v in param_mapping.items(): + if getattr(sampling_params, k): + inference_config[v] = getattr(sampling_params, k) + + return inference_config + + @staticmethod + def _tool_parameters_to_input_schema( + tool_parameters: Optional[Dict[str, ToolParamDefinition]] + ) -> Dict: + input_schema = {"type": "object"} + if not tool_parameters: + return input_schema + + json_properties = {} + required = [] + for name, param in tool_parameters.items(): + json_property = { + "type": param.param_type, + } + + if param.description: + json_property["description"] = param.description + if param.required: + required.append(name) + json_properties[name] = json_property + + input_schema["properties"] = json_properties + if required: + input_schema["required"] = required + return input_schema + + @staticmethod + def _tools_to_tool_config( + tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice] + ) -> Optional[Dict]: + if not tools: + return None + + bedrock_tools = [] + for tool in tools: + tool_name = ( + tool.tool_name + if isinstance(tool.tool_name, str) + else tool.tool_name.value + ) + + tool_spec = { + "toolSpec": { + "name": tool_name, + "inputSchema": { + "json": BedrockInferenceAdapter._tool_parameters_to_input_schema( + tool.parameters + ), + }, + } + } + + if tool.description: + tool_spec["toolSpec"]["description"] = tool.description + + bedrock_tools.append(tool_spec) + tool_config = { + "tools": bedrock_tools, + } + + if tool_choice: + tool_config["toolChoice"] = ( + {"any": {}} + if tool_choice.value == ToolChoice.required + else {"auto": {}} + ) + return tool_config + + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + # zero-shot tool definitions as input to the model + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> ( + AsyncGenerator + ): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: + bedrock_model = self.map_to_provider_model(model) + inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( + sampling_params + ) + + tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice) + bedrock_messages, system_bedrock_messages = ( + BedrockInferenceAdapter._messages_to_bedrock_messages(messages) + ) + + converse_api_params = { + "modelId": bedrock_model, + "messages": bedrock_messages, + } + if inference_config: + converse_api_params["inferenceConfig"] = inference_config + + # Tool use is not supported in streaming mode + if tool_config and not stream: + converse_api_params["toolConfig"] = tool_config + if system_bedrock_messages: + converse_api_params["system"] = system_bedrock_messages + + if not stream: + converse_api_res = self.client.converse(**converse_api_params) + + output_message = BedrockInferenceAdapter._bedrock_message_to_message( + converse_api_res + ) + + yield ChatCompletionResponse( + completion_message=output_message, + logprobs=None, + ) + else: + converse_stream_api_res = self.client.converse_stream(**converse_api_params) + event_stream = converse_stream_api_res["stream"] + + for chunk in event_stream: + if "messageStart" in chunk: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) + ) + elif "contentBlockStart" in chunk: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content=ToolCall( + tool_name=chunk["contentBlockStart"]["toolUse"][ + "name" + ], + call_id=chunk["contentBlockStart"]["toolUse"][ + "toolUseId" + ], + ), + parse_status=ToolCallParseStatus.started, + ), + ) + ) + elif "contentBlockDelta" in chunk: + if "text" in chunk["contentBlockDelta"]["delta"]: + delta = chunk["contentBlockDelta"]["delta"]["text"] + else: + delta = ToolCallDelta( + content=ToolCall( + arguments=chunk["contentBlockDelta"]["delta"][ + "toolUse" + ]["input"] + ), + parse_status=ToolCallParseStatus.success, + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=delta, + ) + ) + elif "contentBlockStop" in chunk: + # Ignored + pass + elif "messageStop" in chunk: + stop_reason = ( + BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( + chunk["messageStop"]["stopReason"] + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + stop_reason=stop_reason, + ) + ) + elif "metadata" in chunk: + # Ignored + pass + else: + # Ignored + pass diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index f6949cbdc..061e281be 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -13,7 +13,7 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.providers.utils.inference.routable import RoutableProviderForModels +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( @@ -30,9 +30,9 @@ FIREWORKS_SUPPORTED_MODELS = { } -class FireworksInferenceAdapter(Inference, RoutableProviderForModels): +class FireworksInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: FireworksImplConfig) -> None: - RoutableProviderForModels.__init__( + ModelRegistryHelper.__init__( self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS ) self.config = config diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index bd267a5f8..bc1b3d103 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -18,7 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) -from llama_stack.providers.utils.inference.routable import RoutableProviderForModels +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper # TODO: Eventually this will move to the llama cli model list command # mapping of Model SKUs to ollama models @@ -27,12 +27,13 @@ OLLAMA_SUPPORTED_SKUS = { "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16", "Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16", + "Llama-Guard-3-8B": "xe/llamaguard3:latest", } -class OllamaInferenceAdapter(Inference, RoutableProviderForModels): +class OllamaInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, url: str) -> None: - RoutableProviderForModels.__init__( + ModelRegistryHelper.__init__( self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS ) self.url = url diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 9f73a81d1..2ee90d8e3 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -18,7 +18,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) -from llama_stack.providers.utils.inference.routable import RoutableProviderForModels +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from .config import TogetherImplConfig @@ -34,10 +34,10 @@ TOGETHER_SUPPORTED_MODELS = { class TogetherInferenceAdapter( - Inference, NeedsRequestProviderData, RoutableProviderForModels + ModelRegistryHelper, Inference, NeedsRequestProviderData ): def __init__(self, config: TogetherImplConfig) -> None: - RoutableProviderForModels.__init__( + ModelRegistryHelper.__init__( self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS ) self.config = config diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index dca4ea6fb..9c6654ad1 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -12,7 +12,6 @@ from llama_models.sku_list import resolve_model from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) @@ -25,24 +24,39 @@ from .model_parallel import LlamaModelParallelGenerator SEMAPHORE = asyncio.Semaphore(1) -class MetaReferenceInferenceImpl(Inference, RoutableProvider): +class MetaReferenceInferenceImpl(Inference): def __init__(self, config: MetaReferenceImplConfig) -> None: self.config = config model = resolve_model(config.model) if model is None: raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") self.model = model + self.registered_model_defs = [] # verify that the checkpoint actually is for this model lol async def initialize(self) -> None: self.generator = LlamaModelParallelGenerator(self.config) self.generator.start() - async def validate_routing_keys(self, routing_keys: List[str]) -> None: - assert ( - len(routing_keys) == 1 - ), f"Only one routing key is supported {routing_keys}" - assert routing_keys[0] == self.config.model + async def register_model(self, model: ModelDef) -> None: + existing = await self.get_model(model.identifier) + if existing is not None: + return + + if model.identifier != self.model.descriptor(): + raise RuntimeError( + f"Model mismatch: {model.identifier} != {self.model.descriptor()}" + ) + self.registered_model_defs = [model] + + async def list_models(self) -> List[ModelDef]: + return self.registered_model_defs + + async def get_model(self, identifier: str) -> Optional[ModelDef]: + for model in self.registered_model_defs: + if model.identifier == identifier: + return model + return None async def shutdown(self) -> None: self.generator.stop() diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index 4f592e5e0..1534971cd 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -13,7 +13,6 @@ import numpy as np from numpy.typing import NDArray from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.apis.memory import * # noqa: F403 from llama_stack.providers.utils.memory.vector_store import ( @@ -62,7 +61,7 @@ class FaissIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class FaissMemoryImpl(Memory, RoutableProvider): +class FaissMemoryImpl(Memory): def __init__(self, config: FaissImplConfig) -> None: self.config = config self.cache = {} @@ -83,7 +82,6 @@ class FaissMemoryImpl(Memory, RoutableProvider): bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) ) self.cache[memory_bank.identifier] = index - return bank async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: index = self.cache.get(identifier) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py new file mode 100644 index 000000000..dabf698d4 --- /dev/null +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict, List + +from llama_models.sku_list import resolve_model + +from llama_stack.apis.models import * # noqa: F403 + + +class ModelRegistryHelper: + + def __init__(self, stack_to_provider_models_map: Dict[str, str]): + self.stack_to_provider_models_map = stack_to_provider_models_map + self.registered_models = [] + + def map_to_provider_model(self, identifier: str) -> str: + model = resolve_model(identifier) + if not model: + raise ValueError(f"Unknown model: `{identifier}`") + + if identifier not in self.stack_to_provider_models_map: + raise ValueError( + f"Model {identifier} not found in map {self.stack_to_provider_models_map}" + ) + + return self.stack_to_provider_models_map[identifier] + + async def register_model(self, model: ModelDef) -> None: + existing = await self.get_model(model.identifier) + if existing is not None: + return + + if model.identifier not in self.stack_to_provider_models_map: + raise ValueError( + f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}" + ) + + self.registered_models.append(model) + + async def list_models(self) -> List[ModelDef]: + return self.registered_models + + async def get_model(self, identifier: str) -> Optional[ModelDef]: + for model in self.registered_models: + if model.identifier == identifier: + return model + return None diff --git a/llama_stack/providers/utils/inference/routable.py b/llama_stack/providers/utils/inference/routable.py deleted file mode 100644 index a36631208..000000000 --- a/llama_stack/providers/utils/inference/routable.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Dict, List - -from llama_models.sku_list import resolve_model - -from llama_stack.distribution.datatypes import RoutableProvider - - -class RoutableProviderForModels(RoutableProvider): - - def __init__(self, stack_to_provider_models_map: Dict[str, str]): - self.stack_to_provider_models_map = stack_to_provider_models_map - - async def validate_routing_keys(self, routing_keys: List[str]): - for routing_key in routing_keys: - if routing_key not in self.stack_to_provider_models_map: - raise ValueError( - f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}" - ) - - def map_to_provider_model(self, routing_key: str) -> str: - model = resolve_model(routing_key) - if not model: - raise ValueError(f"Unknown model: `{routing_key}`") - - if routing_key not in self.stack_to_provider_models_map: - raise ValueError( - f"Model {routing_key} not found in map {self.stack_to_provider_models_map}" - ) - - return self.stack_to_provider_models_map[routing_key]