mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
inference registry updates
This commit is contained in:
parent
4215cc9331
commit
59302a86df
12 changed files with 570 additions and 535 deletions
|
@ -17,14 +17,19 @@ class DistributionInspectConfig(BaseModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_provider_impl(*args, **kwargs):
|
async def get_provider_impl(*args, **kwargs):
|
||||||
return DistributionInspectImpl()
|
impl = DistributionInspectImpl()
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
class DistributionInspectImpl(Inspect):
|
class DistributionInspectImpl(Inspect):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
|
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
|
||||||
ret = {}
|
ret = {}
|
||||||
all_providers = get_provider_registry()
|
all_providers = get_provider_registry()
|
||||||
|
|
|
@ -20,6 +20,7 @@ class ProviderWithSpec(Provider):
|
||||||
spec: ProviderSpec
|
spec: ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||||
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Does two things:
|
Does two things:
|
||||||
|
@ -134,7 +135,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
impls = {}
|
impls = {}
|
||||||
inner_impls_by_provider_id = {f"inner-{x}": {} for x in router_apis}
|
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||||
|
|
||||||
|
|
|
@ -14,14 +14,13 @@ from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class MemoryRouter(Memory):
|
class MemoryRouter(Memory):
|
||||||
"""Routes to an provider based on the memory bank type"""
|
"""Routes to an provider based on the memory bank identifier"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
self.bank_id_to_type = {}
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -29,32 +28,14 @@ class MemoryRouter(Memory):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_provider_from_bank_id(self, bank_id: str) -> Any:
|
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||||
bank_type = self.bank_id_to_type.get(bank_id)
|
return self.routing_table.list_memory_banks()
|
||||||
if not bank_type:
|
|
||||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(bank_type)
|
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
||||||
if not provider:
|
return self.routing_table.get_memory_bank(identifier)
|
||||||
raise ValueError(f"Could not find provider for {bank_type}")
|
|
||||||
return provider
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
||||||
self,
|
await self.routing_table.register_memory_bank(bank)
|
||||||
name: str,
|
|
||||||
config: MemoryBankConfig,
|
|
||||||
url: Optional[URL] = None,
|
|
||||||
) -> MemoryBank:
|
|
||||||
bank_type = config.type
|
|
||||||
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
|
|
||||||
name, config, url
|
|
||||||
)
|
|
||||||
self.bank_id_to_type[bank.bank_id] = bank_type
|
|
||||||
return bank
|
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
|
||||||
provider = self.get_provider_from_bank_id(bank_id)
|
|
||||||
return await provider.get_memory_bank(bank_id)
|
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -62,7 +43,7 @@ class MemoryRouter(Memory):
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
return await self.get_provider_from_bank_id(bank_id).insert_documents(
|
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
|
||||||
bank_id, documents, ttl_seconds
|
bank_id, documents, ttl_seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -72,7 +53,7 @@ class MemoryRouter(Memory):
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedTextMedia,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
return await self.get_provider_from_bank_id(bank_id).query_documents(
|
return await self.routing_table.get_provider_impl(bank_id).query_documents(
|
||||||
bank_id, query, params
|
bank_id, query, params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -92,6 +73,15 @@ class InferenceRouter(Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def list_models(self) -> List[ModelDef]:
|
||||||
|
return self.routing_table.list_models()
|
||||||
|
|
||||||
|
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||||
|
return self.routing_table.get_model(identifier)
|
||||||
|
|
||||||
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
|
await self.routing_table.register_model(model)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -159,6 +149,15 @@ class SafetyRouter(Safety):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
|
return self.routing_table.list_shields()
|
||||||
|
|
||||||
|
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
|
||||||
|
return self.routing_table.get_shield(shield_type)
|
||||||
|
|
||||||
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
|
await self.routing_table.register_shield(shield)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_type: str,
|
shield_type: str,
|
||||||
|
|
|
@ -15,6 +15,8 @@ from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this routing table maintains state in memory purely. We need to
|
||||||
|
# add persistence to it when we add dynamic registration of objects.
|
||||||
class CommonRoutingTableImpl(RoutingTable):
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -54,7 +56,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
return obj
|
return obj
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def register_object(self, obj: RoutableObject) -> None:
|
async def register_object_common(self, obj: RoutableObject) -> None:
|
||||||
if obj.identifier in self.routing_key_to_object:
|
if obj.identifier in self.routing_key_to_object:
|
||||||
raise ValueError(f"Object `{obj.identifier}` already registered")
|
raise ValueError(f"Object `{obj.identifier}` already registered")
|
||||||
|
|
||||||
|
@ -79,7 +81,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
return self.get_object_by_identifier(identifier)
|
return self.get_object_by_identifier(identifier)
|
||||||
|
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
await self.register_object(model)
|
await self.register_object_common(model)
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
@ -93,7 +95,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
return self.get_object_by_identifier(shield_type)
|
return self.get_object_by_identifier(shield_type)
|
||||||
|
|
||||||
async def register_shield(self, shield: ShieldDef) -> None:
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
await self.register_object(shield)
|
await self.register_object_common(shield)
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
|
@ -107,4 +109,4 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
return self.get_object_by_identifier(identifier)
|
return self.get_object_by_identifier(identifier)
|
||||||
|
|
||||||
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
||||||
await self.register_object(bank)
|
await self.register_object_common(bank)
|
||||||
|
|
|
@ -1,445 +1,445 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import * # noqa: F403
|
from typing import * # noqa: F403
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from botocore.config import Config
|
from botocore.config import Config
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||||
|
|
||||||
|
|
||||||
BEDROCK_SUPPORTED_MODELS = {
|
BEDROCK_SUPPORTED_MODELS = {
|
||||||
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
||||||
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
||||||
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class BedrockInferenceAdapter(Inference, RoutableProviderForModels):
|
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||||
retries_config = {
|
retries_config = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in dict(
|
for k, v in dict(
|
||||||
total_max_attempts=config.total_max_attempts,
|
total_max_attempts=config.total_max_attempts,
|
||||||
mode=config.retry_mode,
|
mode=config.retry_mode,
|
||||||
).items()
|
).items()
|
||||||
if v is not None
|
if v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
config_args = {
|
config_args = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in dict(
|
for k, v in dict(
|
||||||
region_name=config.region_name,
|
region_name=config.region_name,
|
||||||
retries=retries_config if retries_config else None,
|
retries=retries_config if retries_config else None,
|
||||||
connect_timeout=config.connect_timeout,
|
connect_timeout=config.connect_timeout,
|
||||||
read_timeout=config.read_timeout,
|
read_timeout=config.read_timeout,
|
||||||
).items()
|
).items()
|
||||||
if v is not None
|
if v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
boto3_config = Config(**config_args)
|
boto3_config = Config(**config_args)
|
||||||
|
|
||||||
session_args = {
|
session_args = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in dict(
|
for k, v in dict(
|
||||||
aws_access_key_id=config.aws_access_key_id,
|
aws_access_key_id=config.aws_access_key_id,
|
||||||
aws_secret_access_key=config.aws_secret_access_key,
|
aws_secret_access_key=config.aws_secret_access_key,
|
||||||
aws_session_token=config.aws_session_token,
|
aws_session_token=config.aws_session_token,
|
||||||
region_name=config.region_name,
|
region_name=config.region_name,
|
||||||
profile_name=config.profile_name,
|
profile_name=config.profile_name,
|
||||||
).items()
|
).items()
|
||||||
if v is not None
|
if v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
boto3_session = boto3.session.Session(**session_args)
|
boto3_session = boto3.session.Session(**session_args)
|
||||||
|
|
||||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||||
|
|
||||||
def __init__(self, config: BedrockConfig) -> None:
|
def __init__(self, config: BedrockConfig) -> None:
|
||||||
RoutableProviderForModels.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
||||||
)
|
)
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
|
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
|
||||||
tokenizer = Tokenizer.get_instance()
|
tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(tokenizer)
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> BaseClient:
|
def client(self) -> BaseClient:
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
||||||
if bedrock_stop_reason == "max_tokens":
|
if bedrock_stop_reason == "max_tokens":
|
||||||
return StopReason.out_of_tokens
|
return StopReason.out_of_tokens
|
||||||
return StopReason.end_of_turn
|
return StopReason.end_of_turn
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
||||||
for builtin_tool in BuiltinTool:
|
for builtin_tool in BuiltinTool:
|
||||||
if builtin_tool.value == tool_name_str:
|
if builtin_tool.value == tool_name_str:
|
||||||
return builtin_tool
|
return builtin_tool
|
||||||
else:
|
else:
|
||||||
return tool_name_str
|
return tool_name_str
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
||||||
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||||
converse_api_res["stopReason"]
|
converse_api_res["stopReason"]
|
||||||
)
|
)
|
||||||
|
|
||||||
bedrock_message = converse_api_res["output"]["message"]
|
bedrock_message = converse_api_res["output"]["message"]
|
||||||
|
|
||||||
role = bedrock_message["role"]
|
role = bedrock_message["role"]
|
||||||
contents = bedrock_message["content"]
|
contents = bedrock_message["content"]
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
text_content = []
|
text_content = []
|
||||||
for content in contents:
|
for content in contents:
|
||||||
if "toolUse" in content:
|
if "toolUse" in content:
|
||||||
tool_use = content["toolUse"]
|
tool_use = content["toolUse"]
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
||||||
tool_use["name"]
|
tool_use["name"]
|
||||||
),
|
),
|
||||||
arguments=tool_use["input"] if "input" in tool_use else None,
|
arguments=tool_use["input"] if "input" in tool_use else None,
|
||||||
call_id=tool_use["toolUseId"],
|
call_id=tool_use["toolUseId"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "text" in content:
|
elif "text" in content:
|
||||||
text_content.append(content["text"])
|
text_content.append(content["text"])
|
||||||
|
|
||||||
return CompletionMessage(
|
return CompletionMessage(
|
||||||
role=role,
|
role=role,
|
||||||
content=text_content,
|
content=text_content,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _messages_to_bedrock_messages(
|
def _messages_to_bedrock_messages(
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
||||||
bedrock_messages = []
|
bedrock_messages = []
|
||||||
system_bedrock_messages = []
|
system_bedrock_messages = []
|
||||||
|
|
||||||
user_contents = []
|
user_contents = []
|
||||||
assistant_contents = None
|
assistant_contents = None
|
||||||
for message in messages:
|
for message in messages:
|
||||||
role = message.role
|
role = message.role
|
||||||
content_list = (
|
content_list = (
|
||||||
message.content
|
message.content
|
||||||
if isinstance(message.content, list)
|
if isinstance(message.content, list)
|
||||||
else [message.content]
|
else [message.content]
|
||||||
)
|
)
|
||||||
if role == "ipython" or role == "user":
|
if role == "ipython" or role == "user":
|
||||||
if not user_contents:
|
if not user_contents:
|
||||||
user_contents = []
|
user_contents = []
|
||||||
|
|
||||||
if role == "ipython":
|
if role == "ipython":
|
||||||
user_contents.extend(
|
user_contents.extend(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"toolResult": {
|
"toolResult": {
|
||||||
"toolUseId": message.call_id,
|
"toolUseId": message.call_id,
|
||||||
"content": [
|
"content": [
|
||||||
{"text": content} for content in content_list
|
{"text": content} for content in content_list
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
user_contents.extend(
|
user_contents.extend(
|
||||||
[{"text": content} for content in content_list]
|
[{"text": content} for content in content_list]
|
||||||
)
|
)
|
||||||
|
|
||||||
if assistant_contents:
|
if assistant_contents:
|
||||||
bedrock_messages.append(
|
bedrock_messages.append(
|
||||||
{"role": "assistant", "content": assistant_contents}
|
{"role": "assistant", "content": assistant_contents}
|
||||||
)
|
)
|
||||||
assistant_contents = None
|
assistant_contents = None
|
||||||
elif role == "system":
|
elif role == "system":
|
||||||
system_bedrock_messages.extend(
|
system_bedrock_messages.extend(
|
||||||
[{"text": content} for content in content_list]
|
[{"text": content} for content in content_list]
|
||||||
)
|
)
|
||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
if not assistant_contents:
|
if not assistant_contents:
|
||||||
assistant_contents = []
|
assistant_contents = []
|
||||||
|
|
||||||
assistant_contents.extend(
|
assistant_contents.extend(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"text": content,
|
"text": content,
|
||||||
}
|
}
|
||||||
for content in content_list
|
for content in content_list
|
||||||
]
|
]
|
||||||
+ [
|
+ [
|
||||||
{
|
{
|
||||||
"toolUse": {
|
"toolUse": {
|
||||||
"input": tool_call.arguments,
|
"input": tool_call.arguments,
|
||||||
"name": (
|
"name": (
|
||||||
tool_call.tool_name
|
tool_call.tool_name
|
||||||
if isinstance(tool_call.tool_name, str)
|
if isinstance(tool_call.tool_name, str)
|
||||||
else tool_call.tool_name.value
|
else tool_call.tool_name.value
|
||||||
),
|
),
|
||||||
"toolUseId": tool_call.call_id,
|
"toolUseId": tool_call.call_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for tool_call in message.tool_calls
|
for tool_call in message.tool_calls
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_contents:
|
if user_contents:
|
||||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||||
user_contents = None
|
user_contents = None
|
||||||
else:
|
else:
|
||||||
# Unknown role
|
# Unknown role
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if user_contents:
|
if user_contents:
|
||||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||||
if assistant_contents:
|
if assistant_contents:
|
||||||
bedrock_messages.append(
|
bedrock_messages.append(
|
||||||
{"role": "assistant", "content": assistant_contents}
|
{"role": "assistant", "content": assistant_contents}
|
||||||
)
|
)
|
||||||
|
|
||||||
if system_bedrock_messages:
|
if system_bedrock_messages:
|
||||||
return bedrock_messages, system_bedrock_messages
|
return bedrock_messages, system_bedrock_messages
|
||||||
|
|
||||||
return bedrock_messages, None
|
return bedrock_messages, None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
||||||
inference_config = {}
|
inference_config = {}
|
||||||
if sampling_params:
|
if sampling_params:
|
||||||
param_mapping = {
|
param_mapping = {
|
||||||
"max_tokens": "maxTokens",
|
"max_tokens": "maxTokens",
|
||||||
"temperature": "temperature",
|
"temperature": "temperature",
|
||||||
"top_p": "topP",
|
"top_p": "topP",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v in param_mapping.items():
|
for k, v in param_mapping.items():
|
||||||
if getattr(sampling_params, k):
|
if getattr(sampling_params, k):
|
||||||
inference_config[v] = getattr(sampling_params, k)
|
inference_config[v] = getattr(sampling_params, k)
|
||||||
|
|
||||||
return inference_config
|
return inference_config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _tool_parameters_to_input_schema(
|
def _tool_parameters_to_input_schema(
|
||||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
|
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
input_schema = {"type": "object"}
|
input_schema = {"type": "object"}
|
||||||
if not tool_parameters:
|
if not tool_parameters:
|
||||||
return input_schema
|
return input_schema
|
||||||
|
|
||||||
json_properties = {}
|
json_properties = {}
|
||||||
required = []
|
required = []
|
||||||
for name, param in tool_parameters.items():
|
for name, param in tool_parameters.items():
|
||||||
json_property = {
|
json_property = {
|
||||||
"type": param.param_type,
|
"type": param.param_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.description:
|
if param.description:
|
||||||
json_property["description"] = param.description
|
json_property["description"] = param.description
|
||||||
if param.required:
|
if param.required:
|
||||||
required.append(name)
|
required.append(name)
|
||||||
json_properties[name] = json_property
|
json_properties[name] = json_property
|
||||||
|
|
||||||
input_schema["properties"] = json_properties
|
input_schema["properties"] = json_properties
|
||||||
if required:
|
if required:
|
||||||
input_schema["required"] = required
|
input_schema["required"] = required
|
||||||
return input_schema
|
return input_schema
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _tools_to_tool_config(
|
def _tools_to_tool_config(
|
||||||
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
||||||
) -> Optional[Dict]:
|
) -> Optional[Dict]:
|
||||||
if not tools:
|
if not tools:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
bedrock_tools = []
|
bedrock_tools = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
tool_name = (
|
tool_name = (
|
||||||
tool.tool_name
|
tool.tool_name
|
||||||
if isinstance(tool.tool_name, str)
|
if isinstance(tool.tool_name, str)
|
||||||
else tool.tool_name.value
|
else tool.tool_name.value
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_spec = {
|
tool_spec = {
|
||||||
"toolSpec": {
|
"toolSpec": {
|
||||||
"name": tool_name,
|
"name": tool_name,
|
||||||
"inputSchema": {
|
"inputSchema": {
|
||||||
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
||||||
tool.parameters
|
tool.parameters
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool.description:
|
if tool.description:
|
||||||
tool_spec["toolSpec"]["description"] = tool.description
|
tool_spec["toolSpec"]["description"] = tool.description
|
||||||
|
|
||||||
bedrock_tools.append(tool_spec)
|
bedrock_tools.append(tool_spec)
|
||||||
tool_config = {
|
tool_config = {
|
||||||
"tools": bedrock_tools,
|
"tools": bedrock_tools,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
tool_config["toolChoice"] = (
|
tool_config["toolChoice"] = (
|
||||||
{"any": {}}
|
{"any": {}}
|
||||||
if tool_choice.value == ToolChoice.required
|
if tool_choice.value == ToolChoice.required
|
||||||
else {"auto": {}}
|
else {"auto": {}}
|
||||||
)
|
)
|
||||||
return tool_config
|
return tool_config
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> (
|
) -> (
|
||||||
AsyncGenerator
|
AsyncGenerator
|
||||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||||
bedrock_model = self.map_to_provider_model(model)
|
bedrock_model = self.map_to_provider_model(model)
|
||||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||||
sampling_params
|
sampling_params
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
||||||
bedrock_messages, system_bedrock_messages = (
|
bedrock_messages, system_bedrock_messages = (
|
||||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
||||||
)
|
)
|
||||||
|
|
||||||
converse_api_params = {
|
converse_api_params = {
|
||||||
"modelId": bedrock_model,
|
"modelId": bedrock_model,
|
||||||
"messages": bedrock_messages,
|
"messages": bedrock_messages,
|
||||||
}
|
}
|
||||||
if inference_config:
|
if inference_config:
|
||||||
converse_api_params["inferenceConfig"] = inference_config
|
converse_api_params["inferenceConfig"] = inference_config
|
||||||
|
|
||||||
# Tool use is not supported in streaming mode
|
# Tool use is not supported in streaming mode
|
||||||
if tool_config and not stream:
|
if tool_config and not stream:
|
||||||
converse_api_params["toolConfig"] = tool_config
|
converse_api_params["toolConfig"] = tool_config
|
||||||
if system_bedrock_messages:
|
if system_bedrock_messages:
|
||||||
converse_api_params["system"] = system_bedrock_messages
|
converse_api_params["system"] = system_bedrock_messages
|
||||||
|
|
||||||
if not stream:
|
if not stream:
|
||||||
converse_api_res = self.client.converse(**converse_api_params)
|
converse_api_res = self.client.converse(**converse_api_params)
|
||||||
|
|
||||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||||
converse_api_res
|
converse_api_res
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ChatCompletionResponse(
|
yield ChatCompletionResponse(
|
||||||
completion_message=output_message,
|
completion_message=output_message,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
||||||
event_stream = converse_stream_api_res["stream"]
|
event_stream = converse_stream_api_res["stream"]
|
||||||
|
|
||||||
for chunk in event_stream:
|
for chunk in event_stream:
|
||||||
if "messageStart" in chunk:
|
if "messageStart" in chunk:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
delta="",
|
delta="",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "contentBlockStart" in chunk:
|
elif "contentBlockStart" in chunk:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
content=ToolCall(
|
content=ToolCall(
|
||||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
tool_name=chunk["contentBlockStart"]["toolUse"][
|
||||||
"name"
|
"name"
|
||||||
],
|
],
|
||||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||||
"toolUseId"
|
"toolUseId"
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
parse_status=ToolCallParseStatus.started,
|
parse_status=ToolCallParseStatus.started,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "contentBlockDelta" in chunk:
|
elif "contentBlockDelta" in chunk:
|
||||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||||
else:
|
else:
|
||||||
delta = ToolCallDelta(
|
delta = ToolCallDelta(
|
||||||
content=ToolCall(
|
content=ToolCall(
|
||||||
arguments=chunk["contentBlockDelta"]["delta"][
|
arguments=chunk["contentBlockDelta"]["delta"][
|
||||||
"toolUse"
|
"toolUse"
|
||||||
]["input"]
|
]["input"]
|
||||||
),
|
),
|
||||||
parse_status=ToolCallParseStatus.success,
|
parse_status=ToolCallParseStatus.success,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=delta,
|
delta=delta,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "contentBlockStop" in chunk:
|
elif "contentBlockStop" in chunk:
|
||||||
# Ignored
|
# Ignored
|
||||||
pass
|
pass
|
||||||
elif "messageStop" in chunk:
|
elif "messageStop" in chunk:
|
||||||
stop_reason = (
|
stop_reason = (
|
||||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||||
chunk["messageStop"]["stopReason"]
|
chunk["messageStop"]["stopReason"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta="",
|
delta="",
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "metadata" in chunk:
|
elif "metadata" in chunk:
|
||||||
# Ignored
|
# Ignored
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Ignored
|
# Ignored
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -13,7 +13,7 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.augment_messages import (
|
||||||
|
@ -30,9 +30,9 @@ FIREWORKS_SUPPORTED_MODELS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
RoutableProviderForModels.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
|
@ -18,7 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.augment_messages import (
|
||||||
augment_messages_for_tools,
|
augment_messages_for_tools,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
# TODO: Eventually this will move to the llama cli model list command
|
# TODO: Eventually this will move to the llama cli model list command
|
||||||
# mapping of Model SKUs to ollama models
|
# mapping of Model SKUs to ollama models
|
||||||
|
@ -27,12 +27,13 @@ OLLAMA_SUPPORTED_SKUS = {
|
||||||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||||
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
||||||
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
||||||
|
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
RoutableProviderForModels.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
|
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
|
||||||
)
|
)
|
||||||
self.url = url
|
self.url = url
|
||||||
|
|
|
@ -18,7 +18,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.augment_messages import (
|
||||||
augment_messages_for_tools,
|
augment_messages_for_tools,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
@ -34,10 +34,10 @@ TOGETHER_SUPPORTED_MODELS = {
|
||||||
|
|
||||||
|
|
||||||
class TogetherInferenceAdapter(
|
class TogetherInferenceAdapter(
|
||||||
Inference, NeedsRequestProviderData, RoutableProviderForModels
|
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||||
):
|
):
|
||||||
def __init__(self, config: TogetherImplConfig) -> None:
|
def __init__(self, config: TogetherImplConfig) -> None:
|
||||||
RoutableProviderForModels.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
|
@ -12,7 +12,6 @@ from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.augment_messages import (
|
||||||
augment_messages_for_tools,
|
augment_messages_for_tools,
|
||||||
)
|
)
|
||||||
|
@ -25,24 +24,39 @@ from .model_parallel import LlamaModelParallelGenerator
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
class MetaReferenceInferenceImpl(Inference):
|
||||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.registered_model_defs = []
|
||||||
# verify that the checkpoint actually is for this model lol
|
# verify that the checkpoint actually is for this model lol
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.generator = LlamaModelParallelGenerator(self.config)
|
self.generator = LlamaModelParallelGenerator(self.config)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
assert (
|
existing = await self.get_model(model.identifier)
|
||||||
len(routing_keys) == 1
|
if existing is not None:
|
||||||
), f"Only one routing key is supported {routing_keys}"
|
return
|
||||||
assert routing_keys[0] == self.config.model
|
|
||||||
|
if model.identifier != self.model.descriptor():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
|
||||||
|
)
|
||||||
|
self.registered_model_defs = [model]
|
||||||
|
|
||||||
|
async def list_models(self) -> List[ModelDef]:
|
||||||
|
return self.registered_model_defs
|
||||||
|
|
||||||
|
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||||
|
for model in self.registered_model_defs:
|
||||||
|
if model.identifier == identifier:
|
||||||
|
return model
|
||||||
|
return None
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
|
@ -13,7 +13,6 @@ import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
@ -62,7 +61,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class FaissMemoryImpl(Memory, RoutableProvider):
|
class FaissMemoryImpl(Memory):
|
||||||
def __init__(self, config: FaissImplConfig) -> None:
|
def __init__(self, config: FaissImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
@ -83,7 +82,6 @@ class FaissMemoryImpl(Memory, RoutableProvider):
|
||||||
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
return bank
|
|
||||||
|
|
||||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
||||||
index = self.cache.get(identifier)
|
index = self.cache.get(identifier)
|
||||||
|
|
51
llama_stack/providers/utils/inference/model_registry.py
Normal file
51
llama_stack/providers/utils/inference/model_registry.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistryHelper:
|
||||||
|
|
||||||
|
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||||
|
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||||
|
self.registered_models = []
|
||||||
|
|
||||||
|
def map_to_provider_model(self, identifier: str) -> str:
|
||||||
|
model = resolve_model(identifier)
|
||||||
|
if not model:
|
||||||
|
raise ValueError(f"Unknown model: `{identifier}`")
|
||||||
|
|
||||||
|
if identifier not in self.stack_to_provider_models_map:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.stack_to_provider_models_map[identifier]
|
||||||
|
|
||||||
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
|
existing = await self.get_model(model.identifier)
|
||||||
|
if existing is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if model.identifier not in self.stack_to_provider_models_map:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.registered_models.append(model)
|
||||||
|
|
||||||
|
async def list_models(self) -> List[ModelDef]:
|
||||||
|
return self.registered_models
|
||||||
|
|
||||||
|
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||||
|
for model in self.registered_models:
|
||||||
|
if model.identifier == identifier:
|
||||||
|
return model
|
||||||
|
return None
|
|
@ -1,36 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
|
|
||||||
class RoutableProviderForModels(RoutableProvider):
|
|
||||||
|
|
||||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
|
||||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: List[str]):
|
|
||||||
for routing_key in routing_keys:
|
|
||||||
if routing_key not in self.stack_to_provider_models_map:
|
|
||||||
raise ValueError(
|
|
||||||
f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def map_to_provider_model(self, routing_key: str) -> str:
|
|
||||||
model = resolve_model(routing_key)
|
|
||||||
if not model:
|
|
||||||
raise ValueError(f"Unknown model: `{routing_key}`")
|
|
||||||
|
|
||||||
if routing_key not in self.stack_to_provider_models_map:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model {routing_key} not found in map {self.stack_to_provider_models_map}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.stack_to_provider_models_map[routing_key]
|
|
Loading…
Add table
Add a link
Reference in a new issue