mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-19 03:10:03 +00:00
Remove "routing_table" and "routing_key" concepts for the user (#201)
This PR makes several core changes to the developer experience surrounding Llama Stack. Background: PR #92 introduced the notion of "routing" to the Llama Stack. It introduces three object types: (1) models, (2) shields and (3) memory banks. Each of these objects can be associated with a distinct provider. So you can get model A to be inferenced locally while model B, C can be inference remotely (e.g.) However, this had a few drawbacks: you could not address the provider instances -- i.e., if you configured "meta-reference" with a given model, you could not assign an identifier to this instance which you could re-use later. the above meant that you could not register a "routing_key" (e.g. model) dynamically and say "please use this existing provider I have already configured" for a new model. the terms "routing_table" and "routing_key" were exposed directly to the user. in my view, this is way too much overhead for a new user (which almost everyone is.) people come to the stack wanting to do ML and encounter a completely unexpected term. What this PR does: This PR structures the run config with only a single prominent key: - providers Providers are instances of configured provider types. Here's an example which shows two instances of the remote::tgi provider which are serving two different models. providers: inference: - provider_id: foo provider_type: remote::tgi config: { ... } - provider_id: bar provider_type: remote::tgi config: { ... } Secondly, the PR adds dynamic registration of { models | shields | memory_banks } to the API surface. The distribution still acts like a "routing table" (as previously) except that it asks the backing providers for a listing of these objects. For example it asks a TGI or Ollama inference adapter what models it is serving. Only the models that are being actually served can be requested by the user for inference. Otherwise, the Stack server will throw an error. When dynamically registering these objects, you can use the provider IDs shown above. Info about providers can be obtained using the Api.inspect set of endpoints (/providers, /routes, etc.) The above examples shows the correspondence between inference providers and models registry items. Things work similarly for the safety <=> shields and memory <=> memory_banks pairs. Registry: This PR also makes it so that Providers need to implement additional methods for registering and listing objects. For example, each Inference provider is now expected to implement the ModelsProtocolPrivate protocol (naming is not great!) which consists of two methods register_model list_models The goal is to inform the provider that a certain model needs to be supported so the provider can make any relevant backend changes if needed (or throw an error if the model cannot be supported.) There are many other cleanups included some of which are detailed in a follow-up comment.
This commit is contained in:
parent
8c3010553f
commit
6bb57e72a7
93 changed files with 4697 additions and 4457 deletions
|
@ -1,445 +1,451 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import * # noqa: F403
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
BEDROCK_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
||||
}
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(Inference, RoutableProviderForModels):
|
||||
|
||||
@staticmethod
|
||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
|
||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
||||
)
|
||||
self._config = config
|
||||
|
||||
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
||||
if bedrock_stop_reason == "max_tokens":
|
||||
return StopReason.out_of_tokens
|
||||
return StopReason.end_of_turn
|
||||
|
||||
@staticmethod
|
||||
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
||||
for builtin_tool in BuiltinTool:
|
||||
if builtin_tool.value == tool_name_str:
|
||||
return builtin_tool
|
||||
else:
|
||||
return tool_name_str
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
||||
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
converse_api_res["stopReason"]
|
||||
)
|
||||
|
||||
bedrock_message = converse_api_res["output"]["message"]
|
||||
|
||||
role = bedrock_message["role"]
|
||||
contents = bedrock_message["content"]
|
||||
|
||||
tool_calls = []
|
||||
text_content = []
|
||||
for content in contents:
|
||||
if "toolUse" in content:
|
||||
tool_use = content["toolUse"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
||||
tool_use["name"]
|
||||
),
|
||||
arguments=tool_use["input"] if "input" in tool_use else None,
|
||||
call_id=tool_use["toolUseId"],
|
||||
)
|
||||
)
|
||||
elif "text" in content:
|
||||
text_content.append(content["text"])
|
||||
|
||||
return CompletionMessage(
|
||||
role=role,
|
||||
content=text_content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _messages_to_bedrock_messages(
|
||||
messages: List[Message],
|
||||
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
||||
bedrock_messages = []
|
||||
system_bedrock_messages = []
|
||||
|
||||
user_contents = []
|
||||
assistant_contents = None
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content_list = (
|
||||
message.content
|
||||
if isinstance(message.content, list)
|
||||
else [message.content]
|
||||
)
|
||||
if role == "ipython" or role == "user":
|
||||
if not user_contents:
|
||||
user_contents = []
|
||||
|
||||
if role == "ipython":
|
||||
user_contents.extend(
|
||||
[
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message.call_id,
|
||||
"content": [
|
||||
{"text": content} for content in content_list
|
||||
],
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
user_contents.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
assistant_contents = None
|
||||
elif role == "system":
|
||||
system_bedrock_messages.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
elif role == "assistant":
|
||||
if not assistant_contents:
|
||||
assistant_contents = []
|
||||
|
||||
assistant_contents.extend(
|
||||
[
|
||||
{
|
||||
"text": content,
|
||||
}
|
||||
for content in content_list
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"toolUse": {
|
||||
"input": tool_call.arguments,
|
||||
"name": (
|
||||
tool_call.tool_name
|
||||
if isinstance(tool_call.tool_name, str)
|
||||
else tool_call.tool_name.value
|
||||
),
|
||||
"toolUseId": tool_call.call_id,
|
||||
}
|
||||
}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
)
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
user_contents = None
|
||||
else:
|
||||
# Unknown role
|
||||
pass
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
|
||||
if system_bedrock_messages:
|
||||
return bedrock_messages, system_bedrock_messages
|
||||
|
||||
return bedrock_messages, None
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
||||
inference_config = {}
|
||||
if sampling_params:
|
||||
param_mapping = {
|
||||
"max_tokens": "maxTokens",
|
||||
"temperature": "temperature",
|
||||
"top_p": "topP",
|
||||
}
|
||||
|
||||
for k, v in param_mapping.items():
|
||||
if getattr(sampling_params, k):
|
||||
inference_config[v] = getattr(sampling_params, k)
|
||||
|
||||
return inference_config
|
||||
|
||||
@staticmethod
|
||||
def _tool_parameters_to_input_schema(
|
||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||
) -> Dict:
|
||||
input_schema = {"type": "object"}
|
||||
if not tool_parameters:
|
||||
return input_schema
|
||||
|
||||
json_properties = {}
|
||||
required = []
|
||||
for name, param in tool_parameters.items():
|
||||
json_property = {
|
||||
"type": param.param_type,
|
||||
}
|
||||
|
||||
if param.description:
|
||||
json_property["description"] = param.description
|
||||
if param.required:
|
||||
required.append(name)
|
||||
json_properties[name] = json_property
|
||||
|
||||
input_schema["properties"] = json_properties
|
||||
if required:
|
||||
input_schema["required"] = required
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def _tools_to_tool_config(
|
||||
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
||||
) -> Optional[Dict]:
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
tool_name = (
|
||||
tool.tool_name
|
||||
if isinstance(tool.tool_name, str)
|
||||
else tool.tool_name.value
|
||||
)
|
||||
|
||||
tool_spec = {
|
||||
"toolSpec": {
|
||||
"name": tool_name,
|
||||
"inputSchema": {
|
||||
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
||||
tool.parameters
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if tool.description:
|
||||
tool_spec["toolSpec"]["description"] = tool.description
|
||||
|
||||
bedrock_tools.append(tool_spec)
|
||||
tool_config = {
|
||||
"tools": bedrock_tools,
|
||||
}
|
||||
|
||||
if tool_choice:
|
||||
tool_config["toolChoice"] = (
|
||||
{"any": {}}
|
||||
if tool_choice.value == ToolChoice.required
|
||||
else {"auto": {}}
|
||||
)
|
||||
return tool_config
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> (
|
||||
AsyncGenerator
|
||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
bedrock_model = self.map_to_provider_model(model)
|
||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||
sampling_params
|
||||
)
|
||||
|
||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
||||
bedrock_messages, system_bedrock_messages = (
|
||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
||||
)
|
||||
|
||||
converse_api_params = {
|
||||
"modelId": bedrock_model,
|
||||
"messages": bedrock_messages,
|
||||
}
|
||||
if inference_config:
|
||||
converse_api_params["inferenceConfig"] = inference_config
|
||||
|
||||
# Tool use is not supported in streaming mode
|
||||
if tool_config and not stream:
|
||||
converse_api_params["toolConfig"] = tool_config
|
||||
if system_bedrock_messages:
|
||||
converse_api_params["system"] = system_bedrock_messages
|
||||
|
||||
if not stream:
|
||||
converse_api_res = self.client.converse(**converse_api_params)
|
||||
|
||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||
converse_api_res
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=output_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
||||
event_stream = converse_stream_api_res["stream"]
|
||||
|
||||
for chunk in event_stream:
|
||||
if "messageStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
elif "contentBlockStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
||||
"name"
|
||||
],
|
||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||
"toolUseId"
|
||||
],
|
||||
),
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif "contentBlockDelta" in chunk:
|
||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||
else:
|
||||
delta = ToolCallDelta(
|
||||
content=ToolCall(
|
||||
arguments=chunk["contentBlockDelta"]["delta"][
|
||||
"toolUse"
|
||||
]["input"]
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
elif "contentBlockStop" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
elif "messageStop" in chunk:
|
||||
stop_reason = (
|
||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
chunk["messageStop"]["stopReason"]
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
elif "metadata" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
else:
|
||||
# Ignored
|
||||
pass
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import * # noqa: F403
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
BEDROCK_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
||||
}
|
||||
|
||||
|
||||
# NOTE: this is not quite tested after the recent refactors
|
||||
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
||||
)
|
||||
self._config = config
|
||||
|
||||
self._client = _create_bedrock_client(config)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
||||
if bedrock_stop_reason == "max_tokens":
|
||||
return StopReason.out_of_tokens
|
||||
return StopReason.end_of_turn
|
||||
|
||||
@staticmethod
|
||||
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
||||
for builtin_tool in BuiltinTool:
|
||||
if builtin_tool.value == tool_name_str:
|
||||
return builtin_tool
|
||||
else:
|
||||
return tool_name_str
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
||||
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
converse_api_res["stopReason"]
|
||||
)
|
||||
|
||||
bedrock_message = converse_api_res["output"]["message"]
|
||||
|
||||
role = bedrock_message["role"]
|
||||
contents = bedrock_message["content"]
|
||||
|
||||
tool_calls = []
|
||||
text_content = []
|
||||
for content in contents:
|
||||
if "toolUse" in content:
|
||||
tool_use = content["toolUse"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
||||
tool_use["name"]
|
||||
),
|
||||
arguments=tool_use["input"] if "input" in tool_use else None,
|
||||
call_id=tool_use["toolUseId"],
|
||||
)
|
||||
)
|
||||
elif "text" in content:
|
||||
text_content.append(content["text"])
|
||||
|
||||
return CompletionMessage(
|
||||
role=role,
|
||||
content=text_content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _messages_to_bedrock_messages(
|
||||
messages: List[Message],
|
||||
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
||||
bedrock_messages = []
|
||||
system_bedrock_messages = []
|
||||
|
||||
user_contents = []
|
||||
assistant_contents = None
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content_list = (
|
||||
message.content
|
||||
if isinstance(message.content, list)
|
||||
else [message.content]
|
||||
)
|
||||
if role == "ipython" or role == "user":
|
||||
if not user_contents:
|
||||
user_contents = []
|
||||
|
||||
if role == "ipython":
|
||||
user_contents.extend(
|
||||
[
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message.call_id,
|
||||
"content": [
|
||||
{"text": content} for content in content_list
|
||||
],
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
user_contents.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
assistant_contents = None
|
||||
elif role == "system":
|
||||
system_bedrock_messages.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
elif role == "assistant":
|
||||
if not assistant_contents:
|
||||
assistant_contents = []
|
||||
|
||||
assistant_contents.extend(
|
||||
[
|
||||
{
|
||||
"text": content,
|
||||
}
|
||||
for content in content_list
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"toolUse": {
|
||||
"input": tool_call.arguments,
|
||||
"name": (
|
||||
tool_call.tool_name
|
||||
if isinstance(tool_call.tool_name, str)
|
||||
else tool_call.tool_name.value
|
||||
),
|
||||
"toolUseId": tool_call.call_id,
|
||||
}
|
||||
}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
)
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
user_contents = None
|
||||
else:
|
||||
# Unknown role
|
||||
pass
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
|
||||
if system_bedrock_messages:
|
||||
return bedrock_messages, system_bedrock_messages
|
||||
|
||||
return bedrock_messages, None
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
||||
inference_config = {}
|
||||
if sampling_params:
|
||||
param_mapping = {
|
||||
"max_tokens": "maxTokens",
|
||||
"temperature": "temperature",
|
||||
"top_p": "topP",
|
||||
}
|
||||
|
||||
for k, v in param_mapping.items():
|
||||
if getattr(sampling_params, k):
|
||||
inference_config[v] = getattr(sampling_params, k)
|
||||
|
||||
return inference_config
|
||||
|
||||
@staticmethod
|
||||
def _tool_parameters_to_input_schema(
|
||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]],
|
||||
) -> Dict:
|
||||
input_schema = {"type": "object"}
|
||||
if not tool_parameters:
|
||||
return input_schema
|
||||
|
||||
json_properties = {}
|
||||
required = []
|
||||
for name, param in tool_parameters.items():
|
||||
json_property = {
|
||||
"type": param.param_type,
|
||||
}
|
||||
|
||||
if param.description:
|
||||
json_property["description"] = param.description
|
||||
if param.required:
|
||||
required.append(name)
|
||||
json_properties[name] = json_property
|
||||
|
||||
input_schema["properties"] = json_properties
|
||||
if required:
|
||||
input_schema["required"] = required
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def _tools_to_tool_config(
|
||||
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
||||
) -> Optional[Dict]:
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
tool_name = (
|
||||
tool.tool_name
|
||||
if isinstance(tool.tool_name, str)
|
||||
else tool.tool_name.value
|
||||
)
|
||||
|
||||
tool_spec = {
|
||||
"toolSpec": {
|
||||
"name": tool_name,
|
||||
"inputSchema": {
|
||||
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
||||
tool.parameters
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if tool.description:
|
||||
tool_spec["toolSpec"]["description"] = tool.description
|
||||
|
||||
bedrock_tools.append(tool_spec)
|
||||
tool_config = {
|
||||
"tools": bedrock_tools,
|
||||
}
|
||||
|
||||
if tool_choice:
|
||||
tool_config["toolChoice"] = (
|
||||
{"any": {}}
|
||||
if tool_choice.value == ToolChoice.required
|
||||
else {"auto": {}}
|
||||
)
|
||||
return tool_config
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> (
|
||||
AsyncGenerator
|
||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
bedrock_model = self.map_to_provider_model(model)
|
||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||
sampling_params
|
||||
)
|
||||
|
||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
||||
bedrock_messages, system_bedrock_messages = (
|
||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
||||
)
|
||||
|
||||
converse_api_params = {
|
||||
"modelId": bedrock_model,
|
||||
"messages": bedrock_messages,
|
||||
}
|
||||
if inference_config:
|
||||
converse_api_params["inferenceConfig"] = inference_config
|
||||
|
||||
# Tool use is not supported in streaming mode
|
||||
if tool_config and not stream:
|
||||
converse_api_params["toolConfig"] = tool_config
|
||||
if system_bedrock_messages:
|
||||
converse_api_params["system"] = system_bedrock_messages
|
||||
|
||||
if not stream:
|
||||
converse_api_res = self.client.converse(**converse_api_params)
|
||||
|
||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||
converse_api_res
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=output_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
||||
event_stream = converse_stream_api_res["stream"]
|
||||
|
||||
for chunk in event_stream:
|
||||
if "messageStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
elif "contentBlockStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
||||
"name"
|
||||
],
|
||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||
"toolUseId"
|
||||
],
|
||||
),
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif "contentBlockDelta" in chunk:
|
||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||
else:
|
||||
delta = ToolCallDelta(
|
||||
content=ToolCall(
|
||||
arguments=chunk["contentBlockDelta"]["delta"][
|
||||
"toolUse"
|
||||
]["input"]
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
elif "contentBlockStop" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
elif "messageStop" in chunk:
|
||||
stop_reason = (
|
||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
chunk["messageStop"]["stopReason"]
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
elif "metadata" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
else:
|
||||
# Ignored
|
||||
pass
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
|
||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||
|
|
|
@ -6,39 +6,41 @@
|
|||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
|
||||
DATABRICKS_SUPPORTED_MODELS = {
|
||||
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
||||
}
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(Inference):
|
||||
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> OpenAI:
|
||||
return OpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
@ -46,47 +48,10 @@ class DatabricksInferenceAdapter(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_databricks_messages(self, messages: list[Message]) -> list:
|
||||
databricks_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
databricks_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return databricks_messages
|
||||
|
||||
def resolve_databricks_model(self, model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert (
|
||||
model is not None
|
||||
and model.descriptor(shorten_default_variant=True)
|
||||
in DATABRICKS_SUPPORTED_MODELS
|
||||
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(DATABRICKS_SUPPORTED_MODELS.keys())}"
|
||||
|
||||
return DATABRICKS_SUPPORTED_MODELS.get(
|
||||
model.descriptor(shorten_default_variant=True)
|
||||
)
|
||||
|
||||
def get_databricks_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -108,150 +73,46 @@ class DatabricksInferenceAdapter(Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
options = self.get_databricks_chat_options(request)
|
||||
databricks_model = self.resolve_databricks_model(request.model)
|
||||
|
||||
if not request.stream:
|
||||
|
||||
r = self.client.chat.completions.create(
|
||||
model=databricks_model,
|
||||
messages=self._messages_to_databricks_messages(messages),
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
|
||||
stop_reason = None
|
||||
if r.choices[0].finish_reason:
|
||||
if r.choices[0].finish_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r.choices[0].finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r.choices[0].message.content, stop_reason
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
|
||||
for chunk in self.client.chat.completions.create(
|
||||
model=databricks_model,
|
||||
messages=self._messages_to_databricks_messages(messages),
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
if chunk.choices[0].finish_reason:
|
||||
if (
|
||||
stop_reason is None
|
||||
and chunk.choices[0].finish_reason == "stop"
|
||||
):
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif (
|
||||
stop_reason is None
|
||||
and chunk.choices[0].finish_reason == "length"
|
||||
):
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
text = chunk.choices[0].delta.content
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
if text is None:
|
||||
continue
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
}
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -10,14 +10,19 @@ from fireworks.client import Fireworks
|
|||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
@ -27,21 +32,18 @@ FIREWORKS_SUPPORTED_MODELS = {
|
|||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
||||
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
|
||||
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
||||
}
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> Fireworks:
|
||||
return Fireworks(api_key=self.config.api_key)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
@ -49,7 +51,7 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -59,27 +61,7 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list:
|
||||
fireworks_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
fireworks_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return fireworks_messages
|
||||
|
||||
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -101,147 +83,48 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
# accumulate sampling params and other options to pass to fireworks
|
||||
options = self.get_fireworks_chat_options(request)
|
||||
fireworks_model = self.map_to_provider_model(request.model)
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.chat.completions.acreate(
|
||||
model=fireworks_model,
|
||||
messages=self._messages_to_fireworks_messages(messages),
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r.choices[0].finish_reason:
|
||||
if r.choices[0].finish_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r.choices[0].finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r.choices[0].message.content, stop_reason
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
client = Fireworks(api_key=self.config.api_key)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await client.completion.acreate(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
|
||||
async for chunk in self.client.chat.completions.acreate(
|
||||
model=fireworks_model,
|
||||
messages=self._messages_to_fireworks_messages(messages),
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
if chunk.choices[0].finish_reason:
|
||||
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif (
|
||||
stop_reason is None
|
||||
and chunk.choices[0].finish_reason == "length"
|
||||
):
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
text = chunk.choices[0].delta.content
|
||||
if text is None:
|
||||
continue
|
||||
stream = client.completion.acreate(**params)
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||
# Fireworks always prepends with BOS
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
options = get_sampling_options(request)
|
||||
options.setdefault("max_tokens", 512)
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": prompt,
|
||||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -9,35 +9,38 @@ from typing import AsyncGenerator
|
|||
import httpx
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
||||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
OLLAMA_SUPPORTED_SKUS = {
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
OLLAMA_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
||||
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
||||
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
|
||||
}
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
||||
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
|
||||
)
|
||||
self.url = url
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
|
@ -55,7 +58,33 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Dynamic model registration is not supported")
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
|
||||
|
||||
ret = []
|
||||
res = await self.client.ps()
|
||||
for r in res["models"]:
|
||||
if r["model"] not in ollama_to_llama:
|
||||
print(f"Ollama is running a model unknown to Llama Stack: {r['model']}")
|
||||
continue
|
||||
|
||||
llama_model = ollama_to_llama[r["model"]]
|
||||
ret.append(
|
||||
ModelDef(
|
||||
identifier=llama_model,
|
||||
llama_model=llama_model,
|
||||
metadata={
|
||||
"ollama_model": r["model"],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -65,32 +94,7 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
|
||||
ollama_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
ollama_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return ollama_messages
|
||||
|
||||
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
if (
|
||||
request.sampling_params.repetition_penalty is not None
|
||||
and request.sampling_params.repetition_penalty != 1.0
|
||||
):
|
||||
options["repeat_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -111,156 +115,61 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
# accumulate sampling params and other options to pass to ollama
|
||||
options = self.get_ollama_chat_options(request)
|
||||
ollama_model = self.map_to_provider_model(request.model)
|
||||
|
||||
res = await self.client.ps()
|
||||
need_model_pull = True
|
||||
for r in res["models"]:
|
||||
if ollama_model == r["model"]:
|
||||
need_model_pull = False
|
||||
break
|
||||
|
||||
if need_model_pull:
|
||||
print(f"Pulling model: {ollama_model}")
|
||||
status = await self.client.pull(ollama_model)
|
||||
assert (
|
||||
status["status"] == "success"
|
||||
), f"Failed to pull model {self.model} in ollama"
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.chat(
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(messages),
|
||||
stream=False,
|
||||
options=options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r["done"]:
|
||||
if r["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r["done_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r["message"]["content"], stop_reason
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"options": get_sampling_options(request),
|
||||
"raw": True,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await self.client.generate(**params)
|
||||
assert isinstance(r, dict)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
text=r["response"],
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(request, response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.generate(**params)
|
||||
async for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
text=chunk["response"],
|
||||
)
|
||||
)
|
||||
stream = await self.client.chat(
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(messages),
|
||||
stream=True,
|
||||
options=options,
|
||||
)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk["done"]:
|
||||
if stop_reason is None and chunk["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif stop_reason is None and chunk["done_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = chunk["message"]["content"]
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
|
||||
class SampleInferenceImpl(Inference, RoutableProvider):
|
||||
class SampleInferenceImpl(Inference):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
|
|
@ -34,7 +34,7 @@ class InferenceEndpointImplConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class InferenceAPIImplConfig(BaseModel):
|
||||
model_id: str = Field(
|
||||
huggingface_repo: str = Field(
|
||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
|
|
|
@ -6,18 +6,27 @@
|
|||
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_model_input_info,
|
||||
)
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
@ -25,24 +34,39 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _HfAdapter(Inference, RoutableProvider):
|
||||
class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||
client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor()
|
||||
for model in all_registered_models()
|
||||
if model.huggingface_repo
|
||||
}
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Model registration is not supported for HuggingFace models")
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
repo = self.model_id
|
||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||
return [
|
||||
ModelDef(
|
||||
identifier=identifier,
|
||||
llama_model=identifier,
|
||||
metadata={
|
||||
"huggingface_repo": repo,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -52,16 +76,7 @@ class _HfAdapter(Inference, RoutableProvider):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -83,146 +98,71 @@ class _HfAdapter(Inference, RoutableProvider):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||
prompt = self.tokenizer.decode(model_input.tokens)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
input_tokens = len(model_input.tokens)
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await self.client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
text="".join(t.text for t in r.details.tokens),
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(request, response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
|
||||
choice = OpenAICompatCompletionChoice(text=token_result.text)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt, input_tokens = chat_completion_request_to_model_input_info(
|
||||
request, self.formatter
|
||||
)
|
||||
max_new_tokens = min(
|
||||
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
||||
self.max_tokens - input_tokens - 1,
|
||||
)
|
||||
options = get_sampling_options(request)
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
details=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**options,
|
||||
)
|
||||
|
||||
print(f"Calculated max_new_tokens: {max_new_tokens}")
|
||||
|
||||
options = self.get_chat_options(request)
|
||||
if not request.stream:
|
||||
response = await self.client.text_generation(
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
details=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**options,
|
||||
)
|
||||
stop_reason = None
|
||||
if response.details.finish_reason:
|
||||
if response.details.finish_reason in ["stop", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif response.details.finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
response.generated_text,
|
||||
stop_reason,
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
tokens = []
|
||||
|
||||
async for response in await self.client.text_generation(
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
details=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**options,
|
||||
):
|
||||
token_result = response.token
|
||||
|
||||
buffer += token_result.text
|
||||
tokens.append(token_result.id)
|
||||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = text
|
||||
|
||||
if stop_reason is None:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
|
@ -236,7 +176,7 @@ class TGIAdapter(_HfAdapter):
|
|||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.model_id, token=config.api_token
|
||||
model=config.huggingface_repo, token=config.api_token
|
||||
)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
|
|
|
@ -8,17 +8,22 @@ from typing import AsyncGenerator
|
|||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from together import Together
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
@ -34,19 +39,14 @@ TOGETHER_SUPPORTED_MODELS = {
|
|||
|
||||
|
||||
class TogetherInferenceAdapter(
|
||||
Inference, NeedsRequestProviderData, RoutableProviderForModels
|
||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> Together:
|
||||
return Together(api_key=self.config.api_key)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
@ -64,27 +64,7 @@ class TogetherInferenceAdapter(
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_together_messages(self, messages: list[Message]) -> list:
|
||||
together_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
together_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return together_messages
|
||||
|
||||
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -95,7 +75,6 @@ class TogetherInferenceAdapter(
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
|
@ -108,7 +87,6 @@ class TogetherInferenceAdapter(
|
|||
together_api_key = provider_data.together_api_key
|
||||
|
||||
client = Together(api_key=together_api_key)
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -120,146 +98,46 @@ class TogetherInferenceAdapter(
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
# accumulate sampling params and other options to pass to together
|
||||
options = self.get_together_chat_options(request)
|
||||
together_model = self.map_to_provider_model(request.model)
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
if not request.stream:
|
||||
# TODO: might need to add back an async here
|
||||
r = client.chat.completions.create(
|
||||
model=together_model,
|
||||
messages=self._messages_to_together_messages(messages),
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r.choices[0].finish_reason:
|
||||
if (
|
||||
r.choices[0].finish_reason == "stop"
|
||||
or r.choices[0].finish_reason == "eos"
|
||||
):
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r.choices[0].finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r.choices[0].message.content, stop_reason
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Together
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
|
||||
for chunk in client.chat.completions.create(
|
||||
model=together_model,
|
||||
messages=self._messages_to_together_messages(messages),
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
if finish_reason := chunk.choices[0].finish_reason:
|
||||
if stop_reason is None and finish_reason in ["stop", "eos"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif stop_reason is None and finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Together
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
text = chunk.choices[0].delta.content
|
||||
if text is None:
|
||||
continue
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
}
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -5,16 +5,17 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import chromadb
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
|
@ -65,7 +66,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class ChromaMemoryAdapter(Memory, RoutableProvider):
|
||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||
url = url.rstrip("/")
|
||||
|
@ -93,56 +94,43 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
collection = await self.client.create_collection(
|
||||
name=bank_id,
|
||||
metadata={"bank": bank.json()},
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=memory_bank.identifier,
|
||||
metadata={"bank": memory_bank.json()},
|
||||
)
|
||||
bank_index = BankWithIndex(
|
||||
bank=bank, index=ChromaIndex(self.client, collection)
|
||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||
)
|
||||
self.cache[bank_id] = bank_index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||
if bank_index is None:
|
||||
return None
|
||||
return bank_index.bank
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
self.cache[memory_bank.identifier] = bank_index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
collections = await self.client.list_collections()
|
||||
for collection in collections:
|
||||
if collection.name == bank_id:
|
||||
print(collection.metadata)
|
||||
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=ChromaIndex(self.client, collection),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
try:
|
||||
data = json.loads(collection.metadata["bank"])
|
||||
bank = parse_obj_as(MemoryBankDef, data)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
return None
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse bank: {collection.metadata}")
|
||||
continue
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=ChromaIndex(self.client, collection),
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -150,7 +138,7 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
|
|||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
|
@ -162,7 +150,7 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import List, Tuple
|
||||
|
||||
import psycopg2
|
||||
|
@ -12,11 +11,11 @@ from numpy.typing import NDArray
|
|||
from psycopg2 import sql
|
||||
from psycopg2.extras import execute_values, Json
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
|
@ -46,23 +45,17 @@ def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
|
|||
execute_values(cur, query, values, template="(%s, %s)")
|
||||
|
||||
|
||||
def load_models(cur, keys: List[str], cls):
|
||||
def load_models(cur, cls):
|
||||
query = "SELECT key, data FROM metadata_store"
|
||||
if keys:
|
||||
placeholders = ",".join(["%s"] * len(keys))
|
||||
query += f" WHERE key IN ({placeholders})"
|
||||
cur.execute(query, keys)
|
||||
else:
|
||||
cur.execute(query)
|
||||
|
||||
cur.execute(query)
|
||||
rows = cur.fetchall()
|
||||
return [cls(**row["data"]) for row in rows]
|
||||
return [parse_obj_as(cls, row["data"]) for row in rows]
|
||||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, bank: MemoryBank, dimension: int, cursor):
|
||||
def __init__(self, bank: MemoryBankDef, dimension: int, cursor):
|
||||
self.cursor = cursor
|
||||
self.table_name = f"vector_store_{bank.name}"
|
||||
self.table_name = f"vector_store_{bank.identifier}"
|
||||
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
|
@ -119,7 +112,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: PGVectorConfig) -> None:
|
||||
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
||||
self.config = config
|
||||
|
@ -161,57 +154,37 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
upsert_models(
|
||||
self.cursor,
|
||||
[
|
||||
(bank.bank_id, bank),
|
||||
(memory_bank.identifier, memory_bank),
|
||||
],
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
bank=memory_bank,
|
||||
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return bank
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||
if bank_index is None:
|
||||
return None
|
||||
return bank_index.bank
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
banks = load_models(self.cursor, [bank_id], MemoryBank)
|
||||
if not banks:
|
||||
return None
|
||||
|
||||
bank = banks[0]
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
banks = load_models(self.cursor, MemoryBankDef)
|
||||
for bank in banks:
|
||||
if bank.identifier not in self.cache:
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
return banks
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -219,7 +192,7 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
|||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
|
@ -231,7 +204,7 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
|
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
|||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
|
||||
class SampleMemoryImpl(Memory, RoutableProvider):
|
||||
class SampleMemoryImpl(Memory):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||
# these are the memory banks the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
|
|
@ -1,8 +1,15 @@
|
|||
from .config import WeaviateConfig
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
||||
from .weaviate import WeaviateMemoryAdapter
|
||||
|
||||
impl = WeaviateMemoryAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
return impl
|
||||
|
|
|
@ -4,15 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WeaviateRequestProviderData(BaseModel):
|
||||
# if there _is_ provider data, it must specify the API KEY
|
||||
# if you want it to be optional, use Optional[str]
|
||||
weaviate_api_key: str
|
||||
weaviate_cluster_url: str
|
||||
|
||||
@json_schema_type
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
collection: str = Field(default="MemoryBank")
|
||||
pass
|
||||
|
|
|
@ -1,14 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Optional, Dict, Any
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import weaviate
|
||||
import weaviate.classes as wvc
|
||||
from numpy.typing import NDArray
|
||||
from weaviate.classes.init import Auth
|
||||
|
||||
from llama_stack.apis.memory import *
|
||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
|
@ -16,162 +22,154 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(self, client: weaviate.Client, collection: str):
|
||||
def __init__(self, client: weaviate.Client, collection_name: str):
|
||||
self.client = client
|
||||
self.collection = collection
|
||||
self.collection_name = collection_name
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
data_objects = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
|
||||
data_objects.append(wvc.data.DataObject(
|
||||
properties={
|
||||
"chunk_content": chunk,
|
||||
},
|
||||
vector = embeddings[i].tolist()
|
||||
))
|
||||
data_objects.append(
|
||||
wvc.data.DataObject(
|
||||
properties={
|
||||
"chunk_content": chunk.json(),
|
||||
},
|
||||
vector=embeddings[i].tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
# Inserting chunks into a prespecified Weaviate collection
|
||||
assert self.collection is not None, "Collection name must be specified"
|
||||
my_collection = self.client.collections.get(self.collection)
|
||||
|
||||
await my_collection.data.insert_many(data_objects)
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
|
||||
# TODO: make this async friendly
|
||||
collection.data.insert_many(data_objects)
|
||||
|
||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||
assert self.collection is not None, "Collection name must be specified"
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
|
||||
my_collection = self.client.collections.get(self.collection)
|
||||
|
||||
results = my_collection.query.near_vector(
|
||||
near_vector = embedding.tolist(),
|
||||
limit = k,
|
||||
return_meta_data = wvc.query.MetadataQuery(distance=True)
|
||||
results = collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||
)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc in results.objects:
|
||||
chunk_json = doc.properties["chunk_content"]
|
||||
try:
|
||||
chunk = doc.properties["chunk_content"]
|
||||
chunks.append(chunk)
|
||||
scores.append(1.0 / doc.metadata.distance)
|
||||
|
||||
except Exception as e:
|
||||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse document: {e}")
|
||||
print(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(1.0 / doc.metadata.distance)
|
||||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateMemoryAdapter(Memory):
|
||||
class WeaviateMemoryAdapter(
|
||||
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
|
||||
):
|
||||
def __init__(self, config: WeaviateConfig) -> None:
|
||||
self.config = config
|
||||
self.client = None
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
request_provider_data = get_request_provider_data()
|
||||
|
||||
if request_provider_data is not None:
|
||||
assert isinstance(request_provider_data, WeaviateRequestProviderData)
|
||||
|
||||
# Connect to Weaviate Cloud
|
||||
return weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url = request_provider_data.weaviate_cluster_url,
|
||||
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
|
||||
)
|
||||
provider_data = self.get_request_provider_data()
|
||||
assert provider_data is not None, "Request provider data must be set"
|
||||
assert isinstance(provider_data, WeaviateRequestProviderData)
|
||||
|
||||
key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}"
|
||||
if key in self.client_cache:
|
||||
return self.client_cache[key]
|
||||
|
||||
client = weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url=provider_data.weaviate_cluster_url,
|
||||
auth_credentials=Auth.api_key(provider_data.weaviate_api_key),
|
||||
)
|
||||
self.client_cache[key] = client
|
||||
return client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.client = self._get_client()
|
||||
|
||||
# Create collection if it doesn't exist
|
||||
if not self.client.collections.exists(self.config.collection):
|
||||
self.client.collections.create(
|
||||
name = self.config.collection,
|
||||
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
|
||||
properties=[
|
||||
wvc.config.Property(
|
||||
name="chunk_content",
|
||||
data_type=wvc.config.DataType.TEXT,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise RuntimeError("Could not connect to Weaviate server") from e
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client = self._get_client()
|
||||
for client in self.client_cache.values():
|
||||
client.close()
|
||||
|
||||
if self.client:
|
||||
self.client.close()
|
||||
|
||||
async def create_memory_bank(
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
self.client = self._get_client()
|
||||
|
||||
# Store the bank as a new collection in Weaviate
|
||||
self.client.collections.create(
|
||||
name=bank_id
|
||||
)
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
# Create collection if it doesn't exist
|
||||
if not client.collections.exists(memory_bank.identifier):
|
||||
client.collections.create(
|
||||
name=memory_bank.identifier,
|
||||
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||
properties=[
|
||||
wvc.config.Property(
|
||||
name="chunk_content",
|
||||
data_type=wvc.config.DataType.TEXT,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=WeaviateIndex(cleint = self.client, collection = bank_id),
|
||||
bank=memory_bank,
|
||||
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return bank
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||
if bank_index is None:
|
||||
return None
|
||||
return bank_index.bank
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
# TODO: right now the Llama Stack is the source of truth for these banks. That is
|
||||
# not ideal. It should be Weaviate which is the source of truth. Unfortunately,
|
||||
# list() happens at Stack startup when the Weaviate client (credentials) is not
|
||||
# yet available. We need to figure out a way to make this work.
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
|
||||
self.client = self._get_client()
|
||||
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
collections = await self.client.collections.list_all().keys()
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if not bank:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
for collection in collections:
|
||||
if collection == bank_id:
|
||||
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=WeaviateIndex(self.client, collection),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
client = self._get_client()
|
||||
if not client.collections.exists(bank_id):
|
||||
raise ValueError(f"Collection with name `{bank_id}` not found")
|
||||
|
||||
return None
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
|
@ -189,4 +187,4 @@ class WeaviateMemoryAdapter(Memory):
|
|||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
return await index.query_documents(query, params)
|
||||
return await index.query_documents(query, params)
|
||||
|
|
|
@ -7,14 +7,13 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
import traceback
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import boto3
|
||||
|
||||
from llama_stack.apis.safety import * # noqa
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
|
@ -22,16 +21,17 @@ from .config import BedrockSafetyConfig
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SUPPORTED_SHIELD_TYPES = [
|
||||
"bedrock_guardrail",
|
||||
BEDROCK_SUPPORTED_SHIELDS = [
|
||||
ShieldType.generic_content_shield.value,
|
||||
]
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, RoutableProvider):
|
||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||
if not config.aws_profile:
|
||||
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
||||
self.config = config
|
||||
self.registered_shields = []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
|
@ -45,16 +45,23 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
for key in routing_keys:
|
||||
if key not in SUPPORTED_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
raise NotImplementedError(
|
||||
"""
|
||||
`list_shields` not implemented; this should read all guardrails from
|
||||
bedrock and populate guardrailId and guardrailVersion in the ShieldDef.
|
||||
"""
|
||||
)
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
if shield_type not in SUPPORTED_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
|
@ -69,52 +76,38 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
|
|||
|
||||
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"run_shield::{params}::messages={messages}")
|
||||
if "guardrailIdentifier" not in params:
|
||||
raise RuntimeError(
|
||||
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
||||
)
|
||||
|
||||
if "guardrailVersion" not in params:
|
||||
raise RuntimeError(
|
||||
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
||||
)
|
||||
shield_params = shield_def.params
|
||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(
|
||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||
)
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(
|
||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||
)
|
||||
|
||||
response = self.boto_client.apply_guardrail(
|
||||
guardrailIdentifier=params.get("guardrailIdentifier"),
|
||||
guardrailVersion=params.get("guardrailVersion"),
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
content=content_messages,
|
||||
)
|
||||
logger.debug(f"run_shield:: response: {response}::")
|
||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
for output in response["outputs"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
user_message = output["text"]
|
||||
for assessment in response["assessments"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
return SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
response = self.boto_client.apply_guardrail(
|
||||
guardrailIdentifier=shield_params["guardrailIdentifier"],
|
||||
guardrailVersion=shield_params["guardrailVersion"],
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
content=content_messages,
|
||||
)
|
||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
for output in response["outputs"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
user_message = output["text"]
|
||||
for assessment in response["assessments"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
|
||||
except Exception:
|
||||
error_str = traceback.format_exc()
|
||||
logger.error(
|
||||
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
|
||||
return SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
|||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
|
||||
class SampleSafetyImpl(Safety, RoutableProvider):
|
||||
class SampleSafetyImpl(Safety):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
# these are the safety shields the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
|
|
@ -6,26 +6,21 @@
|
|||
from together import Together
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
from .config import TogetherSafetyConfig
|
||||
|
||||
|
||||
SAFETY_SHIELD_TYPES = {
|
||||
TOGETHER_SHIELD_MODEL_MAP = {
|
||||
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
}
|
||||
|
||||
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
@ -35,16 +30,28 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
for key in routing_keys:
|
||||
if key not in SAFETY_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier=ShieldType.llama_guard.value,
|
||||
type=ShieldType.llama_guard.value,
|
||||
params={},
|
||||
)
|
||||
]
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
if shield_type not in SAFETY_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
model = shield_def.params.get("model", "llama_guard")
|
||||
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
||||
raise ValueError(f"Unsupported safety model: {model}")
|
||||
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
|
@ -57,8 +64,6 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
|||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
model_name = SAFETY_SHIELD_TYPES[shield_type]
|
||||
|
||||
# messages can have role assistant or user
|
||||
api_messages = []
|
||||
for message in messages:
|
||||
|
@ -66,7 +71,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
|||
api_messages.append({"role": message.role, "content": message.content})
|
||||
|
||||
violation = await get_safety_response(
|
||||
together_api_key, model_name, api_messages
|
||||
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
|
@ -90,7 +95,6 @@ async def get_safety_response(
|
|||
if parts[0] == "unsafe":
|
||||
return SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message="unsafe",
|
||||
metadata={"violation_type": parts[1]},
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue