Merge branch 'main' into clarifai-inference-provider

This commit is contained in:
sanjaychelliah 2024-11-26 18:01:45 +05:30 committed by GitHub
commit 4b9085d312
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
536 changed files with 34661 additions and 12116 deletions

View file

@ -1,257 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from openai import OpenAI
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import DatabricksImplConfig
DATABRICKS_SUPPORTED_MODELS = {
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
}
class DatabricksInferenceAdapter(Inference):
def __init__(self, config: DatabricksImplConfig) -> None:
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> OpenAI:
return OpenAI(
base_url=self.config.url,
api_key=self.config.api_token
)
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_databricks_messages(self, messages: list[Message]) -> list:
databricks_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
databricks_messages.append({"role": role, "content": message.content})
return databricks_messages
def resolve_databricks_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in DATABRICKS_SUPPORTED_MODELS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(DATABRICKS_SUPPORTED_MODELS.keys())}"
return DATABRICKS_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)
def get_databricks_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
options = self.get_databricks_chat_options(request)
databricks_model = self.resolve_databricks_model(request.model)
if not request.stream:
r = self.client.chat.completions.create(
model=databricks_model,
messages=self._messages_to_databricks_messages(messages),
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if r.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
for chunk in self.client.chat.completions.create(
model=databricks_model,
messages=self._messages_to_databricks_messages(messages),
stream=True,
**options,
):
if chunk.choices[0].finish_reason:
if (
stop_reason is None
and chunk.choices[0].finish_reason == "stop"
):
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
and chunk.choices[0].finish_reason == "length"
):
stop_reason = StopReason.out_of_tokens
break
text = chunk.choices[0].delta.content
if text is None:
continue
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -1,247 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
}
class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
def __init__(self, config: FireworksImplConfig) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> Fireworks:
return Fireworks(api_key=self.config.api_key)
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list:
fireworks_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
fireworks_messages.append({"role": role, "content": message.content})
return fireworks_messages
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to fireworks
options = self.get_fireworks_chat_options(request)
fireworks_model = self.map_to_provider_model(request.model)
if not request.stream:
r = await self.client.chat.completions.acreate(
model=fireworks_model,
messages=self._messages_to_fireworks_messages(messages),
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if r.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in self.client.chat.completions.acreate(
model=fireworks_model,
messages=self._messages_to_fireworks_messages(messages),
stream=True,
**options,
):
if chunk.choices[0].finish_reason:
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
and chunk.choices[0].finish_reason == "length"
):
stop_reason = StopReason.out_of_tokens
break
text = chunk.choices[0].delta.content
if text is None:
continue
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -1,266 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
}
class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
def __init__(self, url: str) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
)
self.url = url
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.url)
async def initialize(self) -> None:
print("Initializing Ollama, checking connectivity to server...")
try:
await self.client.ps()
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
) from e
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
ollama_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
ollama_messages.append({"role": role, "content": message.content})
return ollama_messages
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
if (
request.sampling_params.repetition_penalty is not None
and request.sampling_params.repetition_penalty != 1.0
):
options["repeat_penalty"] = request.sampling_params.repetition_penalty
return options
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.map_to_provider_model(request.model)
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
if not request.stream:
r = await self.client.chat(
model=ollama_model,
messages=self._messages_to_ollama_messages(messages),
stream=False,
options=options,
)
stop_reason = None
if r["done"]:
if r["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r["message"]["content"], stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
stream = await self.client.chat(
model=ollama_model,
messages=self._messages_to_ollama_messages(messages),
stream=True,
options=options,
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
if chunk["done"]:
if stop_reason is None and chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif stop_reason is None and chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk["message"]["content"]
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -1,260 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
logger = logging.getLogger(__name__)
class _HfAdapter(Inference, RoutableProvider):
client: AsyncInferenceClient
max_tokens: int
model_id: str
def __init__(self) -> None:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def get_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
input_tokens = len(model_input.tokens)
max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
print(f"Calculated max_new_tokens: {max_new_tokens}")
options = self.get_chat_options(request)
if not request.stream:
response = await self.client.text_generation(
prompt=prompt,
stream=False,
details=True,
max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
)
stop_reason = None
if response.details.finish_reason:
if response.details.finish_reason in ["stop", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif response.details.finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
response.generated_text,
stop_reason,
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
tokens = []
async for response in await self.client.text_generation(
prompt=prompt,
stream=True,
details=True,
max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
):
token_result = response.token
buffer += token_result.text
tokens.append(token_result.id)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
# parse tool calls and report errors
message = self.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(
model=config.model_id, token=config.api_token
)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
class InferenceEndpointAdapter(_HfAdapter):
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
# Get the inference endpoint details
api = HfApi(token=config.api_token)
endpoint = api.get_inference_endpoint(config.endpoint_name)
# Wait for the endpoint to be ready (if not already)
endpoint.wait(timeout=60)
# Initialize the adapter
self.client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(
endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
)

View file

@ -1,265 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from .config import TogetherImplConfig
TOGETHER_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
}
class TogetherInferenceAdapter(
Inference, NeedsRequestProviderData, RoutableProviderForModels
):
def __init__(self, config: TogetherImplConfig) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> Together:
return Together(api_key=self.config.api_key)
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_together_messages(self, messages: list[Message]) -> list:
together_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
together_messages.append({"role": role, "content": message.content})
return together_messages
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
client = Together(api_key=together_api_key)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
# accumulate sampling params and other options to pass to together
options = self.get_together_chat_options(request)
together_model = self.map_to_provider_model(request.model)
messages = augment_messages_for_tools(request)
if not request.stream:
# TODO: might need to add back an async here
r = client.chat.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if (
r.choices[0].finish_reason == "stop"
or r.choices[0].finish_reason == "eos"
):
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
for chunk in client.chat.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
stream=True,
**options,
):
if finish_reason := chunk.choices[0].finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk.choices[0].delta.content
if text is None:
continue
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -1,8 +0,0 @@
from .config import WeaviateConfig
async def get_adapter_impl(config: WeaviateConfig, _deps):
from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -1,18 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
class WeaviateRequestProviderData(BaseModel):
# if there _is_ provider data, it must specify the API KEY
# if you want it to be optional, use Optional[str]
weaviate_api_key: str
weaviate_cluster_url: str
@json_schema_type
class WeaviateConfig(BaseModel):
collection: str = Field(default="MemoryBank")

View file

@ -1,192 +0,0 @@
import json
import uuid
from typing import List, Optional, Dict, Any
from numpy.typing import NDArray
import weaviate
import weaviate.classes as wvc
from weaviate.classes.init import Auth
from llama_stack.apis.memory import *
from llama_stack.distribution.request_headers import get_request_provider_data
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from .config import WeaviateConfig, WeaviateRequestProviderData
class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection: str):
self.client = client
self.collection = collection
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
data_objects = []
for i, chunk in enumerate(chunks):
data_objects.append(wvc.data.DataObject(
properties={
"chunk_content": chunk,
},
vector = embeddings[i].tolist()
))
# Inserting chunks into a prespecified Weaviate collection
assert self.collection is not None, "Collection name must be specified"
my_collection = self.client.collections.get(self.collection)
await my_collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
assert self.collection is not None, "Collection name must be specified"
my_collection = self.client.collections.get(self.collection)
results = my_collection.query.near_vector(
near_vector = embedding.tolist(),
limit = k,
return_meta_data = wvc.query.MetadataQuery(distance=True)
)
chunks = []
scores = []
for doc in results.objects:
try:
chunk = doc.properties["chunk_content"]
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)
except Exception as e:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {e}")
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class WeaviateMemoryAdapter(Memory):
def __init__(self, config: WeaviateConfig) -> None:
self.config = config
self.client = None
self.cache = {}
def _get_client(self) -> weaviate.Client:
request_provider_data = get_request_provider_data()
if request_provider_data is not None:
assert isinstance(request_provider_data, WeaviateRequestProviderData)
# Connect to Weaviate Cloud
return weaviate.connect_to_weaviate_cloud(
cluster_url = request_provider_data.weaviate_cluster_url,
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
)
async def initialize(self) -> None:
try:
self.client = self._get_client()
# Create collection if it doesn't exist
if not self.client.collections.exists(self.config.collection):
self.client.collections.create(
name = self.config.collection,
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
]
)
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to Weaviate server") from e
async def shutdown(self) -> None:
self.client = self._get_client()
if self.client:
self.client.close()
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
self.client = self._get_client()
# Store the bank as a new collection in Weaviate
self.client.collections.create(
name=bank_id
)
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(cleint = self.client, collection = bank_id),
)
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
self.client = self._get_client()
if bank_id in self.cache:
return self.cache[bank_id]
collections = await self.client.collections.list_all().keys()
for collection in collections:
if collection == bank_id:
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(self.client, collection),
)
self.cache[bank_id] = index
return index
return None
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -1,120 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import traceback
from typing import Any, Dict, List
import boto3
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)
SUPPORTED_SHIELD_TYPES = [
"bedrock_guardrail",
]
class BedrockSafetyAdapter(Safety, RoutableProvider):
def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
self.config = config
async def initialize(self) -> None:
try:
print(f"initializing with profile --- > {self.config}")
self.boto_client = boto3.Session(
profile_name=self.config.aws_profile
).client("bedrock-runtime")
except Exception as e:
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
{
"text": {
"text": "Is the AB503 Product a better investment than the S&P 500?"
}
}
]```
However the incoming messages are of this type UserMessage(content=....) coming from
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
"""
try:
logger.debug(f"run_shield::{params}::messages={messages}")
if "guardrailIdentifier" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
response = self.boto_client.apply_guardrail(
guardrailIdentifier=params.get("guardrailIdentifier"),
guardrailVersion=params.get("guardrailVersion"),
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
logger.debug(f"run_shield:: response: {response}::")
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
except Exception:
error_str = traceback.format_exc()
logger.error(
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
)
return None

View file

@ -1,16 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
class BedrockSafetyConfig(BaseModel):
"""Configuration information for a guardrail that you want to use in the request."""
aws_profile: str = Field(
default="default",
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
)

View file

@ -1,26 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
class TogetherProviderDataValidator(BaseModel):
together_api_key: str
@json_schema_type
class TogetherSafetyConfig(BaseModel):
url: str = Field(
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
)
api_key: Optional[str] = Field(
default=None,
description="The Together AI API Key (default for the distribution, if any)",
)

View file

@ -1,97 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from together import Together
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from .config import TogetherSafetyConfig
SAFETY_SHIELD_TYPES = {
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
}
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SAFETY_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type not in SAFETY_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
model_name = SAFETY_SHIELD_TYPES[shield_type]
# messages can have role assistant or user
api_messages = []
for message in messages:
if message.role in (Role.user.value, Role.assistant.value):
api_messages.append({"role": message.role, "content": message.content})
violation = await get_safety_response(
together_api_key, model_name, api_messages
)
return RunShieldResponse(violation=violation)
async def get_safety_response(
api_key: str, model_name: str, messages: List[Dict[str, str]]
) -> Optional[SafetyViolation]:
client = Together(api_key=api_key)
response = client.chat.completions.create(messages=messages, model=model_name)
if len(response.choices) == 0:
return None
response_text = response.choices[0].message.content
if response_text == "safe":
return None
parts = response_text.split("\n")
if len(parts) != 2:
return None
if parts[0] == "unsafe":
return SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message="unsafe",
metadata={"violation_type": parts[1]},
)
return None

View file

@ -1,201 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from opentelemetry import metrics, trace
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import (
ConsoleMetricExporter,
PeriodicExportingMetricReader,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
return int.from_bytes(s.encode(), byteorder="big", signed=False)
def string_to_span_id(s: str) -> int:
# Use only the first 8 bytes (64 bits) for span ID
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig):
self.config = config
self.resource = Resource.create(
{ResourceAttributes.SERVICE_NAME: "foobar-service"}
)
# Set up tracing with Jaeger exporter
jaeger_exporter = JaegerExporter(
agent_host_name=self.config.jaeger_host,
agent_port=self.config.jaeger_port,
)
trace_provider = TracerProvider(resource=self.resource)
trace_processor = BatchSpanProcessor(jaeger_exporter)
trace_provider.add_span_processor(trace_processor)
trace.set_tracer_provider(trace_provider)
self.tracer = trace.get_tracer(__name__)
# Set up metrics
metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter())
metric_provider = MeterProvider(
resource=self.resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
async def log_event(self, event: Event) -> None:
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event)
elif isinstance(event, MetricEvent):
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event)
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
span = trace.get_current_span()
span.add_event(
name=event.message,
attributes={"severity": event.severity.value, **event.attributes},
timestamp=event.timestamp,
)
def _log_metric(self, event: MetricEvent) -> None:
if isinstance(event.value, int):
self.meter.create_counter(
name=event.metric,
unit=event.unit,
description=f"Counter for {event.metric}",
).add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
self.meter.create_gauge(
name=event.metric,
unit=event.unit,
description=f"Gauge for {event.metric}",
).set(event.value, attributes=event.attributes)
def _log_structured(self, event: StructuredLogEvent) -> None:
if isinstance(event.payload, SpanStartPayload):
context = trace.set_span_in_context(
trace.NonRecordingSpan(
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.span_id),
is_remote=True,
)
)
)
span = self.tracer.start_span(
name=event.payload.name,
kind=trace.SpanKind.INTERNAL,
context=context,
attributes=event.attributes,
)
if event.payload.parent_span_id:
span.set_parent(
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.payload.parent_span_id),
is_remote=True,
)
)
elif isinstance(event.payload, SpanEndPayload):
span = trace.get_current_span()
span.set_status(
trace.Status(
trace.StatusCode.OK
if event.payload.status == SpanStatus.OK
else trace.StatusCode.ERROR
)
)
span.end(end_time=event.timestamp)
async def get_trace(self, trace_id: str) -> Trace:
# we need to look up the root span id
raise NotImplementedError("not yet no")
# Usage example
async def main():
telemetry = OpenTelemetryTelemetry("my-service")
await telemetry.initialize()
# Log an unstructured event
await telemetry.log_event(
UnstructuredLogEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
message="This is a log message",
severity=LogSeverity.INFO,
)
)
# Log a metric event
await telemetry.log_event(
MetricEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
metric="my_metric",
value=42,
unit="count",
)
)
# Log a structured event (span start)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanStartPayload(name="my_operation"),
)
)
# Log a structured event (span end)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanEndPayload(status=SpanStatus.OK),
)
)
await telemetry.shutdown()
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -6,10 +6,18 @@
from enum import Enum
from typing import Any, List, Optional, Protocol
from urllib.parse import urlparse
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import Dataset
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield
@json_schema_type
class Api(Enum):
@ -17,17 +25,55 @@ class Api(Enum):
safety = "safety"
agents = "agents"
memory = "memory"
datasetio = "datasetio"
scoring = "scoring"
eval = "eval"
telemetry = "telemetry"
models = "models"
shields = "shields"
memory_banks = "memory_banks"
datasets = "datasets"
scoring_functions = "scoring_functions"
eval_tasks = "eval_tasks"
# built-in API
inspect = "inspect"
class ModelsProtocolPrivate(Protocol):
async def register_model(self, model: Model) -> None: ...
async def unregister_model(self, model_id: str) -> None: ...
class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...
class MemoryBanksProtocolPrivate(Protocol):
async def list_memory_banks(self) -> List[MemoryBank]: ...
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFn]: ...
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
class EvalTasksProtocolPrivate(Protocol):
async def register_eval_task(self, eval_task: EvalTask) -> None: ...
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
@ -40,24 +86,24 @@ class ProviderSpec(BaseModel):
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
deprecation_warning: Optional[str] = Field(
default=None,
description="If this provider is deprecated, specify the warning message here",
)
deprecation_error: Optional[str] = Field(
default=None,
description="If this provider is deprecated and does NOT work, specify the error message here",
)
# used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list)
class RoutingTable(Protocol):
def get_routing_keys(self) -> List[str]: ...
def get_provider_impl(self, routing_key: str) -> Any: ...
class RoutableProvider(Protocol):
"""
A provider which sits behind the RoutingTable and can get routed to.
All Inference / Safety / Memory providers fall into this bucket.
"""
async def validate_routing_keys(self, keys: List[str]) -> None: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
@ -113,21 +159,27 @@ Fully-qualified name of the module to import. The module is expected to have:
class RemoteProviderConfig(BaseModel):
host: str = "localhost"
port: int
port: Optional[int] = None
protocol: str = "http"
@property
def url(self) -> str:
return f"http://{self.host}:{self.port}"
if self.port is None:
return f"{self.protocol}://{self.host}"
return f"{self.protocol}://{self.host}:{self.port}"
@classmethod
def from_url(cls, url: str) -> "RemoteProviderConfig":
parsed = urlparse(url)
return cls(host=parsed.hostname, port=parsed.port, protocol=parsed.scheme)
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
default=None,
adapter: AdapterSpec = Field(
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
API responses, specify the adapter here.
""",
)
@ -137,34 +189,21 @@ as being "Llama Stack compatible"
@property
def module(self) -> str:
if self.adapter:
return self.adapter.module
return f"llama_stack.apis.{self.api.value}.client"
return self.adapter.module
@property
def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages
return []
return self.adapter.pip_packages
@property
def provider_data_validator(self) -> Optional[str]:
if self.adapter:
return self.adapter.provider_data_validator
return None
return self.adapter.provider_data_validator
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
adapter=adapter,
)

View file

@ -1,109 +0,0 @@
# LocalInference
LocalInference provides a local inference implementation powered by [executorch](https://github.com/pytorch/executorch/).
Llama Stack currently supports on-device inference for iOS with Android coming soon. You can run on-device inference on Android today using [executorch](https://github.com/pytorch/executorch/tree/main/examples/demo-apps/android/LlamaDemo), PyTorchs on-device inference library.
## Installation
We're working on making LocalInference easier to set up. For now, you'll need to import it via `.xcframework`:
1. Clone the executorch submodule in this repo and its dependencies: `git submodule update --init --recursive`
1. Install [Cmake](https://cmake.org/) for the executorch build`
1. Drag `LocalInference.xcodeproj` into your project
1. Add `LocalInference` as a framework in your app target
1. Add a package dependency on https://github.com/pytorch/executorch (branch latest)
1. Add all the kernels / backends from executorch (but not exectuorch itself!) as frameworks in your app target:
- backend_coreml
- backend_mps
- backend_xnnpack
- kernels_custom
- kernels_optimized
- kernels_portable
- kernels_quantized
1. In "Build Settings" > "Other Linker Flags" > "Any iOS Simulator SDK", add:
```
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_optimized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_custom-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_quantized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_xnnpack-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_coreml-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_mps-simulator-release.a
```
1. In "Build Settings" > "Other Linker Flags" > "Any iOS SDK", add:
```
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_optimized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_custom-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_quantized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_xnnpack-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_coreml-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_mps-simulator-release.a
```
## Preparing a model
1. Prepare a `.pte` file [following the executorch docs](https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md#step-2-prepare-model)
2. Bundle the `.pte` and `tokenizer.model` file into your app
## Using LocalInference
1. Instantiate LocalInference with a DispatchQueue. Optionally, pass it into your agents service:
```swift
init () {
runnerQueue = DispatchQueue(label: "org.meta.llamastack")
inferenceService = LocalInferenceService(queue: runnerQueue)
agentsService = LocalAgentsService(inference: inferenceService)
}
```
2. Before making any inference calls, load your model from your bundle:
```swift
let mainBundle = Bundle.main
inferenceService.loadModel(
modelPath: mainBundle.url(forResource: "llama32_1b_spinquant", withExtension: "pte"),
tokenizerPath: mainBundle.url(forResource: "tokenizer", withExtension: "model"),
completion: {_ in } // use to handle load failures
)
```
3. Make inference calls (or agents calls) as you normally would with LlamaStack:
```
for await chunk in try await agentsService.initAndCreateTurn(
messages: [
.UserMessage(Components.Schemas.UserMessage(
content: .case1("Call functions as needed to handle any actions in the following text:\n\n" + text),
role: .user))
]
) {
```
## Troubleshooting
If you receive errors like "missing package product" or "invalid checksum", try cleaning the build folder and resetting the Swift package cache:
(Opt+Click) Product > Clean Build Folder Immediately
```
rm -rf \
~/Library/org.swift.swiftpm \
~/Library/Caches/org.swift.swiftpm \
~/Library/Caches/com.apple.dt.Xcode \
~/Library/Developer/Xcode/DerivedData
```

View file

@ -1,46 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceImplConfig(BaseModel):
model: str = Field(
default="Llama3.1-8B-Instruct",
description="Model descriptor from `llama model list`",
)
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model
@property
def model_parallel_size(self) -> int:
# HACK ALERT: this will be fixed when we move inference configuration
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
# as configuration there
resolved = resolve_model(self.model)
assert resolved is not None
return resolved.pth_file_count

View file

@ -1,225 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import AsyncIterator, List, Union
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, RoutableProvider):
def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config
model = resolve_model(config.model)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
# verify that the checkpoint actually is for this model lol
async def initialize(self) -> None:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
assert (
len(routing_keys) == 1
), f"Only one routing key is supported {routing_keys}"
assert routing_keys[0] == self.config.model
async def shutdown(self) -> None:
self.generator.stop()
# hm, when stream=False, we should not be doing SSE :/ which is what the
# top-level server is going to do. make the typing more specific here
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
async with SEMAPHORE:
if request.stream:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
tokens = []
logprobs = []
stop_reason = None
buffer = ""
ipython = False
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
buffer += token_result.text
tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if not request.stream:
if request.logprobs:
assert (
len(token_result.logprobs) == 1
), "Expected logprob to contain 1 result for the current token"
assert (
request.logprobs.top_k == 1
), "Only top_k=1 is supported for LogProbConfig"
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
# TODO(ashwin): parse tool calls separately here and report errors?
# if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)

View file

@ -1,106 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from termcolor import cprint
from torch import Tensor
from llama_stack.apis.inference import QuantizationType
from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceImplConfig,
)
def is_fbgemm_available() -> bool:
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except ImportError:
return False
def swiglu_wrapper(
self,
x: Tensor,
):
from .fp8_impls import ffn_swiglu
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
model: Transformer,
config: MetaReferenceImplConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
checkpoint = config.checkpoint_config.checkpoint
# Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow")
fp8_scales_path = os.path.join(
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
else:
cprint("Quantizing fp8 weights from bf16...", "yellow")
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
return model

View file

@ -1,129 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import uuid
from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
)
from llama_stack.providers.utils.telemetry import tracing
from .config import FaissImplConfig
logger = logging.getLogger(__name__)
class FaissIndex(EmbeddingIndex):
id_by_index: Dict[int, str]
chunk_by_index: Dict[int, str]
def __init__(self, dimension: int):
self.index = faiss.IndexFlatL2(dimension)
self.id_by_index = {}
self.chunk_by_index = {}
@tracing.span(name="add_chunks")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
self.id_by_index[indexlen + i] = chunk.document_id
self.index.add(np.array(embeddings).astype(np.float32))
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
distances, indices = self.index.search(
embedding.reshape(1, -1).astype(np.float32), k
)
chunks = []
scores = []
for d, i in zip(distances[0], indices[0]):
if i < 0:
continue
chunks.append(self.chunk_by_index[int(i)])
scores.append(1.0 / float(d))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory, RoutableProvider):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.cache = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
assert url is None, "URL is not supported for this implementation"
assert (
config.type == MemoryBankType.vector.value
), f"Only vector banks are supported {config.type}"
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
index = self.cache.get(bank_id)
if index is None:
return None
return index.bank
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SafetyConfig
async def get_provider_impl(config: SafetyConfig, deps):
from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -1,52 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Optional
from llama_models.sku_list import CoreModelId, safety_models
from pydantic import BaseModel, validator
class MetaReferenceShieldType(Enum):
llama_guard = "llama_guard"
code_scanner_guard = "code_scanner_guard"
injection_shield = "injection_shield"
jailbreak_shield = "jailbreak_shield"
class LlamaGuardShieldConfig(BaseModel):
model: str = "Llama-Guard-3-1B"
excluded_categories: List[str] = []
disable_input_check: bool = False
disable_output_check: bool = False
@validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in safety_models()
if (
m.core_model_id
in {
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_1b,
CoreModelId.llama_guard_3_11b_vision,
}
)
]
if model not in permitted_models:
raise ValueError(
f"Invalid model: {model}. Must be one of {permitted_models}"
)
return model
class SafetyConfig(BaseModel):
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
enable_prompt_guard: Optional[bool] = False

View file

@ -1,110 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api, RoutableProvider
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
)
from .config import MetaReferenceShieldType, SafetyConfig
from .shields import CodeScannerShield, LlamaGuardShield, ShieldBase
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
if self.config.enable_prompt_guard:
from .shields import PromptGuardShield
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
_ = PromptGuardShield.instance(model_dir)
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
available_shields = [v.value for v in MetaReferenceShieldType]
for key in routing_keys:
if key not in available_shields:
raise ValueError(f"Unknown safety shield type: {key}")
async def run_shield(
self,
shield_type: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
available_shields = [v.value for v in MetaReferenceShieldType]
assert shield_type in available_shields, f"Unknown shield {shield_type}"
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type))
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield.run(messages)
violation = None
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
violation = SafetyViolation(
violation_level=(
ViolationLevel.ERROR
if shield.on_violation_action == OnViolationAction.RAISE
else ViolationLevel.WARN
),
user_message=res.violation_return_message,
metadata={
"violation_type": res.violation_type,
},
)
return RunShieldResponse(violation=violation)
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
cfg = self.config
if typ == MetaReferenceShieldType.llama_guard:
cfg = cfg.llama_guard_shield
assert (
cfg is not None
), "Cannot use LlamaGuardShield since not present in config"
return LlamaGuardShield(
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
disable_input_check=cfg.disable_input_check,
disable_output_check=cfg.disable_output_check,
)
elif typ == MetaReferenceShieldType.jailbreak_shield:
from .shields import JailbreakShield
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
return JailbreakShield.instance(model_dir)
elif typ == MetaReferenceShieldType.injection_shield:
from .shields import InjectionShield
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
return InjectionShield.instance(model_dir)
elif typ == MetaReferenceShieldType.code_scanner_guard:
return CodeScannerShield.instance()
else:
raise ValueError(f"Unknown shield type: {typ}")

View file

@ -1,33 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# supress warnings and spew of logs from hugging face
import transformers
from .base import ( # noqa: F401
DummyShield,
OnViolationAction,
ShieldBase,
ShieldResponse,
TextShield,
)
from .code_scanner import CodeScannerShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import ( # noqa: F401
InjectionShield,
JailbreakShield,
PromptGuardShield,
)
transformers.logging.set_verbosity_error()
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings("ignore")

View file

@ -1,64 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from pydantic import BaseModel
from llama_stack.apis.safety import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
# TODO: clean this up; just remove this type completely
class ShieldResponse(BaseModel):
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None
# TODO: this is a caller / agent concern
class OnViolationAction(Enum):
IGNORE = 0
WARN = 1
RAISE = 2
class ShieldBase(ABC):
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
self.on_violation_action = on_violation_action
@abstractmethod
async def run(self, messages: List[Message]) -> ShieldResponse:
raise NotImplementedError()
def message_content_as_str(message: Message) -> str:
return interleaved_text_media_as_str(message.content)
# For shields that operate on simple strings
class TextShield(ShieldBase):
def convert_messages_to_text(self, messages: List[Message]) -> str:
return "\n".join([message_content_as_str(m) for m in messages])
async def run(self, messages: List[Message]) -> ShieldResponse:
text = self.convert_messages_to_text(messages)
return await self.run_impl(text)
@abstractmethod
async def run_impl(self, text: str) -> ShieldResponse:
raise NotImplementedError()
class DummyShield(TextShield):
async def run_impl(self, text: str) -> ShieldResponse:
# Dummy return LOW to test e2e
return ShieldResponse(is_violation=False)

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from .base import ShieldResponse, TextShield
class CodeScannerShield(TextShield):
async def run_impl(self, text: str) -> ShieldResponse:
from codeshield.cs import CodeShield
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
result = await CodeShield.scan_code(text)
if result.is_insecure:
return ShieldResponse(
is_violation=True,
violation_type=",".join(
[issue.pattern_id for issue in result.issues_found]
),
violation_return_message="Sorry, I found security concerns in the code.",
)
else:
return ShieldResponse(is_violation=False)

View file

@ -1,145 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import auto, Enum
from typing import List
import torch
from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
class PromptGuardShield(TextShield):
class Mode(Enum):
INJECTION = auto()
JAILBREAK = auto()
_instances = {}
_model_cache = None
@staticmethod
def instance(
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action=OnViolationAction.RAISE,
) -> "PromptGuardShield":
action_value = on_violation_action.value
key = (model_dir, threshold, temperature, mode, action_value)
if key not in PromptGuardShield._instances:
PromptGuardShield._instances[key] = PromptGuardShield(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=mode,
on_violation_action=on_violation_action,
)
return PromptGuardShield._instances[key]
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.device = "cuda"
if PromptGuardShield._model_cache is None:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
PromptGuardShield._model_cache = (tokenizer, model)
self.tokenizer, self.model = PromptGuardShield._model_cache
self.temperature = temperature
self.threshold = threshold
self.mode = mode
def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1])
async def run_impl(self, text: str) -> ShieldResponse:
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs[0]
probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item()
cprint(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
color="magenta",
)
if self.mode == self.Mode.INJECTION and (
score_embedded + score_malicious > self.threshold
):
return ShieldResponse(
is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse(
is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return ShieldResponse(
is_violation=False,
)
class JailbreakShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.JAILBREAK,
on_violation_action=on_violation_action,
)
class InjectionShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.INJECTION,
on_violation_action=on_violation_action,
)

View file

@ -1,356 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import uuid
from typing import Any
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import (
CompletionMessage,
InterleavedTextMedia,
Message,
StopReason,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.llama3.api.tokenizer import Tokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from llama_stack.apis.inference import ChatCompletionRequest, Inference
from llama_stack.apis.inference.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
LogProbConfig,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from .config import VLLMConfig
log = logging.getLogger(__name__)
def _random_uuid() -> str:
return str(uuid.uuid4().hex)
def _vllm_sampling_params(sampling_params: Any) -> SamplingParams:
"""Convert sampling params to vLLM sampling params."""
if sampling_params is None:
return SamplingParams()
# TODO convert what I saw in my first test ... but surely there's more to do here
kwargs = {
"temperature": sampling_params.temperature,
}
if sampling_params.top_k >= 1:
kwargs["top_k"] = sampling_params.top_k
if sampling_params.top_p:
kwargs["top_p"] = sampling_params.top_p
if sampling_params.max_tokens >= 1:
kwargs["max_tokens"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
return SamplingParams().from_optional(**kwargs)
class VLLMInferenceImpl(Inference, RoutableProviderForModels):
"""Inference implementation for vLLM."""
HF_MODEL_MAPPINGS = {
# TODO: seems like we should be able to build this table dynamically ...
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
}
def __init__(self, config: VLLMConfig):
Inference.__init__(self)
RoutableProviderForModels.__init__(
self,
stack_to_provider_models_map=self.HF_MODEL_MAPPINGS,
)
self.config = config
self.engine = None
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
async def initialize(self):
"""Initialize the vLLM inference adapter."""
log.info("Initializing vLLM inference adapter")
# Disable usage stats reporting. This would be a surprising thing for most
# people to find out was on by default.
# https://docs.vllm.ai/en/latest/serving/usage_stats.html
if "VLLM_NO_USAGE_STATS" not in os.environ:
os.environ["VLLM_NO_USAGE_STATS"] = "1"
hf_model = self.HF_MODEL_MAPPINGS.get(self.config.model)
# TODO -- there are a ton of options supported here ...
engine_args = AsyncEngineArgs()
engine_args.model = hf_model
# We will need a new config item for this in the future if model support is more broad
# than it is today (llama only)
engine_args.tokenizer = hf_model
engine_args.tensor_parallel_size = self.config.tensor_parallel_size
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
async def shutdown(self):
"""Shutdown the vLLM inference adapter."""
log.info("Shutting down vLLM inference adapter")
if self.engine:
self.engine.shutdown_background_loop()
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Any | None = ...,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
log.info("vLLM completion")
messages = [Message(role="user", content=content)]
async for result in self.chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
):
yield result
async def chat_completion(
self,
model: str,
messages: list[Message],
sampling_params: Any | None = ...,
tools: list[ToolDefinition] | None = ...,
tool_choice: ToolChoice | None = ...,
tool_prompt_format: ToolPromptFormat | None = ...,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
log.info("vLLM chat completion")
assert self.engine is not None
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
log.info("Sampling params: %s", sampling_params)
vllm_sampling_params = _vllm_sampling_params(sampling_params)
messages = augment_messages_for_tools(request)
log.info("Augmented messages: %s", messages)
prompt = "".join([str(message.content) for message in messages])
request_id = _random_uuid()
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id
)
if not stream:
# Non-streaming case
final_output = None
stop_reason = None
async for request_output in results_generator:
final_output = request_output
if stop_reason is None and request_output.outputs:
reason = request_output.outputs[-1].stop_reason
if reason == "stop":
stop_reason = StopReason.end_of_turn
elif reason == "length":
stop_reason = StopReason.out_of_tokens
if not stop_reason:
stop_reason = StopReason.end_of_message
if final_output:
response = "".join([output.text for output in final_output.outputs])
yield ChatCompletionResponse(
completion_message=CompletionMessage(
content=response,
stop_reason=stop_reason,
),
logprobs=None,
)
else:
# Streaming case
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
last_chunk = ""
ipython = False
stop_reason = None
async for chunk in results_generator:
if not chunk.outputs:
log.warning("Empty chunk received")
continue
if chunk.outputs[-1].stop_reason:
reason = chunk.outputs[-1].stop_reason
if stop_reason is None and reason == "stop":
stop_reason = StopReason.end_of_turn
elif stop_reason is None and reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = "".join([output.text for output in chunk.outputs])
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
last_chunk_len = len(last_chunk)
last_chunk = text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text[last_chunk_len:],
stop_reason=stop_reason,
)
)
if not stop_reason:
stop_reason = StopReason.end_of_message
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
async def embeddings(
self, model: str, contents: list[InterleavedTextMedia]
) -> EmbeddingsResponse:
log.info("vLLM embeddings")
# TODO
raise NotImplementedError()

View file

@ -21,6 +21,7 @@ async def get_provider_impl(
deps[Api.inference],
deps[Api.memory],
deps[Api.safety],
deps[Api.memory_banks],
)
await impl.initialize()
return impl

View file

@ -6,6 +6,7 @@
import asyncio
import copy
import logging
import os
import re
import secrets
@ -19,11 +20,11 @@ from urllib.parse import urlparse
import httpx
from termcolor import cprint
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore
@ -42,6 +43,8 @@ from .tools.builtin import (
)
from .tools.safety import SafeTool
log = logging.getLogger(__name__)
def make_random_string(length: int = 8):
return "".join(
@ -56,6 +59,7 @@ class ChatAgent(ShieldRunnerMixin):
agent_config: AgentConfig,
inference_api: Inference,
memory_api: Memory,
memory_banks_api: MemoryBanks,
safety_api: Safety,
persistence_store: KVStore,
):
@ -63,6 +67,7 @@ class ChatAgent(ShieldRunnerMixin):
self.agent_config = agent_config
self.inference_api = inference_api
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.safety_api = safety_api
self.storage = AgentPersistence(agent_id, persistence_store)
@ -108,7 +113,7 @@ class ChatAgent(ShieldRunnerMixin):
# May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way
for m in turn.input_messages:
msg = m.copy()
msg = m.model_copy()
if isinstance(msg, UserMessage):
msg.context = None
messages.append(msg)
@ -134,7 +139,6 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason=StopReason.end_of_turn,
)
)
# print_dialog(messages)
return messages
async def create_session(self, name: str) -> str:
@ -144,6 +148,8 @@ class ChatAgent(ShieldRunnerMixin):
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
@ -151,7 +157,7 @@ class ChatAgent(ShieldRunnerMixin):
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if len(turns) == 0 and self.agent_config.instructions != "":
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(turns):
@ -180,10 +186,8 @@ class ChatAgent(ShieldRunnerMixin):
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
cprint(
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
"white",
attrs=["bold"],
)
output_message = chunk
continue
@ -392,17 +396,11 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = 0
while True:
msg = input_messages[-1]
if msg.role == Role.user.value:
color = "blue"
elif msg.role == Role.ipython.value:
color = "yellow"
else:
color = None
if len(str(msg)) > 1000:
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
else:
msg_str = str(msg)
cprint(f"{msg_str}", color=color)
log.info(f"{msg_str}")
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -419,7 +417,7 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason = None
with tracing.span("inference"):
async for chunk in self.inference_api.chat_completion(
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
@ -501,12 +499,12 @@ class ChatAgent(ShieldRunnerMixin):
)
if n_iter >= self.agent_config.max_infer_iters:
cprint("Done with MAX iterations, exiting.")
log.info("Done with MAX iterations, exiting.")
yield message
break
if stop_reason == StopReason.out_of_tokens:
cprint("Out of token budget, exiting.")
log.info("Out of token budget, exiting.")
yield message
break
@ -520,10 +518,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + attachments
yield message
else:
cprint(f"Partial message: {str(message)}", color="green")
log.info(f"Partial message: {str(message)}")
input_messages = input_messages + [message]
else:
cprint(f"{str(message)}", color="green")
log.info(f"{str(message)}")
try:
tool_call = message.tool_calls[0]
@ -635,14 +633,14 @@ class ChatAgent(ShieldRunnerMixin):
raise ValueError(f"Session {session_id} not found")
if session_info.memory_bank_id is None:
memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session_id}",
config=VectorMemoryBankConfig(
bank_id = f"memory_bank_{session_id}"
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
bank_id = memory_bank.bank_id
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id
@ -735,9 +733,8 @@ class ChatAgent(ShieldRunnerMixin):
for c in chunks[: memory.max_chunks]:
tokens += c.token_count
if tokens > memory.max_tokens_in_context:
cprint(
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
"red",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
@ -781,7 +778,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
print(f"Downloading {url} -> {filepath}")
log.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
@ -821,20 +818,3 @@ async def execute_tool_call_maybe(
tool = tools_dict[name]
result_messages = await tool.run(messages)
return result_messages
def print_dialog(messages: List[Message]):
for i, m in enumerate(messages):
if m.role == Role.user.value:
color = "red"
elif m.role == Role.assistant.value:
color = "white"
elif m.role == Role.ipython.value:
color = "yellow"
elif m.role == Role.system.value:
color = "green"
else:
color = "white"
s = str(m)
cprint(f"{i} ::: {s[:100]}...", color=color)

View file

@ -11,6 +11,7 @@ from typing import AsyncGenerator
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403
@ -30,11 +31,14 @@ class MetaReferenceAgentsImpl(Agents):
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
memory_banks_api: MemoryBanks,
):
self.config = config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.memory_banks_api = memory_banks_api
self.in_memory_store = InmemoryKVStoreImpl()
async def initialize(self) -> None:
@ -48,7 +52,7 @@ class MetaReferenceAgentsImpl(Agents):
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.json(),
value=agent_config.model_dump_json(),
)
return AgentCreateResponse(
agent_id=agent_id,
@ -81,6 +85,7 @@ class MetaReferenceAgentsImpl(Agents):
inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,
memory_banks_api=self.memory_banks_api,
persistence_store=(
self.persistence_store
if agent_config.enable_session_persistence
@ -113,16 +118,76 @@ class MetaReferenceAgentsImpl(Agents):
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
agent = await self.get_agent(agent_id)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
attachments=attachments,
stream=stream,
stream=True,
)
if stream:
return self._create_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _create_agent_turn_streaming(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event
async def get_agents_turn(
self, agent_id: str, session_id: str, turn_id: str
) -> Turn:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
turn = json.loads(turn)
turn = Turn(**turn)
return turn
async def get_agents_step(
self, agent_id: str, session_id: str, turn_id: str, step_id: str
) -> AgentStepResponse:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
turn = json.loads(turn)
turn = Turn(**turn)
steps = turn.steps
for step in steps:
if step.step_id == step_id:
return AgentStepResponse(step=step)
raise ValueError(f"Provided step_id {step_id} could not be found")
async def get_agents_session(
self,
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session:
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
session = Session(**json.loads(session), turns=[])
turns = []
if turn_ids:
for turn_id in turn_ids:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
turn = json.loads(turn)
turn = Turn(**turn)
turns.append(turn)
return Session(
session_name=session.session_name,
session_id=session_id,
turns=turns if turns else [],
started_at=session.started_at,
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
async def delete_agents(self, agent_id: str) -> None:
await self.persistence_store.delete(f"agent:{agent_id}")

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class MetaReferenceAgentsImplConfig(BaseModel):
persistence_store: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"persistence_store": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="agents_store.db",
)
}

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
import logging
import uuid
from datetime import datetime
@ -15,6 +15,8 @@ from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel):
session_id: str
@ -37,7 +39,7 @@ class AgentPersistence:
)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
value=session_info.model_dump_json(),
)
return session_id
@ -58,13 +60,13 @@ class AgentPersistence:
session_info.memory_bank_id = bank_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
value=session_info.model_dump_json(),
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.json(),
value=turn.model_dump_json(),
)
async def get_session_turns(self, session_id: str) -> List[Turn]:
@ -78,7 +80,7 @@ class AgentPersistence:
turn = Turn(**json.loads(value))
turns.append(turn)
except Exception as e:
print(f"Error parsing turn: {e}")
log.error(f"Error parsing turn: {e}")
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns

View file

@ -10,8 +10,6 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from termcolor import cprint # noqa: F401
from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
@ -36,7 +34,6 @@ async def generate_rag_query(
query = await llm_rag_query_generator(config, messages, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
# cprint(f"Generated query >>>: {query}", color="green")
return query
@ -63,13 +60,12 @@ async def llm_rag_query_generator(
model = config.model
message = UserMessage(content=content)
response = inference_api.chat_completion(
response = await inference_api.chat_completion(
model=model,
messages=[message],
stream=False,
)
async for chunk in response:
query = chunk.completion_message.content
query = response.completion_message.content
return query

View file

@ -5,14 +5,16 @@
# the root directory of this source tree.
import asyncio
import logging
from typing import List
from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from llama_stack.apis.safety import * # noqa: F403
log = logging.getLogger(__name__)
class SafetyException(Exception): # noqa: N818
def __init__(self, violation: SafetyViolation):
@ -32,18 +34,18 @@ class ShieldRunnerMixin:
self.output_shields = output_shields
async def run_multiple_shields(
self, messages: List[Message], shield_types: List[str]
self, messages: List[Message], identifiers: List[str]
) -> None:
responses = await asyncio.gather(
*[
self.safety_api.run_shield(
shield_type=shield_type,
shield_id=identifier,
messages=messages,
)
for shield_type in shield_types
for identifier in identifiers
]
)
for shield_type, response in zip(shield_types, responses):
for identifier, response in zip(identifiers, responses):
if not response.violation:
continue
@ -51,7 +53,4 @@ class ShieldRunnerMixin:
if violation.violation_level == ViolationLevel.ERROR:
raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN:
cprint(
f"[Warn]{shield_type} raised a warning",
color="red",
)
log.warning(f"[Warn]{identifier} raised a warning")

View file

@ -16,7 +16,7 @@ from llama_stack.apis.agents import * # noqa: F403
from ..agents import (
AGENT_INSTANCES_BY_ID,
MetaReferenceAgentsImpl,
MetaReferenceImplConfig,
MetaReferenceInferenceConfig,
)
@ -26,6 +26,7 @@ class MockInferenceAPI:
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -79,7 +80,7 @@ class MockInferenceAPI:
class MockSafetyAPI:
async def run_shield(
self, shield_type: str, messages: List[Message]
self, shield_id: str, messages: List[Message]
) -> RunShieldResponse:
return RunShieldResponse(violation=None)
@ -166,7 +167,7 @@ def mock_memory_api():
@pytest.fixture
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
impl = MetaReferenceAgentsImpl(
config=MetaReferenceImplConfig(),
config=MetaReferenceInferenceConfig(),
inference_api=mock_inference_api,
safety_api=mock_safety_api,
memory_api=mock_memory_api,

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import json
import logging
import re
import tempfile
@ -12,7 +13,6 @@ from abc import abstractmethod
from typing import List, Optional
import requests
from termcolor import cprint
from .ipython_tool.code_execution import (
CodeExecutionContext,
@ -27,6 +27,9 @@ from llama_stack.apis.agents import * # noqa: F403
from .base import BaseTool
log = logging.getLogger(__name__)
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match:
@ -86,10 +89,13 @@ class PhotogenTool(SingleMessageBuiltinTool):
class SearchTool(SingleMessageBuiltinTool):
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
self.api_key = api_key
self.engine_type = engine
if engine == SearchEngineType.bing:
self.engine = BingSearch(api_key, **kwargs)
elif engine == SearchEngineType.brave:
self.engine = BraveSearch(api_key, **kwargs)
elif engine == SearchEngineType.tavily:
self.engine = TavilySearch(api_key, **kwargs)
else:
raise ValueError(f"Unknown search engine: {engine}")
@ -257,6 +263,21 @@ class BraveSearch:
return {"query": query, "top_k": clean_response}
class TavilySearch:
def __init__(self, api_key: str) -> None:
self.api_key = api_key
async def search(self, query: str) -> str:
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": self.api_key, "query": query},
)
return json.dumps(self._clean_tavily_response(response.json()))
def _clean_tavily_response(self, search_response, top_k=3):
return {"query": search_response["query"], "top_k": search_response["results"]}
class WolframAlphaTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None:
self.api_key = api_key
@ -365,7 +386,7 @@ class CodeInterpreterTool(BaseTool):
if res_out != "":
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
if out_type == "stderr":
cprint(f"ipython tool error: ↓\n{res_out}", color="red")
log.error(f"ipython tool error: ↓\n{res_out}")
message = ToolResponseMessage(
call_id=tool_call.call_id,

View file

@ -11,6 +11,7 @@ A custom Matplotlib backend that overrides the show method to return image bytes
import base64
import io
import json as _json
import logging
import matplotlib
from matplotlib.backend_bases import FigureManagerBase
@ -18,6 +19,8 @@ from matplotlib.backend_bases import FigureManagerBase
# Import necessary components from Matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg
log = logging.getLogger(__name__)
class CustomFigureCanvas(FigureCanvasAgg):
def show(self):
@ -80,7 +83,7 @@ def show():
)
req_con.send_bytes(_json_dump.encode("utf-8"))
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
print(resp)
log.info(resp)
FigureCanvas = CustomFigureCanvas

View file

@ -9,8 +9,7 @@ from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
from ..safety import ShieldRunnerMixin
from .builtin import BaseTool

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import LocalFSDatasetIOConfig
async def get_provider_impl(
config: LocalFSDatasetIOConfig,
_deps,
):
from .datasetio import LocalFSDatasetIOImpl
impl = LocalFSDatasetIOImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.datasetio import * # noqa: F401, F403
class LocalFSDatasetIOConfig(BaseModel): ...

View file

@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from abc import ABC, abstractmethod
from dataclasses import dataclass
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from .config import LocalFSDatasetIOConfig
class BaseDataset(ABC):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError()
@abstractmethod
def __getitem__(self, idx):
raise NotImplementedError()
@abstractmethod
def load(self):
raise NotImplementedError()
@dataclass
class DatasetInfo:
dataset_def: Dataset
dataset_impl: BaseDataset
class PandasDataframeDataset(BaseDataset):
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.dataset_def = dataset_def
self.df = None
def __len__(self) -> int:
assert self.df is not None, "Dataset not loaded. Please call .load() first"
return len(self.df)
def __getitem__(self, idx):
assert self.df is not None, "Dataset not loaded. Please call .load() first"
if isinstance(idx, slice):
return self.df.iloc[idx].to_dict(orient="records")
else:
return self.df.iloc[idx].to_dict()
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
# note that we will drop any columns in dataset that are not in the schema
df = df[self.dataset_def.dataset_schema.keys()]
# check all columns in dataset schema are present
assert len(df.columns) == len(self.dataset_def.dataset_schema)
# TODO: type checking against column types in dataset schema
return df
def load(self) -> None:
if self.df is not None:
return
df = get_dataframe_from_url(self.dataset_def.url)
if df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
self.df = self._validate_dataset_schema(df)
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: LocalFSDatasetIOConfig) -> None:
self.config = config
# local registry for keeping track of datasets within the provider
self.dataset_infos = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def register_dataset(
self,
dataset: Dataset,
) -> None:
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
)
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
dataset_info = self.dataset_infos.get(dataset_id)
dataset_info.dataset_impl.load()
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
next_page_token = int(page_token)
start = next_page_token
if rows_in_page == -1:
end = len(dataset_info.dataset_impl)
else:
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
rows = dataset_info.dataset_impl[start:end]
return PaginatedRowsResult(
rows=rows,
total_count=len(rows),
next_page_token=str(end),
)

View file

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceEvalConfig
async def get_provider_impl(
config: MetaReferenceEvalConfig,
deps: Dict[Api, ProviderSpec],
):
from .eval import MetaReferenceEvalImpl
impl = MetaReferenceEvalImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
deps[Api.scoring],
deps[Api.inference],
deps[Api.agents],
)
await impl.initialize()
return impl

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from pydantic import BaseModel
class MetaReferenceEvalConfig(BaseModel):
kvstore: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "meta_reference_eval.db").as_posix()
) # Uses SQLite config specific to Meta Reference Eval storage

