forked from phoenix-oss/llama-stack-mirror
impls
-> inline
, adapters
-> remote
(#381)
This commit is contained in:
parent
b10e9f46bb
commit
994732e2e0
169 changed files with 106 additions and 105 deletions
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleAgentsImpl
|
||||
|
||||
impl = SampleAgentsImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
|
||||
|
||||
class SampleAgentsImpl(Agents):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from .bedrock import BedrockInferenceAdapter
|
||||
from .config import BedrockConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BedrockConfig, _deps):
|
||||
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = BedrockInferenceAdapter(config)
|
||||
|
||||
await impl.initialize()
|
||||
|
||||
return impl
|
|
@ -1,439 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import * # noqa: F403
|
||||
|
||||
from botocore.client import BaseClient
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
||||
|
||||
BEDROCK_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
||||
}
|
||||
|
||||
|
||||
# NOTE: this is not quite tested after the recent refactors
|
||||
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
||||
)
|
||||
self._config = config
|
||||
|
||||
self._client = create_bedrock_client(config)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
||||
if bedrock_stop_reason == "max_tokens":
|
||||
return StopReason.out_of_tokens
|
||||
return StopReason.end_of_turn
|
||||
|
||||
@staticmethod
|
||||
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
||||
for builtin_tool in BuiltinTool:
|
||||
if builtin_tool.value == tool_name_str:
|
||||
return builtin_tool
|
||||
else:
|
||||
return tool_name_str
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
||||
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
converse_api_res["stopReason"]
|
||||
)
|
||||
|
||||
bedrock_message = converse_api_res["output"]["message"]
|
||||
|
||||
role = bedrock_message["role"]
|
||||
contents = bedrock_message["content"]
|
||||
|
||||
tool_calls = []
|
||||
text_content = []
|
||||
for content in contents:
|
||||
if "toolUse" in content:
|
||||
tool_use = content["toolUse"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
||||
tool_use["name"]
|
||||
),
|
||||
arguments=tool_use["input"] if "input" in tool_use else None,
|
||||
call_id=tool_use["toolUseId"],
|
||||
)
|
||||
)
|
||||
elif "text" in content:
|
||||
text_content.append(content["text"])
|
||||
|
||||
return CompletionMessage(
|
||||
role=role,
|
||||
content=text_content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _messages_to_bedrock_messages(
|
||||
messages: List[Message],
|
||||
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
||||
bedrock_messages = []
|
||||
system_bedrock_messages = []
|
||||
|
||||
user_contents = []
|
||||
assistant_contents = None
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content_list = (
|
||||
message.content
|
||||
if isinstance(message.content, list)
|
||||
else [message.content]
|
||||
)
|
||||
if role == "ipython" or role == "user":
|
||||
if not user_contents:
|
||||
user_contents = []
|
||||
|
||||
if role == "ipython":
|
||||
user_contents.extend(
|
||||
[
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message.call_id,
|
||||
"content": [
|
||||
{"text": content} for content in content_list
|
||||
],
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
user_contents.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
assistant_contents = None
|
||||
elif role == "system":
|
||||
system_bedrock_messages.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
elif role == "assistant":
|
||||
if not assistant_contents:
|
||||
assistant_contents = []
|
||||
|
||||
assistant_contents.extend(
|
||||
[
|
||||
{
|
||||
"text": content,
|
||||
}
|
||||
for content in content_list
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"toolUse": {
|
||||
"input": tool_call.arguments,
|
||||
"name": (
|
||||
tool_call.tool_name
|
||||
if isinstance(tool_call.tool_name, str)
|
||||
else tool_call.tool_name.value
|
||||
),
|
||||
"toolUseId": tool_call.call_id,
|
||||
}
|
||||
}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
)
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
user_contents = None
|
||||
else:
|
||||
# Unknown role
|
||||
pass
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
|
||||
if system_bedrock_messages:
|
||||
return bedrock_messages, system_bedrock_messages
|
||||
|
||||
return bedrock_messages, None
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
||||
inference_config = {}
|
||||
if sampling_params:
|
||||
param_mapping = {
|
||||
"max_tokens": "maxTokens",
|
||||
"temperature": "temperature",
|
||||
"top_p": "topP",
|
||||
}
|
||||
|
||||
for k, v in param_mapping.items():
|
||||
if getattr(sampling_params, k):
|
||||
inference_config[v] = getattr(sampling_params, k)
|
||||
|
||||
return inference_config
|
||||
|
||||
@staticmethod
|
||||
def _tool_parameters_to_input_schema(
|
||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]],
|
||||
) -> Dict:
|
||||
input_schema = {"type": "object"}
|
||||
if not tool_parameters:
|
||||
return input_schema
|
||||
|
||||
json_properties = {}
|
||||
required = []
|
||||
for name, param in tool_parameters.items():
|
||||
json_property = {
|
||||
"type": param.param_type,
|
||||
}
|
||||
|
||||
if param.description:
|
||||
json_property["description"] = param.description
|
||||
if param.required:
|
||||
required.append(name)
|
||||
json_properties[name] = json_property
|
||||
|
||||
input_schema["properties"] = json_properties
|
||||
if required:
|
||||
input_schema["required"] = required
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def _tools_to_tool_config(
|
||||
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
||||
) -> Optional[Dict]:
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
tool_name = (
|
||||
tool.tool_name
|
||||
if isinstance(tool.tool_name, str)
|
||||
else tool.tool_name.value
|
||||
)
|
||||
|
||||
tool_spec = {
|
||||
"toolSpec": {
|
||||
"name": tool_name,
|
||||
"inputSchema": {
|
||||
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
||||
tool.parameters
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if tool.description:
|
||||
tool_spec["toolSpec"]["description"] = tool.description
|
||||
|
||||
bedrock_tools.append(tool_spec)
|
||||
tool_config = {
|
||||
"tools": bedrock_tools,
|
||||
}
|
||||
|
||||
if tool_choice:
|
||||
tool_config["toolChoice"] = (
|
||||
{"any": {}}
|
||||
if tool_choice.value == ToolChoice.required
|
||||
else {"auto": {}}
|
||||
)
|
||||
return tool_config
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params_for_chat_completion(request)
|
||||
converse_api_res = self.client.converse(**params)
|
||||
|
||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||
converse_api_res
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=output_message,
|
||||
logprobs=None,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params_for_chat_completion(request)
|
||||
converse_stream_api_res = self.client.converse_stream(**params)
|
||||
event_stream = converse_stream_api_res["stream"]
|
||||
|
||||
for chunk in event_stream:
|
||||
if "messageStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
elif "contentBlockStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
|
||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||
"toolUseId"
|
||||
],
|
||||
),
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif "contentBlockDelta" in chunk:
|
||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||
else:
|
||||
delta = ToolCallDelta(
|
||||
content=ToolCall(
|
||||
arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
|
||||
"input"
|
||||
]
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
elif "contentBlockStop" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
elif "messageStop" in chunk:
|
||||
stop_reason = (
|
||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
chunk["messageStop"]["stopReason"]
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
elif "metadata" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
else:
|
||||
# Ignored
|
||||
pass
|
||||
|
||||
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
||||
bedrock_model = self.map_to_provider_model(request.model)
|
||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||
request.sampling_params
|
||||
)
|
||||
|
||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(
|
||||
request.tools, request.tool_choice
|
||||
)
|
||||
bedrock_messages, system_bedrock_messages = (
|
||||
BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
|
||||
)
|
||||
|
||||
converse_api_params = {
|
||||
"modelId": bedrock_model,
|
||||
"messages": bedrock_messages,
|
||||
}
|
||||
if inference_config:
|
||||
converse_api_params["inferenceConfig"] = inference_config
|
||||
|
||||
# Tool use is not supported in streaming mode
|
||||
if tool_config and not request.stream:
|
||||
converse_api_params["toolConfig"] = tool_config
|
||||
if system_bedrock_messages:
|
||||
converse_api_params["system"] = system_bedrock_messages
|
||||
|
||||
return converse_api_params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
|
@ -1,14 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BedrockConfig(BedrockBaseConfig):
|
||||
pass
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||
assert isinstance(
|
||||
config, DatabricksImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: str = Field(
|
||||
default=None,
|
||||
description="The Databricks API token",
|
||||
)
|
|
@ -1,127 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
|
||||
DATABRICKS_SUPPORTED_MODELS = {
|
||||
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
||||
}
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: FireworksImplConfig, _deps):
|
||||
from .fireworks import FireworksInferenceAdapter
|
||||
|
||||
assert isinstance(
|
||||
config, FireworksImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = FireworksInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FireworksImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="https://api.fireworks.ai/inference",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
api_key: str = Field(
|
||||
default="",
|
||||
description="The Fireworks.ai API Key",
|
||||
)
|
|
@ -1,212 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
||||
FIREWORKS_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
||||
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
|
||||
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
||||
"Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct",
|
||||
"Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct",
|
||||
}
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = CompletionRequest(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
client = Fireworks(api_key=self.config.api_key)
|
||||
if stream:
|
||||
return self._stream_completion(request, client)
|
||||
else:
|
||||
return await self._nonstream_completion(request, client)
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest, client: Fireworks
|
||||
) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await client.completion.acreate(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_completion(
|
||||
self, request: CompletionRequest, client: Fireworks
|
||||
) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = client.completion.acreate(**params)
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
yield chunk
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
client = Fireworks(api_key=self.config.api_key)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = await client.chat.completions.acreate(**params)
|
||||
else:
|
||||
r = await client.completion.acreate(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
if "messages" in params:
|
||||
stream = client.chat.completions.acreate(**params)
|
||||
else:
|
||||
stream = client.completion.acreate(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||
) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_dict(m) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
request, self.formatter
|
||||
)
|
||||
elif isinstance(request, CompletionRequest):
|
||||
assert (
|
||||
not media_present
|
||||
), "Fireworks does not support media for Completion requests"
|
||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||
else:
|
||||
raise ValueError(f"Unknown request type {type(request)}")
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if "prompt" in input_dict:
|
||||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||
|
||||
options = get_sampling_options(request.sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
options["response_format"] = {
|
||||
"type": "grammar",
|
||||
"grammar": fmt.bnf,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
|
@ -1,19 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
|
||||
class OllamaImplConfig(RemoteProviderConfig):
|
||||
port: int = 11434
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
impl = OllamaInferenceAdapter(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,299 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
OLLAMA_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
||||
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
||||
"Llama-Guard-3-8B": "llama-guard3:8b",
|
||||
"Llama-Guard-3-1B": "llama-guard3:1b",
|
||||
"Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16",
|
||||
}
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
return AsyncClient(host=self.url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print("Initializing Ollama, checking connectivity to server...")
|
||||
try:
|
||||
await self.client.ps()
|
||||
except httpx.ConnectError as e:
|
||||
raise RuntimeError(
|
||||
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
||||
) from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Dynamic model registration is not supported")
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
|
||||
|
||||
ret = []
|
||||
res = await self.client.ps()
|
||||
for r in res["models"]:
|
||||
if r["model"] not in ollama_to_llama:
|
||||
print(f"Ollama is running a model unknown to Llama Stack: {r['model']}")
|
||||
continue
|
||||
|
||||
llama_model = ollama_to_llama[r["model"]]
|
||||
ret.append(
|
||||
ModelDef(
|
||||
identifier=llama_model,
|
||||
llama_model=llama_model,
|
||||
metadata={
|
||||
"ollama_model": r["model"],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = CompletionRequest(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.generate(**params)
|
||||
async for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
text=chunk["response"],
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.generate(**params)
|
||||
assert isinstance(r, dict)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
text=r["response"],
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
return process_completion_response(response, self.formatter)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _get_params(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||
) -> dict:
|
||||
sampling_options = get_sampling_options(request.sampling_params)
|
||||
# This is needed since the Ollama API expects num_predict to be set
|
||||
# for early truncation instead of max_tokens.
|
||||
if sampling_options.get("max_tokens") is not None:
|
||||
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
||||
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
contents = [
|
||||
await convert_message_to_dict_for_ollama(m)
|
||||
for m in request.messages
|
||||
]
|
||||
# flatten the list of lists
|
||||
input_dict["messages"] = [
|
||||
item for sublist in contents for item in sublist
|
||||
]
|
||||
else:
|
||||
input_dict["raw"] = True
|
||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
request, self.formatter
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
not media_present
|
||||
), "Ollama does not support media for Completion requests"
|
||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||
input_dict["raw"] = True
|
||||
|
||||
return {
|
||||
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
||||
**input_dict,
|
||||
"options": sampling_options,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = await self.client.chat(**params)
|
||||
else:
|
||||
r = await self.client.generate(**params)
|
||||
assert isinstance(r, dict)
|
||||
|
||||
if "message" in r:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
text=r["message"]["content"],
|
||||
)
|
||||
else:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
text=r["response"],
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
if "messages" in params:
|
||||
s = await self.client.chat(**params)
|
||||
else:
|
||||
s = await self.client.generate(**params)
|
||||
async for chunk in s:
|
||||
if "message" in chunk:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
text=chunk["message"]["content"],
|
||||
)
|
||||
else:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
text=chunk["response"],
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
|
||||
async def _convert_content(content) -> dict:
|
||||
if isinstance(content, ImageMedia):
|
||||
return {
|
||||
"role": message.role,
|
||||
"images": [
|
||||
await convert_image_media_to_url(
|
||||
content, download=True, include_format=False
|
||||
)
|
||||
],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"role": message.role,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
if isinstance(message.content, list):
|
||||
return [await _convert_content(c) for c in message.content]
|
||||
else:
|
||||
return [await _convert_content(message.content)]
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleInferenceImpl
|
||||
|
||||
impl = SampleInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
|
||||
class SampleInferenceImpl(Inference):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
|
@ -1,29 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Union
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig],
|
||||
_deps,
|
||||
):
|
||||
if isinstance(config, TGIImplConfig):
|
||||
impl = TGIAdapter()
|
||||
elif isinstance(config, InferenceAPIImplConfig):
|
||||
impl = InferenceAPIAdapter()
|
||||
elif isinstance(config, InferenceEndpointImplConfig):
|
||||
impl = InferenceEndpointAdapter()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}."
|
||||
)
|
||||
|
||||
await impl.initialize(config)
|
||||
return impl
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TGIImplConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 8080
|
||||
protocol: str = "http"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"{self.protocol}://{self.host}:{self.port}"
|
||||
|
||||
api_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="A bearer token if your TGI endpoint is protected.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceEndpointImplConfig(BaseModel):
|
||||
endpoint_name: str = Field(
|
||||
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceAPIImplConfig(BaseModel):
|
||||
huggingface_repo: str = Field(
|
||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
|
@ -1,294 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_model_input_info,
|
||||
completion_request_to_prompt_model_input_info,
|
||||
)
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||
client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor()
|
||||
for model in all_registered_models()
|
||||
if model.huggingface_repo
|
||||
}
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Model registration is not supported for HuggingFace models")
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
repo = self.model_id
|
||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||
return [
|
||||
ModelDef(
|
||||
identifier=identifier,
|
||||
llama_model=identifier,
|
||||
metadata={
|
||||
"huggingface_repo": repo,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = CompletionRequest(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_max_new_tokens(self, sampling_params, input_tokens):
|
||||
return min(
|
||||
sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
||||
self.max_tokens - input_tokens - 1,
|
||||
)
|
||||
|
||||
def _build_options(
|
||||
self,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
fmt: ResponseFormat = None,
|
||||
):
|
||||
options = get_sampling_options(sampling_params)
|
||||
# delete key "max_tokens" from options since its not supported by the API
|
||||
options.pop("max_tokens", None)
|
||||
if fmt:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["grammar"] = {
|
||||
"type": "json",
|
||||
"value": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
raise ValueError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format: {fmt.type}")
|
||||
|
||||
return options
|
||||
|
||||
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
||||
prompt, input_tokens = completion_request_to_prompt_model_input_info(
|
||||
request, self.formatter
|
||||
)
|
||||
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
details=True,
|
||||
max_new_tokens=self._get_max_new_tokens(
|
||||
request.sampling_params, input_tokens
|
||||
),
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = self._get_params_for_completion(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
finish_reason = None
|
||||
if chunk.details:
|
||||
finish_reason = chunk.details.finish_reason
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
text=token_result.text, finish_reason=finish_reason
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = self._get_params_for_completion(request)
|
||||
r = await self.client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
text="".join(t.text for t in r.details.tokens),
|
||||
)
|
||||
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
return process_completion_response(response, self.formatter)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await self.client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
text="".join(t.text for t in r.details.tokens),
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
|
||||
choice = OpenAICompatCompletionChoice(text=token_result.text)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt, input_tokens = chat_completion_request_to_model_input_info(
|
||||
request, self.formatter
|
||||
)
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
details=True,
|
||||
max_new_tokens=self._get_max_new_tokens(
|
||||
request.sampling_params, input_tokens
|
||||
),
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
|
||||
|
||||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.huggingface_repo, token=config.api_token
|
||||
)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
|
||||
|
||||
class InferenceEndpointAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
|
||||
# Get the inference endpoint details
|
||||
api = HfApi(token=config.api_token)
|
||||
endpoint = api.get_inference_endpoint(config.endpoint_name)
|
||||
|
||||
# Wait for the endpoint to be ready (if not already)
|
||||
endpoint.wait(timeout=60)
|
||||
|
||||
# Initialize the adapter
|
||||
self.client = endpoint.async_client
|
||||
self.model_id = endpoint.repository
|
||||
self.max_tokens = int(
|
||||
endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
|
||||
)
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherImplConfig, _deps):
|
||||
from .together import TogetherInferenceAdapter
|
||||
|
||||
assert isinstance(
|
||||
config, TogetherImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = TogetherInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Together AI API Key",
|
||||
)
|
|
@ -1,227 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from together import Together
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
TOGETHER_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
}
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(
|
||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = CompletionRequest(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_client(self) -> Together:
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
return Together(api_key=together_api_key)
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client().completions.create(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
|
||||
) -> dict:
|
||||
options = get_sampling_options(sampling_params)
|
||||
if fmt:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
raise NotImplementedError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
r = self._get_client().completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
s = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||
) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_dict(m) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
request, self.formatter
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
not media_present
|
||||
), "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
|
||||
from .vllm import VLLMInferenceAdapter
|
||||
|
||||
assert isinstance(
|
||||
config, VLLMInferenceAdapterConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = VLLMInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMInferenceAdapterConfig(BaseModel):
|
||||
url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The URL for the vLLM model serving endpoint",
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
default="fake",
|
||||
description="The API token",
|
||||
)
|
|
@ -1,151 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import all_registered_models, resolve_model
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
|
||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.client = None
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor()
|
||||
for model in all_registered_models()
|
||||
if model.huggingface_repo
|
||||
}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Model registration is not supported for vLLM models")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
models = []
|
||||
for model in self.client.models.list():
|
||||
repo = model.id
|
||||
if repo not in self.huggingface_repo_to_llama_model_id:
|
||||
print(f"Unknown model served by vllm: {repo}")
|
||||
continue
|
||||
|
||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||
models.append(
|
||||
ModelDef(
|
||||
identifier=identifier,
|
||||
llama_model=identifier,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, self.client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, self.client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
|
||||
# generator so this wrapper is not necessary?
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
options = get_sampling_options(request.sampling_params)
|
||||
if "max_tokens" not in options:
|
||||
options["max_tokens"] = self.config.max_tokens
|
||||
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise ValueError(f"Unknown model: {request.model}")
|
||||
|
||||
return {
|
||||
"model": model.huggingface_repo,
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||
from .chroma import ChromaMemoryAdapter
|
||||
|
||||
impl = ChromaMemoryAdapter(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,159 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import chromadb
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
||||
|
||||
class ChromaIndex(EmbeddingIndex):
|
||||
def __init__(self, client: chromadb.AsyncHttpClient, collection):
|
||||
self.client = client
|
||||
self.collection = collection
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
await self.collection.add(
|
||||
documents=[chunk.json() for chunk in chunks],
|
||||
embeddings=embeddings,
|
||||
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
|
||||
)
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryDocumentsResponse:
|
||||
results = await self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
n_results=k,
|
||||
include=["documents", "distances"],
|
||||
)
|
||||
distances = results["distances"][0]
|
||||
documents = results["documents"][0]
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for dist, doc in zip(distances, documents):
|
||||
try:
|
||||
doc = json.loads(doc)
|
||||
chunk = Chunk(**doc)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse document: {doc}")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(1.0 / float(dist))
|
||||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||
url = url.rstrip("/")
|
||||
parsed = urlparse(url)
|
||||
|
||||
if parsed.path and parsed.path != "/":
|
||||
raise ValueError("URL should not contain a path")
|
||||
|
||||
self.host = parsed.hostname
|
||||
self.port = parsed.port
|
||||
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
print(f"Connecting to Chroma server at: {self.host}:{self.port}")
|
||||
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise RuntimeError("Could not connect to Chroma server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=memory_bank.identifier,
|
||||
metadata={"bank": memory_bank.json()},
|
||||
)
|
||||
bank_index = BankWithIndex(
|
||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||
)
|
||||
self.cache[memory_bank.identifier] = bank_index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
collections = await self.client.list_collections()
|
||||
for collection in collections:
|
||||
try:
|
||||
data = json.loads(collection.metadata["bank"])
|
||||
bank = parse_obj_as(MemoryBankDef, data)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse bank: {collection.metadata}")
|
||||
continue
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=ChromaIndex(self.client, collection),
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
return await index.query_documents(query, params)
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import PGVectorConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: PGVectorConfig, _deps):
|
||||
from .pgvector import PGVectorMemoryAdapter
|
||||
|
||||
impl = PGVectorMemoryAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PGVectorConfig(BaseModel):
|
||||
host: str = Field(default="localhost")
|
||||
port: int = Field(default=5432)
|
||||
db: str = Field(default="postgres")
|
||||
user: str = Field(default="postgres")
|
||||
password: str = Field(default="mysecretpassword")
|
|
@ -1,212 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import psycopg2
|
||||
from numpy.typing import NDArray
|
||||
from psycopg2 import sql
|
||||
from psycopg2.extras import execute_values, Json
|
||||
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
||||
from .config import PGVectorConfig
|
||||
|
||||
|
||||
def check_extension_version(cur):
|
||||
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||
result = cur.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
|
||||
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
|
||||
query = sql.SQL(
|
||||
"""
|
||||
INSERT INTO metadata_store (key, data)
|
||||
VALUES %s
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET data = EXCLUDED.data
|
||||
"""
|
||||
)
|
||||
|
||||
values = [(key, Json(model.dict())) for key, model in keys_models]
|
||||
execute_values(cur, query, values, template="(%s, %s)")
|
||||
|
||||
|
||||
def load_models(cur, cls):
|
||||
cur.execute("SELECT key, data FROM metadata_store")
|
||||
rows = cur.fetchall()
|
||||
return [parse_obj_as(cls, row["data"]) for row in rows]
|
||||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, bank: MemoryBankDef, dimension: int, cursor):
|
||||
self.cursor = cursor
|
||||
self.table_name = f"vector_store_{bank.identifier}"
|
||||
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB,
|
||||
embedding vector({dimension})
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
values = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
values.append(
|
||||
(
|
||||
f"{chunk.document_id}:chunk-{i}",
|
||||
Json(chunk.dict()),
|
||||
embeddings[i].tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
query = sql.SQL(
|
||||
f"""
|
||||
INSERT INTO {self.table_name} (id, document, embedding)
|
||||
VALUES %s
|
||||
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
|
||||
"""
|
||||
)
|
||||
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryDocumentsResponse:
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
SELECT document, embedding <-> %s::vector AS distance
|
||||
FROM {self.table_name}
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
""",
|
||||
(embedding.tolist(), k),
|
||||
)
|
||||
results = self.cursor.fetchall()
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc, dist in results:
|
||||
chunks.append(Chunk(**doc))
|
||||
scores.append(1.0 / float(dist))
|
||||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: PGVectorConfig) -> None:
|
||||
self.config = config
|
||||
self.cursor = None
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
database=self.config.db,
|
||||
user=self.config.user,
|
||||
password=self.config.password,
|
||||
)
|
||||
self.conn.autocommit = True
|
||||
self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
|
||||
|
||||
version = check_extension_version(self.cursor)
|
||||
if version:
|
||||
print(f"Vector extension version: {version}")
|
||||
else:
|
||||
raise RuntimeError("Vector extension is not installed.")
|
||||
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS metadata_store (
|
||||
key TEXT PRIMARY KEY,
|
||||
data JSONB
|
||||
)
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
upsert_models(
|
||||
self.cursor,
|
||||
[
|
||||
(memory_bank.identifier, memory_bank),
|
||||
],
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
banks = load_models(self.cursor, MemoryBankDef)
|
||||
for bank in banks:
|
||||
if bank.identifier not in self.cache:
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
return banks
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
return await index.query_documents(query, params)
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import QdrantConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: QdrantConfig, _deps):
|
||||
from .qdrant import QdrantVectorMemoryAdapter
|
||||
|
||||
impl = QdrantVectorMemoryAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,25 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QdrantConfig(BaseModel):
|
||||
location: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
port: Optional[int] = 6333
|
||||
grpc_port: int = 6334
|
||||
prefer_grpc: bool = False
|
||||
https: Optional[bool] = None
|
||||
api_key: Optional[str] = None
|
||||
prefix: Optional[str] = None
|
||||
timeout: Optional[int] = None
|
||||
host: Optional[str] = None
|
||||
path: Optional[str] = None
|
|
@ -1,170 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.adapters.memory.qdrant.config import QdrantConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
||||
|
||||
def convert_id(_id: str) -> str:
|
||||
"""
|
||||
Converts any string into a UUID string based on a seed.
|
||||
|
||||
Qdrant accepts UUID strings and unsigned integers as point ID.
|
||||
We use a seed to convert each string into a UUID string deterministically.
|
||||
This allows us to overwrite the same point with the original ID.
|
||||
"""
|
||||
return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id))
|
||||
|
||||
|
||||
class QdrantIndex(EmbeddingIndex):
|
||||
def __init__(self, client: AsyncQdrantClient, collection_name: str):
|
||||
self.client = client
|
||||
self.collection_name = collection_name
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
if not await self.client.collection_exists(self.collection_name):
|
||||
await self.client.create_collection(
|
||||
self.collection_name,
|
||||
vectors_config=models.VectorParams(
|
||||
size=len(embeddings[0]), distance=models.Distance.COSINE
|
||||
),
|
||||
)
|
||||
|
||||
points = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
chunk_id = f"{chunk.document_id}:chunk-{i}"
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=convert_id(chunk_id),
|
||||
vector=embedding,
|
||||
payload={"chunk_content": chunk.model_dump()}
|
||||
| {CHUNK_ID_KEY: chunk_id},
|
||||
)
|
||||
)
|
||||
|
||||
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryDocumentsResponse:
|
||||
results = (
|
||||
await self.client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
query=embedding.tolist(),
|
||||
limit=k,
|
||||
with_payload=True,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
).points
|
||||
|
||||
chunks, scores = [], []
|
||||
for point in results:
|
||||
assert isinstance(point, models.ScoredPoint)
|
||||
assert point.payload is not None
|
||||
|
||||
try:
|
||||
chunk = Chunk(**point.payload["chunk_content"])
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(point.score)
|
||||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: QdrantConfig) -> None:
|
||||
self.config = config
|
||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=QdrantIndex(self.client, memory_bank.identifier),
|
||||
)
|
||||
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
# Qdrant doesn't have collection level metadata to store the bank properties
|
||||
# So we only return from the cache value
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if not bank:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
return await index.query_documents(query, params)
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleMemoryImpl
|
||||
|
||||
impl = SampleMemoryImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
|
||||
class SampleMemoryImpl(Memory):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||
# these are the memory banks the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
||||
from .weaviate import WeaviateMemoryAdapter
|
||||
|
||||
impl = WeaviateMemoryAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WeaviateRequestProviderData(BaseModel):
|
||||
weaviate_api_key: str
|
||||
weaviate_cluster_url: str
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
pass
|
|
@ -1,192 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import weaviate
|
||||
import weaviate.classes as wvc
|
||||
from numpy.typing import NDArray
|
||||
from weaviate.classes.init import Auth
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(self, client: weaviate.Client, collection_name: str):
|
||||
self.client = client
|
||||
self.collection_name = collection_name
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
data_objects = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
data_objects.append(
|
||||
wvc.data.DataObject(
|
||||
properties={
|
||||
"chunk_content": chunk.json(),
|
||||
},
|
||||
vector=embeddings[i].tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
# Inserting chunks into a prespecified Weaviate collection
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
|
||||
# TODO: make this async friendly
|
||||
collection.data.insert_many(data_objects)
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryDocumentsResponse:
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
|
||||
results = collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||
)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc in results.objects:
|
||||
chunk_json = doc.properties["chunk_content"]
|
||||
try:
|
||||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(1.0 / doc.metadata.distance)
|
||||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateMemoryAdapter(
|
||||
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
|
||||
):
|
||||
def __init__(self, config: WeaviateConfig) -> None:
|
||||
self.config = config
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
provider_data = self.get_request_provider_data()
|
||||
assert provider_data is not None, "Request provider data must be set"
|
||||
assert isinstance(provider_data, WeaviateRequestProviderData)
|
||||
|
||||
key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}"
|
||||
if key in self.client_cache:
|
||||
return self.client_cache[key]
|
||||
|
||||
client = weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url=provider_data.weaviate_cluster_url,
|
||||
auth_credentials=Auth.api_key(provider_data.weaviate_api_key),
|
||||
)
|
||||
self.client_cache[key] = client
|
||||
return client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for client in self.client_cache.values():
|
||||
client.close()
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
# Create collection if it doesn't exist
|
||||
if not client.collections.exists(memory_bank.identifier):
|
||||
client.collections.create(
|
||||
name=memory_bank.identifier,
|
||||
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||
properties=[
|
||||
wvc.config.Property(
|
||||
name="chunk_content",
|
||||
data_type=wvc.config.DataType.TEXT,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
# TODO: right now the Llama Stack is the source of truth for these banks. That is
|
||||
# not ideal. It should be Weaviate which is the source of truth. Unfortunately,
|
||||
# list() happens at Stack startup when the Weaviate client (credentials) is not
|
||||
# yet available. We need to figure out a way to make this work.
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if not bank:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
client = self._get_client()
|
||||
if not client.collections.exists(bank_id):
|
||||
raise ValueError(f"Collection with name `{bank_id}` not found")
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
return await index.query_documents(query, params)
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BedrockSafetyConfig, _deps) -> Any:
|
||||
from .bedrock import BedrockSafetyAdapter
|
||||
|
||||
impl = BedrockSafetyAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,119 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.apis.safety import * # noqa
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BEDROCK_SUPPORTED_SHIELDS = [
|
||||
ShieldType.generic_content_shield.value,
|
||||
]
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||
self.config = config
|
||||
self.registered_shields = []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.bedrock_runtime_client = create_bedrock_client(self.config)
|
||||
self.bedrock_client = create_bedrock_client(self.config, "bedrock")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
response = self.bedrock_client.list_guardrails()
|
||||
shields = []
|
||||
for guardrail in response["guardrails"]:
|
||||
# populate the shield def with the guardrail id and version
|
||||
shield_def = ShieldDef(
|
||||
identifier=guardrail["id"],
|
||||
shield_type=ShieldType.generic_content_shield.value,
|
||||
params={
|
||||
"guardrailIdentifier": guardrail["id"],
|
||||
"guardrailVersion": guardrail["version"],
|
||||
},
|
||||
)
|
||||
self.registered_shields.append(shield_def)
|
||||
shields.append(shield_def)
|
||||
return shields
|
||||
|
||||
async def run_shield(
|
||||
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(identifier)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {identifier}")
|
||||
|
||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
{
|
||||
"text": {
|
||||
"text": "Is the AB503 Product a better investment than the S&P 500?"
|
||||
}
|
||||
}
|
||||
]```
|
||||
However the incoming messages are of this type UserMessage(content=....) coming from
|
||||
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
|
||||
|
||||
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
||||
"""
|
||||
|
||||
shield_params = shield_def.params
|
||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(
|
||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||
)
|
||||
|
||||
response = self.bedrock_runtime_client.apply_guardrail(
|
||||
guardrailIdentifier=shield_params["guardrailIdentifier"],
|
||||
guardrailVersion=shield_params["guardrailVersion"],
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
content=content_messages,
|
||||
)
|
||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
for output in response["outputs"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
user_message = output["text"]
|
||||
for assessment in response["assessments"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return RunShieldResponse()
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BedrockSafetyConfig(BedrockBaseConfig):
|
||||
pass
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleSafetyImpl
|
||||
|
||||
impl = SampleSafetyImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class SampleSafetyImpl(Safety):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
# these are the safety shields the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import TogetherProviderDataValidator, TogetherSafetyConfig # noqa: F401
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherSafetyConfig, _deps):
|
||||
from .together import TogetherSafetyImpl
|
||||
|
||||
assert isinstance(
|
||||
config, TogetherSafetyConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = TogetherSafetyImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TogetherProviderDataValidator(BaseModel):
|
||||
together_api_key: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherSafetyConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Together AI API Key (default for the distribution, if any)",
|
||||
)
|
|
@ -1,101 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from together import Together
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
from .config import TogetherSafetyConfig
|
||||
|
||||
|
||||
TOGETHER_SHIELD_MODEL_MAP = {
|
||||
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
}
|
||||
|
||||
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier=ShieldType.llama_guard.value,
|
||||
shield_type=ShieldType.llama_guard.value,
|
||||
params={},
|
||||
)
|
||||
]
|
||||
|
||||
async def run_shield(
|
||||
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(identifier)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {identifier}")
|
||||
|
||||
model = shield_def.params.get("model", "llama_guard")
|
||||
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
||||
raise ValueError(f"Unsupported safety model: {model}")
|
||||
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
# messages can have role assistant or user
|
||||
api_messages = []
|
||||
for message in messages:
|
||||
if message.role in (Role.user.value, Role.assistant.value):
|
||||
api_messages.append({"role": message.role, "content": message.content})
|
||||
|
||||
violation = await get_safety_response(
|
||||
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
|
||||
async def get_safety_response(
|
||||
api_key: str, model_name: str, messages: List[Dict[str, str]]
|
||||
) -> Optional[SafetyViolation]:
|
||||
client = Together(api_key=api_key)
|
||||
response = client.chat.completions.create(messages=messages, model=model_name)
|
||||
if len(response.choices) == 0:
|
||||
return None
|
||||
|
||||
response_text = response.choices[0].message.content
|
||||
if response_text == "safe":
|
||||
return None
|
||||
|
||||
parts = response_text.split("\n")
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
if parts[0] == "unsafe":
|
||||
return SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata={"violation_type": parts[1]},
|
||||
)
|
||||
|
||||
return None
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import OpenTelemetryConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OpenTelemetryConfig, _deps):
|
||||
from .opentelemetry import OpenTelemetryAdapter
|
||||
|
||||
impl = OpenTelemetryAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OpenTelemetryConfig(BaseModel):
|
||||
jaeger_host: str = "localhost"
|
||||
jaeger_port: int = 6831
|
|
@ -1,201 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import (
|
||||
ConsoleMetricExporter,
|
||||
PeriodicExportingMetricReader,
|
||||
)
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
from .config import OpenTelemetryConfig
|
||||
|
||||
|
||||
def string_to_trace_id(s: str) -> int:
|
||||
# Convert the string to bytes and then to an integer
|
||||
return int.from_bytes(s.encode(), byteorder="big", signed=False)
|
||||
|
||||
|
||||
def string_to_span_id(s: str) -> int:
|
||||
# Use only the first 8 bytes (64 bits) for span ID
|
||||
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
return span.is_recording()
|
||||
|
||||
|
||||
class OpenTelemetryAdapter(Telemetry):
|
||||
def __init__(self, config: OpenTelemetryConfig):
|
||||
self.config = config
|
||||
|
||||
self.resource = Resource.create(
|
||||
{ResourceAttributes.SERVICE_NAME: "foobar-service"}
|
||||
)
|
||||
|
||||
# Set up tracing with Jaeger exporter
|
||||
jaeger_exporter = JaegerExporter(
|
||||
agent_host_name=self.config.jaeger_host,
|
||||
agent_port=self.config.jaeger_port,
|
||||
)
|
||||
trace_provider = TracerProvider(resource=self.resource)
|
||||
trace_processor = BatchSpanProcessor(jaeger_exporter)
|
||||
trace_provider.add_span_processor(trace_processor)
|
||||
trace.set_tracer_provider(trace_provider)
|
||||
self.tracer = trace.get_tracer(__name__)
|
||||
|
||||
# Set up metrics
|
||||
metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter())
|
||||
metric_provider = MeterProvider(
|
||||
resource=self.resource, metric_readers=[metric_reader]
|
||||
)
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
trace.get_tracer_provider().shutdown()
|
||||
metrics.get_meter_provider().shutdown()
|
||||
|
||||
async def log_event(self, event: Event) -> None:
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
self._log_unstructured(event)
|
||||
elif isinstance(event, MetricEvent):
|
||||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event)
|
||||
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
|
||||
span = trace.get_current_span()
|
||||
span.add_event(
|
||||
name=event.message,
|
||||
attributes={"severity": event.severity.value, **event.attributes},
|
||||
timestamp=event.timestamp,
|
||||
)
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
if isinstance(event.value, int):
|
||||
self.meter.create_counter(
|
||||
name=event.metric,
|
||||
unit=event.unit,
|
||||
description=f"Counter for {event.metric}",
|
||||
).add(event.value, attributes=event.attributes)
|
||||
elif isinstance(event.value, float):
|
||||
self.meter.create_gauge(
|
||||
name=event.metric,
|
||||
unit=event.unit,
|
||||
description=f"Gauge for {event.metric}",
|
||||
).set(event.value, attributes=event.attributes)
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent) -> None:
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
context = trace.set_span_in_context(
|
||||
trace.NonRecordingSpan(
|
||||
trace.SpanContext(
|
||||
trace_id=string_to_trace_id(event.trace_id),
|
||||
span_id=string_to_span_id(event.span_id),
|
||||
is_remote=True,
|
||||
)
|
||||
)
|
||||
)
|
||||
span = self.tracer.start_span(
|
||||
name=event.payload.name,
|
||||
kind=trace.SpanKind.INTERNAL,
|
||||
context=context,
|
||||
attributes=event.attributes,
|
||||
)
|
||||
|
||||
if event.payload.parent_span_id:
|
||||
span.set_parent(
|
||||
trace.SpanContext(
|
||||
trace_id=string_to_trace_id(event.trace_id),
|
||||
span_id=string_to_span_id(event.payload.parent_span_id),
|
||||
is_remote=True,
|
||||
)
|
||||
)
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = trace.get_current_span()
|
||||
span.set_status(
|
||||
trace.Status(
|
||||
trace.StatusCode.OK
|
||||
if event.payload.status == SpanStatus.OK
|
||||
else trace.StatusCode.ERROR
|
||||
)
|
||||
)
|
||||
span.end(end_time=event.timestamp)
|
||||
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
# we need to look up the root span id
|
||||
raise NotImplementedError("not yet no")
|
||||
|
||||
|
||||
# Usage example
|
||||
async def main():
|
||||
telemetry = OpenTelemetryTelemetry("my-service")
|
||||
await telemetry.initialize()
|
||||
|
||||
# Log an unstructured event
|
||||
await telemetry.log_event(
|
||||
UnstructuredLogEvent(
|
||||
trace_id="trace123",
|
||||
span_id="span456",
|
||||
timestamp=datetime.now(),
|
||||
message="This is a log message",
|
||||
severity=LogSeverity.INFO,
|
||||
)
|
||||
)
|
||||
|
||||
# Log a metric event
|
||||
await telemetry.log_event(
|
||||
MetricEvent(
|
||||
trace_id="trace123",
|
||||
span_id="span456",
|
||||
timestamp=datetime.now(),
|
||||
metric="my_metric",
|
||||
value=42,
|
||||
unit="count",
|
||||
)
|
||||
)
|
||||
|
||||
# Log a structured event (span start)
|
||||
await telemetry.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id="trace123",
|
||||
span_id="span789",
|
||||
timestamp=datetime.now(),
|
||||
payload=SpanStartPayload(name="my_operation"),
|
||||
)
|
||||
)
|
||||
|
||||
# Log a structured event (span end)
|
||||
await telemetry.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id="trace123",
|
||||
span_id="span789",
|
||||
timestamp=datetime.now(),
|
||||
payload=SpanEndPayload(status=SpanStatus.OK),
|
||||
)
|
||||
)
|
||||
|
||||
await telemetry.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleTelemetryImpl
|
||||
|
||||
impl = SampleTelemetryImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
|
||||
class SampleTelemetryImpl(Telemetry):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
Loading…
Add table
Add a link
Reference in a new issue