Merge branch 'main' into nvidia-e2e-notebook

This commit is contained in:
Jash Gulabrai 2025-04-30 12:05:11 -04:00
commit 012dd6891f
96 changed files with 4675 additions and 426 deletions

View file

@ -23,6 +23,9 @@ from llama_stack.apis.agents import (
Document,
ListAgentSessionsResponse,
ListAgentsResponse,
OpenAIResponseInputMessage,
OpenAIResponseInputTool,
OpenAIResponseObject,
Session,
Turn,
)
@ -40,6 +43,7 @@ from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_imp
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@ -63,9 +67,16 @@ class MetaReferenceAgentsImpl(Agents):
self.tool_groups_api = tool_groups_api
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl = None
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
self.openai_responses_impl = OpenAIResponsesImpl(
self.persistence_store,
inference_api=self.inference_api,
tool_groups_api=self.tool_groups_api,
tool_runtime_api=self.tool_runtime_api,
)
# check if "bwrap" is available
if not shutil.which("bwrap"):
@ -244,3 +255,23 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
) -> ListAgentSessionsResponse:
pass
# OpenAI responses
async def get_openai_response(
self,
id: str,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.get_openai_response(id)
async def create_openai_response(
self,
input: Union[str, List[OpenAIResponseInputMessage]],
model: str,
previous_response_id: Optional[str] = None,
store: Optional[bool] = True,
stream: Optional[bool] = False,
tools: Optional[List[OpenAIResponseInputTool]] = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input, model, previous_response_id, store, stream, tools
)

View file

@ -0,0 +1,319 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from typing import AsyncIterator, List, Optional, Union, cast
from openai.types.chat import ChatCompletionToolParam
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessage,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseOutput,
OpenAIResponseOutputMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.inference.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIImageURL,
OpenAIMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.kvstore import KVStore
logger = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]:
messages: List[OpenAIMessageParam] = []
for output_message in previous_response.output:
if isinstance(output_message, OpenAIResponseOutputMessage):
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
return messages
async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> List[OpenAIResponseOutputMessage]:
output_messages = []
for choice in choices:
output_content = ""
if isinstance(choice.message.content, str):
output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text
# TODO: handle image content
output_messages.append(
OpenAIResponseOutputMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
status="completed",
)
)
return output_messages
class OpenAIResponsesImpl:
def __init__(
self,
persistence_store: KVStore,
inference_api: Inference,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
):
self.persistence_store = persistence_store
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
async def get_openai_response(
self,
id: str,
) -> OpenAIResponseObject:
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
response_json = await self.persistence_store.get(key=key)
if response_json is None:
raise ValueError(f"OpenAI response with id '{id}' not found")
return OpenAIResponseObject.model_validate_json(response_json)
async def create_openai_response(
self,
input: Union[str, List[OpenAIResponseInputMessage]],
model: str,
previous_response_id: Optional[str] = None,
store: Optional[bool] = True,
stream: Optional[bool] = False,
tools: Optional[List[OpenAIResponseInputTool]] = None,
):
stream = False if stream is None else stream
messages: List[OpenAIMessageParam] = []
if previous_response_id:
previous_response = await self.get_openai_response(previous_response_id)
messages.extend(await _previous_response_to_messages(previous_response))
# TODO: refactor this user_content parsing out into a separate method
user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = ""
if isinstance(input, list):
user_content = []
for user_input in input:
if isinstance(user_input.content, list):
for user_input_content in user_input.content:
if isinstance(user_input_content, OpenAIResponseInputMessageContentText):
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text))
elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage):
if user_input_content.image_url:
image_url = OpenAIImageURL(
url=user_input_content.image_url, detail=user_input_content.detail
)
user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
else:
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
else:
user_content = input
messages.append(OpenAIUserMessageParam(content=user_content))
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
chat_response = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
)
if stream:
# TODO: refactor this into a separate method that handles streaming
chat_response_id = ""
chat_response_content = []
# TODO: these chunk_ fields are hacky and only take the last chunk into account
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
async for chunk in chat_response:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# TODO: this only works for text content
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
assistant_message = OpenAIAssistantMessageParam(content="".join(chat_response_content))
chat_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
else:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
output_messages: List[OpenAIResponseOutput] = []
if chat_response.choices[0].message.tool_calls:
output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages)
)
else:
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
model=model,
object="response",
status="completed",
output=output_messages,
)
if store:
# Store in kvstore
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
await self.persistence_store.set(
key=key,
value=response.model_dump_json(),
)
if stream:
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]:
# TODO: response created should actually get emitted much earlier in the process
yield OpenAIResponseObjectStreamResponseCreated(response=response)
yield OpenAIResponseObjectStreamResponseCompleted(response=response)
return async_response()
return response
async def _convert_response_tools_to_chat_tools(
self, tools: List[OpenAIResponseInputTool]
) -> List[ChatCompletionToolParam]:
chat_tools: List[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
if input_tool.type == "web_search":
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
)
chat_tool = convert_tooldef_to_openai_tool(tool_def)
chat_tools.append(chat_tool)
else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools
async def _execute_tool_and_return_final_output(
self, model_id: str, stream: bool, chat_response: OpenAIChatCompletion, messages: List[OpenAIMessageParam]
) -> List[OpenAIResponseOutput]:
output_messages: List[OpenAIResponseOutput] = []
choice = chat_response.choices[0]
# If the choice is not an assistant message, we don't need to execute any tools
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
if not choice.message.tool_calls:
return output_messages
# Add the assistant message with tool_calls response to the messages list
messages.append(choice.message)
for tool_call in choice.message.tool_calls:
tool_call_id = tool_call.id
function = tool_call.function
# If for some reason the tool call doesn't have a function or id, we can't execute it
if not function or not tool_call_id:
continue
# TODO: telemetry spans for tool calls
result = await self._execute_tool_call(function)
# Handle tool call failure
if not result:
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="failed",
)
)
continue
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
),
)
result_content = ""
# TODO: handle other result content types and lists
if isinstance(result.content, str):
result_content = result.content
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
tool_results_chat_response = await self.inference_api.openai_chat_completion(
model=model_id,
messages=messages,
stream=stream,
)
# type cast to appease mypy
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices)
# TODO: Wire in annotations with URLs, titles, etc to these output messages
output_messages.extend(tool_final_outputs)
return output_messages
async def _execute_tool_call(
self,
function: OpenAIChatCompletionToolCallFunction,
) -> Optional[ToolInvocationResult]:
if not function.name:
return None
function_args = json.loads(function.arguments) if function.arguments else {}
logger.info(f"executing tool call: {function.name} with args: {function_args}")
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=function_args,
)
logger.debug(f"tool call {function.name} completed with result: {result}")
return result

