mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
rebase on top of registry
This commit is contained in:
commit
6abef716dd
107 changed files with 4813 additions and 3587 deletions
|
|
@ -1,445 +1,445 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import * # noqa: F403
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
BEDROCK_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
||||
}
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(Inference, RoutableProviderForModels):
|
||||
|
||||
@staticmethod
|
||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
|
||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
||||
)
|
||||
self._config = config
|
||||
|
||||
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
||||
if bedrock_stop_reason == "max_tokens":
|
||||
return StopReason.out_of_tokens
|
||||
return StopReason.end_of_turn
|
||||
|
||||
@staticmethod
|
||||
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
||||
for builtin_tool in BuiltinTool:
|
||||
if builtin_tool.value == tool_name_str:
|
||||
return builtin_tool
|
||||
else:
|
||||
return tool_name_str
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
||||
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
converse_api_res["stopReason"]
|
||||
)
|
||||
|
||||
bedrock_message = converse_api_res["output"]["message"]
|
||||
|
||||
role = bedrock_message["role"]
|
||||
contents = bedrock_message["content"]
|
||||
|
||||
tool_calls = []
|
||||
text_content = []
|
||||
for content in contents:
|
||||
if "toolUse" in content:
|
||||
tool_use = content["toolUse"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
||||
tool_use["name"]
|
||||
),
|
||||
arguments=tool_use["input"] if "input" in tool_use else None,
|
||||
call_id=tool_use["toolUseId"],
|
||||
)
|
||||
)
|
||||
elif "text" in content:
|
||||
text_content.append(content["text"])
|
||||
|
||||
return CompletionMessage(
|
||||
role=role,
|
||||
content=text_content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _messages_to_bedrock_messages(
|
||||
messages: List[Message],
|
||||
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
||||
bedrock_messages = []
|
||||
system_bedrock_messages = []
|
||||
|
||||
user_contents = []
|
||||
assistant_contents = None
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content_list = (
|
||||
message.content
|
||||
if isinstance(message.content, list)
|
||||
else [message.content]
|
||||
)
|
||||
if role == "ipython" or role == "user":
|
||||
if not user_contents:
|
||||
user_contents = []
|
||||
|
||||
if role == "ipython":
|
||||
user_contents.extend(
|
||||
[
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message.call_id,
|
||||
"content": [
|
||||
{"text": content} for content in content_list
|
||||
],
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
user_contents.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
assistant_contents = None
|
||||
elif role == "system":
|
||||
system_bedrock_messages.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
elif role == "assistant":
|
||||
if not assistant_contents:
|
||||
assistant_contents = []
|
||||
|
||||
assistant_contents.extend(
|
||||
[
|
||||
{
|
||||
"text": content,
|
||||
}
|
||||
for content in content_list
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"toolUse": {
|
||||
"input": tool_call.arguments,
|
||||
"name": (
|
||||
tool_call.tool_name
|
||||
if isinstance(tool_call.tool_name, str)
|
||||
else tool_call.tool_name.value
|
||||
),
|
||||
"toolUseId": tool_call.call_id,
|
||||
}
|
||||
}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
)
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
user_contents = None
|
||||
else:
|
||||
# Unknown role
|
||||
pass
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
|
||||
if system_bedrock_messages:
|
||||
return bedrock_messages, system_bedrock_messages
|
||||
|
||||
return bedrock_messages, None
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
||||
inference_config = {}
|
||||
if sampling_params:
|
||||
param_mapping = {
|
||||
"max_tokens": "maxTokens",
|
||||
"temperature": "temperature",
|
||||
"top_p": "topP",
|
||||
}
|
||||
|
||||
for k, v in param_mapping.items():
|
||||
if getattr(sampling_params, k):
|
||||
inference_config[v] = getattr(sampling_params, k)
|
||||
|
||||
return inference_config
|
||||
|
||||
@staticmethod
|
||||
def _tool_parameters_to_input_schema(
|
||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||
) -> Dict:
|
||||
input_schema = {"type": "object"}
|
||||
if not tool_parameters:
|
||||
return input_schema
|
||||
|
||||
json_properties = {}
|
||||
required = []
|
||||
for name, param in tool_parameters.items():
|
||||
json_property = {
|
||||
"type": param.param_type,
|
||||
}
|
||||
|
||||
if param.description:
|
||||
json_property["description"] = param.description
|
||||
if param.required:
|
||||
required.append(name)
|
||||
json_properties[name] = json_property
|
||||
|
||||
input_schema["properties"] = json_properties
|
||||
if required:
|
||||
input_schema["required"] = required
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def _tools_to_tool_config(
|
||||
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
||||
) -> Optional[Dict]:
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
tool_name = (
|
||||
tool.tool_name
|
||||
if isinstance(tool.tool_name, str)
|
||||
else tool.tool_name.value
|
||||
)
|
||||
|
||||
tool_spec = {
|
||||
"toolSpec": {
|
||||
"name": tool_name,
|
||||
"inputSchema": {
|
||||
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
||||
tool.parameters
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if tool.description:
|
||||
tool_spec["toolSpec"]["description"] = tool.description
|
||||
|
||||
bedrock_tools.append(tool_spec)
|
||||
tool_config = {
|
||||
"tools": bedrock_tools,
|
||||
}
|
||||
|
||||
if tool_choice:
|
||||
tool_config["toolChoice"] = (
|
||||
{"any": {}}
|
||||
if tool_choice.value == ToolChoice.required
|
||||
else {"auto": {}}
|
||||
)
|
||||
return tool_config
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> (
|
||||
AsyncGenerator
|
||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
bedrock_model = self.map_to_provider_model(model)
|
||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||
sampling_params
|
||||
)
|
||||
|
||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
||||
bedrock_messages, system_bedrock_messages = (
|
||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
||||
)
|
||||
|
||||
converse_api_params = {
|
||||
"modelId": bedrock_model,
|
||||
"messages": bedrock_messages,
|
||||
}
|
||||
if inference_config:
|
||||
converse_api_params["inferenceConfig"] = inference_config
|
||||
|
||||
# Tool use is not supported in streaming mode
|
||||
if tool_config and not stream:
|
||||
converse_api_params["toolConfig"] = tool_config
|
||||
if system_bedrock_messages:
|
||||
converse_api_params["system"] = system_bedrock_messages
|
||||
|
||||
if not stream:
|
||||
converse_api_res = self.client.converse(**converse_api_params)
|
||||
|
||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||
converse_api_res
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=output_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
||||
event_stream = converse_stream_api_res["stream"]
|
||||
|
||||
for chunk in event_stream:
|
||||
if "messageStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
elif "contentBlockStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
||||
"name"
|
||||
],
|
||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||
"toolUseId"
|
||||
],
|
||||
),
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif "contentBlockDelta" in chunk:
|
||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||
else:
|
||||
delta = ToolCallDelta(
|
||||
content=ToolCall(
|
||||
arguments=chunk["contentBlockDelta"]["delta"][
|
||||
"toolUse"
|
||||
]["input"]
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
elif "contentBlockStop" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
elif "messageStop" in chunk:
|
||||
stop_reason = (
|
||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
chunk["messageStop"]["stopReason"]
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
elif "metadata" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
else:
|
||||
# Ignored
|
||||
pass
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import * # noqa: F403
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
BEDROCK_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
||||
}
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
|
||||
@staticmethod
|
||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
|
||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
||||
)
|
||||
self._config = config
|
||||
|
||||
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
||||
if bedrock_stop_reason == "max_tokens":
|
||||
return StopReason.out_of_tokens
|
||||
return StopReason.end_of_turn
|
||||
|
||||
@staticmethod
|
||||
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
||||
for builtin_tool in BuiltinTool:
|
||||
if builtin_tool.value == tool_name_str:
|
||||
return builtin_tool
|
||||
else:
|
||||
return tool_name_str
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
||||
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
converse_api_res["stopReason"]
|
||||
)
|
||||
|
||||
bedrock_message = converse_api_res["output"]["message"]
|
||||
|
||||
role = bedrock_message["role"]
|
||||
contents = bedrock_message["content"]
|
||||
|
||||
tool_calls = []
|
||||
text_content = []
|
||||
for content in contents:
|
||||
if "toolUse" in content:
|
||||
tool_use = content["toolUse"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
||||
tool_use["name"]
|
||||
),
|
||||
arguments=tool_use["input"] if "input" in tool_use else None,
|
||||
call_id=tool_use["toolUseId"],
|
||||
)
|
||||
)
|
||||
elif "text" in content:
|
||||
text_content.append(content["text"])
|
||||
|
||||
return CompletionMessage(
|
||||
role=role,
|
||||
content=text_content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _messages_to_bedrock_messages(
|
||||
messages: List[Message],
|
||||
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
||||
bedrock_messages = []
|
||||
system_bedrock_messages = []
|
||||
|
||||
user_contents = []
|
||||
assistant_contents = None
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content_list = (
|
||||
message.content
|
||||
if isinstance(message.content, list)
|
||||
else [message.content]
|
||||
)
|
||||
if role == "ipython" or role == "user":
|
||||
if not user_contents:
|
||||
user_contents = []
|
||||
|
||||
if role == "ipython":
|
||||
user_contents.extend(
|
||||
[
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message.call_id,
|
||||
"content": [
|
||||
{"text": content} for content in content_list
|
||||
],
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
user_contents.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
assistant_contents = None
|
||||
elif role == "system":
|
||||
system_bedrock_messages.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
elif role == "assistant":
|
||||
if not assistant_contents:
|
||||
assistant_contents = []
|
||||
|
||||
assistant_contents.extend(
|
||||
[
|
||||
{
|
||||
"text": content,
|
||||
}
|
||||
for content in content_list
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"toolUse": {
|
||||
"input": tool_call.arguments,
|
||||
"name": (
|
||||
tool_call.tool_name
|
||||
if isinstance(tool_call.tool_name, str)
|
||||
else tool_call.tool_name.value
|
||||
),
|
||||
"toolUseId": tool_call.call_id,
|
||||
}
|
||||
}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
)
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
user_contents = None
|
||||
else:
|
||||
# Unknown role
|
||||
pass
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
|
||||
if system_bedrock_messages:
|
||||
return bedrock_messages, system_bedrock_messages
|
||||
|
||||
return bedrock_messages, None
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
||||
inference_config = {}
|
||||
if sampling_params:
|
||||
param_mapping = {
|
||||
"max_tokens": "maxTokens",
|
||||
"temperature": "temperature",
|
||||
"top_p": "topP",
|
||||
}
|
||||
|
||||
for k, v in param_mapping.items():
|
||||
if getattr(sampling_params, k):
|
||||
inference_config[v] = getattr(sampling_params, k)
|
||||
|
||||
return inference_config
|
||||
|
||||
@staticmethod
|
||||
def _tool_parameters_to_input_schema(
|
||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||
) -> Dict:
|
||||
input_schema = {"type": "object"}
|
||||
if not tool_parameters:
|
||||
return input_schema
|
||||
|
||||
json_properties = {}
|
||||
required = []
|
||||
for name, param in tool_parameters.items():
|
||||
json_property = {
|
||||
"type": param.param_type,
|
||||
}
|
||||
|
||||
if param.description:
|
||||
json_property["description"] = param.description
|
||||
if param.required:
|
||||
required.append(name)
|
||||
json_properties[name] = json_property
|
||||
|
||||
input_schema["properties"] = json_properties
|
||||
if required:
|
||||
input_schema["required"] = required
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def _tools_to_tool_config(
|
||||
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
||||
) -> Optional[Dict]:
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
tool_name = (
|
||||
tool.tool_name
|
||||
if isinstance(tool.tool_name, str)
|
||||
else tool.tool_name.value
|
||||
)
|
||||
|
||||
tool_spec = {
|
||||
"toolSpec": {
|
||||
"name": tool_name,
|
||||
"inputSchema": {
|
||||
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
||||
tool.parameters
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if tool.description:
|
||||
tool_spec["toolSpec"]["description"] = tool.description
|
||||
|
||||
bedrock_tools.append(tool_spec)
|
||||
tool_config = {
|
||||
"tools": bedrock_tools,
|
||||
}
|
||||
|
||||
if tool_choice:
|
||||
tool_config["toolChoice"] = (
|
||||
{"any": {}}
|
||||
if tool_choice.value == ToolChoice.required
|
||||
else {"auto": {}}
|
||||
)
|
||||
return tool_config
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> (
|
||||
AsyncGenerator
|
||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
bedrock_model = self.map_to_provider_model(model)
|
||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||
sampling_params
|
||||
)
|
||||
|
||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
||||
bedrock_messages, system_bedrock_messages = (
|
||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
||||
)
|
||||
|
||||
converse_api_params = {
|
||||
"modelId": bedrock_model,
|
||||
"messages": bedrock_messages,
|
||||
}
|
||||
if inference_config:
|
||||
converse_api_params["inferenceConfig"] = inference_config
|
||||
|
||||
# Tool use is not supported in streaming mode
|
||||
if tool_config and not stream:
|
||||
converse_api_params["toolConfig"] = tool_config
|
||||
if system_bedrock_messages:
|
||||
converse_api_params["system"] = system_bedrock_messages
|
||||
|
||||
if not stream:
|
||||
converse_api_res = self.client.converse(**converse_api_params)
|
||||
|
||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||
converse_api_res
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=output_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
||||
event_stream = converse_stream_api_res["stream"]
|
||||
|
||||
for chunk in event_stream:
|
||||
if "messageStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
elif "contentBlockStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
||||
"name"
|
||||
],
|
||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||
"toolUseId"
|
||||
],
|
||||
),
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif "contentBlockDelta" in chunk:
|
||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||
else:
|
||||
delta = ToolCallDelta(
|
||||
content=ToolCall(
|
||||
arguments=chunk["contentBlockDelta"]["delta"][
|
||||
"toolUse"
|
||||
]["input"]
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
elif "contentBlockStop" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
elif "messageStop" in chunk:
|
||||
stop_reason = (
|
||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
chunk["messageStop"]["stopReason"]
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
elif "metadata" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
else:
|
||||
# Ignored
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||
assert isinstance(
|
||||
config, DatabricksImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: str = Field(
|
||||
default=None,
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
|
||||
DATABRICKS_SUPPORTED_MODELS = {
|
||||
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
||||
}
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
}
|
||||
|
|
@ -10,14 +10,19 @@ from fireworks.client import Fireworks
|
|||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
|
@ -27,21 +32,18 @@ FIREWORKS_SUPPORTED_MODELS = {
|
|||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
||||
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
|
||||
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
||||
}
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> Fireworks:
|
||||
return Fireworks(api_key=self.config.api_key)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
|
@ -49,7 +51,7 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
|
@ -59,27 +61,7 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list:
|
||||
fireworks_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
fireworks_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return fireworks_messages
|
||||
|
||||
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -101,147 +83,41 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
# accumulate sampling params and other options to pass to fireworks
|
||||
options = self.get_fireworks_chat_options(request)
|
||||
fireworks_model = self.map_to_provider_model(request.model)
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.chat.completions.acreate(
|
||||
model=fireworks_model,
|
||||
messages=self._messages_to_fireworks_messages(messages),
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r.choices[0].finish_reason:
|
||||
if r.choices[0].finish_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r.choices[0].finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r.choices[0].message.content, stop_reason
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
client = Fireworks(api_key=self.config.api_key)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await client.completion.acreate(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
|
||||
async for chunk in self.client.chat.completions.acreate(
|
||||
model=fireworks_model,
|
||||
messages=self._messages_to_fireworks_messages(messages),
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
if chunk.choices[0].finish_reason:
|
||||
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif (
|
||||
stop_reason is None
|
||||
and chunk.choices[0].finish_reason == "length"
|
||||
):
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
text = chunk.choices[0].delta.content
|
||||
if text is None:
|
||||
continue
|
||||
stream = client.completion.acreate(**params)
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||
# Fireworks always prepends with BOS
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
options = get_sampling_options(request)
|
||||
options.setdefault("max_tokens", 512)
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": prompt,
|
||||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@
|
|||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
|
||||
class OllamaImplConfig(RemoteProviderConfig):
|
||||
port: int = 11434
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
|
|
|
|||
|
|
@ -9,34 +9,36 @@ from typing import AsyncGenerator
|
|||
import httpx
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
OLLAMA_SUPPORTED_SKUS = {
|
||||
# "Llama3.1-8B-Instruct": "llama3.1",
|
||||
OLLAMA_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
||||
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
||||
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
|
||||
}
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
||||
class OllamaInferenceAdapter(Inference):
|
||||
def __init__(self, url: str) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
|
||||
)
|
||||
self.url = url
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
|
|
@ -54,7 +56,29 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Unsupported model {model.identifier}. Supported models: {OLLAMA_SUPPORTED_MODELS.keys()}"
|
||||
)
|
||||
|
||||
ollama_model = OLLAMA_SUPPORTED_MODELS[model.identifier]
|
||||
res = await self.client.ps()
|
||||
need_model_pull = True
|
||||
for r in res["models"]:
|
||||
if ollama_model == r["model"]:
|
||||
need_model_pull = False
|
||||
break
|
||||
|
||||
print(f"Ollama model `{ollama_model}` needs pull -> {need_model_pull}")
|
||||
if need_model_pull:
|
||||
print(f"Pulling model: {ollama_model}")
|
||||
status = await self.client.pull(ollama_model)
|
||||
assert (
|
||||
status["status"] == "success"
|
||||
), f"Failed to pull model {self.model} in ollama"
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
|
@ -64,32 +88,7 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
|
||||
ollama_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
ollama_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return ollama_messages
|
||||
|
||||
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
if (
|
||||
request.sampling_params.repetition_penalty is not None
|
||||
and request.sampling_params.repetition_penalty != 1.0
|
||||
):
|
||||
options["repeat_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -110,156 +109,54 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
# accumulate sampling params and other options to pass to ollama
|
||||
options = self.get_ollama_chat_options(request)
|
||||
ollama_model = self.map_to_provider_model(request.model)
|
||||
|
||||
res = await self.client.ps()
|
||||
need_model_pull = True
|
||||
for r in res["models"]:
|
||||
if ollama_model == r["model"]:
|
||||
need_model_pull = False
|
||||
break
|
||||
|
||||
if need_model_pull:
|
||||
print(f"Pulling model: {ollama_model}")
|
||||
status = await self.client.pull(ollama_model)
|
||||
assert (
|
||||
status["status"] == "success"
|
||||
), f"Failed to pull model {self.model} in ollama"
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.chat(
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(messages),
|
||||
stream=False,
|
||||
options=options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r["done"]:
|
||||
if r["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r["done_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r["message"]["content"], stop_reason
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"options": get_sampling_options(request),
|
||||
"raw": True,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await self.client.generate(**params)
|
||||
assert isinstance(r, dict)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
text=r["response"],
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(request, response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.generate(**params)
|
||||
async for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
text=chunk["response"],
|
||||
)
|
||||
)
|
||||
stream = await self.client.chat(
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(messages),
|
||||
stream=True,
|
||||
options=options,
|
||||
)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk["done"]:
|
||||
if stop_reason is None and chunk["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif stop_reason is None and chunk["done_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = chunk["message"]["content"]
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
|
||||
class SampleInferenceImpl(Inference, RoutableProvider):
|
||||
class SampleInferenceImpl(Inference):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class InferenceEndpointImplConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class InferenceAPIImplConfig(BaseModel):
|
||||
model_id: str = Field(
|
||||
huggingface_repo: str = Field(
|
||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
|
|
|
|||
|
|
@ -10,14 +10,19 @@ from typing import AsyncGenerator
|
|||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_model_input_info,
|
||||
)
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
|
@ -25,24 +30,31 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _HfAdapter(Inference, RoutableProvider):
|
||||
class _HfAdapter(Inference):
|
||||
client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
resolved_model = resolve_model(model.identifier)
|
||||
if resolved_model is None:
|
||||
raise ValueError(f"Unknown model: {model.identifier}")
|
||||
|
||||
if not resolved_model.huggingface_repo:
|
||||
raise ValueError(
|
||||
f"Model {model.identifier} does not have a HuggingFace repo"
|
||||
)
|
||||
|
||||
if self.model_id != resolved_model.huggingface_repo:
|
||||
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
|
@ -52,16 +64,7 @@ class _HfAdapter(Inference, RoutableProvider):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -83,146 +86,64 @@ class _HfAdapter(Inference, RoutableProvider):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||
prompt = self.tokenizer.decode(model_input.tokens)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
input_tokens = len(model_input.tokens)
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await self.client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
text="".join(t.text for t in r.details.tokens),
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(request, response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
|
||||
choice = OpenAICompatCompletionChoice(text=token_result.text)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt, input_tokens = chat_completion_request_to_model_input_info(
|
||||
request, self.formatter
|
||||
)
|
||||
max_new_tokens = min(
|
||||
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
||||
self.max_tokens - input_tokens - 1,
|
||||
)
|
||||
|
||||
print(f"Calculated max_new_tokens: {max_new_tokens}")
|
||||
|
||||
options = self.get_chat_options(request)
|
||||
if not request.stream:
|
||||
response = await self.client.text_generation(
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
details=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**options,
|
||||
)
|
||||
stop_reason = None
|
||||
if response.details.finish_reason:
|
||||
if response.details.finish_reason in ["stop", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif response.details.finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
response.generated_text,
|
||||
stop_reason,
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
tokens = []
|
||||
|
||||
async for response in await self.client.text_generation(
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
details=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**options,
|
||||
):
|
||||
token_result = response.token
|
||||
|
||||
buffer += token_result.text
|
||||
tokens.append(token_result.id)
|
||||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = text
|
||||
|
||||
if stop_reason is None:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
options = get_sampling_options(request)
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
details=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**options,
|
||||
)
|
||||
|
||||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
|
|
@ -236,7 +157,7 @@ class TGIAdapter(_HfAdapter):
|
|||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.model_id, token=config.api_token
|
||||
model=config.huggingface_repo, token=config.api_token
|
||||
)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
|
|
|
|||
|
|
@ -8,17 +8,22 @@ from typing import AsyncGenerator
|
|||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from together import Together
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
|
@ -34,19 +39,14 @@ TOGETHER_SUPPORTED_MODELS = {
|
|||
|
||||
|
||||
class TogetherInferenceAdapter(
|
||||
Inference, NeedsRequestProviderData, RoutableProviderForModels
|
||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> Together:
|
||||
return Together(api_key=self.config.api_key)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
|
@ -64,27 +64,7 @@ class TogetherInferenceAdapter(
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_together_messages(self, messages: list[Message]) -> list:
|
||||
together_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
together_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return together_messages
|
||||
|
||||
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -95,7 +75,6 @@ class TogetherInferenceAdapter(
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
|
|
@ -108,7 +87,6 @@ class TogetherInferenceAdapter(
|
|||
together_api_key = provider_data.together_api_key
|
||||
|
||||
client = Together(api_key=together_api_key)
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
@ -120,146 +98,39 @@ class TogetherInferenceAdapter(
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
# accumulate sampling params and other options to pass to together
|
||||
options = self.get_together_chat_options(request)
|
||||
together_model = self.map_to_provider_model(request.model)
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
if not request.stream:
|
||||
# TODO: might need to add back an async here
|
||||
r = client.chat.completions.create(
|
||||
model=together_model,
|
||||
messages=self._messages_to_together_messages(messages),
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r.choices[0].finish_reason:
|
||||
if (
|
||||
r.choices[0].finish_reason == "stop"
|
||||
or r.choices[0].finish_reason == "eos"
|
||||
):
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r.choices[0].finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r.choices[0].message.content, stop_reason
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Together
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
|
||||
for chunk in client.chat.completions.create(
|
||||
model=together_model,
|
||||
messages=self._messages_to_together_messages(messages),
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
if finish_reason := chunk.choices[0].finish_reason:
|
||||
if stop_reason is None and finish_reason in ["stop", "eos"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif stop_reason is None and finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Together
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
text = chunk.choices[0].delta.content
|
||||
if text is None:
|
||||
continue
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
|
@ -13,7 +12,6 @@ import chromadb
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
|
|
@ -65,7 +63,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class ChromaMemoryAdapter(Memory, RoutableProvider):
|
||||
class ChromaMemoryAdapter(Memory):
|
||||
def __init__(self, url: str) -> None:
|
||||
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||
url = url.rstrip("/")
|
||||
|
|
@ -93,48 +91,33 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
collection = await self.client.create_collection(
|
||||
name=bank_id,
|
||||
metadata={"bank": bank.json()},
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=memory_bank.identifier,
|
||||
)
|
||||
bank_index = BankWithIndex(
|
||||
bank=bank, index=ChromaIndex(self.client, collection)
|
||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||
)
|
||||
self.cache[bank_id] = bank_index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||
if bank_index is None:
|
||||
return None
|
||||
return bank_index.bank
|
||||
self.cache[memory_bank.identifier] = bank_index
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if bank is None:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
collections = await self.client.list_collections()
|
||||
for collection in collections:
|
||||
if collection.name == bank_id:
|
||||
print(collection.metadata)
|
||||
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=ChromaIndex(self.client, collection),
|
||||
|
|
|
|||
|
|
@ -4,18 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
|
||||
import psycopg2
|
||||
from numpy.typing import NDArray
|
||||
from psycopg2 import sql
|
||||
from psycopg2.extras import execute_values, Json
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
|
|
@ -32,33 +28,6 @@ def check_extension_version(cur):
|
|||
return result[0] if result else None
|
||||
|
||||
|
||||
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
|
||||
query = sql.SQL(
|
||||
"""
|
||||
INSERT INTO metadata_store (key, data)
|
||||
VALUES %s
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET data = EXCLUDED.data
|
||||
"""
|
||||
)
|
||||
|
||||
values = [(key, Json(model.dict())) for key, model in keys_models]
|
||||
execute_values(cur, query, values, template="(%s, %s)")
|
||||
|
||||
|
||||
def load_models(cur, keys: List[str], cls):
|
||||
query = "SELECT key, data FROM metadata_store"
|
||||
if keys:
|
||||
placeholders = ",".join(["%s"] * len(keys))
|
||||
query += f" WHERE key IN ({placeholders})"
|
||||
cur.execute(query, keys)
|
||||
else:
|
||||
cur.execute(query)
|
||||
|
||||
rows = cur.fetchall()
|
||||
return [cls(**row["data"]) for row in rows]
|
||||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, bank: MemoryBank, dimension: int, cursor):
|
||||
self.cursor = cursor
|
||||
|
|
@ -119,7 +88,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
||||
class PGVectorMemoryAdapter(Memory):
|
||||
def __init__(self, config: PGVectorConfig) -> None:
|
||||
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
||||
self.config = config
|
||||
|
|
@ -144,14 +113,6 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
|||
else:
|
||||
raise RuntimeError("Vector extension is not installed.")
|
||||
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS metadata_store (
|
||||
key TEXT PRIMARY KEY,
|
||||
data JSONB
|
||||
)
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
|
|
@ -161,51 +122,28 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
upsert_models(
|
||||
self.cursor,
|
||||
[
|
||||
(bank.bank_id, bank),
|
||||
],
|
||||
)
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
bank=memory_bank,
|
||||
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||
if bank_index is None:
|
||||
return None
|
||||
return bank_index.bank
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
banks = load_models(self.cursor, [bank_id], MemoryBank)
|
||||
if not banks:
|
||||
return None
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if not bank:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
bank = banks[0]
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
|
|
|
|||
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
|||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
|
||||
class SampleMemoryImpl(Memory, RoutableProvider):
|
||||
class SampleMemoryImpl(Memory):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||
# these are the memory banks the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
|
|
|||
15
llama_stack/providers/adapters/memory/weaviate/__init__.py
Normal file
15
llama_stack/providers/adapters/memory/weaviate/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
||||
from .weaviate import WeaviateMemoryAdapter
|
||||
|
||||
impl = WeaviateMemoryAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
16
llama_stack/providers/adapters/memory/weaviate/config.py
Normal file
16
llama_stack/providers/adapters/memory/weaviate/config.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WeaviateRequestProviderData(BaseModel):
|
||||
weaviate_api_key: str
|
||||
weaviate_cluster_url: str
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
pass
|
||||
180
llama_stack/providers/adapters/memory/weaviate/weaviate.py
Normal file
180
llama_stack/providers/adapters/memory/weaviate/weaviate.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import weaviate
|
||||
import weaviate.classes as wvc
|
||||
from numpy.typing import NDArray
|
||||
from weaviate.classes.init import Auth
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(self, client: weaviate.Client, collection_name: str):
|
||||
self.client = client
|
||||
self.collection_name = collection_name
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
data_objects = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
data_objects.append(
|
||||
wvc.data.DataObject(
|
||||
properties={
|
||||
"chunk_content": chunk.json(),
|
||||
},
|
||||
vector=embeddings[i].tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
# Inserting chunks into a prespecified Weaviate collection
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
|
||||
# TODO: make this async friendly
|
||||
collection.data.insert_many(data_objects)
|
||||
|
||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
|
||||
results = collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||
)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc in results.objects:
|
||||
chunk_json = doc.properties["chunk_content"]
|
||||
try:
|
||||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(1.0 / doc.metadata.distance)
|
||||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData):
|
||||
def __init__(self, config: WeaviateConfig) -> None:
|
||||
self.config = config
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
provider_data = self.get_request_provider_data()
|
||||
assert provider_data is not None, "Request provider data must be set"
|
||||
assert isinstance(provider_data, WeaviateRequestProviderData)
|
||||
|
||||
key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}"
|
||||
if key in self.client_cache:
|
||||
return self.client_cache[key]
|
||||
|
||||
client = weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url=provider_data.weaviate_cluster_url,
|
||||
auth_credentials=Auth.api_key(provider_data.weaviate_api_key),
|
||||
)
|
||||
self.client_cache[key] = client
|
||||
return client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for client in self.client_cache.values():
|
||||
client.close()
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
# Create collection if it doesn't exist
|
||||
if not client.collections.exists(memory_bank.identifier):
|
||||
client.collections.create(
|
||||
name=memory_bank.identifier,
|
||||
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||
properties=[
|
||||
wvc.config.Property(
|
||||
name="chunk_content",
|
||||
data_type=wvc.config.DataType.TEXT,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if not bank:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
client = self._get_client()
|
||||
if not client.collections.exists(bank_id):
|
||||
raise ValueError(f"Collection with name `{bank_id}` not found")
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
return await index.query_documents(query, params)
|
||||
|
|
@ -7,14 +7,12 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
import traceback
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import boto3
|
||||
|
||||
from llama_stack.apis.safety import * # noqa
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
|
|
@ -22,16 +20,17 @@ from .config import BedrockSafetyConfig
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SUPPORTED_SHIELD_TYPES = [
|
||||
"bedrock_guardrail",
|
||||
BEDROCK_SUPPORTED_SHIELDS = [
|
||||
ShieldType.generic_content_shield.value,
|
||||
]
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, RoutableProvider):
|
||||
class BedrockSafetyAdapter(Safety):
|
||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||
if not config.aws_profile:
|
||||
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
||||
self.config = config
|
||||
self.registered_shields = []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
|
|
@ -45,16 +44,27 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
for key in routing_keys:
|
||||
if key not in SUPPORTED_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type not in BEDROCK_SUPPORTED_SHIELDS:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
shield_params = shield.params
|
||||
if "guardrailIdentifier" not in shield_params:
|
||||
raise ValueError(
|
||||
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
||||
)
|
||||
|
||||
if "guardrailVersion" not in shield_params:
|
||||
raise ValueError(
|
||||
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
||||
)
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
if shield_type not in SUPPORTED_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
|
|
@ -69,52 +79,38 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
|
|||
|
||||
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"run_shield::{params}::messages={messages}")
|
||||
if "guardrailIdentifier" not in params:
|
||||
raise RuntimeError(
|
||||
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
||||
)
|
||||
|
||||
if "guardrailVersion" not in params:
|
||||
raise RuntimeError(
|
||||
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
||||
)
|
||||
shield_params = shield_def.params
|
||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(
|
||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||
)
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(
|
||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||
)
|
||||
|
||||
response = self.boto_client.apply_guardrail(
|
||||
guardrailIdentifier=params.get("guardrailIdentifier"),
|
||||
guardrailVersion=params.get("guardrailVersion"),
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
content=content_messages,
|
||||
)
|
||||
logger.debug(f"run_shield:: response: {response}::")
|
||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
for output in response["outputs"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
user_message = output["text"]
|
||||
for assessment in response["assessments"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
return SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
response = self.boto_client.apply_guardrail(
|
||||
guardrailIdentifier=shield_params["guardrailIdentifier"],
|
||||
guardrailVersion=shield_params["guardrailVersion"],
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
content=content_messages,
|
||||
)
|
||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
for output in response["outputs"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
user_message = output["text"]
|
||||
for assessment in response["assessments"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
|
||||
except Exception:
|
||||
error_str = traceback.format_exc()
|
||||
logger.error(
|
||||
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
|
||||
return SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
|||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
|
||||
class SampleSafetyImpl(Safety, RoutableProvider):
|
||||
class SampleSafetyImpl(Safety):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
# these are the safety shields the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -6,26 +6,20 @@
|
|||
from together import Together
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
|
||||
from .config import TogetherSafetyConfig
|
||||
|
||||
|
||||
SAFETY_SHIELD_TYPES = {
|
||||
TOGETHER_SHIELD_MODEL_MAP = {
|
||||
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
}
|
||||
|
||||
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
|
@ -35,16 +29,20 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
for key in routing_keys:
|
||||
if key not in SAFETY_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type != ShieldType.llama_guard.value:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
if shield_type not in SAFETY_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
model = shield_def.params.get("model", "llama_guard")
|
||||
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
||||
raise ValueError(f"Unsupported safety model: {model}")
|
||||
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
|
|
@ -57,8 +55,6 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
|||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
model_name = SAFETY_SHIELD_TYPES[shield_type]
|
||||
|
||||
# messages can have role assistant or user
|
||||
api_messages = []
|
||||
for message in messages:
|
||||
|
|
@ -66,7 +62,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
|||
api_messages.append({"role": message.role, "content": message.content})
|
||||
|
||||
violation = await get_safety_response(
|
||||
together_api_key, model_name, api_messages
|
||||
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
|
|
|
|||
|
|
@ -43,23 +43,14 @@ class ProviderSpec(BaseModel):
|
|||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
deps__: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RoutingTable(Protocol):
|
||||
def get_routing_keys(self) -> List[str]: ...
|
||||
|
||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
class RoutableProvider(Protocol):
|
||||
"""
|
||||
A provider which sits behind the RoutingTable and can get routed to.
|
||||
|
||||
All Inference / Safety / Memory providers fall into this bucket.
|
||||
"""
|
||||
|
||||
async def validate_routing_keys(self, keys: List[str]) -> None: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_type: str = Field(
|
||||
|
|
@ -156,6 +147,10 @@ as being "Llama Stack compatible"
|
|||
return None
|
||||
|
||||
|
||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||
|
||||
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
|
|
|
|||
|
|
@ -144,6 +144,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
|
@ -635,14 +637,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
if session_info.memory_bank_id is None:
|
||||
memory_bank = await self.memory_api.create_memory_bank(
|
||||
name=f"memory_bank_{session_id}",
|
||||
config=VectorMemoryBankConfig(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
memory_bank = VectorMemoryBankDef(
|
||||
identifier=bank_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
bank_id = memory_bank.bank_id
|
||||
await self.memory_api.register_memory_bank(memory_bank)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
bank_id = session_info.memory_bank_id
|
||||
|
|
@ -673,7 +674,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _retrieve_context(
|
||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
||||
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
|
||||
) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids)
|
||||
bank_ids = []
|
||||
|
||||
memory = self._memory_tool_definition()
|
||||
|
|
@ -722,12 +723,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
if not chunks:
|
||||
return None, bank_ids
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
if not chunks:
|
||||
return None, bank_ids
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def create_agent_turn(
|
||||
def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
|
|
@ -113,16 +113,22 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
attachments: Optional[List[Attachment]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(agent_id)
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
attachments=attachments,
|
||||
stream=stream,
|
||||
stream=True,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _create_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import CodeShieldConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeShieldConfig, deps):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||
def __init__(self, config: CodeScannerConfig, deps) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type != ShieldType.code_scanner.value:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
|
||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||
result = await CodeShield.scan_code(text)
|
||||
|
||||
violation = None
|
||||
if result.is_insecure:
|
||||
violation = SafetyViolation(
|
||||
violation_level=(ViolationLevel.ERROR),
|
||||
user_message="Sorry, I found security concerns in the code.",
|
||||
metadata={
|
||||
"violation_type": ",".join(
|
||||
[issue.pattern_id for issue in result.issues_found]
|
||||
)
|
||||
},
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeShieldConfig(BaseModel):
|
||||
pass
|
||||
|
|
@ -43,13 +43,12 @@ class MetaReferenceEvalsImpl(Evals):
|
|||
print("generation start")
|
||||
for msg in x1[:5]:
|
||||
print("generation for msg: ", msg)
|
||||
response = self.inference_api.chat_completion(
|
||||
response = await self.inference_api.chat_completion(
|
||||
model=model,
|
||||
messages=[msg],
|
||||
stream=False,
|
||||
)
|
||||
async for x in response:
|
||||
generation_outputs.append(x.completion_message.content)
|
||||
generation_outputs.append(response.completion_message.content)
|
||||
|
||||
x2 = task_impl.postprocess(generation_outputs)
|
||||
eval_results = task_impl.score(x2)
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ class Llama:
|
|||
token=next_token[0].item(),
|
||||
text=self.tokenizer.decode(next_token.tolist()),
|
||||
logprobs=(
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
|
||||
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
|
||||
if logprobs
|
||||
else None
|
||||
),
|
||||
|
|
|
|||
|
|
@ -6,15 +6,14 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncIterator, List, Union
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
|
@ -25,7 +24,7 @@ from .model_parallel import LlamaModelParallelGenerator
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||
class MetaReferenceInferenceImpl(Inference):
|
||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
|
|
@ -35,21 +34,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print(f"Loading model `{self.model.descriptor()}`")
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
assert (
|
||||
len(routing_keys) == 1
|
||||
), f"Only one routing key is supported {routing_keys}"
|
||||
assert routing_keys[0] == self.config.model
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
if model.identifier != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
||||
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
||||
# top-level server is going to do. make the typing more specific here
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -59,9 +57,10 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncIterator[
|
||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||
]:
|
||||
) -> AsyncGenerator:
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
|
|
@ -74,7 +73,6 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
|
|
@ -88,21 +86,74 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
async with SEMAPHORE:
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
||||
stop_reason = None
|
||||
|
||||
buffer = ""
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
async with SEMAPHORE:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
|
|
@ -113,10 +164,9 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
buffer += token_result.text
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
@ -127,13 +177,6 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
),
|
||||
)
|
||||
)
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
if not request.stream:
|
||||
if request.logprobs:
|
||||
logprobs.append(token_result.logprob)
|
||||
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
|
|
@ -154,59 +197,61 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
delta = text
|
||||
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
# TODO(ashwin): parse tool calls separately here and report errors?
|
||||
# if someone breaks the iteration before coming here we are toast
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
if request.stream:
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
||||
else:
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,15 +13,15 @@ from typing import Optional
|
|||
import torch
|
||||
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from llama_models.llama3.api.model import Transformer, TransformerBlock
|
||||
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from termcolor import cprint
|
||||
from torch import Tensor
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.apis.inference.config import (
|
||||
CheckpointQuantizationFormat,
|
||||
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||
MetaReferenceImplConfig,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
|
@ -14,7 +13,6 @@ import numpy as np
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
|
@ -63,7 +61,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class FaissMemoryImpl(Memory, RoutableProvider):
|
||||
class FaissMemoryImpl(Memory):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
|
|
@ -72,37 +70,18 @@ class FaissMemoryImpl(Memory, RoutableProvider):
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
assert url is None, "URL is not supported for this implementation"
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
config.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {config.type}"
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
||||
)
|
||||
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
|
||||
self.cache[bank_id] = index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
index = self.cache.get(bank_id)
|
||||
if index is None:
|
||||
return None
|
||||
return index.bank
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ def message_content_as_str(message: Message) -> str:
|
|||
return interleaved_text_media_as_str(message.content)
|
||||
|
||||
|
||||
# For shields that operate on simple strings
|
||||
class TextShield(ShieldBase):
|
||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||
return "\n".join([message_content_as_str(m) for m in messages])
|
||||
|
|
@ -56,9 +55,3 @@ class TextShield(ShieldBase):
|
|||
@abstractmethod
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DummyShield(TextShield):
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
# Dummy return LOW to test e2e
|
||||
return ShieldResponse(is_violation=False)
|
||||
|
|
@ -9,23 +9,19 @@ from typing import List, Optional
|
|||
|
||||
from llama_models.sku_list import CoreModelId, safety_models
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class MetaReferenceShieldType(Enum):
|
||||
llama_guard = "llama_guard"
|
||||
code_scanner_guard = "code_scanner_guard"
|
||||
injection_shield = "injection_shield"
|
||||
jailbreak_shield = "jailbreak_shield"
|
||||
class PromptGuardType(Enum):
|
||||
injection = "injection"
|
||||
jailbreak = "jailbreak"
|
||||
|
||||
|
||||
class LlamaGuardShieldConfig(BaseModel):
|
||||
model: str = "Llama-Guard-3-1B"
|
||||
excluded_categories: List[str] = []
|
||||
disable_input_check: bool = False
|
||||
disable_output_check: bool = False
|
||||
|
||||
@validator("model")
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
|
|
@ -47,10 +43,6 @@ class LlamaGuardShieldConfig(BaseModel):
|
|||
return model
|
||||
|
||||
|
||||
class PromptGuardShieldConfig(BaseModel):
|
||||
model: str = "Prompt-Guard-86M"
|
||||
|
||||
|
||||
class SafetyConfig(BaseModel):
|
||||
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
||||
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None
|
||||
enable_prompt_guard: Optional[bool] = False
|
||||
|
|
|
|||
|
|
@ -113,8 +113,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(on_violation_action)
|
||||
|
|
@ -132,8 +130,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
self.model = model
|
||||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
self.disable_input_check = disable_input_check
|
||||
self.disable_output_check = disable_output_check
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
|
|
@ -180,12 +176,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||
return ShieldResponse(
|
||||
is_violation=False,
|
||||
)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
shield_input_message = self.build_vision_shield_input(messages)
|
||||
|
|
@ -6,56 +6,43 @@
|
|||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api, RoutableProvider
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceShieldType, SafetyConfig
|
||||
|
||||
from .shields import (
|
||||
CodeScannerShield,
|
||||
InjectionShield,
|
||||
JailbreakShield,
|
||||
LlamaGuardShield,
|
||||
PromptGuardShield,
|
||||
ShieldBase,
|
||||
)
|
||||
from .base import OnViolationAction, ShieldBase
|
||||
from .config import SafetyConfig
|
||||
from .llama_guard import LlamaGuardShield
|
||||
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
|
||||
|
||||
|
||||
def resolve_and_get_path(model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert model is not None, f"Could not resolve model {model_name}"
|
||||
model_dir = model_local_dir(model.descriptor())
|
||||
return model_dir
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||
class MetaReferenceSafetyImpl(Safety):
|
||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
self.available_shields = []
|
||||
if config.llama_guard_shield:
|
||||
self.available_shields.append(ShieldType.llama_guard.value)
|
||||
if config.enable_prompt_guard:
|
||||
self.available_shields.append(ShieldType.prompt_guard.value)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
shield_cfg = self.config.prompt_guard_shield
|
||||
if shield_cfg is not None:
|
||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||
if self.config.enable_prompt_guard:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
_ = PromptGuardShield.instance(model_dir)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||
for key in routing_keys:
|
||||
if key not in available_shields:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type not in self.available_shields:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
@ -63,10 +50,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
|||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||
assert shield_type in available_shields, f"Unknown shield {shield_type}"
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type))
|
||||
shield = self.get_shield_impl(shield_def)
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
|
|
@ -92,34 +80,22 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
|||
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||
cfg = self.config
|
||||
if typ == MetaReferenceShieldType.llama_guard:
|
||||
cfg = cfg.llama_guard_shield
|
||||
assert (
|
||||
cfg is not None
|
||||
), "Cannot use LlamaGuardShield since not present in config"
|
||||
|
||||
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
|
||||
if shield.type == ShieldType.llama_guard.value:
|
||||
cfg = self.config.llama_guard_shield
|
||||
return LlamaGuardShield(
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
disable_input_check=cfg.disable_input_check,
|
||||
disable_output_check=cfg.disable_output_check,
|
||||
)
|
||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||
assert (
|
||||
cfg.prompt_guard_shield is not None
|
||||
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
||||
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
||||
return JailbreakShield.instance(model_dir)
|
||||
elif typ == MetaReferenceShieldType.injection_shield:
|
||||
assert (
|
||||
cfg.prompt_guard_shield is not None
|
||||
), "Cannot use PromptGuardShield since not present in config"
|
||||
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
||||
return InjectionShield.instance(model_dir)
|
||||
elif typ == MetaReferenceShieldType.code_scanner_guard:
|
||||
return CodeScannerShield.instance()
|
||||
elif shield.type == ShieldType.prompt_guard.value:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
subtype = shield.params.get("prompt_guard_type", "injection")
|
||||
if subtype == "injection":
|
||||
return InjectionShield.instance(model_dir)
|
||||
elif subtype == "jailbreak":
|
||||
return JailbreakShield.instance(model_dir)
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
||||
else:
|
||||
raise ValueError(f"Unknown shield type: {typ}")
|
||||
raise ValueError(f"Unknown shield type: {shield.type}")
|
||||
|
|
|
|||
|
|
@ -1,33 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# supress warnings and spew of logs from hugging face
|
||||
import transformers
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
DummyShield,
|
||||
OnViolationAction,
|
||||
ShieldBase,
|
||||
ShieldResponse,
|
||||
TextShield,
|
||||
)
|
||||
from .code_scanner import CodeScannerShield # noqa: F401
|
||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
||||
from .prompt_guard import ( # noqa: F401
|
||||
InjectionShield,
|
||||
JailbreakShield,
|
||||
PromptGuardShield,
|
||||
)
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
import os
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from .base import ShieldResponse, TextShield
|
||||
|
||||
|
||||
class CodeScannerShield(TextShield):
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||
result = await CodeShield.scan_code(text)
|
||||
if result.is_insecure:
|
||||
return ShieldResponse(
|
||||
is_violation=True,
|
||||
violation_type=",".join(
|
||||
[issue.pattern_id for issue in result.issues_found]
|
||||
),
|
||||
violation_return_message="Sorry, I found security concerns in the code.",
|
||||
)
|
||||
else:
|
||||
return ShieldResponse(is_violation=False)
|
||||
11
llama_stack/providers/impls/vllm/__init__.py
Normal file
11
llama_stack/providers/impls/vllm/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from typing import Any
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
35
llama_stack/providers/impls/vllm/config.py
Normal file
35
llama_stack/providers/impls/vllm/config.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMConfig(BaseModel):
|
||||
"""Configuration for the vLLM inference provider."""
|
||||
|
||||
model: str = Field(
|
||||
default="Llama3.1-8B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
tensor_parallel_size: int = Field(
|
||||
default=1,
|
||||
description="Number of tensor parallel replicas (number of GPUs to use).",
|
||||
)
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = supported_inference_models()
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
raise ValueError(
|
||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
||||
)
|
||||
return model
|
||||
241
llama_stack/providers/impls/vllm/vllm.py
Normal file
241
llama_stack/providers/impls/vllm/vllm.py
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def _vllm_sampling_params(sampling_params: Any) -> SamplingParams:
|
||||
"""Convert sampling params to vLLM sampling params."""
|
||||
if sampling_params is None:
|
||||
return SamplingParams()
|
||||
|
||||
# TODO convert what I saw in my first test ... but surely there's more to do here
|
||||
kwargs = {
|
||||
"temperature": sampling_params.temperature,
|
||||
}
|
||||
if sampling_params.top_k >= 1:
|
||||
kwargs["top_k"] = sampling_params.top_k
|
||||
if sampling_params.top_p:
|
||||
kwargs["top_p"] = sampling_params.top_p
|
||||
if sampling_params.max_tokens >= 1:
|
||||
kwargs["max_tokens"] = sampling_params.max_tokens
|
||||
if sampling_params.repetition_penalty > 0:
|
||||
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
return SamplingParams(**kwargs)
|
||||
|
||||
|
||||
class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
||||
"""Inference implementation for vLLM."""
|
||||
|
||||
HF_MODEL_MAPPINGS = {
|
||||
# TODO: seems like we should be able to build this table dynamically ...
|
||||
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
||||
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
||||
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
|
||||
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
|
||||
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
|
||||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
|
||||
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
|
||||
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
|
||||
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
|
||||
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
|
||||
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
|
||||
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
|
||||
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
|
||||
}
|
||||
|
||||
def __init__(self, config: VLLMConfig):
|
||||
Inference.__init__(self)
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
stack_to_provider_models_map=self.HF_MODEL_MAPPINGS,
|
||||
)
|
||||
self.config = config
|
||||
self.engine = None
|
||||
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the vLLM inference adapter."""
|
||||
|
||||
log.info("Initializing vLLM inference adapter")
|
||||
|
||||
# Disable usage stats reporting. This would be a surprising thing for most
|
||||
# people to find out was on by default.
|
||||
# https://docs.vllm.ai/en/latest/serving/usage_stats.html
|
||||
if "VLLM_NO_USAGE_STATS" not in os.environ:
|
||||
os.environ["VLLM_NO_USAGE_STATS"] = "1"
|
||||
|
||||
hf_model = self.HF_MODEL_MAPPINGS.get(self.config.model)
|
||||
|
||||
# TODO -- there are a ton of options supported here ...
|
||||
engine_args = AsyncEngineArgs()
|
||||
engine_args.model = hf_model
|
||||
# We will need a new config item for this in the future if model support is more broad
|
||||
# than it is today (llama only)
|
||||
engine_args.tokenizer = hf_model
|
||||
engine_args.tensor_parallel_size = self.config.tensor_parallel_size
|
||||
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the vLLM inference adapter."""
|
||||
log.info("Shutting down vLLM inference adapter")
|
||||
if self.engine:
|
||||
self.engine.shutdown_background_loop()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Any | None = ...,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||
log.info("vLLM completion")
|
||||
messages = [UserMessage(content=content)]
|
||||
return self.chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[Message],
|
||||
sampling_params: Any | None = ...,
|
||||
tools: list[ToolDefinition] | None = ...,
|
||||
tool_choice: ToolChoice | None = ...,
|
||||
tool_prompt_format: ToolPromptFormat | None = ...,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
log.info("vLLM chat completion")
|
||||
|
||||
assert self.engine is not None
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
log.info("Sampling params: %s", sampling_params)
|
||||
request_id = _random_uuid()
|
||||
|
||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||
vllm_sampling_params = _vllm_sampling_params(request.sampling_params)
|
||||
results_generator = self.engine.generate(
|
||||
prompt, vllm_sampling_params, request_id
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, results_generator)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, results_generator)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> ChatCompletionResponse:
|
||||
outputs = [o async for o in results_generator]
|
||||
final_output = outputs[-1]
|
||||
|
||||
assert final_output is not None
|
||||
outputs = final_output.outputs
|
||||
finish_reason = outputs[-1].stop_reason
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=finish_reason,
|
||||
text="".join([output.text for output in outputs]),
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(request, response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> AsyncGenerator:
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
async for chunk in results_generator:
|
||||
if not chunk.outputs:
|
||||
log.warning("Empty chunk received")
|
||||
continue
|
||||
|
||||
text = "".join([output.text for output in chunk.outputs])
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk.outputs[-1].stop_reason,
|
||||
text=text,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self, model: str, contents: list[InterleavedTextMedia]
|
||||
) -> EmbeddingsResponse:
|
||||
log.info("vLLM embeddings")
|
||||
# TODO
|
||||
raise NotImplementedError()
|
||||
|
|
@ -41,6 +41,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="ollama",
|
||||
pip_packages=["ollama"],
|
||||
config_class="llama_stack.providers.adapters.inference.ollama.OllamaImplConfig",
|
||||
module="llama_stack.providers.adapters.inference.ollama",
|
||||
),
|
||||
),
|
||||
|
|
@ -103,4 +104,24 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="databricks",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.adapters.inference.databricks",
|
||||
config_class="llama_stack.providers.adapters.inference.databricks.DatabricksImplConfig",
|
||||
),
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="vllm",
|
||||
pip_packages=[
|
||||
"vllm",
|
||||
],
|
||||
module="llama_stack.providers.impls.vllm",
|
||||
config_class="llama_stack.providers.impls.vllm.VLLMConfig",
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -56,6 +56,16 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
AdapterSpec(
|
||||
adapter_type="weaviate",
|
||||
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
||||
module="llama_stack.providers.adapters.memory.weaviate",
|
||||
config_class="llama_stack.providers.adapters.memory.weaviate.WeaviateConfig",
|
||||
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.memory,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.safety,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
"transformers",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
],
|
||||
|
|
@ -61,4 +60,14 @@ def available_providers() -> List[ProviderSpec]:
|
|||
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
||||
),
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="meta-reference/codeshield",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
],
|
||||
module="llama_stack.providers.impls.meta_reference.codeshield",
|
||||
config_class="llama_stack.providers.impls.meta_reference.codeshield.CodeShieldConfig",
|
||||
api_dependencies=[],
|
||||
),
|
||||
]
|
||||
|
|
|
|||
5
llama_stack/providers/tests/__init__.py
Normal file
5
llama_stack/providers/tests/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
5
llama_stack/providers/tests/agents/__init__.py
Normal file
5
llama_stack/providers/tests/agents/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
5
llama_stack/providers/tests/inference/__init__.py
Normal file
5
llama_stack/providers/tests/inference/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
providers:
|
||||
- provider_id: test-ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 11434
|
||||
- provider_id: test-tgi
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: http://localhost:7001
|
||||
- provider_id: test-remote
|
||||
provider_type: remote
|
||||
config:
|
||||
host: localhost
|
||||
port: 7002
|
||||
- provider_id: test-together
|
||||
provider_type: remote::together
|
||||
config: {}
|
||||
# if a provider needs private keys from the client, they use the
|
||||
# "get_request_provider_data" function (see distribution/request_headers.py)
|
||||
# this is a place to provide such data.
|
||||
provider_data:
|
||||
"test-together":
|
||||
together_api_key:
|
||||
0xdeadbeefputrealapikeyhere
|
||||
243
llama_stack/providers/tests/inference/test_inference.py
Normal file
243
llama_stack/providers/tests/inference/test_inference.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
||||
# since it depends on the provider you are testing. On top of that you need
|
||||
# `pytest` and `pytest-asyncio` installed.
|
||||
#
|
||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||
#
|
||||
# 3. Run:
|
||||
#
|
||||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/memory/test_inference.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
|
||||
|
||||
def group_chunks(response):
|
||||
return {
|
||||
event_type: list(group)
|
||||
for event_type, group in itertools.groupby(
|
||||
response, key=lambda chunk: chunk.event.event_type
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
Llama_8B = "Llama3.1-8B-Instruct"
|
||||
Llama_3B = "Llama3.2-3B-Instruct"
|
||||
|
||||
|
||||
def get_expected_stop_reason(model: str):
|
||||
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
||||
|
||||
|
||||
# This is going to create multiple Stack impls without tearing down the previous one
|
||||
# Fix that!
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session",
|
||||
params=[
|
||||
{"model": Llama_8B},
|
||||
{"model": Llama_3B},
|
||||
],
|
||||
ids=lambda d: d["model"],
|
||||
)
|
||||
async def inference_settings(request):
|
||||
model = request.param["model"]
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.inference,
|
||||
models=[
|
||||
ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return {
|
||||
"impl": impls[Api.inference],
|
||||
"common_params": {
|
||||
"model": model,
|
||||
"tool_choice": ToolChoice.auto,
|
||||
"tool_prompt_format": (
|
||||
ToolPromptFormat.json
|
||||
if "Llama3.1" in model
|
||||
else ToolPromptFormat.python_list
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages():
|
||||
return [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="What's the weather like today?"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tool_definition():
|
||||
return ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get the current weather",
|
||||
parameters={
|
||||
"location": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The city and state, e.g. San Francisco, CA",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
response = await inference_impl.chat_completion(
|
||||
messages=sample_messages,
|
||||
stream=False,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert response.completion_message.role == "assistant"
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
assert len(response.completion_message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
response = [
|
||||
r
|
||||
async for r in inference_impl.chat_completion(
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
|
||||
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
assert end.event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_calling(
|
||||
inference_settings,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl = inference_settings["impl"]
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
)
|
||||
]
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=False,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
|
||||
message = response.completion_message
|
||||
|
||||
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
||||
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
|
||||
# assert message.stop_reason == stop_reason
|
||||
assert message.tool_calls is not None
|
||||
assert len(message.tool_calls) > 0
|
||||
|
||||
call = message.tool_calls[0]
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_calling_streaming(
|
||||
inference_settings,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl = inference_settings["impl"]
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
)
|
||||
]
|
||||
|
||||
response = [
|
||||
r
|
||||
async for r in inference_impl.chat_completion(
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=True,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
|
||||
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
||||
# expected_stop_reason = get_expected_stop_reason(
|
||||
# inference_settings["common_params"]["model"]
|
||||
# )
|
||||
# end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
# assert end.event.stop_reason == expected_stop_reason
|
||||
|
||||
model = inference_settings["common_params"]["model"]
|
||||
if "Llama3.1" in model:
|
||||
assert all(
|
||||
isinstance(chunk.event.delta, ToolCallDelta)
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||
|
||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
# assert last.event.stop_reason == expected_stop_reason
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
||||
assert isinstance(last.event.delta.content, ToolCall)
|
||||
|
||||
call = last.event.delta.content
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
126
llama_stack/providers/tests/inference/test_prompt_adapter.py
Normal file
126
llama_stack/providers/tests/inference/test_prompt_adapter.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
from llama_models.llama3.api import * # noqa: F403
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages
|
||||
|
||||
MODEL = "Llama3.1-8B-Instruct"
|
||||
|
||||
|
||||
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_system_default(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
||||
async def test_system_builtin_only(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
||||
|
||||
async def test_system_custom_only(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 3)
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_system_custom_and_builtin(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
||||
|
||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_user_provided_system_message(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
5
llama_stack/providers/tests/memory/__init__.py
Normal file
5
llama_stack/providers/tests/memory/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
providers:
|
||||
- provider_id: test-faiss
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
- provider_id: test-chroma
|
||||
provider_type: remote::chroma
|
||||
config:
|
||||
host: localhost
|
||||
port: 6001
|
||||
- provider_id: test-remote
|
||||
provider_type: remote
|
||||
config:
|
||||
host: localhost
|
||||
port: 7002
|
||||
- provider_id: test-weaviate
|
||||
provider_type: remote::weaviate
|
||||
config: {}
|
||||
# if a provider needs private keys from the client, they use the
|
||||
# "get_request_provider_data" function (see distribution/request_headers.py)
|
||||
# this is a place to provide such data.
|
||||
provider_data:
|
||||
"test-weaviate":
|
||||
weaviate_api_key: 0xdeadbeefputrealapikeyhere
|
||||
weaviate_cluster_url: http://foobarbaz
|
||||
119
llama_stack/providers/tests/memory/test_memory.py
Normal file
119
llama_stack/providers/tests/memory/test_memory.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
||||
# since it depends on the provider you are testing. On top of that you need
|
||||
# `pytest` and `pytest-asyncio` installed.
|
||||
#
|
||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||
#
|
||||
# 3. Run:
|
||||
#
|
||||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/memory/test_memory.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def memory_impl():
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.memory,
|
||||
memory_banks=[],
|
||||
)
|
||||
return impls[Api.memory]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents():
|
||||
return [
|
||||
MemoryBankDocument(
|
||||
document_id="doc1",
|
||||
content="Python is a high-level programming language.",
|
||||
metadata={"category": "programming", "difficulty": "beginner"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="doc2",
|
||||
content="Machine learning is a subset of artificial intelligence.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="doc3",
|
||||
content="Data structures are fundamental to computer science.",
|
||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="doc4",
|
||||
content="Neural networks are inspired by biological neural networks.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def register_memory_bank(memory_impl: Memory):
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
await memory_impl.register_memory_bank(bank)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(memory_impl, sample_documents):
|
||||
with pytest.raises(ValueError):
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
await register_memory_bank(memory_impl)
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
query1 = "programming language"
|
||||
response1 = await memory_impl.query_documents("test_bank", query1)
|
||||
assert_valid_response(response1)
|
||||
assert any("Python" in chunk.content for chunk in response1.chunks)
|
||||
|
||||
# Test case 3: Query with semantic similarity
|
||||
query3 = "AI and brain-inspired computing"
|
||||
response3 = await memory_impl.query_documents("test_bank", query3)
|
||||
assert_valid_response(response3)
|
||||
assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks)
|
||||
|
||||
# Test case 4: Query with limit on number of results
|
||||
query4 = "computer"
|
||||
params4 = {"max_chunks": 2}
|
||||
response4 = await memory_impl.query_documents("test_bank", query4, params4)
|
||||
assert_valid_response(response4)
|
||||
assert len(response4.chunks) <= 2
|
||||
|
||||
# Test case 5: Query with threshold on similarity score
|
||||
query5 = "quantum computing" # Not directly related to any document
|
||||
params5 = {"score_threshold": 0.5}
|
||||
response5 = await memory_impl.query_documents("test_bank", query5, params5)
|
||||
assert_valid_response(response5)
|
||||
assert all(score >= 0.5 for score in response5.scores)
|
||||
|
||||
|
||||
def assert_valid_response(response: QueryDocumentsResponse):
|
||||
assert isinstance(response, QueryDocumentsResponse)
|
||||
assert len(response.chunks) > 0
|
||||
assert len(response.scores) > 0
|
||||
assert len(response.chunks) == len(response.scores)
|
||||
for chunk in response.chunks:
|
||||
assert isinstance(chunk.content, str)
|
||||
assert chunk.document_id is not None
|
||||
100
llama_stack/providers/tests/resolver.py
Normal file
100
llama_stack/providers/tests/resolver.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||
|
||||
|
||||
async def resolve_impls_for_test(
|
||||
api: Api,
|
||||
models: List[ModelDef] = None,
|
||||
memory_banks: List[MemoryBankDef] = None,
|
||||
shields: List[ShieldDef] = None,
|
||||
):
|
||||
if "PROVIDER_CONFIG" not in os.environ:
|
||||
raise ValueError(
|
||||
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
|
||||
)
|
||||
|
||||
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
if "providers" not in config_dict:
|
||||
raise ValueError("Config file should contain a `providers` key")
|
||||
|
||||
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
|
||||
if len(providers_by_id) == 0:
|
||||
raise ValueError("No providers found in config file")
|
||||
|
||||
if "PROVIDER_ID" in os.environ:
|
||||
provider_id = os.environ["PROVIDER_ID"]
|
||||
if provider_id not in providers_by_id:
|
||||
raise ValueError(f"Provider ID {provider_id} not found in config file")
|
||||
provider = providers_by_id[provider_id]
|
||||
else:
|
||||
provider = list(providers_by_id.values())[0]
|
||||
provider_id = provider["provider_id"]
|
||||
print(f"No provider ID specified, picking first `{provider_id}`")
|
||||
|
||||
models = models or []
|
||||
shields = shields or []
|
||||
memory_banks = memory_banks or []
|
||||
|
||||
models = [
|
||||
ModelDef(
|
||||
**{
|
||||
**m.dict(),
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
for m in models
|
||||
]
|
||||
shields = [
|
||||
ShieldDef(
|
||||
**{
|
||||
**s.dict(),
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
for s in shields
|
||||
]
|
||||
memory_banks = [
|
||||
MemoryBankDef(
|
||||
**{
|
||||
**m.dict(),
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
for m in memory_banks
|
||||
]
|
||||
run_config = dict(
|
||||
built_at=datetime.now(),
|
||||
image_name="test-fixture",
|
||||
apis=[api],
|
||||
providers={api.value: [Provider(**provider)]},
|
||||
models=models,
|
||||
memory_banks=memory_banks,
|
||||
shields=shields,
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
impls = await resolve_impls_with_routing(run_config)
|
||||
|
||||
if "provider_data" in config_dict:
|
||||
provider_data = config_dict["provider_data"].get(provider_id, {})
|
||||
if provider_data:
|
||||
set_request_provider_data(
|
||||
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
||||
)
|
||||
|
||||
return impls
|
||||
35
llama_stack/providers/utils/inference/model_registry.py
Normal file
35
llama_stack/providers/utils/inference/model_registry.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
|
||||
class ModelRegistryHelper:
|
||||
|
||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||
|
||||
def map_to_provider_model(self, identifier: str) -> str:
|
||||
model = resolve_model(identifier)
|
||||
if not model:
|
||||
raise ValueError(f"Unknown model: `{identifier}`")
|
||||
|
||||
if identifier not in self.stack_to_provider_models_map:
|
||||
raise ValueError(
|
||||
f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
|
||||
)
|
||||
|
||||
return self.stack_to_provider_models_map[identifier]
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
if model.identifier not in self.stack_to_provider_models_map:
|
||||
raise ValueError(
|
||||
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
||||
)
|
||||
189
llama_stack/providers/utils/inference/openai_compat.py
Normal file
189
llama_stack/providers/utils/inference/openai_compat.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class OpenAICompatCompletionChoice(BaseModel):
|
||||
finish_reason: Optional[str] = None
|
||||
text: Optional[str] = None
|
||||
delta: Optional[OpenAICompatCompletionChoiceDelta] = None
|
||||
|
||||
|
||||
class OpenAICompatCompletionResponse(BaseModel):
|
||||
choices: List[OpenAICompatCompletionChoice]
|
||||
|
||||
|
||||
def get_sampling_options(request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if params := request.sampling_params:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(params, attr):
|
||||
options[attr] = getattr(params, attr)
|
||||
|
||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||
options["repeat_penalty"] = params.repetition_penalty
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def text_from_choice(choice) -> str:
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
return choice.delta.content
|
||||
|
||||
return choice.text
|
||||
|
||||
|
||||
def process_chat_completion_response(
|
||||
request: ChatCompletionRequest,
|
||||
response: OpenAICompatCompletionResponse,
|
||||
formatter: ChatFormat,
|
||||
) -> ChatCompletionResponse:
|
||||
choice = response.choices[0]
|
||||
|
||||
stop_reason = None
|
||||
if reason := choice.finish_reason:
|
||||
if reason in ["stop", "eos"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif reason == "eom":
|
||||
stop_reason = StopReason.end_of_message
|
||||
elif reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = formatter.decode_assistant_message_from_content(
|
||||
text_from_choice(choice), stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
|
||||
|
||||
async def process_chat_completion_stream_response(
|
||||
request: ChatCompletionRequest,
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
formatter: ChatFormat,
|
||||
) -> AsyncGenerator:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
if finish_reason:
|
||||
if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif stop_reason is None and finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = text_from_choice(choice)
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
if ipython:
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
@ -3,7 +3,11 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Tuple
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_models.datatypes import ModelFamily
|
||||
|
|
@ -19,7 +23,28 @@ from llama_models.sku_list import resolve_model
|
|||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
|
||||
def chat_completion_request_to_prompt(
|
||||
request: ChatCompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
model_input = formatter.encode_dialog_prompt(messages)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
def chat_completion_request_to_model_input_info(
|
||||
request: ChatCompletionRequest, formatter: ChatFormat
|
||||
) -> Tuple[str, int]:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
model_input = formatter.encode_dialog_prompt(messages)
|
||||
return (
|
||||
formatter.tokenizer.decode(model_input.tokens),
|
||||
len(model_input.tokens),
|
||||
)
|
||||
|
||||
|
||||
def chat_completion_request_to_messages(
|
||||
request: ChatCompletionRequest,
|
||||
) -> List[Message]:
|
||||
"""Reads chat completion request and augments the messages to handle tools.
|
||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||
add user messsage for custom tools, etc.
|
||||
|
|
@ -48,7 +73,6 @@ def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
|
|||
def augment_messages_for_tools_llama_3_1(
|
||||
request: ChatCompletionRequest,
|
||||
) -> List[Message]:
|
||||
|
||||
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||
|
||||
existing_messages = request.messages
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
|
||||
class RoutableProviderForModels(RoutableProvider):
|
||||
|
||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]):
|
||||
for routing_key in routing_keys:
|
||||
if routing_key not in self.stack_to_provider_models_map:
|
||||
raise ValueError(
|
||||
f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}"
|
||||
)
|
||||
|
||||
def map_to_provider_model(self, routing_key: str) -> str:
|
||||
model = resolve_model(routing_key)
|
||||
if not model:
|
||||
raise ValueError(f"Unknown model: `{routing_key}`")
|
||||
|
||||
if routing_key not in self.stack_to_provider_models_map:
|
||||
raise ValueError(
|
||||
f"Model {routing_key} not found in map {self.stack_to_provider_models_map}"
|
||||
)
|
||||
|
||||
return self.stack_to_provider_models_map[routing_key]
|
||||
|
|
@ -31,6 +31,10 @@ class RedisKVStoreConfig(CommonConfig):
|
|||
host: str = "localhost"
|
||||
port: int = 6379
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"redis://{self.host}:{self.port}"
|
||||
|
||||
|
||||
class SqliteKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
|
||||
|
|
|
|||
|
|
@ -146,22 +146,22 @@ class EmbeddingIndex(ABC):
|
|||
|
||||
@dataclass
|
||||
class BankWithIndex:
|
||||
bank: MemoryBank
|
||||
bank: MemoryBankDef
|
||||
index: EmbeddingIndex
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
model = get_embedding_model(self.bank.config.embedding_model)
|
||||
model = get_embedding_model(self.bank.embedding_model)
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks = make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
content,
|
||||
self.bank.config.chunk_size_in_tokens,
|
||||
self.bank.config.overlap_size_in_tokens
|
||||
or (self.bank.config.chunk_size_in_tokens // 4),
|
||||
self.bank.chunk_size_in_tokens,
|
||||
self.bank.overlap_size_in_tokens
|
||||
or (self.bank.chunk_size_in_tokens // 4),
|
||||
)
|
||||
if not chunks:
|
||||
continue
|
||||
|
|
@ -189,6 +189,6 @@ class BankWithIndex:
|
|||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
model = get_embedding_model(self.bank.config.embedding_model)
|
||||
model = get_embedding_model(self.bank.embedding_model)
|
||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||
return await self.index.query(query_vector, k)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue