inference registry updates

This commit is contained in:
Ashwin Bharambe 2024-10-05 22:25:48 -07:00 committed by Ashwin Bharambe
parent 4215cc9331
commit 59302a86df
12 changed files with 570 additions and 535 deletions

View file

@ -17,14 +17,19 @@ class DistributionInspectConfig(BaseModel):
pass pass
def get_provider_impl(*args, **kwargs): async def get_provider_impl(*args, **kwargs):
return DistributionInspectImpl() impl = DistributionInspectImpl()
await impl.initialize()
return impl
class DistributionInspectImpl(Inspect): class DistributionInspectImpl(Inspect):
def __init__(self): def __init__(self):
pass pass
async def initialize(self) -> None:
pass
async def list_providers(self) -> Dict[str, List[ProviderInfo]]: async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
ret = {} ret = {}
all_providers = get_provider_registry() all_providers = get_provider_registry()

View file

@ -20,6 +20,7 @@ class ProviderWithSpec(Provider):
spec: ProviderSpec spec: ProviderSpec
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
""" """
Does two things: Does two things:
@ -134,7 +135,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
print("") print("")
impls = {} impls = {}
inner_impls_by_provider_id = {f"inner-{x}": {} for x in router_apis} inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies} deps = {a: impls[a] for a in provider.spec.api_dependencies}

View file

@ -14,14 +14,13 @@ from llama_stack.apis.safety import * # noqa: F403
class MemoryRouter(Memory): class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank type""" """Routes to an provider based on the memory bank identifier"""
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
self.routing_table = routing_table self.routing_table = routing_table
self.bank_id_to_type = {}
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -29,32 +28,14 @@ class MemoryRouter(Memory):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def get_provider_from_bank_id(self, bank_id: str) -> Any: async def list_memory_banks(self) -> List[MemoryBankDef]:
bank_type = self.bank_id_to_type.get(bank_id) return self.routing_table.list_memory_banks()
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.routing_table.get_provider_impl(bank_type) async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
if not provider: return self.routing_table.get_memory_bank(identifier)
raise ValueError(f"Could not find provider for {bank_type}")
return provider
async def create_memory_bank( async def register_memory_bank(self, bank: MemoryBankDef) -> None:
self, await self.routing_table.register_memory_bank(bank)
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_type = config.type
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
name, config, url
)
self.bank_id_to_type[bank.bank_id] = bank_type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
provider = self.get_provider_from_bank_id(bank_id)
return await provider.get_memory_bank(bank_id)
async def insert_documents( async def insert_documents(
self, self,
@ -62,7 +43,7 @@ class MemoryRouter(Memory):
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
return await self.get_provider_from_bank_id(bank_id).insert_documents( return await self.routing_table.get_provider_impl(bank_id).insert_documents(
bank_id, documents, ttl_seconds bank_id, documents, ttl_seconds
) )
@ -72,7 +53,7 @@ class MemoryRouter(Memory):
query: InterleavedTextMedia, query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
return await self.get_provider_from_bank_id(bank_id).query_documents( return await self.routing_table.get_provider_impl(bank_id).query_documents(
bank_id, query, params bank_id, query, params
) )
@ -92,6 +73,15 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_models(self) -> List[ModelDef]:
return self.routing_table.list_models()
async def get_model(self, identifier: str) -> Optional[ModelDef]:
return self.routing_table.get_model(identifier)
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
async def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
@ -159,6 +149,15 @@ class SafetyRouter(Safety):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_shields(self) -> List[ShieldDef]:
return self.routing_table.list_shields()
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
return self.routing_table.get_shield(shield_type)
async def register_shield(self, shield: ShieldDef) -> None:
await self.routing_table.register_shield(shield)
async def run_shield( async def run_shield(
self, self,
shield_type: str, shield_type: str,

View file

@ -15,6 +15,8 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
# TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable): class CommonRoutingTableImpl(RoutingTable):
def __init__( def __init__(
self, self,
@ -54,7 +56,7 @@ class CommonRoutingTableImpl(RoutingTable):
return obj return obj
return None return None
def register_object(self, obj: RoutableObject) -> None: async def register_object_common(self, obj: RoutableObject) -> None:
if obj.identifier in self.routing_key_to_object: if obj.identifier in self.routing_key_to_object:
raise ValueError(f"Object `{obj.identifier}` already registered") raise ValueError(f"Object `{obj.identifier}` already registered")
@ -79,7 +81,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return self.get_object_by_identifier(identifier) return self.get_object_by_identifier(identifier)
async def register_model(self, model: ModelDef) -> None: async def register_model(self, model: ModelDef) -> None:
await self.register_object(model) await self.register_object_common(model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
@ -93,7 +95,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return self.get_object_by_identifier(shield_type) return self.get_object_by_identifier(shield_type)
async def register_shield(self, shield: ShieldDef) -> None: async def register_shield(self, shield: ShieldDef) -> None:
await self.register_object(shield) await self.register_object_common(shield)
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
@ -107,4 +109,4 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return self.get_object_by_identifier(identifier) return self.get_object_by_identifier(identifier)
async def register_memory_bank(self, bank: MemoryBankDef) -> None: async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.register_object(bank) await self.register_object_common(bank)

View file

@ -1,445 +1,445 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import * # noqa: F403 from typing import * # noqa: F403
import boto3 import boto3
from botocore.client import BaseClient from botocore.client import BaseClient
from botocore.config import Config from botocore.config import Config
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
BEDROCK_SUPPORTED_MODELS = { BEDROCK_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0", "Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0", "Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0", "Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
} }
class BedrockInferenceAdapter(Inference, RoutableProviderForModels): class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
@staticmethod @staticmethod
def _create_bedrock_client(config: BedrockConfig) -> BaseClient: def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
retries_config = { retries_config = {
k: v k: v
for k, v in dict( for k, v in dict(
total_max_attempts=config.total_max_attempts, total_max_attempts=config.total_max_attempts,
mode=config.retry_mode, mode=config.retry_mode,
).items() ).items()
if v is not None if v is not None
} }
config_args = { config_args = {
k: v k: v
for k, v in dict( for k, v in dict(
region_name=config.region_name, region_name=config.region_name,
retries=retries_config if retries_config else None, retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout, connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout, read_timeout=config.read_timeout,
).items() ).items()
if v is not None if v is not None
} }
boto3_config = Config(**config_args) boto3_config = Config(**config_args)
session_args = { session_args = {
k: v k: v
for k, v in dict( for k, v in dict(
aws_access_key_id=config.aws_access_key_id, aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key, aws_secret_access_key=config.aws_secret_access_key,
aws_session_token=config.aws_session_token, aws_session_token=config.aws_session_token,
region_name=config.region_name, region_name=config.region_name,
profile_name=config.profile_name, profile_name=config.profile_name,
).items() ).items()
if v is not None if v is not None
} }
boto3_session = boto3.session.Session(**session_args) boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config) return boto3_session.client("bedrock-runtime", config=boto3_config)
def __init__(self, config: BedrockConfig) -> None: def __init__(self, config: BedrockConfig) -> None:
RoutableProviderForModels.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
) )
self._config = config self._config = config
self._client = BedrockInferenceAdapter._create_bedrock_client(config) self._client = BedrockInferenceAdapter._create_bedrock_client(config)
tokenizer = Tokenizer.get_instance() tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer) self.formatter = ChatFormat(tokenizer)
@property @property
def client(self) -> BaseClient: def client(self) -> BaseClient:
return self._client return self._client
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client.close() self.client.close()
async def completion( async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason: def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
if bedrock_stop_reason == "max_tokens": if bedrock_stop_reason == "max_tokens":
return StopReason.out_of_tokens return StopReason.out_of_tokens
return StopReason.end_of_turn return StopReason.end_of_turn
@staticmethod @staticmethod
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]: def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
for builtin_tool in BuiltinTool: for builtin_tool in BuiltinTool:
if builtin_tool.value == tool_name_str: if builtin_tool.value == tool_name_str:
return builtin_tool return builtin_tool
else: else:
return tool_name_str return tool_name_str
@staticmethod @staticmethod
def _bedrock_message_to_message(converse_api_res: Dict) -> Message: def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
converse_api_res["stopReason"] converse_api_res["stopReason"]
) )
bedrock_message = converse_api_res["output"]["message"] bedrock_message = converse_api_res["output"]["message"]
role = bedrock_message["role"] role = bedrock_message["role"]
contents = bedrock_message["content"] contents = bedrock_message["content"]
tool_calls = [] tool_calls = []
text_content = [] text_content = []
for content in contents: for content in contents:
if "toolUse" in content: if "toolUse" in content:
tool_use = content["toolUse"] tool_use = content["toolUse"]
tool_calls.append( tool_calls.append(
ToolCall( ToolCall(
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum( tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
tool_use["name"] tool_use["name"]
), ),
arguments=tool_use["input"] if "input" in tool_use else None, arguments=tool_use["input"] if "input" in tool_use else None,
call_id=tool_use["toolUseId"], call_id=tool_use["toolUseId"],
) )
) )
elif "text" in content: elif "text" in content:
text_content.append(content["text"]) text_content.append(content["text"])
return CompletionMessage( return CompletionMessage(
role=role, role=role,
content=text_content, content=text_content,
stop_reason=stop_reason, stop_reason=stop_reason,
tool_calls=tool_calls, tool_calls=tool_calls,
) )
@staticmethod @staticmethod
def _messages_to_bedrock_messages( def _messages_to_bedrock_messages(
messages: List[Message], messages: List[Message],
) -> Tuple[List[Dict], Optional[List[Dict]]]: ) -> Tuple[List[Dict], Optional[List[Dict]]]:
bedrock_messages = [] bedrock_messages = []
system_bedrock_messages = [] system_bedrock_messages = []
user_contents = [] user_contents = []
assistant_contents = None assistant_contents = None
for message in messages: for message in messages:
role = message.role role = message.role
content_list = ( content_list = (
message.content message.content
if isinstance(message.content, list) if isinstance(message.content, list)
else [message.content] else [message.content]
) )
if role == "ipython" or role == "user": if role == "ipython" or role == "user":
if not user_contents: if not user_contents:
user_contents = [] user_contents = []
if role == "ipython": if role == "ipython":
user_contents.extend( user_contents.extend(
[ [
{ {
"toolResult": { "toolResult": {
"toolUseId": message.call_id, "toolUseId": message.call_id,
"content": [ "content": [
{"text": content} for content in content_list {"text": content} for content in content_list
], ],
} }
} }
] ]
) )
else: else:
user_contents.extend( user_contents.extend(
[{"text": content} for content in content_list] [{"text": content} for content in content_list]
) )
if assistant_contents: if assistant_contents:
bedrock_messages.append( bedrock_messages.append(
{"role": "assistant", "content": assistant_contents} {"role": "assistant", "content": assistant_contents}
) )
assistant_contents = None assistant_contents = None
elif role == "system": elif role == "system":
system_bedrock_messages.extend( system_bedrock_messages.extend(
[{"text": content} for content in content_list] [{"text": content} for content in content_list]
) )
elif role == "assistant": elif role == "assistant":
if not assistant_contents: if not assistant_contents:
assistant_contents = [] assistant_contents = []
assistant_contents.extend( assistant_contents.extend(
[ [
{ {
"text": content, "text": content,
} }
for content in content_list for content in content_list
] ]
+ [ + [
{ {
"toolUse": { "toolUse": {
"input": tool_call.arguments, "input": tool_call.arguments,
"name": ( "name": (
tool_call.tool_name tool_call.tool_name
if isinstance(tool_call.tool_name, str) if isinstance(tool_call.tool_name, str)
else tool_call.tool_name.value else tool_call.tool_name.value
), ),
"toolUseId": tool_call.call_id, "toolUseId": tool_call.call_id,
} }
} }
for tool_call in message.tool_calls for tool_call in message.tool_calls
] ]
) )
if user_contents: if user_contents:
bedrock_messages.append({"role": "user", "content": user_contents}) bedrock_messages.append({"role": "user", "content": user_contents})
user_contents = None user_contents = None
else: else:
# Unknown role # Unknown role
pass pass
if user_contents: if user_contents:
bedrock_messages.append({"role": "user", "content": user_contents}) bedrock_messages.append({"role": "user", "content": user_contents})
if assistant_contents: if assistant_contents:
bedrock_messages.append( bedrock_messages.append(
{"role": "assistant", "content": assistant_contents} {"role": "assistant", "content": assistant_contents}
) )
if system_bedrock_messages: if system_bedrock_messages:
return bedrock_messages, system_bedrock_messages return bedrock_messages, system_bedrock_messages
return bedrock_messages, None return bedrock_messages, None
@staticmethod @staticmethod
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict: def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
inference_config = {} inference_config = {}
if sampling_params: if sampling_params:
param_mapping = { param_mapping = {
"max_tokens": "maxTokens", "max_tokens": "maxTokens",
"temperature": "temperature", "temperature": "temperature",
"top_p": "topP", "top_p": "topP",
} }
for k, v in param_mapping.items(): for k, v in param_mapping.items():
if getattr(sampling_params, k): if getattr(sampling_params, k):
inference_config[v] = getattr(sampling_params, k) inference_config[v] = getattr(sampling_params, k)
return inference_config return inference_config
@staticmethod @staticmethod
def _tool_parameters_to_input_schema( def _tool_parameters_to_input_schema(
tool_parameters: Optional[Dict[str, ToolParamDefinition]] tool_parameters: Optional[Dict[str, ToolParamDefinition]]
) -> Dict: ) -> Dict:
input_schema = {"type": "object"} input_schema = {"type": "object"}
if not tool_parameters: if not tool_parameters:
return input_schema return input_schema
json_properties = {} json_properties = {}
required = [] required = []
for name, param in tool_parameters.items(): for name, param in tool_parameters.items():
json_property = { json_property = {
"type": param.param_type, "type": param.param_type,
} }
if param.description: if param.description:
json_property["description"] = param.description json_property["description"] = param.description
if param.required: if param.required:
required.append(name) required.append(name)
json_properties[name] = json_property json_properties[name] = json_property
input_schema["properties"] = json_properties input_schema["properties"] = json_properties
if required: if required:
input_schema["required"] = required input_schema["required"] = required
return input_schema return input_schema
@staticmethod @staticmethod
def _tools_to_tool_config( def _tools_to_tool_config(
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice] tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
) -> Optional[Dict]: ) -> Optional[Dict]:
if not tools: if not tools:
return None return None
bedrock_tools = [] bedrock_tools = []
for tool in tools: for tool in tools:
tool_name = ( tool_name = (
tool.tool_name tool.tool_name
if isinstance(tool.tool_name, str) if isinstance(tool.tool_name, str)
else tool.tool_name.value else tool.tool_name.value
) )
tool_spec = { tool_spec = {
"toolSpec": { "toolSpec": {
"name": tool_name, "name": tool_name,
"inputSchema": { "inputSchema": {
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema( "json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
tool.parameters tool.parameters
), ),
}, },
} }
} }
if tool.description: if tool.description:
tool_spec["toolSpec"]["description"] = tool.description tool_spec["toolSpec"]["description"] = tool.description
bedrock_tools.append(tool_spec) bedrock_tools.append(tool_spec)
tool_config = { tool_config = {
"tools": bedrock_tools, "tools": bedrock_tools,
} }
if tool_choice: if tool_choice:
tool_config["toolChoice"] = ( tool_config["toolChoice"] = (
{"any": {}} {"any": {}}
if tool_choice.value == ToolChoice.required if tool_choice.value == ToolChoice.required
else {"auto": {}} else {"auto": {}}
) )
return tool_config return tool_config
async def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> ( ) -> (
AsyncGenerator AsyncGenerator
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
bedrock_model = self.map_to_provider_model(model) bedrock_model = self.map_to_provider_model(model)
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
sampling_params sampling_params
) )
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice) tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
bedrock_messages, system_bedrock_messages = ( bedrock_messages, system_bedrock_messages = (
BedrockInferenceAdapter._messages_to_bedrock_messages(messages) BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
) )
converse_api_params = { converse_api_params = {
"modelId": bedrock_model, "modelId": bedrock_model,
"messages": bedrock_messages, "messages": bedrock_messages,
} }
if inference_config: if inference_config:
converse_api_params["inferenceConfig"] = inference_config converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode # Tool use is not supported in streaming mode
if tool_config and not stream: if tool_config and not stream:
converse_api_params["toolConfig"] = tool_config converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages: if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages converse_api_params["system"] = system_bedrock_messages
if not stream: if not stream:
converse_api_res = self.client.converse(**converse_api_params) converse_api_res = self.client.converse(**converse_api_params)
output_message = BedrockInferenceAdapter._bedrock_message_to_message( output_message = BedrockInferenceAdapter._bedrock_message_to_message(
converse_api_res converse_api_res
) )
yield ChatCompletionResponse( yield ChatCompletionResponse(
completion_message=output_message, completion_message=output_message,
logprobs=None, logprobs=None,
) )
else: else:
converse_stream_api_res = self.client.converse_stream(**converse_api_params) converse_stream_api_res = self.client.converse_stream(**converse_api_params)
event_stream = converse_stream_api_res["stream"] event_stream = converse_stream_api_res["stream"]
for chunk in event_stream: for chunk in event_stream:
if "messageStart" in chunk: if "messageStart" in chunk:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start, event_type=ChatCompletionResponseEventType.start,
delta="", delta="",
) )
) )
elif "contentBlockStart" in chunk: elif "contentBlockStart" in chunk:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content=ToolCall( content=ToolCall(
tool_name=chunk["contentBlockStart"]["toolUse"][ tool_name=chunk["contentBlockStart"]["toolUse"][
"name" "name"
], ],
call_id=chunk["contentBlockStart"]["toolUse"][ call_id=chunk["contentBlockStart"]["toolUse"][
"toolUseId" "toolUseId"
], ],
), ),
parse_status=ToolCallParseStatus.started, parse_status=ToolCallParseStatus.started,
), ),
) )
) )
elif "contentBlockDelta" in chunk: elif "contentBlockDelta" in chunk:
if "text" in chunk["contentBlockDelta"]["delta"]: if "text" in chunk["contentBlockDelta"]["delta"]:
delta = chunk["contentBlockDelta"]["delta"]["text"] delta = chunk["contentBlockDelta"]["delta"]["text"]
else: else:
delta = ToolCallDelta( delta = ToolCallDelta(
content=ToolCall( content=ToolCall(
arguments=chunk["contentBlockDelta"]["delta"][ arguments=chunk["contentBlockDelta"]["delta"][
"toolUse" "toolUse"
]["input"] ]["input"]
), ),
parse_status=ToolCallParseStatus.success, parse_status=ToolCallParseStatus.success,
) )
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=delta, delta=delta,
) )
) )
elif "contentBlockStop" in chunk: elif "contentBlockStop" in chunk:
# Ignored # Ignored
pass pass
elif "messageStop" in chunk: elif "messageStop" in chunk:
stop_reason = ( stop_reason = (
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
chunk["messageStop"]["stopReason"] chunk["messageStop"]["stopReason"]
) )
) )
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.complete,
delta="", delta="",
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
elif "metadata" in chunk: elif "metadata" in chunk:
# Ignored # Ignored
pass pass
else: else:
# Ignored # Ignored
pass pass

View file

@ -13,7 +13,7 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
@ -30,9 +30,9 @@ FIREWORKS_SUPPORTED_MODELS = {
} }
class FireworksInferenceAdapter(Inference, RoutableProviderForModels): class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
RoutableProviderForModels.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
) )
self.config = config self.config = config

View file

@ -18,7 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, augment_messages_for_tools,
) )
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
# TODO: Eventually this will move to the llama cli model list command # TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models # mapping of Model SKUs to ollama models
@ -27,12 +27,13 @@ OLLAMA_SUPPORTED_SKUS = {
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16", "Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16", "Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
} }
class OllamaInferenceAdapter(Inference, RoutableProviderForModels): class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
RoutableProviderForModels.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
) )
self.url = url self.url = url

View file

@ -18,7 +18,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, augment_messages_for_tools,
) )
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import TogetherImplConfig from .config import TogetherImplConfig
@ -34,10 +34,10 @@ TOGETHER_SUPPORTED_MODELS = {
class TogetherInferenceAdapter( class TogetherInferenceAdapter(
Inference, NeedsRequestProviderData, RoutableProviderForModels ModelRegistryHelper, Inference, NeedsRequestProviderData
): ):
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
RoutableProviderForModels.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
) )
self.config = config self.config = config

View file

@ -12,7 +12,6 @@ from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, augment_messages_for_tools,
) )
@ -25,24 +24,39 @@ from .model_parallel import LlamaModelParallelGenerator
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, RoutableProvider): class MetaReferenceInferenceImpl(Inference):
def __init__(self, config: MetaReferenceImplConfig) -> None: def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config self.config = config
model = resolve_model(config.model) model = resolve_model(config.model)
if model is None: if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model self.model = model
self.registered_model_defs = []
# verify that the checkpoint actually is for this model lol # verify that the checkpoint actually is for this model lol
async def initialize(self) -> None: async def initialize(self) -> None:
self.generator = LlamaModelParallelGenerator(self.config) self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start() self.generator.start()
async def validate_routing_keys(self, routing_keys: List[str]) -> None: async def register_model(self, model: ModelDef) -> None:
assert ( existing = await self.get_model(model.identifier)
len(routing_keys) == 1 if existing is not None:
), f"Only one routing key is supported {routing_keys}" return
assert routing_keys[0] == self.config.model
if model.identifier != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
)
self.registered_model_defs = [model]
async def list_models(self) -> List[ModelDef]:
return self.registered_model_defs
async def get_model(self, identifier: str) -> Optional[ModelDef]:
for model in self.registered_model_defs:
if model.identifier == identifier:
return model
return None
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.generator.stop() self.generator.stop()

View file

@ -13,7 +13,6 @@ import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
@ -62,7 +61,7 @@ class FaissIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory, RoutableProvider): class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None: def __init__(self, config: FaissImplConfig) -> None:
self.config = config self.config = config
self.cache = {} self.cache = {}
@ -83,7 +82,6 @@ class FaissMemoryImpl(Memory, RoutableProvider):
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
) )
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
return bank
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
index = self.cache.get(identifier) index = self.cache.get(identifier)

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import * # noqa: F403
class ModelRegistryHelper:
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map
self.registered_models = []
def map_to_provider_model(self, identifier: str) -> str:
model = resolve_model(identifier)
if not model:
raise ValueError(f"Unknown model: `{identifier}`")
if identifier not in self.stack_to_provider_models_map:
raise ValueError(
f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
)
return self.stack_to_provider_models_map[identifier]
async def register_model(self, model: ModelDef) -> None:
existing = await self.get_model(model.identifier)
if existing is not None:
return
if model.identifier not in self.stack_to_provider_models_map:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
)
self.registered_models.append(model)
async def list_models(self) -> List[ModelDef]:
return self.registered_models
async def get_model(self, identifier: str) -> Optional[ModelDef]:
for model in self.registered_models:
if model.identifier == identifier:
return model
return None

View file

@ -1,36 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List
from llama_models.sku_list import resolve_model
from llama_stack.distribution.datatypes import RoutableProvider
class RoutableProviderForModels(RoutableProvider):
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map
async def validate_routing_keys(self, routing_keys: List[str]):
for routing_key in routing_keys:
if routing_key not in self.stack_to_provider_models_map:
raise ValueError(
f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}"
)
def map_to_provider_model(self, routing_key: str) -> str:
model = resolve_model(routing_key)
if not model:
raise ValueError(f"Unknown model: `{routing_key}`")
if routing_key not in self.stack_to_provider_models_map:
raise ValueError(
f"Model {routing_key} not found in map {self.stack_to_provider_models_map}"
)
return self.stack_to_provider_models_map[routing_key]