Merge remote-tracking branch 'origin/main' into resp_branching

This commit is contained in:
Ashwin Bharambe 2025-10-02 09:44:48 -07:00
commit c1c06e6e2f
65 changed files with 56610 additions and 61 deletions

View file

@ -772,6 +772,12 @@ class Agents(Protocol):
#
# Both of these APIs are inherently stateful.
@webmethod(
route="/openai/v1/responses/{response_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_openai_response(
self,
@ -784,6 +790,7 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
async def create_openai_response(
self,
@ -809,6 +816,7 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
async def list_openai_responses(
self,
@ -828,10 +836,9 @@ class Agents(Protocol):
...
@webmethod(
route="/responses/{response_id}/input_items",
method="GET",
level=LLAMA_STACK_API_V1,
route="/openai/v1/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
async def list_openai_response_input_items(
self,
response_id: str,
@ -853,6 +860,7 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
"""Delete an OpenAI response by its ID.

View file

@ -43,6 +43,7 @@ class Batches(Protocol):
Note: This API is currently under active development and may undergo changes.
"""
@webmethod(route="/openai/v1/batches", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
async def create_batch(
self,
@ -63,6 +64,7 @@ class Batches(Protocol):
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch.
@ -72,6 +74,7 @@ class Batches(Protocol):
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress.
@ -81,6 +84,7 @@ class Batches(Protocol):
"""
...
@webmethod(route="/openai/v1/batches", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
async def list_batches(
self,

View file

@ -105,6 +105,7 @@ class OpenAIFileDeleteResponse(BaseModel):
@trace_protocol
class Files(Protocol):
# OpenAI Files API Endpoints
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
async def openai_upload_file(
self,
@ -127,6 +128,7 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_files(
self,
@ -146,6 +148,7 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_file(
self,
@ -159,6 +162,7 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_file(
self,
@ -172,6 +176,7 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_file_content(
self,

View file

@ -1064,6 +1064,7 @@ class InferenceProvider(Protocol):
raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion(
self,
@ -1115,6 +1116,7 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion(
self,
@ -1171,6 +1173,7 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
async def openai_embeddings(
self,
@ -1200,6 +1203,7 @@ class Inference(InferenceProvider):
- Embedding models: these models generate embeddings to be used for semantic search.
"""
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
async def list_chat_completions(
self,
@ -1218,6 +1222,9 @@ class Inference(InferenceProvider):
"""
raise NotImplementedError("List chat completions is not implemented")
@webmethod(
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
"""Describe a chat completion by its ID.

View file

@ -111,6 +111,14 @@ class Models(Protocol):
"""
...
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API.
:returns: A OpenAIListModelsResponse.
"""
...
@webmethod(route="/models/{model_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_model(
self,

View file

@ -114,6 +114,7 @@ class Safety(Protocol):
"""
...
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
"""Classifies if text and/or image inputs are potentially harmful.

View file

@ -512,6 +512,7 @@ class VectorIO(Protocol):
...
# OpenAI Vector Stores API endpoints
@webmethod(route="/openai/v1/vector_stores", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
async def openai_create_vector_store(
self,
@ -538,6 +539,7 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/vector_stores", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_vector_stores(
self,
@ -556,6 +558,9 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_vector_store(
self,
@ -568,6 +573,9 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(
route="/vector_stores/{vector_store_id}",
method="POST",
@ -590,6 +598,9 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(
route="/vector_stores/{vector_store_id}",
method="DELETE",
@ -606,6 +617,12 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/search",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/search",
method="POST",
@ -638,6 +655,12 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files",
method="POST",
@ -660,6 +683,12 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files",
method="GET",
@ -686,6 +715,12 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="GET",
@ -704,6 +739,12 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}/content",
method="GET",
@ -722,6 +763,12 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="POST",
@ -742,6 +789,12 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="DELETE",
@ -765,6 +818,12 @@ class VectorIO(Protocol):
method="POST",
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
@ -787,6 +846,12 @@ class VectorIO(Protocol):
method="GET",
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
@ -800,6 +865,12 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
method="GET",
@ -828,6 +899,12 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
method="POST",

View file

@ -115,7 +115,7 @@ docker run -it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $HOME/.llama:/root/.llama \
# NOTE: mount the llama-stack directory if testing local changes else not needed
-v /home/hjshah/git/llama-stack:/app/llama-stack-source \
-v $HOME/git/llama-stack:/app/llama-stack-source \
# localhost/distribution-dell:dev if building / testing locally
llamastack/distribution-{{ name }}\
--port $LLAMA_STACK_PORT \

View file

@ -50,6 +50,12 @@ from llama_stack.apis.inference import (
CompletionMessage,
Inference,
Message,
OpenAIAssistantMessageParam,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
SamplingParams,
StopReason,
SystemMessage,
@ -67,6 +73,11 @@ from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict_new,
convert_openai_chat_completion_stream,
convert_tooldef_to_openai_tool,
)
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing
@ -176,12 +187,12 @@ class ChatAgent(ShieldRunnerMixin):
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
turn_id = str(uuid.uuid4())
span = tracing.get_current_span()
if span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
@ -504,26 +515,93 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls = []
content = ""
stop_reason = None
stop_reason: StopReason | None = None
async with tracing.span("inference") as span:
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self.tool_defs,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
def _serialize_nested(value):
"""Recursively serialize nested Pydantic models to dicts."""
from pydantic import BaseModel
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
elif isinstance(value, dict):
return {k: _serialize_nested(v) for k, v in value.items()}
elif isinstance(value, list):
return [_serialize_nested(item) for item in value]
else:
return value
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
# Serialize any nested Pydantic models to plain dicts
openai_msg = _serialize_nested(openai_msg)
role = openai_msg.get("role")
if role == "user":
return OpenAIUserMessageParam(**openai_msg)
elif role == "system":
return OpenAISystemMessageParam(**openai_msg)
elif role == "assistant":
return OpenAIAssistantMessageParam(**openai_msg)
elif role == "tool":
return OpenAIToolMessageParam(**openai_msg)
elif role == "developer":
return OpenAIDeveloperMessageParam(**openai_msg)
else:
raise ValueError(f"Unknown message role: {role}")
# Convert messages to OpenAI format
openai_messages: list[OpenAIMessageParam] = [
_add_type(await convert_message_to_openai_dict_new(message)) for message in input_messages
]
# Convert tool definitions to OpenAI format
openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])]
# Extract tool_choice from tool_config for OpenAI compatibility
# Note: tool_choice can only be provided when tools are also provided
tool_choice = None
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
tc = self.agent_config.tool_config.tool_choice
tool_choice_str = tc.value if hasattr(tc, "value") else str(tc)
# Convert tool_choice to OpenAI format
if tool_choice_str in ("auto", "none", "required"):
tool_choice = tool_choice_str
else:
# It's a specific tool name, wrap it in the proper format
tool_choice = {"type": "function", "function": {"name": tool_choice_str}}
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None)
top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None)
max_tokens = getattr(sampling_params, "max_tokens", None)
# Use OpenAI chat completion
openai_stream = await self.inference_api.openai_chat_completion(
model=self.agent_config.model,
messages=openai_messages,
tools=openai_tools if openai_tools else None,
tool_choice=tool_choice,
response_format=self.agent_config.response_format,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
stream=True,
sampling_params=sampling_params,
tool_config=self.agent_config.tool_config,
):
)
# Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream(
openai_stream, enable_incremental_tool_calls=True
)
async for chunk in response_stream:
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
stop_reason = event.stop_reason or StopReason.end_of_turn
continue
delta = event.delta
@ -532,7 +610,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls.append(delta.tool_call)
elif delta.parse_status == ToolCallParseStatus.failed:
# If we cannot parse the tools, set the content to the unparsed raw text
content = delta.tool_call
content = str(delta.tool_call)
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -559,9 +637,7 @@ class ChatAgent(ShieldRunnerMixin):
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
if event.stop_reason is not None:
stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason)
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
span.set_attribute(
"input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),