View file

@ -0,0 +1,270 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from tqdm import tqdm
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:"
class ColumnName(Enum):
input_query = "input_query"
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
generated_answer = "generated_answer"
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
def __init__(
self,
config: MetaReferenceEvalConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
scoring_api: Scoring,
inference_api: Inference,
agents_api: Agents,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.scoring_api = scoring_api
self.inference_api = inference_api
self.agents_api = agents_api
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
self.eval_tasks = {}
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing eval_tasks from kvstore
start_key = EVAL_TASKS_PREFIX
end_key = f"{EVAL_TASKS_PREFIX}\xff"
stored_eval_tasks = await self.kvstore.range(start_key, end_key)
for eval_task in stored_eval_tasks:
eval_task = EvalTask.model_validate_json(eval_task)
self.eval_tasks[eval_task.identifier] = eval_task
async def shutdown(self) -> None: ...
async def register_eval_task(self, task_def: EvalTask) -> None:
# Store in kvstore
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
await self.kvstore.set(
key=key,
value=task_def.model_dump_json(),
)
self.eval_tasks[task_def.identifier] = task_def
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
expected_schemas = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
]
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def run_eval(
self,
task_id: str,
task_config: EvalTaskConfig,
) -> Job:
task_def = self.eval_tasks[task_id]
dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=(
-1 if task_config.num_examples is None else task_config.num_examples
),
)
res = await self.evaluate_rows(
task_id=task_id,
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
task_config=task_config,
)
# TODO: currently needs to wait for generation before returning
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs))
self.jobs[job_id] = res
return Job(job_id=job_id)
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
) -> List[Dict[str, Any]]:
candidate = task_config.eval_candidate
create_response = await self.agents_api.create_agent(candidate.config)
agent_id = create_response.agent_id
generations = []
for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
input_messages = [UserMessage(**x) for x in input_messages]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session(
agent_id, f"session-{i}"
)
session_id = session_create_response.session_id
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=input_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await self.agents_api.create_agent_turn(
**turn_request
)
]
final_event = turn_response[-1].event.payload
generations.append(
{
ColumnName.generated_answer.value: final_event.turn.output_message.content
}
)
return generations
async def _run_model_generation(
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
) -> List[Dict[str, Any]]:
candidate = task_config.eval_candidate
assert (
candidate.sampling_params.max_tokens is not None
), "SamplingParams.max_tokens must be provided"
generations = []
for x in tqdm(input_rows):
if ColumnName.completion_input.value in x:
input_content = eval(str(x[ColumnName.completion_input.value]))
response = await self.inference_api.completion(
model=candidate.model,
content=input_content,
sampling_params=candidate.sampling_params,
)
generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_str = str(
x[ColumnName.chat_completion_input.value]
)
input_messages = eval(chat_completion_input_str)
input_messages = [UserMessage(**x) for x in input_messages]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages += input_messages
response = await self.inference_api.chat_completion(
model_id=candidate.model,
messages=messages,
sampling_params=candidate.sampling_params,
)
generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
else:
raise ValueError("Invalid input row")
return generations
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse:
candidate = task_config.eval_candidate
if candidate.type == "agent":
generations = await self._run_agent_generation(input_rows, task_config)
elif candidate.type == "model":
generations = await self._run_model_generation(input_rows, task_config)
else:
raise ValueError(f"Invalid candidate type: {candidate.type}")
# scoring with generated_answer
score_input_rows = [
input_r | generated_r
for input_r, generated_r in zip(input_rows, generations)
]
if task_config.type == "app" and task_config.scoring_params is not None:
scoring_functions_dict = {
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
)
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]:
if job_id in self.jobs:
return JobStatus.completed
return None
async def job_cancel(self, task_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse:
status = await self.job_status(task_id, job_id)
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")
return self.jobs[job_id]

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceImplConfig # noqa
from typing import Union
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
_deps,
):
from .inference import MetaReferenceInferenceImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceInferenceConfig(BaseModel):
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
# (including our testing code) who might be using llama-stack as a library.
create_distributed_process_group: bool = True
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
# can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model
@property
def model_parallel_size(self) -> int:
resolved = resolve_model(self.model)
return resolved.pth_file_count
@classmethod
def sample_run_config(
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
**kwargs,
) -> Dict[str, Any]:
return {
"model": model,
"max_seq_len": 4096,
"checkpoint_dir": checkpoint_dir,
}
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
quantization: QuantizationConfig
@classmethod
def sample_run_config(
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
**kwargs,
) -> Dict[str, Any]:
config = super().sample_run_config(model, checkpoint_dir, **kwargs)
config["quantization"] = {
"type": "fp8",
}
return config

View file