View file

@ -17,10 +17,8 @@ from llama_stack.apis.common.type_system import (
DialogType,
StringType,
)
from llama_stack.apis.datasets import Datasets
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
validate_dataset_schema,
)
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
@ -36,21 +34,3 @@ EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
}
],
}
async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def:
raise ValueError(f"Dataset {dataset_id} does not exist.")
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
if dataset_type not in EXPECTED_DATASET_SCHEMA:
raise ValueError(f"Dataset type {dataset_type} is not supported.")
validate_dataset_schema(dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type])

View file

@ -48,9 +48,6 @@ from llama_stack.apis.post_training import (
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.validator import (
validate_input_dataset_schema,
)
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
@ -348,11 +345,9 @@ class LoraFinetuningSingleDevice:
all_rows = await fetch_rows(dataset_id)
rows = all_rows.data
await validate_input_dataset_schema(
datasets_api=self.datasets_api,
dataset_id=dataset_id,
dataset_type=self._data_format.value,
)
# TODO (xiyan): validate dataset schema
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset(
rows,

View file

@ -30,7 +30,7 @@ class TelemetryConfig(BaseModel):
)
service_name: str = Field(
# service name is always the same, use zero-width space to avoid clutter
default="",
default="",
description="The service name to use for telemetry",
)
sinks: List[TelemetrySink] = Field(
@ -52,7 +52,7 @@ class TelemetryConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:}",
"service_name": "${env.OTEL_SERVICE_NAME:}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
"sqlite_db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
}

View file

@ -227,6 +227,16 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="llama-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter
adapter = LlamaCompatInferenceAdapter(config)
return adapter

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class LlamaProviderDataValidator(BaseModel):
llama_api_key: Optional[str] = Field(
default=None,
description="API key for api.llama models",
)
@json_schema_type
class LlamaCompatConfig(BaseModel):
api_key: Optional[str] = Field(
default=None,
description="The Llama API key",
)
openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/",
description="The URL for the Llama API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> Dict[str, Any]:
return {
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
"api_key": api_key,
}

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
LlamaCompatConfig,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from .models import MODEL_ENTRIES
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Scout-17B-16E-Instruct-FP8",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]

View file

@ -433,6 +433,12 @@ class OllamaInferenceAdapter(
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model)
# ollama still makes tool calls even when tool_choice is "none"
# so we need to remove the tools in that case
if tool_choice == "none" and tools is not None:
tools = None
params = {
k: v
for k, v in {

View file

@ -90,6 +90,9 @@ class LiteLLMOpenAIMixin(
raise ValueError(f"Unsupported model: {model.provider_resource_id}")
return model
def get_litellm_model_name(self, model_id: str) -> str:
return "openai/" + model_id if self.is_openai_compat else model_id
async def completion(
self,
model_id: str,
@ -130,8 +133,7 @@ class LiteLLMOpenAIMixin(
)
params = await self._get_params(request)
if self.is_openai_compat:
params["model"] = "openai/" + params["model"]
params["model"] = self.get_litellm_model_name(params["model"])
logger.debug(f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm
@ -220,21 +222,23 @@ class LiteLLMOpenAIMixin(
else request.tool_config.tool_choice
)
return {
"model": request.model,
"api_key": self.get_api_key(),
"api_base": self.api_base,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
def get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
return {
"model": request.model,
"api_key": api_key,
"api_base": self.api_base,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
return api_key
async def embeddings(
self,
@ -247,7 +251,7 @@ class LiteLLMOpenAIMixin(
model = await self.model_store.get_model(model_id)
response = litellm.embedding(
model=model.provider_resource_id,
model=self.get_litellm_model_name(model.provider_resource_id),
input=[interleaved_content_as_str(content) for content in contents],
)
@ -278,7 +282,7 @@ class LiteLLMOpenAIMixin(
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=prompt,
best_of=best_of,
echo=echo,
@ -297,6 +301,8 @@ class LiteLLMOpenAIMixin(
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.atext_completion(**params)
@ -328,7 +334,7 @@ class LiteLLMOpenAIMixin(
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
@ -351,6 +357,8 @@ class LiteLLMOpenAIMixin(
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.acompletion(**params)

View file

@ -638,10 +638,13 @@ async def convert_message_to_openai_dict_new(
)
for tool in message.tool_calls
]
params = {}
if tool_calls:
params["tool_calls"] = tool_calls
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=await _convert_message_content(message.content),
tool_calls=tool_calls or None,
**params,
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(