@ -8,12 +8,13 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import logging
import math
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List, Optional
from typing import Generator, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -24,24 +25,32 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolPromptFormat,
)
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from termcolor import cprint
from pydantic import BaseModel
from llama_stack.apis.inference import QuantizationType
from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
)
from .config import MetaReferenceImplConfig
from .config import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
MetaReferenceInferenceConfig,
MetaReferenceQuantizedInferenceConfig,
)
log = logging.getLogger(__name__)
def model_checkpoint_dir(model) -> str:
@ -58,8 +67,7 @@ def model_checkpoint_dir(model) -> str:
return str(checkpoint_dir)
@dataclass
class TokenResult:
class TokenResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
@ -67,7 +75,11 @@ class TokenResult:
class Llama:
@staticmethod
def build(config: MetaReferenceImplConfig):
def build(
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
],
):
"""
Build a Llama instance by initializing and loading a model checkpoint.
@ -76,15 +88,7 @@ class Llama:
and loads the pre-trained model and tokenizer.
"""
model = resolve_model(config.model)
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
llama_model = model.core_model_id.value
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
@ -105,7 +109,10 @@ class Llama:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_dir = model_checkpoint_dir(model)
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
ckpt_dir = model_checkpoint_dir(model)
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
@ -126,31 +133,48 @@ class Llama:
**params,
)
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
tokenizer = Tokenizer(model_path=tokenizer_path)
tokenizer = Tokenizer.get_instance()
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
fp8 = (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
)
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
from .quantization.loader import convert_to_fp8_quantized_model
if fp8:
from .quantization.loader import convert_to_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
elif isinstance(config.quantization, Int4QuantizationConfig):
from .quantization.loader import convert_to_int4_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_quantized_model(model, config)
model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True)
if (
model_args.quantization_args is not None
and model_args.quantization_args.spinquant
):
# Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict.
from .quantization.hadamard_utils import (
add_hadamard_transform_for_spinquant,
)
add_hadamard_transform_for_spinquant(model)
else:
raise NotImplementedError(
"Currently int4 and fp8 are the only supported quantization methods."
)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
@ -163,14 +187,21 @@ class Llama:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
def __init__(
self,
model: Transformer,
tokenizer: Tokenizer,
args: ModelArgs,
llama_model: str,
):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
self.llama_model = llama_model
@torch.inference_mode()
def generate(
@ -182,14 +213,17 @@ class Llama:
logprobs: bool = False,
echo: bool = False,
include_stop_token: bool = False,
print_input_tokens: bool = False,
logits_processor: Optional["LogitsProcessor"] = None,
) -> Generator:
params = self.model.params
# input_tokens = [
# self.formatter.vision_token if t == 128256 else t
# for t in model_input.tokens
# ]
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
if print_input_tokens:
input_tokens = [
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
prompt_tokens = [model_input.tokens]
bsz = 1
@ -199,9 +233,7 @@ class Llama:
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
)
log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}")
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
@ -240,8 +272,7 @@ class Llama:
ignore_index=pad_id,
)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(
@ -257,6 +288,9 @@ class Llama:
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if logits_processor is not None:
logits = logits_processor.process_logits(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
@ -307,15 +341,12 @@ class Llama:
if all(eos_reached):
break
def text_completion(
def completion(
self,
content: InterleavedTextMedia,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
request: CompletionRequest,
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
@ -323,26 +354,32 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
model_input = self.formatter.encode_content(content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
logprobs=bool(request.logprobs),
include_stop_token=True,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
)
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
request: ChatCompletionRequest,
) -> Generator:
messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
@ -353,13 +390,18 @@ class Llama:
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
messages,
tool_prompt_format,
request.tool_prompt_format,
),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
logprobs=bool(request.logprobs),
include_stop_token=True,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
)
@ -386,3 +428,64 @@ def sample_top_p(probs, p):
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(
self, tokens: torch.Tensor, scores: torch.Tensor
) -> torch.Tensor:
token_sequence = tokens[0, :].tolist()
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
if self.mask is not None:
self.mask.fill_(-math.inf)
else:
self.mask = torch.full_like(scores, -math.inf)
self.mask[:, :, allowed_tokens] = 0
scores = scores + self.mask
return scores
def get_logits_processor(
tokenizer: Tokenizer,
vocab_size: int,
response_format: Optional[ResponseFormat],
) -> Optional["LogitsProcessor"]:
if response_format is None:
return None
if response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Unsupported response format type {response_format.type}")
parser = JsonSchemaParser(response_format.json_schema)
data = TokenEnforcerTokenizerData(
_build_regular_tokens_list(tokenizer, vocab_size),
tokenizer.decode,
tokenizer.stop_tokens,
)
token_enforcer = TokenEnforcer(data, parser)
return LogitsProcessor(token_enforcer)
def _build_regular_tokens_list(
tokenizer: Tokenizer, vocab_size: int
) -> List[Tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []
special_token_ids = set(tokenizer.special_tokens.values())
for token_idx in range(vocab_size):
if token_idx in special_token_ids:
continue
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
decoded_regular = tokenizer.decode([token_idx])
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
return regular_tokens

View file

@ -0,0 +1,430 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import logging
from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
ModelRegistryHelper.__init__(
self,
[
build_model_alias(
model.descriptor(),
model.core_model_id.value,
)
],
)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
# verify that the checkpoint actually is for this model lol
async def initialize(self) -> None:
log.info(f"Loading model `{self.model.descriptor()}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
else:
self.generator = Llama.build(self.config)
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
def check_model(self, request) -> None:
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
if request.stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
def impl():
stop_reason = None
for token_result in self.generator.completion(request):
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
logprobs = None
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
stop_reason=StopReason.out_of_tokens,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
def impl():
tokens = []
logprobs = []
stop_reason = None
tokenizer = self.generator.formatter.tokenizer
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.token in tokenizer.stop_tokens:
# not quite right semantically
stop_reason = StopReason.end_of_turn
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
content = self.generator.formatter.tokenizer.decode(tokens)
return CompletionResponse(
content=content,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
def impl():
tokens = []
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
tokens = []
logprobs = []
stop_reason = None
ipython = False
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def request_with_localized_media(
request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequest, CompletionRequest]:
if not request_has_media(request):
return request
async def _convert_single_content(content):
if isinstance(content, ImageMedia):
url = await convert_image_media_to_url(content, download=True)
return ImageMedia(image=URL(uri=url))
else:
return content
async def _convert_content(content):
if isinstance(content, list):
return [await _convert_single_content(c) for c in content]
else:
return await _convert_single_content(content)
if isinstance(request, ChatCompletionRequest):
for m in request.messages:
m.content = await _convert_content(m.content)
else:
request.content = await _convert_content(request.content)
return request

View file

@ -6,47 +6,35 @@
import os
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
from typing import Any, Generator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir
from .parallel_utils import ModelParallelProcessGroup
@dataclass
class InferenceArgs:
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
class ModelRunner:
def __init__(self, llama):
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: InferenceArgs):
return self.llama.chat_completion(
task.messages,
task.temperature,
task.top_p,
task.max_gen_len,
task.logprobs,
task.tool_prompt_format,
)
def __call__(self, req: Any):
if isinstance(req, ChatCompletionRequest):
return self.llama.chat_completion(req)
elif isinstance(req, CompletionRequest):
return self.llama.completion(req)
else:
raise ValueError(f"Unexpected task type {type(req)}")
def init_model_cb(config: MetaReferenceImplConfig):
def init_model_cb(config: MetaReferenceInferenceConfig):
llama = Llama.build(config)
return ModelRunner(llama)
@ -62,7 +50,7 @@ class LlamaModelParallelGenerator:
clear at the callsite why we need to use a context manager.
"""
def __init__(self, config: MetaReferenceImplConfig):
def __init__(self, config: MetaReferenceInferenceConfig):
self.config = config
self.model = resolve_model(self.config.model)
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
@ -88,23 +76,18 @@ class LlamaModelParallelGenerator:
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
def chat_completion(
def completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
request: CompletionRequest,
) -> Generator:
req_obj = InferenceArgs(
messages=deepcopy(messages),
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
tool_prompt_format=tool_prompt_format,
)
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
yield from gen
def chat_completion(
self,
request: ChatCompletionRequest,
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
yield from gen

View file

@ -4,17 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import multiprocessing
import os
import pickle
import tempfile
import time
import uuid
from typing import Callable, Generator
from enum import Enum
from typing import Callable, Generator, Literal, Optional, Union
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
@ -23,17 +29,99 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank,
)
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult
log = logging.getLogger(__name__)
_END_SENTINEL = "__end_sentinel__"
_CANCEL_SENTINEL = "__cancel_sentinel__"
class ProcessingMessageName(str, Enum):
ready_request = "ready_request"
ready_response = "ready_response"
end_sentinel = "end_sentinel"
cancel_sentinel = "cancel_sentinel"
task_request = "task_request"
task_response = "task_response"
exception_response = "exception_response"
class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = (
ProcessingMessageName.ready_request
)
class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = (
ProcessingMessageName.ready_response
)
class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = (
ProcessingMessageName.end_sentinel
)
class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = (
ProcessingMessageName.cancel_sentinel
)
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
task: Union[CompletionRequest, ChatCompletionRequest]
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = (
ProcessingMessageName.task_response
)
result: TokenResult
class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = (
ProcessingMessageName.exception_response
)
error: str
ProcessingMessage = Union[
ReadyRequest,
ReadyResponse,
EndSentinel,
CancelSentinel,
TaskRequest,
TaskResponse,
ExceptionResponse,
]
class ProcessingMessageWrapper(BaseModel):
payload: Annotated[
ProcessingMessage,
Field(discriminator="type"),
]
def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0
def encode_msg(msg: ProcessingMessage) -> bytes:
return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
def retrieve_requests(reply_socket_url: str):
if mp_rank_0():
context = zmq.Context()
@ -46,21 +134,24 @@ def retrieve_requests(reply_socket_url: str):
time.sleep(0.01)
continue
reply_socket.send_multipart([client_id, pickle.dumps("YES READY")])
ready_response = ReadyResponse()
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
break
def send_obj(obj):
reply_socket.send_multipart([client_id, pickle.dumps(obj)])
def send_obj(obj: ProcessingMessage):
reply_socket.send_multipart([client_id, encode_msg(obj)])
while True:
tasks = [None]
if mp_rank_0():
client_id, task = maybe_get_work(reply_socket)
# there is still an unknown unclean GeneratorExit happening resulting in a
# cancel sentinel getting queued _after_ we have finished sending everything :/
# kind of a hack this is :/
if task != _CANCEL_SENTINEL:
tasks = [task]
client_id, maybe_task_json = maybe_get_work(reply_socket)
if maybe_task_json is not None:
task = maybe_parse_message(maybe_task_json)
# there is still an unknown unclean GeneratorExit happening resulting in a
# cancel sentinel getting queued _after_ we have finished sending everything :/
# kind of a hack this is :/
if task is not None and not isinstance(task, CancelSentinel):
tasks = [task]
torch.distributed.broadcast_object_list(
tasks,
@ -80,35 +171,36 @@ def retrieve_requests(reply_socket_url: str):
for obj in out:
updates = [None]
if mp_rank_0():
_, update = maybe_get_work(reply_socket)
if update == _CANCEL_SENTINEL:
_, update_json = maybe_get_work(reply_socket)
update = maybe_parse_message(update_json)
if isinstance(update, CancelSentinel):
updates = [update]
else:
# only send the update if it's not cancelled otherwise the object sits in the socket
# and gets pulled in the next request lol
send_obj(obj)
send_obj(TaskResponse(result=obj))
torch.distributed.broadcast_object_list(
updates,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
if updates[0] == _CANCEL_SENTINEL:
print("quitting generation loop because request was cancelled")
if isinstance(updates[0], CancelSentinel):
log.info(
"quitting generation loop because request was cancelled"
)
break
if mp_rank_0():
send_obj(_END_SENTINEL)
send_obj(EndSentinel())
except Exception as e:
print(f"[debug] got exception {e}")
import traceback
log.exception("exception in generation loop")
traceback.print_exc()
if mp_rank_0():
send_obj(e)
send_obj(ExceptionResponse(error=str(e)))
if mp_rank_0():
send_obj("DONE")
send_obj(EndSentinel())
def maybe_get_work(sock: zmq.Socket):
@ -116,7 +208,7 @@ def maybe_get_work(sock: zmq.Socket):
client_id = None
try:
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
message = pickle.loads(obj)
message = obj.decode("utf-8")
except zmq.ZMQError as e:
if e.errno != zmq.EAGAIN:
raise e
@ -124,6 +216,22 @@ def maybe_get_work(sock: zmq.Socket):
return client_id, message
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
if maybe_json is None:
return None
try:
return parse_message(maybe_json)
except json.JSONDecodeError:
return None
except ValueError as e:
return None
def parse_message(json_str: str) -> ProcessingMessage:
data = json.loads(json_str)
return ProcessingMessageWrapper(**data).payload
def worker_process_entrypoint(
reply_socket_url: str,
init_model_cb: Callable,
@ -142,11 +250,12 @@ def worker_process_entrypoint(
if isinstance(task, str) and task == _END_SENTINEL:
break
result = model(task)
assert isinstance(task, TaskRequest)
result = model(task.task)
except StopIteration:
break
print("[debug] worker process done")
log.info("[debug] worker process done")
def launch_dist_group(
@ -205,9 +314,9 @@ def start_model_parallel_process(
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send_pyobj("READY?")
response = request_socket.recv_pyobj()
print(f"Finished model load {response}")
request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv()
log.info("Loaded model...")
return request_socket, process
@ -235,31 +344,38 @@ class ModelParallelProcessGroup:
def stop(self):
assert self.started, "process group not started"
if self.process.is_alive():
self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK)
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
self.process.join()
self.started = False
def run_inference(self, request) -> Generator:
def run_inference(
self, req: Union[CompletionRequest, ChatCompletionRequest]
) -> Generator:
assert not self.running, "inference already running"
self.running = True
self.request_socket.send_pyobj(request)
self.request_socket.send(encode_msg(TaskRequest(task=req)))
try:
while True:
obj = self.request_socket.recv_pyobj()
if obj == _END_SENTINEL:
obj_json = self.request_socket.recv()
obj = parse_message(obj_json)
if isinstance(obj, EndSentinel):
break
if isinstance(obj, Exception):
print(f"[debug] got exception {obj}")
raise obj
if isinstance(obj, ExceptionResponse):
log.error(f"[debug] got exception {obj.error}")
raise Exception(obj.error)
if isinstance(obj, TaskResponse):
yield obj.result
yield obj
except GeneratorExit as e:
self.request_socket.send_pyobj(_CANCEL_SENTINEL)
self.request_socket.send(encode_msg(CancelSentinel()))
while True:
obj = self.request_socket.recv_pyobj()
if obj == _END_SENTINEL:
obj_json = self.request_socket.send()
obj = parse_message(obj_json)
if isinstance(obj, EndSentinel):
break
finally:
self.running = False

View file

@ -8,14 +8,20 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections
import logging
from typing import Optional, Type
log = logging.getLogger(__name__)
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
print("Using efficient FP8 operators in FBGEMM.")
log.info("Using efficient FP8 operators in FBGEMM.")
except ImportError:
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
log.error(
"No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt."
)
raise
import torch

View file

@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import math
import re
import torch
from torch import nn
def hadamard_transform(x: torch.Tensor) -> torch.Tensor:
"""Hadamard transform.
This function performs the Hadamard transform on the input tensor 'x'.
The Hadamard transform is a linear transformation that multiplies the input
tensor by the Hadamard matrix of dimension n x n, where n is the size of
the last dimension of the input tensor.
"""
*_, n = x.shape
m = int(math.log2(n))
assert n == 1 << m, "n must be a power of 2"
x = x[..., None]
inv_sqrt2 = 0.5**0.5
for _ in range(m):
top = x[..., ::2, :] + x[..., 1::2, :]
bot = x[..., ::2, :] - x[..., 1::2, :]
x = torch.cat((top, bot), dim=-1)
x *= inv_sqrt2
res = x.squeeze(-2)
return res
class HadamardModule(torch.nn.Module):
"""A module that applies the Hadamard transform to the input tensor.
Args:
group_size: The size of the groups that the input tensor will be divided into
before applying the Hadamard transform.
"""
def __init__(self, group_size: int) -> None:
super().__init__()
self.group_size = group_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
reshape_back = False
orig_shape = x.shape
if self.group_size != x.shape[-1]:
reshape_back = True
x = x.reshape(-1, x.shape[-1] // self.group_size, self.group_size)
x = hadamard_transform(x)
if reshape_back:
x = x.reshape(orig_shape)
return x
def add_hadamard_transform_for_spinquant(
model: torch.nn.Module, prefix: str = ""
) -> None:
"""
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
This function recursively traverses the model's children and looks for layers that match the pattern
"layers.<digit>.feed_forward.w2", where <digit> is one or more digits. When such a layer is found,
it is replaced with a new sequential module that consists of a HadamardModule followed by the original
layer. The HadamardModule applies the Hadamard transform to the input tensor.
See `SpinQuant <https://arxiv.org/abs/2405.16406>_` paper for more details.
Args:
model: An instance of 'torch.nn.Module' (e.g., Transformer model).
prefix: A string prefix to add to the full name of each child module.
Returns:
None
"""
pattern_last_linear_ffn = r"layers.\d+.feed_forward.w2"
for module_name, module in model.named_children():
child_full_name = prefix + "." + module_name
if re.search(pattern_last_linear_ffn, child_full_name):
new_module = nn.Sequential(
HadamardModule(group_size=module.in_features), module
)
del module
setattr(model, module_name, new_module)
else:
add_hadamard_transform_for_spinquant(
module, (prefix + "." if prefix else prefix) + module_name
)

View file

@ -0,0 +1,340 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import logging
import os
from typing import Any, Dict, List, Optional
import torch
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from ..config import MetaReferenceQuantizedInferenceConfig
log = logging.getLogger(__name__)
def swiglu_wrapper(
self,
x: Tensor,
):
from .fp8_impls import ffn_swiglu
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
def convert_to_fp8_quantized_model(
model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
# Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join(
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
else:
log.info("Quantizing fp8 weights from bf16...")
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
return model
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
"""
Int8DynActInt4WeightLinear with LoRA adaptor.
Args:
in_features: Number of input features.
out_features: Number of output features.
bias: Whether to use bias.
device: Device to use.
group_size: Group size for quantization.
precision: Precision of quantization.
scales_precision: Precision of scales.
lora_rank: Rank of LoRA adaptor.
lora_scale: Scale of LoRA adaptor.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias=False,
device=None,
# quantization parameters
group_size: int = 256,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
# LoRA parameters
lora_rank: Optional[int] = None,
lora_scale: Optional[float] = None,
) -> None:
super().__init__(
in_features,
out_features,
bias=bias,
device=device,
groupsize=group_size,
precision=precision,
scales_precision=scales_precision,
)
if lora_rank is not None:
assert lora_scale is not None, "Please specify lora scale for LoRA."
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
self.adaptor = nn.Sequential()
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
self.lora_scale = lora_scale
else:
self.adaptor = None
self.lora_scale = None
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized weights from the state dict."""
if prefix + "zeros" not in state_dict:
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
assert prefix + "scales" in state_dict
state_dict[prefix + "zeros"] = torch.zeros_like(
state_dict[prefix + "scales"]
)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
module_out = super().forward(input_)
if self.adaptor is not None:
adaptor_out = self.adaptor(input_) * self.lora_scale
return module_out + adaptor_out
return module_out
class Int8WeightEmbedding(torch.nn.Embedding):
"""An embedding layer to load int8 weights.
Args:
num_embeddings: Number of embeddings.
embedding_dim: Embedding dimension.
padding_idx: Padding index.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
device=None,
) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized embedding weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
scales = state_dict.pop(prefix + "scales")
state_dict[prefix + "weight"] = weights * scales
class Int8WeightLinear(torch.nn.Linear):
"""A linear layer to load int8 weights.
Args:
in_features: Number of input features.
out_features: Number of output features.
bias: Whether to use bias.
"""
def __init__(
self, in_features: int, out_features: int, bias: bool = True, device=None
) -> None:
super().__init__(in_features, out_features, bias, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized linear weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
scales = state_dict.pop(prefix + "scales")
state_dict[prefix + "weight"] = weights * scales
def _prepare_model_int4_weight_int8_dynamic_activation(
model: torch.nn.Module,
group_size: int,
lora_rank: Optional[int],
lora_scale: Optional[float],
):
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
Note that the weights of embedding and output layers are quantized to int8.
"""
device = None
for module_name, module in model.named_children():
if module_name == "output":
quantized_module = Int8WeightLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias,
device=device,
)
del module
setattr(model, module_name, quantized_module)
elif module_name == "tok_embeddings":
quantized_module = Int8WeightEmbedding(
num_embeddings=module.num_embeddings,
embedding_dim=module.embedding_dim,
padding_idx=module.padding_idx,
device=device,
)
del module
setattr(model, module_name, quantized_module)
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
quantized_module = Int8DynActInt4WeightLinearLoRA(
in_features=module.in_features,
out_features=module.out_features,
bias=False,
group_size=group_size,
lora_rank=lora_rank,
lora_scale=lora_scale,
device=device,
)
del module
setattr(model, module_name, quantized_module)
else:
_prepare_model_int4_weight_int8_dynamic_activation(
module, group_size, lora_rank, lora_scale
)
return model
def convert_to_int4_quantized_model(
model: Transformer,
model_args: ModelArgs,
config: MetaReferenceQuantizedInferenceConfig,
) -> Transformer:
"""Convert the model to int4 quantized model."""
if model_args.quantization_args is None:
raise ValueError("'quantization_args' cannot be None. Please specify it.")
quantization_args = model_args.quantization_args
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
raise NotImplementedError(
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
)
group_size = model_args.quantization_args.group_size
if group_size is None:
raise ValueError(
"'group_size' cannot be None in 'quantization_args'. Please specify it."
)
if model_args.lora_args is None:
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
lora_rank = None
lora_scale = None
else:
lora_rank = model_args.lora_args.rank
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(
model, group_size, lora_rank, lora_scale
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return model.to(device)

View file

@ -1,5 +1,11 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
if [[ $# -ne 1 ]]; then
echo "Error: Please provide the name of CONDA environment you wish to create"
exit 1

View file

@ -8,6 +8,7 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import logging
import os
import shutil
import sys
@ -22,12 +23,18 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
from llama.model import ModelArgs, Transformer, TransformerBlock
from llama.tokenizer import Tokenizer
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from torch.nn.parameter import Parameter
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
quantize_fp8,
)
log = logging.getLogger(__name__)
def main(
ckpt_dir: str,
@ -36,7 +43,6 @@ def main(
max_seq_len: Optional[int] = 512,
max_batch_size: Optional[int] = 4,
model_parallel_size: Optional[int] = None,
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
fp8_activation_scale_ub: Optional[float] = 1200.0,
seed: int = 1,
):
@ -99,7 +105,7 @@ def main(
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
print(ckpt_path)
log.info(ckpt_path)
assert (
quantized_ckpt_dir is not None
), "QUantized checkpoint directory should not be None"
@ -112,7 +118,6 @@ def main(
fp8_weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
@ -124,7 +129,6 @@ def main(
fp8_weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
@ -136,7 +140,6 @@ def main(
fp8_weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():

View file

@ -9,7 +9,7 @@
set -euo pipefail
set -x
cd $(git rev-parse --show-toplevel)
cd $(dirname "$(realpath "$0")")
MASTER_HOST=$1
RUN_ID=$2
@ -21,7 +21,7 @@ NPROC=$7
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-models:/home/$USER/llama-stack" \
torchrun \
--nnodes=$NNODES --nproc_per_node=$NPROC \
--rdzv_id=$RUN_ID \

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import VLLMConfig

View file

@ -15,20 +15,44 @@ class VLLMConfig(BaseModel):
"""Configuration for the vLLM inference provider."""
model: str = Field(
default="Llama3.1-8B-Instruct",
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
tensor_parallel_size: int = Field(
default=1,
description="Number of tensor parallel replicas (number of GPUs to use).",
)
max_tokens: int = Field(
default=4096,
description="Maximum number of tokens to generate.",
)
enforce_eager: bool = Field(
default=False,
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
)
gpu_memory_utilization: float = Field(
default=0.3,
)
@classmethod
def sample_run_config(cls):
return {
"model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.MAX_TOKENS:4096}",
"enforce_eager": "${env.ENFORCE_EAGER:False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.7}",
}
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)

View file

@ -0,0 +1,225 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import uuid
from typing import AsyncGenerator, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import VLLMConfig
log = logging.getLogger(__name__)
def _random_uuid() -> str:
return str(uuid.uuid4().hex)
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
"""Inference implementation for vLLM."""
def __init__(self, config: VLLMConfig):
self.config = config
self.engine = None
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self):
log.info("Initializing vLLM inference adapter")
# Disable usage stats reporting. This would be a surprising thing for most
# people to find out was on by default.
# https://docs.vllm.ai/en/latest/serving/usage_stats.html
if "VLLM_NO_USAGE_STATS" not in os.environ:
os.environ["VLLM_NO_USAGE_STATS"] = "1"
model = resolve_model(self.config.model)
if model is None:
raise ValueError(f"Unknown model {self.config.model}")
if model.huggingface_repo is None:
raise ValueError(f"Model {self.config.model} needs a huggingface repo")
# TODO -- there are a ton of options supported here ...
engine_args = AsyncEngineArgs(
model=model.huggingface_repo,
tokenizer=model.huggingface_repo,
tensor_parallel_size=self.config.tensor_parallel_size,
enforce_eager=self.config.enforce_eager,
gpu_memory_utilization=self.config.gpu_memory_utilization,
guided_decoding_backend="lm-format-enforcer",
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
async def shutdown(self):
"""Shutdown the vLLM inference adapter."""
log.info("Shutting down vLLM inference adapter")
if self.engine:
self.engine.shutdown_background_loop()
async def register_model(self, model: Model) -> None:
raise ValueError(
"You cannot dynamically add a model to a running vllm instance"
)
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
if sampling_params is None:
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
# TODO convert what I saw in my first test ... but surely there's more to do here
kwargs = {
"temperature": sampling_params.temperature,
"max_tokens": self.config.max_tokens,
}
if sampling_params.top_k:
kwargs["top_k"] = sampling_params.top_k
if sampling_params.top_p:
kwargs["top_p"] = sampling_params.top_p
if sampling_params.max_tokens:
kwargs["max_tokens"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
return VLLMSamplingParams(**kwargs)
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
log.info("vLLM completion")
messages = [UserMessage(content=content)]
return self.chat_completion(
model=model_id,
messages=messages,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
log.info("vLLM chat completion")
assert self.engine is not None
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = chat_completion_request_to_prompt(request, self.formatter)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id
)
if stream:
return self._stream_chat_completion(request, results_generator)
else:
return await self._nonstream_chat_completion(request, results_generator)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> ChatCompletionResponse:
outputs = [o async for o in results_generator]
final_output = outputs[-1]
assert final_output is not None
outputs = final_output.outputs
finish_reason = outputs[-1].stop_reason
choice = OpenAICompatCompletionChoice(
finish_reason=finish_reason,
text="".join([output.text for output in outputs]),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> AsyncGenerator:
async def _generate_and_convert_to_openai_compat():
cur = []
async for chunk in results_generator:
if not chunk.outputs:
log.warning("Empty chunk received")
continue
output = chunk.outputs[-1]
new_tokens = output.token_ids[len(cur) :]
text = self.formatter.tokenizer.decode(new_tokens)
cur.extend(new_tokens)
choice = OpenAICompatCompletionChoice(
finish_reason=output.finish_reason,
text=text,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
async def embeddings(
self, model_id: str, contents: list[InterleavedTextMedia]
) -> EmbeddingsResponse:
log.info("vLLM embeddings")
# TODO
raise NotImplementedError()

View file

@ -34,6 +34,10 @@ public class LocalInference: Inference {
}
}
public func stop() {
runnerHolder.runner?.stop()
}
public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
return AsyncStream { continuation in
runnerQueue.async {

View file

@ -81,7 +81,9 @@ func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPay
switch (m.content) {
case .case1(let c):
prompt += _processContent(c)
case .case2(let c):
case .ImageMedia(let c):
prompt += _processContent(c)
case .case3(let c):
prompt += _processContent(c)
}
case .CompletionMessage(let m):

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
@json_schema_type
class FaissImplConfig(BaseModel):
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="faiss_store.db",
)
}

View file

@ -0,0 +1,209 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import io
import json
import logging
from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
)
from llama_stack.providers.utils.telemetry import tracing
from .config import FaissImplConfig
logger = logging.getLogger(__name__)
MEMORY_BANKS_PREFIX = "memory_banks:v1::"
class FaissIndex(EmbeddingIndex):
id_by_index: Dict[int, str]
chunk_by_index: Dict[int, str]
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
self.index = faiss.IndexFlatL2(dimension)
self.id_by_index = {}
self.chunk_by_index = {}
self.kvstore = kvstore
self.bank_id = bank_id
@classmethod
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
instance = cls(dimension, kvstore, bank_id)
await instance.initialize()
return instance
async def initialize(self) -> None:
if not self.kvstore:
return
index_key = f"faiss_index:v1::{self.bank_id}"
stored_data = await self.kvstore.get(index_key)
if stored_data:
data = json.loads(stored_data)
self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()}
self.chunk_by_index = {
int(k): Chunk.model_validate_json(v)
for k, v in data["chunk_by_index"].items()
}
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
self.index = faiss.deserialize_index(np.loadtxt(buffer, dtype=np.uint8))
async def _save_index(self):
if not self.kvstore or not self.bank_id:
return
np_index = faiss.serialize_index(self.index)
buffer = io.BytesIO()
np.savetxt(buffer, np_index)
data = {
"id_by_index": self.id_by_index,
"chunk_by_index": {
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
},
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
}
index_key = f"faiss_index:v1::{self.bank_id}"
await self.kvstore.set(key=index_key, value=json.dumps(data))
async def delete(self):
if not self.kvstore or not self.bank_id:
return
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
@tracing.span(name="add_chunks")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
self.id_by_index[indexlen + i] = chunk.document_id
self.index.add(np.array(embeddings).astype(np.float32))
# Save updated index
await self._save_index()
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
distances, indices = self.index.search(
embedding.reshape(1, -1).astype(np.float32), k
)
chunks = []
scores = []
for d, i in zip(distances[0], indices[0]):
if i < 0:
continue
chunks.append(self.chunk_by_index[int(i)])
scores.append(1.0 / float(d))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.cache = {}
self.kvstore = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing banks from kvstore
start_key = MEMORY_BANKS_PREFIX
end_key = f"{MEMORY_BANKS_PREFIX}\xff"
stored_banks = await self.kvstore.range(start_key, end_key)
for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex(
bank=bank,
index=await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
),
)
self.cache[bank.identifier] = index
async def shutdown(self) -> None:
# Cleanup if needed
pass
async def register_memory_bank(
self,
memory_bank: MemoryBank,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
# Store in kvstore
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
await self.kvstore.set(
key=key,
value=memory_bank.model_dump_json(),
)
# Store in cache
index = BankWithIndex(
bank=memory_bank,
index=await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
),
)
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]:
return [i.bank for i in self.cache.values()]
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found. found: {self.cache.keys()}")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
class LogFormat(Enum):
TEXT = "text"
JSON = "json"
@json_schema_type
class ConsoleConfig(BaseModel):
log_format: LogFormat = LogFormat.TEXT

View file

@ -4,8 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Optional
from .config import LogFormat
from llama_stack.apis.telemetry import * # noqa: F403
from .config import ConsoleConfig
@ -38,7 +41,11 @@ class ConsoleTelemetryImpl(Telemetry):
span_name = ".".join(names) if names else None
formatted = format_event(event, span_name)
if self.config.log_format == LogFormat.JSON:
formatted = format_event_json(event, span_name)
else:
formatted = format_event_text(event, span_name)
if formatted:
print(formatted)
@ -69,7 +76,7 @@ SEVERITY_COLORS = {
}
def format_event(event: Event, span_name: str) -> Optional[str]:
def format_event_text(event: Event, span_name: str) -> Optional[str]:
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
span = ""
if span_name:
@ -87,3 +94,23 @@ def format_event(event: Event, span_name: str) -> Optional[str]:
return None
return f"Unknown event type: {event}"
def format_event_json(event: Event, span_name: str) -> Optional[str]:
base_data = {
"timestamp": event.timestamp.isoformat(),
"trace_id": event.trace_id,
"span_id": event.span_id,
"span_name": span_name,
}
if isinstance(event, UnstructuredLogEvent):
base_data.update(
{"type": "log", "severity": event.severity.name, "message": event.message}
)
return json.dumps(base_data)
elif isinstance(event, StructuredLogEvent):
return None
return json.dumps({"error": f"Unknown event type: {event}"})

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import CodeShieldConfig
async def get_provider_impl(config: CodeShieldConfig, deps):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, Dict, List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from .config import CodeScannerConfig
from llama_stack.apis.safety import * # noqa: F403
log = logging.getLogger(__name__)
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner",
"CodeShield",
]
class MetaReferenceCodeScannerSafetyImpl(Safety):
def __init__(self, config: CodeScannerConfig, deps) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id not in ALLOWED_CODE_SCANNER_MODEL_IDS:
raise ValueError(
f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
)
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
from codeshield.cs import CodeShield
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
log.info(f"Running CodeScannerShield on {text[50:]}")
result = await CodeShield.scan_code(text)
violation = None
if result.is_insecure:
violation = SafetyViolation(
violation_level=(ViolationLevel.ERROR),
user_message="Sorry, I found security concerns in the code.",
metadata={
"violation_type": ",".join(
[issue.pattern_id for issue in result.issues_found]
)
},
)
return RunShieldResponse(violation=violation)

View file

@ -7,6 +7,5 @@
from pydantic import BaseModel
class OpenTelemetryConfig(BaseModel):
jaeger_host: str = "localhost"
jaeger_port: int = 6831
class CodeScannerConfig(BaseModel):
pass

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps):
from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(
config, LlamaGuardConfig
), f"Unexpected config type: {type(config)}"
impl = LlamaGuardSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from typing import List
from pydantic import BaseModel
@json_schema_type
class ConsoleConfig(BaseModel): ...
class LlamaGuardConfig(BaseModel):
excluded_categories: List[str] = []

View file

@ -7,16 +7,21 @@
import re
from string import Template
from typing import List, Optional
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import LlamaGuardConfig
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
SAFE_RESPONSE = "safe"
_INSTANCE = None
CAT_VIOLENT_CRIMES = "Violent Crimes"
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
@ -68,13 +73,21 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_ELECTIONS,
]
# accept both CoreModelId and huggingface repo id
LLAMA_GUARD_MODEL_IDS = {
CoreModelId.llama_guard_3_8b.value: "meta-llama/Llama-Guard-3-8B",
"meta-llama/Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
CoreModelId.llama_guard_3_1b.value: "meta-llama/Llama-Guard-3-1B",
"meta-llama/Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
CoreModelId.llama_guard_3_11b_vision.value: "meta-llama/Llama-Guard-3-11B-Vision",
"meta-llama/Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
}
MODEL_TO_SAFETY_CATEGORIES_MAP = {
CoreModelId.llama_guard_3_8b.value: (
DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
),
CoreModelId.llama_guard_3_1b.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
CoreModelId.llama_guard_3_11b_vision.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES
+ [CAT_CODE_INTERPRETER_ABUSE],
"meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
"meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
}
@ -107,18 +120,56 @@ PROMPT_TEMPLATE = Template(
)
class LlamaGuardShield(ShieldBase):
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS:
raise ValueError(
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}"
)
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Unknown shield {shield_id}")
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
model = LLAMA_GUARD_MODEL_IDS[shield.provider_resource_id]
impl = LlamaGuardShield(
model=model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
)
return await impl.run(messages)
class LlamaGuardShield:
def __init__(
self,
model: str,
inference_api: Inference,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
excluded_categories: Optional[List[str]] = None,
):
super().__init__(on_violation_action)
if excluded_categories is None:
excluded_categories = []
@ -132,8 +183,6 @@ class LlamaGuardShield(ShieldBase):
self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
@ -174,18 +223,12 @@ class LlamaGuardShield(ShieldBase):
for i in range(1, len(messages)):
if messages[i].role == messages[i - 1].role:
raise ValueError(
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
)
return messages
async def run(self, messages: List[Message]) -> ShieldResponse:
async def run(self, messages: List[Message]) -> RunShieldResponse:
messages = self.validate_messages(messages)
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
is_violation=False,
)
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
shield_input_message = self.build_vision_shield_input(messages)
@ -194,8 +237,8 @@ class LlamaGuardShield(ShieldBase):
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
async for chunk in self.inference_api.chat_completion(
model=self.model,
async for chunk in await self.inference_api.chat_completion(
model_id=self.model,
messages=[shield_input_message],
stream=True,
):
@ -205,8 +248,7 @@ class LlamaGuardShield(ShieldBase):
content += event.delta
content = content.strip()
shield_response = self.get_shield_response(content)
return shield_response
return self.get_shield_response(content)
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages))
@ -260,19 +302,23 @@ class LlamaGuardShield(ShieldBase):
conversations=conversations_str,
)
def get_shield_response(self, response: str) -> ShieldResponse:
def get_shield_response(self, response: str) -> RunShieldResponse:
response = response.strip()
if response == SAFE_RESPONSE:
return ShieldResponse(is_violation=False)
return RunShieldResponse(violation=None)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(is_violation=False)
return ShieldResponse(
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
return RunShieldResponse(violation=None)
return RunShieldResponse(
violation=SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message=CANNED_RESPONSE_TEXT,
metadata={"violation_type": unsafe_code},
),
)
raise ValueError(f"Unexpected response: {response}")

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import PromptGuardConfig # noqa: F401
async def get_provider_impl(config: PromptGuardConfig, deps):
from .prompt_guard import PromptGuardSafetyImpl
impl = PromptGuardSafetyImpl(config, deps)
await impl.initialize()
return impl

Some files were not shown because too many files have changed in this diff Show more