Merge remote-tracking branch 'origin/main' into stores
Some checks failed
Installer CI / smoke-test-on-dev (push) Failing after 3s
Installer CI / lint (push) Failing after 3s

This commit is contained in:
Ashwin Bharambe 2025-10-13 11:07:11 -07:00
commit b72154ce5e
1161 changed files with 609896 additions and 42960 deletions

View file

@ -21,7 +21,9 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
deps[Api.safety],
deps[Api.tool_runtime],
deps[Api.tool_groups],
deps[Api.conversations],
policy,
Api.telemetry in deps,
)
await impl.initialize()
return impl

View file

@ -7,8 +7,6 @@
import copy
import json
import re
import secrets
import string
import uuid
import warnings
from collections.abc import AsyncGenerator
@ -51,6 +49,7 @@ from llama_stack.apis.inference import (
Inference,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
@ -84,11 +83,6 @@ from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
from .safety import SafetyException, ShieldRunnerMixin
def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
@ -110,6 +104,7 @@ class ChatAgent(ShieldRunnerMixin):
persistence_store: KVStore,
created_at: str,
policy: list[AccessRule],
telemetry_enabled: bool = False,
):
self.agent_id = agent_id
self.agent_config = agent_config
@ -120,6 +115,7 @@ class ChatAgent(ShieldRunnerMixin):
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.created_at = created_at
self.telemetry_enabled = telemetry_enabled
ShieldRunnerMixin.__init__(
self,
@ -188,28 +184,30 @@ class ChatAgent(ShieldRunnerMixin):
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
turn_id = str(uuid.uuid4())
span = tracing.get_current_span()
if span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
if self.telemetry_enabled:
span = tracing.get_current_span()
if span is not None:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
await self._initialize_tools(request.toolgroups)
async for chunk in self._run_turn(request, turn_id):
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
span = tracing.get_current_span()
if span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", request.turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
if self.telemetry_enabled:
span = tracing.get_current_span()
if span is not None:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", request.turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
await self._initialize_tools()
async for chunk in self._run_turn(request):
@ -395,9 +393,12 @@ class ChatAgent(ShieldRunnerMixin):
touchpoint: str,
) -> AsyncGenerator:
async with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if self.telemetry_enabled and span is not None:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
if len(shields) == 0:
span.set_attribute("output", "no shields")
return
step_id = str(uuid.uuid4())
@ -430,7 +431,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
span.set_attribute("output", e.violation.model_dump_json())
if self.telemetry_enabled and span is not None:
span.set_attribute("output", e.violation.model_dump_json())
yield CompletionMessage(
content=str(e),
@ -453,7 +455,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
span.set_attribute("output", "no violations")
if self.telemetry_enabled and span is not None:
span.set_attribute("output", "no violations")
async def _run(
self,
@ -518,8 +521,9 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason: StopReason | None = None
async with tracing.span("inference") as span:
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
if self.telemetry_enabled and span is not None:
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
def _serialize_nested(value):
"""Recursively serialize nested Pydantic models to dicts."""
@ -579,7 +583,7 @@ class ChatAgent(ShieldRunnerMixin):
max_tokens = getattr(sampling_params, "max_tokens", None)
# Use OpenAI chat completion
openai_stream = await self.inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.agent_config.model,
messages=openai_messages,
tools=openai_tools if openai_tools else None,
@ -590,6 +594,7 @@ class ChatAgent(ShieldRunnerMixin):
max_tokens=max_tokens,
stream=True,
)
openai_stream = await self.inference_api.openai_chat_completion(params)
# Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream(
@ -637,18 +642,19 @@ class ChatAgent(ShieldRunnerMixin):
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
span.set_attribute(
"input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
)
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
}
)
span.set_attribute("output", output_attr)
if self.telemetry_enabled and span is not None:
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
span.set_attribute(
"input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
)
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
}
)
span.set_attribute("output", output_attr)
n_iter += 1
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
@ -756,7 +762,9 @@ class ChatAgent(ShieldRunnerMixin):
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
}
if self.telemetry_enabled
else {},
) as span:
tool_execution_start_time = datetime.now(UTC).isoformat()
tool_result = await self.execute_tool_call_maybe(
@ -771,7 +779,8 @@ class ChatAgent(ShieldRunnerMixin):
call_id=tool_call.call_id,
content=tool_result.content,
)
span.set_attribute("output", result_message.model_dump_json())
if self.telemetry_enabled and span is not None:
span.set_attribute("output", result_message.model_dump_json())
# Store tool execution step
tool_execution_step = ToolExecutionStep(

View file

@ -30,6 +30,7 @@ from llama_stack.apis.agents import (
)
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import (
Inference,
ToolConfig,
@ -63,7 +64,9 @@ class MetaReferenceAgentsImpl(Agents):
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
conversations_api: Conversations,
policy: list[AccessRule],
telemetry_enabled: bool = False,
):
self.config = config
self.inference_api = inference_api
@ -71,6 +74,8 @@ class MetaReferenceAgentsImpl(Agents):
self.safety_api = safety_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.conversations_api = conversations_api
self.telemetry_enabled = telemetry_enabled
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None
@ -86,6 +91,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api=self.tool_runtime_api,
responses_store=self.responses_store,
vector_io_api=self.vector_io_api,
conversations_api=self.conversations_api,
)
async def create_agent(
@ -135,6 +141,7 @@ class MetaReferenceAgentsImpl(Agents):
),
created_at=agent_info.created_at,
policy=self.policy,
telemetry_enabled=self.telemetry_enabled,
)
async def create_agent_session(
@ -322,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
@ -336,6 +344,7 @@ class MetaReferenceAgentsImpl(Agents):
model,
instructions,
previous_response_id,
conversation,
store,
stream,
temperature,

View file

@ -24,6 +24,11 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.common.errors import (
InvalidConversationIdError,
)
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.conversations.conversations import ConversationItem
from llama_stack.apis.inference import (
Inference,
OpenAIMessageParam,
@ -39,7 +44,7 @@ from llama_stack.providers.utils.responses.responses_store import (
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
from .types import ChatCompletionContext
from .types import ChatCompletionContext, ToolContext
from .utils import (
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
@ -61,12 +66,14 @@ class OpenAIResponsesImpl:
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
conversations_api: Conversations,
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
self.conversations_api = conversations_api
self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api,
tool_runtime_api=tool_runtime_api,
@ -91,13 +98,15 @@ class OpenAIResponsesImpl:
async def _process_input_with_previous_response(
self,
input: str | list[OpenAIResponseInput],
tools: list[OpenAIResponseInputTool] | None,
previous_response_id: str | None,
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
"""Process input with optional previous response context.
Returns:
tuple: (all_input for storage, messages for chat completion)
tuple: (all_input for storage, messages for chat completion, tool context)
"""
tool_context = ToolContext(tools)
if previous_response_id:
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
await self.responses_store.get_response_object(previous_response_id)
@ -108,16 +117,18 @@ class OpenAIResponsesImpl:
# Use stored messages directly and convert only new input
message_adapter = TypeAdapter(list[OpenAIMessageParam])
messages = message_adapter.validate_python(previous_response.messages)
new_messages = await convert_response_input_to_chat_messages(input)
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
messages.extend(new_messages)
else:
# Backward compatibility: reconstruct from inputs
messages = await convert_response_input_to_chat_messages(all_input)
tool_context.recover_tools_from_previous_response(previous_response)
else:
all_input = input
messages = await convert_response_input_to_chat_messages(input)
return all_input, messages
return all_input, messages, tool_context
async def _prepend_instructions(self, messages, instructions):
if instructions:
@ -201,6 +212,7 @@ class OpenAIResponsesImpl:
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
@ -217,11 +229,27 @@ class OpenAIResponsesImpl:
if shields is not None:
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
if conversation is not None and previous_response_id is not None:
raise ValueError(
"Mutually exclusive parameters: 'previous_response_id' and 'conversation'. Ensure you are only providing one of these parameters."
)
original_input = input # needed for syncing to Conversations
if conversation is not None:
if not conversation.startswith("conv_"):
raise InvalidConversationIdError(conversation)
# Check conversation exists (raises ConversationNotFoundError if not)
_ = await self.conversations_api.get_conversation(conversation)
input = await self._load_conversation_context(conversation, input)
stream_gen = self._create_streaming_response(
input=input,
original_input=original_input,
model=model,
instructions=instructions,
previous_response_id=previous_response_id,
conversation=conversation,
store=store,
temperature=temperature,
text=text,
@ -232,24 +260,42 @@ class OpenAIResponsesImpl:
if stream:
return stream_gen
else:
response = None
async for stream_chunk in stream_gen:
if stream_chunk.type == "response.completed":
if response is not None:
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
response = stream_chunk.response
# don't leave the generator half complete!
final_response = None
final_event_type = None
failed_response = None
if response is None:
raise ValueError("The response stream never completed")
return response
async for stream_chunk in stream_gen:
if stream_chunk.type in {"response.completed", "response.incomplete"}:
if final_response is not None:
raise ValueError(
"The response stream produced multiple terminal responses! "
f"Earlier response from {final_event_type}"
)
final_response = stream_chunk.response
final_event_type = stream_chunk.type
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
if failed_response is not None:
error_message = (
failed_response.error.message
if failed_response and failed_response.error
else "Response stream failed without error details"
)
raise RuntimeError(f"OpenAI response failed: {error_message}")
if final_response is None:
raise ValueError("The response stream never reached a terminal state")
return final_response
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
original_input: str | list[OpenAIResponseInput] | None = None,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
@ -257,7 +303,9 @@ class OpenAIResponsesImpl:
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
all_input, messages, tool_context = await self._process_input_with_previous_response(
input, tools, previous_response_id
)
await self._prepend_instructions(messages, instructions)
# Structured outputs
@ -269,11 +317,12 @@ class OpenAIResponsesImpl:
response_tools=tools,
temperature=temperature,
response_format=response_format,
inputs=input,
tool_context=tool_context,
inputs=all_input,
)
# Create orchestrator and delegate streaming logic
response_id = f"resp-{uuid.uuid4()}"
response_id = f"resp_{uuid.uuid4()}"
created_at = int(time.time())
orchestrator = StreamingResponseOrchestrator(
@ -288,18 +337,110 @@ class OpenAIResponsesImpl:
# Stream the response
final_response = None
failed_response = None
async for stream_chunk in orchestrator.create_response():
if stream_chunk.type == "response.completed":
if stream_chunk.type in {"response.completed", "response.incomplete"}:
final_response = stream_chunk.response
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
yield stream_chunk
# Store the response if requested
if store and final_response:
await self._store_response(
response=final_response,
input=all_input,
messages=orchestrator.final_messages,
)
# Store and sync immediately after yielding terminal events
# This ensures the storage/syncing happens even if the consumer breaks early
if (
stream_chunk.type in {"response.completed", "response.incomplete"}
and store
and final_response
and failed_response is None
):
await self._store_response(
response=final_response,
input=all_input,
messages=orchestrator.final_messages,
)
if stream_chunk.type in {"response.completed", "response.incomplete"} and conversation and final_response:
# for Conversations, we need to use the original_input if it's available, otherwise use input
sync_input = original_input if original_input is not None else input
await self._sync_response_to_conversation(conversation, sync_input, final_response)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
return await self.responses_store.delete_response_object(response_id)
async def _load_conversation_context(
self, conversation_id: str, content: str | list[OpenAIResponseInput]
) -> list[OpenAIResponseInput]:
"""Load conversation history and merge with provided content."""
conversation_items = await self.conversations_api.list(conversation_id, order="asc")
context_messages = []
for item in conversation_items.data:
if isinstance(item, OpenAIResponseMessage):
if item.role == "user":
context_messages.append(
OpenAIResponseMessage(
role="user", content=item.content, id=item.id if hasattr(item, "id") else None
)
)
elif item.role == "assistant":
context_messages.append(
OpenAIResponseMessage(
role="assistant", content=item.content, id=item.id if hasattr(item, "id") else None
)
)
# add new content to context
if isinstance(content, str):
context_messages.append(OpenAIResponseMessage(role="user", content=content))
elif isinstance(content, list):
context_messages.extend(content)
return context_messages
async def _sync_response_to_conversation(
self, conversation_id: str, content: str | list[OpenAIResponseInput], response: OpenAIResponseObject
) -> None:
"""Sync content and response messages to the conversation."""
conversation_items = []
# add user content message(s)
if isinstance(content, str):
conversation_items.append(
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": content}]}
)
elif isinstance(content, list):
for item in content:
if not isinstance(item, OpenAIResponseMessage):
raise NotImplementedError(f"Unsupported input item type: {type(item)}")
if item.role == "user":
if isinstance(item.content, str):
conversation_items.append(
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": item.content}],
}
)
elif isinstance(item.content, list):
conversation_items.append({"type": "message", "role": "user", "content": item.content})
else:
raise NotImplementedError(f"Unsupported user message content type: {type(item.content)}")
elif item.role == "assistant":
if isinstance(item.content, list):
conversation_items.append({"type": "message", "role": "assistant", "content": item.content})
else:
raise NotImplementedError(f"Unsupported assistant message content type: {type(item.content)}")
else:
raise NotImplementedError(f"Unsupported message role: {item.role}")
# add assistant response message
for output_item in response.output:
if isinstance(output_item, OpenAIResponseMessage) and output_item.role == "assistant":
if hasattr(output_item, "content") and isinstance(output_item.content, list):
conversation_items.append({"type": "message", "role": "assistant", "content": output_item.content})
if conversation_items:
adapter = TypeAdapter(list[ConversationItem])
validated_items = adapter.validate_python(conversation_items)
await self.conversations_api.add_items(conversation_id, validated_items)

View file

@ -13,6 +13,9 @@ from llama_stack.apis.agents.openai_responses import (
ApprovalFilter,
MCPListToolsTool,
OpenAIResponseContentPartOutputText,
OpenAIResponseContentPartReasoningText,
OpenAIResponseContentPartRefusal,
OpenAIResponseError,
OpenAIResponseInputTool,
OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest,
@ -22,8 +25,11 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseContentPartAdded,
OpenAIResponseObjectStreamResponseContentPartDone,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseFailed,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
OpenAIResponseObjectStreamResponseIncomplete,
OpenAIResponseObjectStreamResponseInProgress,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
@ -31,21 +37,31 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseOutputItemAdded,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseObjectStreamResponseReasoningTextDelta,
OpenAIResponseObjectStreamResponseReasoningTextDone,
OpenAIResponseObjectStreamResponseRefusalDelta,
OpenAIResponseObjectStreamResponseRefusalDone,
OpenAIResponseOutput,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText,
OpenAIResponseUsage,
OpenAIResponseUsageInputTokensDetails,
OpenAIResponseUsageOutputTokensDetails,
WebSearchToolTypes,
)
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
@ -94,113 +110,174 @@ class StreamingResponseOrchestrator:
self.tool_executor = tool_executor
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations
self.citation_files: dict[str, str] = {}
# Track accumulated usage across all inference calls
self.accumulated_usage: OpenAIResponseUsage | None = None
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages
output_messages: list[OpenAIResponseOutput] = []
# Create initial response and emit response.created immediately
initial_response = OpenAIResponseObject(
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
cloned: list[OpenAIResponseOutput] = []
for item in outputs:
if hasattr(item, "model_copy"):
cloned.append(item.model_copy(deep=True))
else:
cloned.append(item)
return cloned
def _snapshot_response(
self,
status: str,
outputs: list[OpenAIResponseOutput],
*,
error: OpenAIResponseError | None = None,
) -> OpenAIResponseObject:
return OpenAIResponseObject(
created_at=self.created_at,
id=self.response_id,
model=self.ctx.model,
object="response",
status="in_progress",
output=output_messages.copy(),
status=status,
output=self._clone_outputs(outputs),
text=self.text,
tools=self.ctx.available_tools(),
error=error,
usage=self.accumulated_usage,
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
output_messages: list[OpenAIResponseOutput] = []
# Process all tools (including MCP tools) and emit streaming events
if self.ctx.response_tools:
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
yield stream_event
# Emit response.created followed by response.in_progress to align with OpenAI streaming
yield OpenAIResponseObjectStreamResponseCreated(
response=self._snapshot_response("in_progress", output_messages)
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseInProgress(
response=self._snapshot_response("in_progress", output_messages),
sequence_number=self.sequence_number,
)
async for stream_event in self._process_tools(output_messages):
yield stream_event
n_iter = 0
messages = self.ctx.messages.copy()
final_status = "completed"
last_completion_result: ChatCompletionResult | None = None
while True:
# Text is the default response format for chat completion so don't need to pass it
# (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
completion_result = await self.inference_api.openai_chat_completion(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
)
try:
while True:
# Text is the default response format for chat completion so don't need to pass it
# (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
stream_options={
"include_usage": True,
},
)
completion_result = await self.inference_api.openai_chat_completion(params)
# Process streaming chunks and build complete response
completion_result_data = None
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
if isinstance(stream_event_or_result, ChatCompletionResult):
completion_result_data = stream_event_or_result
else:
yield stream_event_or_result
if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data")
current_response = self._build_chat_completion(completion_result_data)
# Process streaming chunks and build complete response
completion_result_data = None
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
if isinstance(stream_event_or_result, ChatCompletionResult):
completion_result_data = stream_event_or_result
else:
yield stream_event_or_result
if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data")
last_completion_result = completion_result_data
current_response = self._build_chat_completion(completion_result_data)
function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
current_response, messages
)
(
function_tool_calls,
non_function_tool_calls,
approvals,
next_turn_messages,
) = self._separate_tool_calls(current_response, messages)
# add any approval requests required
for tool_call in approvals:
async for evt in self._add_mcp_approval_request(
tool_call.function.name, tool_call.function.arguments, output_messages
# add any approval requests required
for tool_call in approvals:
async for evt in self._add_mcp_approval_request(
tool_call.function.name, tool_call.function.arguments, output_messages
):
yield evt
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
output_messages.append(
await convert_chat_choice_to_response_message(
choice,
self.citation_files,
message_id=completion_result_data.message_item_id,
)
)
# Execute tool calls and coordinate results
async for stream_event in self._coordinate_tool_execution(
function_tool_calls,
non_function_tool_calls,
completion_result_data,
output_messages,
next_turn_messages,
):
yield evt
yield stream_event
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
output_messages.append(await convert_chat_choice_to_response_message(choice))
messages = next_turn_messages
# Execute tool calls and coordinate results
async for stream_event in self._coordinate_tool_execution(
function_tool_calls,
non_function_tool_calls,
completion_result_data,
output_messages,
next_turn_messages,
):
yield stream_event
if not function_tool_calls and not non_function_tool_calls:
break
if not function_tool_calls and not non_function_tool_calls:
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= self.max_infer_iters:
logger.info(
f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}"
)
final_status = "incomplete"
break
n_iter += 1
if n_iter >= self.max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
break
if last_completion_result and last_completion_result.finish_reason == "length":
final_status = "incomplete"
messages = next_turn_messages
except Exception as exc: # noqa: BLE001
self.final_messages = messages.copy()
self.sequence_number += 1
error = OpenAIResponseError(code="internal_error", message=str(exc))
failure_response = self._snapshot_response("failed", output_messages, error=error)
yield OpenAIResponseObjectStreamResponseFailed(
response=failure_response,
sequence_number=self.sequence_number,
)
return
self.final_messages = messages.copy() + [current_response.choices[0].message]
self.final_messages = messages.copy()
# Create final response
final_response = OpenAIResponseObject(
created_at=self.created_at,
id=self.response_id,
model=self.ctx.model,
object="response",
status="completed",
text=self.text,
output=output_messages,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
if final_status == "incomplete":
self.sequence_number += 1
final_response = self._snapshot_response("incomplete", output_messages)
yield OpenAIResponseObjectStreamResponseIncomplete(
response=final_response,
sequence_number=self.sequence_number,
)
else:
final_response = self._snapshot_response("completed", output_messages)
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
"""Separate tool calls into function and non-function categories."""
@ -211,6 +288,8 @@ class StreamingResponseOrchestrator:
for choice in current_response.choices:
next_turn_messages.append(choice.message)
logger.debug(f"Choice message content: {choice.message.content}")
logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}")
if choice.message.tool_calls and self.ctx.response_tools:
for tool_call in choice.message.tool_calls:
@ -227,14 +306,183 @@ class StreamingResponseOrchestrator:
non_function_tool_calls.append(tool_call)
else:
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
next_turn_messages.pop()
else:
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
approvals.append(tool_call)
next_turn_messages.pop()
else:
non_function_tool_calls.append(tool_call)
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
def _accumulate_chunk_usage(self, chunk: OpenAIChatCompletionChunk) -> None:
"""Accumulate usage from a streaming chunk into the response usage format."""
if not chunk.usage:
return
if self.accumulated_usage is None:
# Convert from chat completion format to response format
self.accumulated_usage = OpenAIResponseUsage(
input_tokens=chunk.usage.prompt_tokens,
output_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
input_tokens_details=(
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
if chunk.usage.prompt_tokens_details
else None
),
output_tokens_details=(
OpenAIResponseUsageOutputTokensDetails(
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
)
if chunk.usage.completion_tokens_details
else None
),
)
else:
# Accumulate across multiple inference calls
self.accumulated_usage = OpenAIResponseUsage(
input_tokens=self.accumulated_usage.input_tokens + chunk.usage.prompt_tokens,
output_tokens=self.accumulated_usage.output_tokens + chunk.usage.completion_tokens,
total_tokens=self.accumulated_usage.total_tokens + chunk.usage.total_tokens,
# Use latest non-null details
input_tokens_details=(
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
if chunk.usage.prompt_tokens_details
else self.accumulated_usage.input_tokens_details
),
output_tokens_details=(
OpenAIResponseUsageOutputTokensDetails(
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
)
if chunk.usage.completion_tokens_details
else self.accumulated_usage.output_tokens_details
),
)
async def _handle_reasoning_content_chunk(
self,
reasoning_content: str,
reasoning_part_emitted: bool,
reasoning_content_index: int,
message_item_id: str,
message_output_index: int,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Emit content_part.added event for first reasoning chunk
if not reasoning_part_emitted:
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
content_index=reasoning_content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartReasoningText(
text="", # Will be filled incrementally via reasoning deltas
),
sequence_number=self.sequence_number,
)
# Emit reasoning_text.delta event
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseReasoningTextDelta(
content_index=reasoning_content_index,
delta=reasoning_content,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
async def _handle_refusal_content_chunk(
self,
refusal_content: str,
refusal_part_emitted: bool,
refusal_content_index: int,
message_item_id: str,
message_output_index: int,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Emit content_part.added event for first refusal chunk
if not refusal_part_emitted:
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
content_index=refusal_content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartRefusal(
refusal="", # Will be filled incrementally via refusal deltas
),
sequence_number=self.sequence_number,
)
# Emit refusal.delta event
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseRefusalDelta(
content_index=refusal_content_index,
delta=refusal_content,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
async def _emit_reasoning_done_events(
self,
reasoning_text_accumulated: list[str],
reasoning_content_index: int,
message_item_id: str,
message_output_index: int,
) -> AsyncIterator[OpenAIResponseObjectStream]:
final_reasoning_text = "".join(reasoning_text_accumulated)
# Emit reasoning_text.done event
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseReasoningTextDone(
content_index=reasoning_content_index,
text=final_reasoning_text,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
# Emit content_part.done for reasoning
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
content_index=reasoning_content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartReasoningText(
text=final_reasoning_text,
),
sequence_number=self.sequence_number,
)
async def _emit_refusal_done_events(
self,
refusal_text_accumulated: list[str],
refusal_content_index: int,
message_item_id: str,
message_output_index: int,
) -> AsyncIterator[OpenAIResponseObjectStream]:
final_refusal_text = "".join(refusal_text_accumulated)
# Emit refusal.done event
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseRefusalDone(
content_index=refusal_content_index,
refusal=final_refusal_text,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
# Emit content_part.done for refusal
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
content_index=refusal_content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartRefusal(
refusal=final_refusal_text,
),
sequence_number=self.sequence_number,
)
async def _process_streaming_chunks(
self, completion_result, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
@ -253,11 +501,23 @@ class StreamingResponseOrchestrator:
tool_call_item_ids: dict[int, str] = {}
# Track content parts for streaming events
content_part_emitted = False
reasoning_part_emitted = False
refusal_part_emitted = False
content_index = 0
reasoning_content_index = 1 # reasoning is a separate content part
refusal_content_index = 2 # refusal is a separate content part
message_output_index = len(output_messages)
reasoning_text_accumulated = []
refusal_text_accumulated = []
async for chunk in completion_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
# Accumulate usage from chunks (typically in final chunk with stream_options)
self._accumulate_chunk_usage(chunk)
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
@ -266,8 +526,10 @@ class StreamingResponseOrchestrator:
content_part_emitted = True
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
content_index=content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartOutputText(
text="", # Will be filled incrementally via text deltas
),
@ -275,10 +537,10 @@ class StreamingResponseOrchestrator:
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
content_index=content_index,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
@ -287,6 +549,32 @@ class StreamingResponseOrchestrator:
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Handle reasoning content if present (non-standard field for o1/o3 models)
if hasattr(chunk_choice.delta, "reasoning_content") and chunk_choice.delta.reasoning_content:
async for event in self._handle_reasoning_content_chunk(
reasoning_content=chunk_choice.delta.reasoning_content,
reasoning_part_emitted=reasoning_part_emitted,
reasoning_content_index=reasoning_content_index,
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
reasoning_part_emitted = True
reasoning_text_accumulated.append(chunk_choice.delta.reasoning_content)
# Handle refusal content if present
if chunk_choice.delta.refusal:
async for event in self._handle_refusal_content_chunk(
refusal_content=chunk_choice.delta.refusal,
refusal_part_emitted=refusal_part_emitted,
refusal_content_index=refusal_content_index,
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
refusal_part_emitted = True
refusal_text_accumulated.append(chunk_choice.delta.refusal)
# Aggregate tool call arguments across chunks
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
@ -378,14 +666,36 @@ class StreamingResponseOrchestrator:
final_text = "".join(chat_response_content)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
content_index=content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartOutputText(
text=final_text,
),
sequence_number=self.sequence_number,
)
# Emit reasoning done events if reasoning content was streamed
if reasoning_part_emitted:
async for event in self._emit_reasoning_done_events(
reasoning_text_accumulated=reasoning_text_accumulated,
reasoning_content_index=reasoning_content_index,
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
# Emit refusal done events if refusal content was streamed
if refusal_part_emitted:
async for event in self._emit_refusal_done_events(
refusal_text_accumulated=refusal_text_accumulated,
refusal_content_index=refusal_content_index,
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
# Clear content when there are tool calls (OpenAI spec behavior)
if chat_response_tool_calls:
chat_response_content = []
@ -470,6 +780,8 @@ class StreamingResponseOrchestrator:
tool_call_log = result.final_output_message
tool_response_message = result.final_input_message
self.sequence_number = result.sequence_number
if result.citation_files:
self.citation_files.update(result.citation_files)
if tool_call_log:
output_messages.append(tool_call_log)
@ -518,7 +830,7 @@ class StreamingResponseOrchestrator:
sequence_number=self.sequence_number,
)
async def _process_tools(
async def _process_new_tools(
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Process all tools and emit appropriate streaming events."""
@ -573,7 +885,6 @@ class StreamingResponseOrchestrator:
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
sequence_number=self.sequence_number,
)
try:
# Parse allowed/never allowed tools
always_allowed = None
@ -586,14 +897,22 @@ class StreamingResponseOrchestrator:
never_allowed = mcp_tool.allowed_tools.never
# Call list_mcp_tools
tool_defs = await list_mcp_tools(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
)
tool_defs = None
list_id = f"mcp_list_{uuid.uuid4()}"
attributes = {
"server_label": mcp_tool.server_label,
"server_url": mcp_tool.server_url,
"mcp_list_tools_id": list_id,
}
async with tracing.span("list_mcp_tools", attributes):
tool_defs = await list_mcp_tools(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
)
# Create the MCP list tools message
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
id=list_id,
server_label=mcp_tool.server_label,
tools=[],
)
@ -627,39 +946,26 @@ class StreamingResponseOrchestrator:
},
)
)
# Add the MCP list message to output
output_messages.append(mcp_list_message)
# Emit output_item.added for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
# Emit mcp_list_tools.completed
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
sequence_number=self.sequence_number,
)
# Emit output_item.done for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
yield stream_event
except Exception as e:
# TODO: Emit mcp_list_tools.failed event if needed
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
raise
async def _process_tools(
self, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Handle all mcp tool lists from previous response that are still valid:
for tool in self.ctx.tool_context.previous_tool_listings:
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
yield evt
# Process all remaining tools (including MCP tools) and emit streaming events
if self.ctx.tool_context.tools_to_process:
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
yield stream_event
def _approval_required(self, tool_name: str) -> bool:
if tool_name not in self.mcp_tool_to_server:
return False
@ -694,7 +1000,6 @@ class StreamingResponseOrchestrator:
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
@ -702,3 +1007,60 @@ class StreamingResponseOrchestrator:
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
async def _add_mcp_list_tools(
self, mcp_list_message: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Add the MCP list message to output
output_messages.append(mcp_list_message)
# Emit output_item.added for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
# Emit mcp_list_tools.completed
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
sequence_number=self.sequence_number,
)
# Emit output_item.done for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
async def _reuse_mcp_list_tools(
self, original: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
for t in original.tools:
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
# convert from input_schema to map of ToolParamDefinitions...
tool_def = ToolDefinition(
tool_name=t.name,
description=t.description,
input_schema=t.input_schema,
)
# ...then can convert that to openai completions tool
openai_tool = convert_tooldef_to_openai_tool(tool_def)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
server_label=original.server_label,
tools=original.tools,
)
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
yield stream_event

View file

@ -11,6 +11,9 @@ from collections.abc import AsyncIterator
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseObjectStreamResponseFileSearchCallCompleted,
OpenAIResponseObjectStreamResponseFileSearchCallInProgress,
OpenAIResponseObjectStreamResponseFileSearchCallSearching,
OpenAIResponseObjectStreamResponseMcpCallCompleted,
OpenAIResponseObjectStreamResponseMcpCallFailed,
OpenAIResponseObjectStreamResponseMcpCallInProgress,
@ -35,6 +38,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
from .types import ChatCompletionContext, ToolExecutionResult
@ -94,7 +98,10 @@ class ToolExecutor:
# Yield the final result
yield ToolExecutionResult(
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
sequence_number=sequence_number,
final_output_message=output_message,
final_input_message=input_message,
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
)
async def _execute_knowledge_search_via_vector_store(
@ -129,8 +136,6 @@ class ToolExecutor:
for results in all_results:
search_results.extend(results)
# Convert search results to tool result format matching memory.py
# Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
@ -138,27 +143,58 @@ class ToolExecutor:
)
)
unique_files = set()
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
# Get file_id from attributes if result_item.file_id is empty
file_id = result_item.file_id or (
result_item.attributes.get("document_id") if result_item.attributes else None
)
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
unique_files.add(file_id)
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
citation_instruction = ""
if unique_files:
citation_instruction = (
" Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). "
"Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
)
content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
)
)
# handling missing attributes for old versions
citation_files = {}
for result in search_results:
file_id = result.file_id
if not file_id and result.attributes:
file_id = result.attributes.get("document_id")
filename = result.filename
if not filename and result.attributes:
filename = result.attributes.get("filename")
if not filename:
filename = "unknown"
citation_files[file_id] = filename
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
"citation_files": citation_files,
},
)
@ -188,7 +224,13 @@ class ToolExecutor:
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
elif function_name == "knowledge_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
@ -203,6 +245,16 @@ class ToolExecutor:
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
# For file search, emit searching event
if function_name == "knowledge_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
async def _execute_tool(
self,
function_name: str,
@ -219,12 +271,18 @@ class ToolExecutor:
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
mcp_tool = mcp_tool_to_server[function_name]
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
)
attributes = {
"server_label": mcp_tool.server_label,
"server_url": mcp_tool.server_url,
"tool_name": function_name,
}
async with tracing.span("invoke_mcp_tool", attributes):
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
)
elif function_name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
@ -234,15 +292,20 @@ class ToolExecutor:
# Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options
query = tool_kwargs.get("query", "")
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
async with tracing.span("knowledge_search", {}):
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function_name,
kwargs=tool_kwargs,
)
attributes = {
"tool_name": function_name,
}
async with tracing.span("invoke_tool", attributes):
result = await self.tool_runtime_api.invoke_tool(
tool_name=function_name,
kwargs=tool_kwargs,
)
except Exception as e:
error_exc = e
@ -278,7 +341,13 @@ class ToolExecutor:
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
elif function_name == "knowledge_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)

View file

@ -12,10 +12,18 @@ from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseOutput,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseTool,
OpenAIResponseToolMCP,
)
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
@ -27,6 +35,7 @@ class ToolExecutionResult(BaseModel):
sequence_number: int
final_output_message: OpenAIResponseOutput | None = None
final_input_message: OpenAIMessageParam | None = None
citation_files: dict[str, str] | None = None
@dataclass
@ -54,6 +63,86 @@ class ChatCompletionResult:
return bool(self.tool_calls)
class ToolContext(BaseModel):
"""Holds information about tools from this and (if relevant)
previous response in order to facilitate reuse of previous
listings where appropriate."""
# tools argument passed into current request:
current_tools: list[OpenAIResponseInputTool]
# reconstructed map of tool -> mcp server from previous response:
previous_tools: dict[str, OpenAIResponseInputToolMCP]
# reusable mcp-list-tools objects from previous response:
previous_tool_listings: list[OpenAIResponseOutputMessageMCPListTools]
# tool arguments from current request that still need to be processed:
tools_to_process: list[OpenAIResponseInputTool]
def __init__(
self,
current_tools: list[OpenAIResponseInputTool] | None,
):
super().__init__(
current_tools=current_tools or [],
previous_tools={},
previous_tool_listings=[],
tools_to_process=current_tools or [],
)
def recover_tools_from_previous_response(
self,
previous_response: OpenAIResponseObject,
):
"""Determine which mcp_list_tools objects from previous response we can reuse."""
if self.current_tools and previous_response.tools:
previous_tools_by_label: dict[str, OpenAIResponseToolMCP] = {}
for tool in previous_response.tools:
if isinstance(tool, OpenAIResponseToolMCP):
previous_tools_by_label[tool.server_label] = tool
# collect tool definitions which are the same in current and previous requests:
tools_to_process = []
matched: dict[str, OpenAIResponseInputToolMCP] = {}
for tool in self.current_tools:
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
previous_tool = previous_tools_by_label[tool.server_label]
if previous_tool.allowed_tools == tool.allowed_tools:
matched[tool.server_label] = tool
else:
tools_to_process.append(tool)
else:
tools_to_process.append(tool)
# tools that are not the same or were not previously defined need to be processed:
self.tools_to_process = tools_to_process
# for all matched definitions, get the mcp_list_tools objects from the previous output:
self.previous_tool_listings = [
obj for obj in previous_response.output if obj.type == "mcp_list_tools" and obj.server_label in matched
]
# reconstruct the tool to server mappings that can be reused:
for listing in self.previous_tool_listings:
definition = matched[listing.server_label]
for tool in listing.tools:
self.previous_tools[tool.name] = definition
def available_tools(self) -> list[OpenAIResponseTool]:
if not self.current_tools:
return []
def convert_tool(tool: OpenAIResponseInputTool) -> OpenAIResponseTool:
if isinstance(tool, OpenAIResponseInputToolWebSearch):
return tool
if isinstance(tool, OpenAIResponseInputToolFileSearch):
return tool
if isinstance(tool, OpenAIResponseInputToolFunction):
return tool
if isinstance(tool, OpenAIResponseInputToolMCP):
return OpenAIResponseToolMCP(
server_label=tool.server_label,
allowed_tools=tool.allowed_tools,
)
return [convert_tool(tool) for tool in self.current_tools]
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
@ -61,6 +150,7 @@ class ChatCompletionContext(BaseModel):
chat_tools: list[ChatCompletionToolParam] | None = None
temperature: float | None
response_format: OpenAIResponseFormatParam
tool_context: ToolContext | None
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
@ -71,6 +161,7 @@ class ChatCompletionContext(BaseModel):
response_tools: list[OpenAIResponseInputTool] | None,
temperature: float | None,
response_format: OpenAIResponseFormatParam,
tool_context: ToolContext,
inputs: list[OpenAIResponseInput] | str,
):
super().__init__(
@ -79,6 +170,7 @@ class ChatCompletionContext(BaseModel):
response_tools=response_tools,
temperature=temperature,
response_format=response_format,
tool_context=tool_context,
)
if not isinstance(inputs, str):
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
@ -95,3 +187,8 @@ class ChatCompletionContext(BaseModel):
if request.name == tool_name and request.arguments == arguments:
return request
return None
def available_tools(self) -> list[OpenAIResponseTool]:
if not self.tool_context:
return []
return self.tool_context.available_tools()

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import uuid
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
@ -45,7 +47,12 @@ from llama_stack.apis.inference import (
)
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
async def convert_chat_choice_to_response_message(
choice: OpenAIChoice,
citation_files: dict[str, str] | None = None,
*,
message_id: str | None = None,
) -> OpenAIResponseMessage:
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
output_content = ""
if isinstance(choice.message.content, str):
@ -57,9 +64,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
id=message_id or f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
status="completed",
role="assistant",
)
@ -97,9 +106,13 @@ async def convert_response_content_to_chat_content(
async def convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
previous_messages: list[OpenAIMessageParam] | None = None,
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
:param input: The input to convert
:param previous_messages: Optional previous messages to check for function_call references
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
@ -163,16 +176,53 @@ async def convert_response_input_to_chat_messages(
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
# Skip user messages that duplicate the last user message in previous_messages
# This handles cases where input includes context for function_call_outputs
if previous_messages and input_item.role == "user":
last_user_msg = None
for msg in reversed(previous_messages):
if isinstance(msg, OpenAIUserMessageParam):
last_user_msg = msg
break
if last_user_msg:
last_user_content = getattr(last_user_msg, "content", None)
if last_user_content == content:
continue # Skip duplicate user message
messages.append(message_type(content=content))
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
# Check if unpaired function_call_outputs reference function_calls from previous messages
if previous_messages:
previous_call_ids = _extract_tool_call_ids(previous_messages)
for call_id in list(tool_call_results.keys()):
if call_id in previous_call_ids:
# Valid: this output references a call from previous messages
# Add the tool message
messages.append(tool_call_results[call_id])
del tool_call_results[call_id]
# If still have unpaired outputs, error
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]:
"""Extract all tool_call IDs from messages."""
call_ids = set()
for msg in messages:
if isinstance(msg, OpenAIAssistantMessageParam):
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
for tool_call in tool_calls:
# tool_call is a Pydantic model, use attribute access
call_ids.add(tool_call.id)
return call_ids
async def convert_response_text_to_chat_response_format(
text: OpenAIResponseText,
) -> OpenAIResponseFormatParam:
@ -200,6 +250,53 @@ async def get_message_type_by_role(role: str):
return role_to_type.get(role)
def _extract_citations_from_text(
text: str, citation_files: dict[str, str]
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
"""Extract citation markers from text and create annotations
Args:
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
citation_files: Dictionary mapping file_id to filename
Returns:
Tuple of (annotations_list, clean_text_without_markers)
"""
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
annotations = []
parts = []
total_len = 0
last_end = 0
for m in file_id_regex.finditer(text):
# segment before the marker
prefix = text[last_end : m.start()]
# drop one space if it exists (since marker is at sentence end)
if prefix.endswith(" "):
prefix = prefix[:-1]
parts.append(prefix)
total_len += len(prefix)
fid = m.group(1)
if fid in citation_files:
annotations.append(
OpenAIResponseAnnotationFileCitation(
file_id=fid,
filename=citation_files[fid],
index=total_len, # index points to punctuation
)
)
last_end = m.end()
parts.append(text[last_end:])
cleaned_text = "".join(parts)
return annotations, cleaned_text
def is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],

View file

@ -22,7 +22,10 @@ from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIDeveloperMessageParam,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
@ -178,9 +181,9 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
@ -425,18 +428,23 @@ class ReferenceBatchesImpl(Batches):
valid = False
if batch.endpoint == "/v1/chat/completions":
required_params = [
required_params: list[tuple[str, Any, str]] = [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]
else: # /v1/completions
elif batch.endpoint == "/v1/completions":
required_params = [
("model", str, "a string"),
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
]
else: # /v1/embeddings
required_params = [
("model", str, "a string"),
("input", (str, list), "a string or array of strings"),
]
for param, expected_type, type_string in required_params:
if param not in body:
@ -601,7 +609,8 @@ class ReferenceBatchesImpl(Batches):
# TODO(SECURITY): review body for security issues
if request.url == "/v1/chat/completions":
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
chat_params = OpenAIChatCompletionRequestWithExtraBody(**request.body)
chat_response = await self.inference_api.openai_chat_completion(chat_params)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
@ -614,8 +623,9 @@ class ReferenceBatchesImpl(Batches):
"body": chat_response.model_dump_json(),
},
}
else: # /v1/completions
completion_response = await self.inference_api.openai_completion(**request.body)
elif request.url == "/v1/completions":
completion_params = OpenAICompletionRequestWithExtraBody(**request.body)
completion_response = await self.inference_api.openai_completion(completion_params)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), (
@ -630,6 +640,22 @@ class ReferenceBatchesImpl(Batches):
"body": completion_response.model_dump_json(),
},
}
else: # /v1/embeddings
embeddings_response = await self.inference_api.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(**request.body)
)
assert hasattr(embeddings_response, "model_dump_json"), (
"Embeddings response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": embeddings_response.model_dump_json(),
},
}
except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return {

View file

@ -12,7 +12,14 @@ from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage
from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAISystemMessageParam,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
@ -168,11 +175,12 @@ class MetaReferenceEvalImpl(
sampling_params["stop"] = candidate.sampling_params.stop
input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.openai_completion(
params = OpenAICompletionRequestWithExtraBody(
model=candidate.model,
prompt=input_content,
**sampling_params,
)
response = await self.inference_api.openai_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
@ -187,11 +195,12 @@ class MetaReferenceEvalImpl(
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages
response = await self.inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=candidate.model,
messages=messages,
**sampling_params,
)
response = await self.inference_api.openai_chat_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
else:
raise ValueError("Invalid input row")

View file

@ -22,6 +22,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
@ -65,7 +66,7 @@ class LocalfsFilesImpl(Files):
def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API."""
return f"file-{uuid.uuid4().hex}"
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
def _get_file_path(self, file_id: str) -> Path:
"""Get the filesystem path for a file ID."""
@ -95,7 +96,9 @@ class LocalfsFilesImpl(Files):
raise RuntimeError("Files provider not initialized")
if expires_after is not None:
raise NotImplementedError("File expiration is not supported by this provider")
logger.warning(
f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}"
)
file_id = self._generate_file_id()
file_path = self._get_file_path(file_id)

View file

@ -18,7 +18,7 @@ def model_checkpoint_dir(model_id) -> str:
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
f"If you try to use the native llama model, please download the model using `llama-model download --source meta --model-id {model_id}` (see https://github.com/meta-llama/llama-models). "
f"Otherwise, please save your model checkpoint under {model_local_dir(model_id)}"
)
return str(checkpoint_dir)

View file

@ -6,16 +6,16 @@
import asyncio
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import (
InferenceProvider,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAICompletion,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
@ -65,7 +65,10 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group:
self.generator.stop()
async def openai_completion(self, *args, **kwargs):
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
async def should_refresh_models(self) -> bool:
@ -150,28 +153,6 @@ class MetaReferenceInferenceImpl(
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")

View file

@ -5,17 +5,16 @@
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import (
InferenceProvider,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
@ -73,56 +72,12 @@ class SentenceTransformersInferenceImpl(
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")

View file

@ -104,9 +104,10 @@ class LoraFinetuningSingleDevice:
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
hf_repo = model.huggingface_repo or f"meta-llama/{model.descriptor()}"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
f"Please download the model using `huggingface-cli download {hf_repo} --local-dir ~/.llama/{model.descriptor()}`"
)
return str(checkpoint_dir)

View file

@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from codeshield.cs import CodeShieldScanResult
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -53,7 +53,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)

View file

@ -10,7 +10,12 @@ from string import Template
from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import Inference, Message, UserMessage
from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -159,7 +164,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
@ -169,8 +174,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
if len(messages) > 0 and messages[0].role != "user":
messages[0] = OpenAIUserMessageParam(content=messages[0].content)
# Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
@ -202,7 +207,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
messages = [input]
# convert to user messages format with role
messages = [UserMessage(content=m) for m in messages]
messages = [OpenAIUserMessageParam(content=m) for m in messages]
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
@ -271,7 +276,7 @@ class LlamaGuardShield:
return final_categories
def validate_messages(self, messages: list[Message]) -> None:
def validate_messages(self, messages: list[OpenAIMessageParam]) -> list[OpenAIMessageParam]:
if len(messages) == 0:
raise ValueError("Messages must not be empty")
if messages[0].role != Role.user.value:
@ -282,7 +287,7 @@ class LlamaGuardShield:
return messages
async def run(self, messages: list[Message]) -> RunShieldResponse:
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
messages = self.validate_messages(messages)
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
@ -290,20 +295,21 @@ class LlamaGuardShield:
else:
shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content
content = content.strip()
return self.get_shield_response(content)
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages))
def build_text_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
return OpenAIUserMessageParam(content=self.build_prompt(messages))
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage:
def build_vision_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
conversation = []
most_recent_img = None
@ -326,7 +332,7 @@ class LlamaGuardShield:
else:
raise ValueError(f"Unknown content type: {c}")
conversation.append(UserMessage(content=content))
conversation.append(OpenAIUserMessageParam(content=content))
else:
raise ValueError(f"Unknown content type: {m.content}")
@ -335,9 +341,9 @@ class LlamaGuardShield:
prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1]))
return UserMessage(content=prompt)
return OpenAIUserMessageParam(content=prompt)
def build_prompt(self, messages: list[Message]) -> str:
def build_prompt(self, messages: list[OpenAIMessageParam]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
@ -370,18 +376,20 @@ class LlamaGuardShield:
raise ValueError(f"Unexpected response: {response}")
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
async def run_moderation(self, messages: list[OpenAIMessageParam]) -> ModerationObject:
if not messages:
return self.create_moderation_object(self.model)
# TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content
content = content.strip()
return self.get_moderation_object(content)

View file

@ -9,7 +9,7 @@ from typing import Any
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -22,9 +22,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from .config import PromptGuardConfig, PromptGuardType
@ -56,7 +54,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any],
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
@ -93,7 +91,7 @@ class PromptGuardShield:
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
async def run(self, messages: list[Message]) -> RunShieldResponse:
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
message = messages[-1]
text = interleaved_content_as_str(message.content)

View file

@ -6,7 +6,7 @@
import re
from typing import Any
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import Inference, OpenAIChatCompletionRequestWithExtraBody
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
@ -55,7 +55,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
generated_answer=generated_answer,
)
judge_response = await self.inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=fn_def.params.judge_model,
messages=[
{
@ -64,6 +64,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
}
],
)
judge_response = await self.inference_api.openai_chat_completion(params)
content = judge_response.choices[0].message.content
rating_regexes = fn_def.params.judge_score_regexes

View file

@ -8,7 +8,7 @@
from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
@ -65,11 +65,12 @@ async def llm_rag_query_generator(
model = config.model
message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=model,
messages=[message],
stream=False,
)
response = await inference_api.openai_chat_completion(params)
query = response.choices[0].message.content

View file

@ -8,8 +8,6 @@ import asyncio
import base64
import io
import mimetypes
import secrets
import string
from typing import Any
import httpx
@ -52,10 +50,6 @@ from .context_retriever import generate_rag_query
log = get_logger(name=__name__, category="tool_runtime")
def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
@ -331,5 +325,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
return ToolInvocationResult(
content=result.content or [],
metadata=result.metadata,
metadata={
**(result.metadata or {}),
"citation_files": getattr(result, "citation_files", None),
},
)

View file

@ -200,12 +200,10 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.persistence)
@ -227,8 +225,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
# Cleanup if needed
pass
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def health(self) -> HealthResponse:
"""

View file

@ -410,12 +410,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
"""
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.persistence)
@ -436,8 +434,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]

View file

@ -32,9 +32,12 @@ def available_providers() -> list[ProviderSpec]:
Api.inference,
Api.safety,
Api.vector_io,
Api.vector_dbs,
Api.tool_runtime,
Api.tool_groups,
Api.conversations,
],
optional_api_dependencies=[
Api.telemetry,
],
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
),

View file

@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="cerebras",
provider_type="remote::cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
@ -169,7 +167,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="openai",
provider_type="remote::openai",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="anthropic",
provider_type="remote::anthropic",
pip_packages=["litellm"],
pip_packages=["anthropic"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
@ -189,9 +187,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="gemini",
provider_type="remote::gemini",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
@ -202,7 +198,6 @@ def available_providers() -> list[ProviderSpec]:
adapter_type="vertexai",
provider_type="remote::vertexai",
pip_packages=[
"litellm",
"google-cloud-aiplatform",
],
module="llama_stack.providers.remote.inference.vertexai",
@ -233,9 +228,7 @@ Available Models:
api=Api.inference,
adapter_type="groq",
provider_type="remote::groq",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
@ -245,7 +238,7 @@ Available Models:
api=Api.inference,
adapter_type="llama-openai-compat",
provider_type="remote::llama-openai-compat",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
@ -255,9 +248,7 @@ Available Models:
api=Api.inference,
adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
@ -277,7 +268,7 @@ Available Models:
api=Api.inference,
adapter_type="watsonx",
provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"],
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
@ -287,7 +278,7 @@ Available Models:
api=Api.inference,
provider_type="remote::azure",
adapter_type="azure",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.azure",
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",

View file

@ -11,6 +11,7 @@ from llama_stack.providers.datatypes import (
ProviderSpec,
RemoteProviderSpec,
)
from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS
def available_providers() -> list[ProviderSpec]:
@ -18,9 +19,8 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=[
"chardet",
"pypdf",
pip_packages=DEFAULT_VECTOR_IO_DEPS
+ [
"tqdm",
"numpy",
"scikit-learn",

View file

@ -12,13 +12,16 @@ from llama_stack.providers.datatypes import (
RemoteProviderSpec,
)
# Common dependencies for all vector IO providers that support document processing
DEFAULT_VECTOR_IO_DEPS = ["chardet", "pypdf"]
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::meta-reference",
pip_packages=["faiss-cpu"],
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.",
@ -29,7 +32,7 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::faiss",
pip_packages=["faiss-cpu"],
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
api_dependencies=[Api.inference],
@ -82,7 +85,7 @@ more details about Faiss in general.
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite-vec",
pip_packages=["sqlite-vec"],
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
api_dependencies=[Api.inference],
@ -289,7 +292,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite_vec",
pip_packages=["sqlite-vec"],
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
@ -303,7 +306,7 @@ Please refer to the sqlite-vec provider documentation.
api=Api.vector_io,
adapter_type="chromadb",
provider_type="remote::chromadb",
pip_packages=["chromadb-client"],
pip_packages=["chromadb-client"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
@ -345,7 +348,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::chromadb",
pip_packages=["chromadb"],
pip_packages=["chromadb"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.chroma",
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
@ -389,7 +392,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
api=Api.vector_io,
adapter_type="pgvector",
provider_type="remote::pgvector",
pip_packages=["psycopg2-binary"],
pip_packages=["psycopg2-binary"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
api_dependencies=[Api.inference],
@ -500,7 +503,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
api=Api.vector_io,
adapter_type="weaviate",
provider_type="remote::weaviate",
pip_packages=["weaviate-client>=4.16.5"],
pip_packages=["weaviate-client>=4.16.5"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
@ -541,7 +544,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::qdrant",
pip_packages=["qdrant-client"],
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.qdrant",
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
@ -594,7 +597,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
api=Api.vector_io,
adapter_type="qdrant",
provider_type="remote::qdrant",
pip_packages=["qdrant-client"],
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
@ -607,7 +610,7 @@ Please refer to the inline provider documentation.
api=Api.vector_io,
adapter_type="milvus",
provider_type="remote::milvus",
pip_packages=["pymilvus>=2.4.10"],
pip_packages=["pymilvus>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
@ -813,7 +816,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::milvus",
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
pip_packages=["pymilvus[milvus-lite]>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],

View file

@ -23,6 +23,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
@ -198,7 +199,7 @@ class S3FilesImpl(Files):
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
filename = getattr(file, "filename", None) or "uploaded_file"

View file

@ -10,6 +10,6 @@ from .config import AnthropicConfig
async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter
impl = AnthropicInferenceAdapter(config)
impl = AnthropicInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,13 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from collections.abc import Iterable
from anthropic import AsyncAnthropic
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AnthropicConfig
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class AnthropicInferenceAdapter(OpenAIMixin):
config: AnthropicConfig
provider_data_api_key_field: str = "anthropic_api_key"
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
# TODO: add support for voyageai, which is where these models are hosted
# embedding_model_metadata = {
@ -23,22 +29,8 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# }
def __init__(self, config: AnthropicConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",
)
self.config = config
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self):
return "https://api.anthropic.com/v1"
async def list_provider_model_ids(self) -> Iterable[str]:
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]

View file

@ -21,11 +21,6 @@ class AnthropicProviderDataValidator(BaseModel):
@json_schema_type
class AnthropicConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Anthropic models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {

View file

@ -10,6 +10,6 @@ from .config import AzureConfig
async def get_adapter_impl(config: AzureConfig, _deps):
from .azure import AzureInferenceAdapter
impl = AzureInferenceAdapter(config)
impl = AzureInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,31 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from urllib.parse import urljoin
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: AzureConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="azure",
api_key_from_config=config.api_key.get_secret_value(),
provider_data_api_key_field="azure_api_key",
openai_compat_api_base=str(config.api_base),
)
self.config = config
class AzureInferenceAdapter(OpenAIMixin):
config: AzureConfig
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
provider_data_api_key_field: str = "azure_api_key"
def get_base_url(self) -> str:
"""
@ -37,26 +23,3 @@ class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Returns the Azure API base URL from the configuration.
"""
return urljoin(str(self.config.api_base), "/openai/v1")
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add Azure specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "azure_api_key", None):
params["api_key"] = provider_data.azure_api_key
if getattr(provider_data, "azure_api_base", None):
params["api_base"] = provider_data.azure_api_base
if getattr(provider_data, "azure_api_version", None):
params["api_version"] = provider_data.azure_api_version
if getattr(provider_data, "azure_api_type", None):
params["api_type"] = provider_data.azure_api_type
else:
params["api_key"] = self.config.api_key.get_secret_value()
params["api_base"] = str(self.config.api_base)
params["api_version"] = self.config.api_version
params["api_type"] = self.config.api_type
return params

View file

@ -32,9 +32,6 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type
class AzureConfig(RemoteInferenceProviderConfig):
api_key: SecretStr = Field(
description="Azure API key for Azure",
)
api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
)

View file

@ -6,21 +6,21 @@
import json
from collections.abc import AsyncIterator
from typing import Any
from botocore.client import BaseClient
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
@ -125,66 +125,18 @@ class BedrockInferenceAdapter(
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")

View file

@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
impl = CerebrasInferenceAdapter(config)
impl = CerebrasInferenceAdapter(config=config)
await impl.initialize()

View file

@ -6,77 +6,23 @@
from urllib.parse import urljoin
from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Inference,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
TopKSamplingStrategy,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import CerebrasImplConfig
class CerebrasInferenceAdapter(
OpenAIMixin,
Inference,
):
def __init__(self, config: CerebrasImplConfig) -> None:
self.config = config
# TODO: make this use provider data, etc. like other providers
self._cerebras_client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
)
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request)
else:
raise ValueError(f"Unknown request type {type(request)}")
return {
"model": request.model,
"prompt": prompt,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -21,10 +21,6 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
api_key: SecretStr = Field(
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
description="Cerebras API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
impl = DatabricksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -14,12 +14,13 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
url: str | None = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: SecretStr = Field(
default=SecretStr(None),
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The Databricks API token",
)

View file

@ -4,16 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from collections.abc import Iterable
from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import (
Inference,
Model,
OpenAICompletion,
)
from llama_stack.apis.models import ModelType
from llama_stack.apis.inference import OpenAICompletion, OpenAICompletionRequestWithExtraBody
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -22,81 +17,28 @@ from .config import DatabricksImplConfig
logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(
OpenAIMixin,
Inference,
):
class DatabricksInferenceAdapter(OpenAIMixin):
config: DatabricksImplConfig
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
}
def __init__(self, config: DatabricksImplConfig) -> None:
self.config = config
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def list_provider_model_ids(self) -> Iterable[str]:
return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError()
async def list_models(self) -> list[Model] | None:
self._model_cache = {} # from OpenAIMixin
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
endpoints = ws_client.serving_endpoints.list()
for endpoint in endpoints:
model = Model(
provider_id=self.__provider_id__,
provider_resource_id=endpoint.name,
identifier=endpoint.name,
)
if endpoint.task == "llm/v1/chat":
model.model_type = ModelType.llm # this is redundant, but informative
elif endpoint.task == "llm/v1/embeddings":
if endpoint.name not in self.embedding_model_metadata:
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
continue
model.model_type = ModelType.embedding
model.metadata = self.embedding_model_metadata[endpoint.name]
else:
logger.warning(f"Unknown model type, skipping: {endpoint}")
continue
self._model_cache[endpoint.name] = model
return list(self._model_cache.values())
async def should_refresh_models(self) -> bool:
return False

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: FireworksImplConfig, _deps):
from .fireworks import FireworksInferenceAdapter
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
impl = FireworksInferenceAdapter(config)
impl = FireworksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -18,10 +18,6 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server",
)
api_key: SecretStr | None = Field(
default=None,
description="The Fireworks.ai API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -5,124 +5,23 @@
# the root directory of this source tree.
from fireworks.client import Fireworks
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
LogProbConfig,
ResponseFormat,
ResponseFormatType,
SamplingParams,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
request_has_media,
)
from .config import FireworksImplConfig
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData):
embedding_model_metadata = {
class FireworksInferenceAdapter(OpenAIMixin):
config: FireworksImplConfig
embedding_model_metadata: dict[str, dict[str, int]] = {
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
}
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config
self.allowed_models = config.allowed_models
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
return config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-Provider-Data as { "fireworks_api_key": <your api key>}'
)
return provider_data.fireworks_api_key
provider_data_api_key_field: str = "fireworks_api_key"
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks:
fireworks_api_key = self.get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _build_options(
self,
sampling_params: SamplingParams | None,
fmt: ResponseFormat | None,
logprobs: LogProbConfig | None,
) -> dict:
options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
options["logprobs"] = logprobs.top_k
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
raise ValueError("Required range: 0 < top_k < 5")
return options
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
# TODO: tools are never added to the request, so we need to add them here
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
# Fireworks always prepends with BOS
if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
params = {
"model": request.model,
**input_dict,
"stream": bool(request.stream),
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logger.debug(f"params to fireworks: {params}")
return params

View file

@ -10,6 +10,6 @@ from .config import GeminiConfig
async def get_adapter_impl(config: GeminiConfig, _deps):
from .gemini import GeminiInferenceAdapter
impl = GeminiInferenceAdapter(config)
impl = GeminiInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -21,11 +21,6 @@ class GeminiProviderDataValidator(BaseModel):
@json_schema_type
class GeminiConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Gemini models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {

View file

@ -4,33 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import GeminiConfig
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
embedding_model_metadata = {
class GeminiInferenceAdapter(OpenAIMixin):
config: GeminiConfig
provider_data_api_key_field: str = "gemini_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
}
def __init__(self, config: GeminiConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="gemini",
api_key_from_config=config.api_key,
provider_data_api_key_field="gemini_api_key",
)
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/"
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()

View file

@ -11,5 +11,5 @@ async def get_adapter_impl(config: GroqConfig, _deps):
# import dynamically so the import is used only when it is needed
from .groq import GroqInferenceAdapter
adapter = GroqInferenceAdapter(config)
adapter = GroqInferenceAdapter(config=config)
return adapter

View file

@ -21,12 +21,6 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type
class GroqConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
# The Groq client library loads the GROQ_API_KEY environment variable by default
default=None,
description="The Groq API key",
)
url: str = Field(
default="https://api.groq.com",
description="The URL for the Groq AI server",

View file

@ -6,30 +6,13 @@
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
_config: GroqConfig
class GroqInferenceAdapter(OpenAIMixin):
config: GroqConfig
def __init__(self, config: GroqConfig):
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="groq",
api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key",
)
self.config = config
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
provider_data_api_key_field: str = "groq_api_key"
def get_base_url(self) -> str:
return f"{self.config.url}/openai/v1"
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
async def get_adapter_impl(config: LlamaCompatConfig, _deps):
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter
adapter = LlamaCompatInferenceAdapter(config)
adapter = LlamaCompatInferenceAdapter(config=config)
return adapter

View file

@ -21,11 +21,6 @@ class LlamaProviderDataValidator(BaseModel):
@json_schema_type
class LlamaCompatConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="The Llama API key",
)
openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/",
description="The URL for the Llama API server",

View file

@ -3,44 +3,28 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.inference.inference import OpenAICompletion
from llama_stack.apis.inference.inference import (
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class LlamaCompatInferenceAdapter(OpenAIMixin):
config: LlamaCompatConfig
provider_data_api_key_field: str = "llama_api_key"
"""
Llama API Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
_config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="meta_llama",
api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.
@ -49,33 +33,14 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
return self.config.openai_compat_api_base
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError()
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -15,7 +15,8 @@ async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
if not isinstance(config, NVIDIAConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = NVIDIAInferenceAdapter(config)
adapter = NVIDIAInferenceAdapter(config=config)
await adapter.initialize()
return adapter

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -40,10 +40,6 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
description="A base url for accessing the NVIDIA NIM",
)
api_key: SecretStr | None = Field(
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY")),
description="The NVIDIA API key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",

View file

@ -8,8 +8,8 @@
from openai import NOT_GIVEN
from llama_stack.apis.inference import (
Inference,
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
@ -22,7 +22,9 @@ from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
class NVIDIAInferenceAdapter(OpenAIMixin):
config: NVIDIAConfig
"""
NVIDIA Inference Adapter for Llama Stack.
@ -37,32 +39,21 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
"""
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
}
def __init__(self, config: NVIDIAConfig) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
async def initialize(self) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
if _is_nvidia_hosted(config):
if not config.api_key:
if _is_nvidia_hosted(self.config):
if not self.config.auth_credential:
raise RuntimeError(
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
)
# elif self._config.api_key:
#
# we don't raise this warning because a user may have deployed their
# self-hosted NIM with an API key requirement.
#
# warnings.warn(
# "API key is not required for self-hosted NVIDIA NIM. "
# "Consider removing the api_key from the configuration."
# )
self._config = config
def get_api_key(self) -> str:
"""
@ -70,7 +61,13 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
:return: The NVIDIA API key
"""
return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
if self.config.auth_credential:
return self.config.auth_credential.get_secret_value()
if not _is_nvidia_hosted(self.config):
return "NO KEY REQUIRED"
return None
def get_base_url(self) -> str:
"""
@ -78,15 +75,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
:return: The NVIDIA API base URL
"""
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
OpenAI-compatible embeddings for NVIDIA NIM.
@ -103,11 +96,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
)
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
user=user if user is not None else NOT_GIVEN,
model=await self._get_provider_model_id(params.model),
input=params.input,
encoding_format=params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
dimensions=params.dimensions if params.dimensions is not None else NOT_GIVEN,
user=params.user if params.user is not None else NOT_GIVEN,
extra_body=extra_body,
)

View file

@ -1,217 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from collections.abc import AsyncGenerator
from typing import Any
from openai import AsyncStream
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
)
from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
GreedySamplingStrategy,
JsonSchemaResponseFormat,
TokenLogProbs,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.openai_compat import (
_convert_openai_finish_reason,
convert_message_to_openai_dict_new,
convert_tooldef_to_openai_tool,
)
async def convert_chat_completion_request(
request: ChatCompletionRequest,
n: int = 1,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
# model -> model
# messages -> messages
# sampling_params TODO(mattf): review strategy
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
# strategy=top_k -> nvext.top_k = top_k
# temperature -> temperature
# top_p -> top_p
# top_k -> nvext.top_k
# max_tokens -> max_tokens
# repetition_penalty -> nvext.repetition_penalty
# response_format -> GrammarResponseFormat TODO(mf)
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
# tools -> tools
# tool_choice ("auto", "required") -> tool_choice
# tool_prompt_format -> TBD
# stream -> stream
# logprobs -> logprobs
if request.response_format and not isinstance(request.response_format, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {request.response_format}. Only JsonSchemaResponseFormat is supported."
)
nvext = {}
payload: dict[str, Any] = dict(
model=request.model,
messages=[await convert_message_to_openai_dict_new(message) for message in request.messages],
stream=request.stream,
n=n,
extra_body=dict(nvext=nvext),
extra_headers={
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
},
)
if request.response_format:
# server bug - setting guided_json changes the behavior of response_format resulting in an error
# payload.update(response_format="json_object")
nvext.update(guided_json=request.response_format.json_schema)
if request.tools:
payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools])
if request.tool_config.tool_choice:
payload.update(
tool_choice=request.tool_config.tool_choice.value
) # we cannot include tool_choice w/o tools, server will complain
if request.logprobs:
payload.update(logprobs=True)
payload.update(top_logprobs=request.logprobs.top_k)
if request.sampling_params:
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
strategy = request.sampling_params.strategy
if isinstance(strategy, TopPSamplingStrategy):
nvext.update(top_k=-1)
payload.update(top_p=strategy.top_p)
payload.update(temperature=strategy.temperature)
elif isinstance(strategy, TopKSamplingStrategy):
if strategy.top_k != -1 and strategy.top_k < 1:
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
nvext.update(top_k=strategy.top_k)
elif isinstance(strategy, GreedySamplingStrategy):
nvext.update(top_k=-1)
else:
raise ValueError(f"Unsupported sampling strategy: {strategy}")
return payload
def convert_completion_request(
request: CompletionRequest,
n: int = 1,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
# model -> model
# prompt -> prompt
# sampling_params TODO(mattf): review strategy
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
# strategy=top_k -> nvext.top_k = top_k
# temperature -> temperature
# top_p -> top_p
# top_k -> nvext.top_k
# max_tokens -> max_tokens
# repetition_penalty -> nvext.repetition_penalty
# response_format -> nvext.guided_json
# stream -> stream
# logprobs.top_k -> logprobs
nvext = {}
payload: dict[str, Any] = dict(
model=request.model,
prompt=request.content,
stream=request.stream,
extra_body=dict(nvext=nvext),
extra_headers={
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
},
n=n,
)
if request.response_format:
# this is not openai compliant, it is a nim extension
nvext.update(guided_json=request.response_format.json_schema)
if request.logprobs:
payload.update(logprobs=request.logprobs.top_k)
if request.sampling_params:
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
if request.sampling_params.strategy == "top_p":
nvext.update(top_k=-1)
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=-1)
payload.update(temperature=request.sampling_params.temperature)
return payload
def _convert_openai_completion_logprobs(
logprobs: OpenAICompletionLogprobs | None,
) -> list[TokenLogProbs] | None:
"""
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
"""
if not logprobs:
return None
return [TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs]
def convert_openai_completion_choice(
choice: OpenAIChoice,
) -> CompletionResponse:
"""
Convert an OpenAI Completion Choice into a CompletionResponse.
"""
return CompletionResponse(
content=choice.text,
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
)
async def convert_openai_completion_stream(
stream: AsyncStream[OpenAICompletion],
) -> AsyncGenerator[CompletionResponse, None]:
"""
Convert a stream of OpenAI Completions into a stream
of ChatCompletionResponseStreamChunks.
"""
async for chunk in stream:
choice = chunk.choices[0]
yield CompletionResponseStreamChunk(
delta=choice.text,
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
)

View file

@ -4,53 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import httpx
from llama_stack.log import get_logger
from . import NVIDIAConfig
logger = get_logger(name=__name__, category="inference::nvidia")
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
return "integrate.api.nvidia.com" in config.url
async def _get_health(url: str) -> tuple[bool, bool]:
"""
Query {url}/v1/health/{live,ready} to check if the server is running and ready
Args:
url (str): URL of the server
Returns:
Tuple[bool, bool]: (is_live, is_ready)
"""
async with httpx.AsyncClient() as client:
live = await client.get(f"{url}/v1/health/live")
ready = await client.get(f"{url}/v1/health/ready")
return live.status_code == 200, ready.status_code == 200
async def check_health(config: NVIDIAConfig) -> None:
"""
Check if the server is running and ready
Args:
url (str): URL of the server
Raises:
RuntimeError: If the server is not running or ready
"""
if not _is_nvidia_hosted(config):
logger.info("Checking NVIDIA NIM health...")
try:
is_live, is_ready = await _get_health(config.url)
if not is_live:
raise ConnectionError("NVIDIA NIM is not running")
if not is_ready:
raise ConnectionError("NVIDIA NIM is not ready")
# TODO(mf): should we wait for the server to be ready?
except httpx.ConnectError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e

View file

@ -10,6 +10,6 @@ from .config import OllamaImplConfig
async def get_adapter_impl(config: OllamaImplConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config)
impl = OllamaInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
@ -14,11 +14,9 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:

View file

@ -6,58 +6,29 @@
import asyncio
from typing import Any
from ollama import AsyncClient as AsyncOllamaClient
from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
)
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.inference import (
ChatCompletionRequest,
GrammarResponseFormat,
InferenceProvider,
JsonSchemaResponseFormat,
Message,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
convert_image_content_to_url,
request_has_media,
)
logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
OpenAIMixin,
ModelRegistryHelper,
InferenceProvider,
ModelsProtocolPrivate,
):
class OllamaInferenceAdapter(OpenAIMixin):
config: OllamaImplConfig
# automatically set by the resolver when instantiating the provider
__provider_id__: str
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"all-minilm:l6-v2": {
"embedding_dimension": 384,
"context_length": 512,
@ -76,29 +47,8 @@ class OllamaInferenceAdapter(
},
}
def __init__(self, config: OllamaImplConfig) -> None:
# TODO: remove ModelRegistryHelper.__init__ when completion and
# chat_completion are. this exists to satisfy the input /
# output processing for llama models. specifically,
# tool_calling is handled by raw template processing,
# instead of using the /api/chat endpoint w/ tools=...
ModelRegistryHelper.__init__(
self,
model_entries=[
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
],
)
self.config = config
# Ollama does not support image urls, so we need to download the image and convert it to base64
self.download_images = True
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
download_images: bool = True
_clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property
def ollama_client(self) -> AsyncOllamaClient:
@ -109,7 +59,7 @@ class OllamaInferenceAdapter(
return self._clients[loop]
def get_api_key(self):
return "NO_KEY"
return "NO KEY REQUIRED"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"
@ -122,9 +72,6 @@ class OllamaInferenceAdapter(
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
)
async def should_refresh_models(self) -> bool:
return self.config.refresh_models
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Ollama server.
@ -142,50 +89,6 @@ class OllamaInferenceAdapter(
async def shutdown(self) -> None:
self._clients.clear()
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def _get_params(self, request: ChatCompletionRequest) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict: dict[str, Any] = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist]
else:
input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
llama_model,
)
if fmt := request.response_format:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["format"] = fmt.json_schema
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format is not supported")
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
params = {
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}
logger.debug(f"params to ollama: {params}")
return params
async def register_model(self, model: Model) -> Model:
if await self.check_model_availability(model.provider_model_id):
return model
@ -197,24 +100,3 @@ class OllamaInferenceAdapter(
return model
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {
"role": message.role,
"images": [await convert_image_content_to_url(content, download=True, include_format=False)],
}
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {
"role": message.role,
"content": text,
}
if isinstance(message.content, list):
return [await _convert_content(c) for c in message.content]
else:
return [await _convert_content(message.content)]

View file

@ -10,6 +10,6 @@ from .config import OpenAIConfig
async def get_adapter_impl(config: OpenAIConfig, _deps):
from .openai import OpenAIInferenceAdapter
impl = OpenAIInferenceAdapter(config)
impl = OpenAIInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -21,10 +21,6 @@ class OpenAIProviderDataValidator(BaseModel):
@json_schema_type
class OpenAIConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for OpenAI models",
)
base_url: str = Field(
default="https://api.openai.com/v1",
description="Base URL for OpenAI API",

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
@ -14,53 +13,22 @@ logger = get_logger(name=__name__, category="inference::openai")
#
# This OpenAI adapter implements Inference methods using two mixins -
# This OpenAI adapter implements Inference methods using OpenAIMixin
#
# | Inference Method | Implementation Source |
# |----------------------------|--------------------------|
# | completion | LiteLLMOpenAIMixin |
# | chat_completion | LiteLLMOpenAIMixin |
# | embedding | LiteLLMOpenAIMixin |
# | openai_completion | OpenAIMixin |
# | openai_chat_completion | OpenAIMixin |
# | openai_embeddings | OpenAIMixin |
#
class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class OpenAIInferenceAdapter(OpenAIMixin):
"""
OpenAI Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
embedding_model_metadata = {
config: OpenAIConfig
provider_data_api_key_field: str = "openai_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
}
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="openai",
api_key_from_config=config.api_key,
provider_data_api_key_field="openai_api_key",
)
self.config = config
# we set is_openai_compat so users can use the canonical
# openai model names like "gpt-4" or "gpt-3.5-turbo"
# and the model name will be translated to litellm's
# "openai/gpt-4" or "openai/gpt-3.5-turbo" transparently.
# if we do not set this, users will be exposed to the
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Get the OpenAI API base URL.
@ -68,9 +36,3 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Returns the OpenAI API base URL from the configuration.
"""
return self.config.base_url
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()

View file

@ -13,15 +13,15 @@ from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model
from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from .config import PassthroughImplConfig
@ -31,12 +31,6 @@ class PassthroughInferenceAdapter(Inference):
ModelRegistryHelper.__init__(self)
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
@ -76,120 +70,37 @@ class PassthroughInferenceAdapter(Inference):
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
params = params.model_copy()
params.model = model_obj.provider_resource_id
return await client.inference.openai_completion(**params)
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_completion(**request_params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = params.model_copy()
params.model = model_obj.provider_resource_id
return await client.inference.openai_chat_completion(**params)
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_chat_completion(**request_params)
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
json_params = {}

View file

@ -11,6 +11,6 @@ async def get_adapter_impl(config: RunpodImplConfig, _deps):
from .runpod import RunpodInferenceAdapter
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
impl = RunpodInferenceAdapter(config)
impl = RunpodInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -18,8 +18,9 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
default=None,
description="The URL for the Runpod model serving endpoint",
)
api_token: str | None = Field(
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The API token",
)

View file

@ -4,75 +4,39 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import RunpodImplConfig
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
RUNPOD_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
}
SAFETY_MODELS_ENTRIES = []
class RunpodInferenceAdapter(OpenAIMixin):
"""
Adapter for RunPod's OpenAI-compatible API endpoints.
Supports VLLM for serverless endpoint self-hosted or public endpoints.
Can work with any runpod endpoints that support OpenAI-compatible API
"""
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES
config: RunpodImplConfig
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return self.config.url
class RunpodInferenceAdapter(
ModelRegistryHelper,
Inference,
):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
self.config = config
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def openai_embeddings(
async def openai_chat_completion(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Override to add RunPod-specific stream_options requirement."""
params = params.model_copy()
if params.stream and not params.stream_options:
params.stream_options = {"include_usage": True}
return await super().openai_chat_completion(params)

View file

@ -11,6 +11,6 @@ async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
from .sambanova import SambaNovaInferenceAdapter
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
impl = SambaNovaInferenceAdapter(config)
impl = SambaNovaInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -25,10 +25,6 @@ class SambaNovaImplConfig(RemoteInferenceProviderConfig):
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: SecretStr | None = Field(
default=None,
description="The SambaNova cloud API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -5,40 +5,20 @@
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import SambaNovaImplConfig
class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class SambaNovaInferenceAdapter(OpenAIMixin):
config: SambaNovaImplConfig
provider_data_api_key_field: str = "sambanova_api_key"
download_images: bool = True # SambaNova does not support image downloads server-size, perform them on the client
"""
SambaNova Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of LiteLLMOpenAIMixin.check_model_availability().
- OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists
- LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM
"""
def __init__(self, config: SambaNovaImplConfig):
self.config = config
self.environment_available_models: list[str] = []
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="sambanova",
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
provider_data_api_key_field="sambanova_api_key",
openai_compat_api_base=self.config.url,
download_images=True, # SambaNova requires base64 image encoding
json_schema_strict=False, # SambaNova doesn't support strict=True yet
)
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.

View file

@ -13,6 +13,8 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class TGIImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = Field(
description="The URL for the TGI serving endpoint",
)

View file

@ -5,53 +5,24 @@
# the root directory of this source tree.
from collections.abc import Iterable
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
ResponseFormat,
ResponseFormatType,
SamplingParams,
)
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
)
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = get_logger(name=__name__, category="inference::tgi")
def build_hf_repo_model_entries():
return [
build_hf_repo_model_entry(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
class _HfAdapter(
OpenAIMixin,
Inference,
):
class _HfAdapter(OpenAIMixin):
url: str
api_key: SecretStr
@ -61,98 +32,18 @@ class _HfAdapter(
overwrite_completion_id = True # TGI always returns id=""
def __init__(self) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
}
def get_api_key(self):
return self.api_key.get_secret_value()
return "NO KEY REQUIRED"
def get_base_url(self):
return self.url
async def shutdown(self) -> None:
pass
async def list_models(self) -> list[Model] | None:
models = []
async for model in self.client.models.list():
models.append(
Model(
identifier=model.id,
provider_resource_id=model.id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
)
return models
async def register_model(self, model: Model) -> Model:
if model.provider_resource_id != self.model_id:
raise ValueError(
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
)
return model
async def unregister_model(self, model_id: str) -> None:
pass
def _get_max_new_tokens(self, sampling_params, input_tokens):
return min(
sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
def _build_options(
self,
sampling_params: SamplingParams | None = None,
fmt: ResponseFormat = None,
):
options = get_sampling_options(sampling_params)
# TGI does not support temperature=0 when using greedy sampling
# We set it to 1e-3 instead, anything lower outputs garbage from TGI
# We can use top_p sampling strategy to specify lower temperature
if abs(options["temperature"]) < 1e-10:
options["temperature"] = 1e-3
# delete key "max_tokens" from options since its not supported by the API
options.pop("max_tokens", None)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return options
async def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = await chat_completion_request_to_model_input_info(
request, self.register_helper.get_llama_model(request.model)
)
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=self._get_max_new_tokens(request.sampling_params, input_tokens),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**self._build_options(request.sampling_params, request.response_format),
)
async def list_provider_model_ids(self) -> Iterable[str]:
return [self.model_id]
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: TogetherImplConfig, _deps):
from .together import TogetherInferenceAdapter
assert isinstance(config, TogetherImplConfig), f"Unexpected config type: {type(config)}"
impl = TogetherInferenceAdapter(config)
impl = TogetherInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -18,10 +18,6 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
)
api_key: SecretStr | None = Field(
default=None,
description="The Together AI API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:

View file

@ -5,41 +5,30 @@
# the root directory of this source tree.
from openai import AsyncOpenAI
from collections.abc import Iterable
from together import AsyncTogether
from together.constants import BASE_URL
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
LogProbConfig,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
ResponseFormat,
ResponseFormatType,
SamplingParams,
)
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.models import Model
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
request_has_media,
)
from .config import TogetherImplConfig
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData):
embedding_model_metadata = {
class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
config: TogetherImplConfig
embedding_model_metadata: dict[str, dict[str, int]] = {
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
"BAAI/bge-base-en-v1.5": {"embedding_dimension": 768, "context_length": 512},
@ -47,27 +36,16 @@ class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData)
"intfloat/multilingual-e5-large-instruct": {"embedding_dimension": 1024, "context_length": 512},
}
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config
self.allowed_models = config.allowed_models
self._model_cache: dict[str, Model] = {}
_model_cache: dict[str, Model] = {}
def get_api_key(self):
return self.config.api_key.get_secret_value()
provider_data_api_key_field: str = "together_api_key"
def get_base_url(self):
return BASE_URL
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def _get_client(self) -> AsyncTogether:
together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None
if config_api_key:
together_api_key = config_api_key
else:
@ -79,97 +57,13 @@ class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData)
together_api_key = provider_data.together_api_key
return AsyncTogether(api_key=together_api_key)
def _get_openai_client(self) -> AsyncOpenAI:
together_client = self._get_client().client
return AsyncOpenAI(
base_url=together_client.base_url,
api_key=together_client.api_key,
)
def _build_options(
self,
sampling_params: SamplingParams | None,
logprobs: LogProbConfig | None,
fmt: ResponseFormat,
) -> dict:
options = get_sampling_options(sampling_params)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
if logprobs.top_k != 1:
raise ValueError(
f"Unsupported value: Together only supports logprobs top_k=1. {logprobs.top_k} was provided",
)
options["logprobs"] = 1
return options
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
params = {
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
}
logger.debug(f"params to together: {params}")
return params
async def list_models(self) -> list[Model] | None:
self._model_cache = {}
async def list_provider_model_ids(self) -> Iterable[str]:
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
for m in await self._get_client().models.list():
if m.type == "embedding":
if m.id not in self.embedding_model_metadata:
logger.warning(f"Unknown embedding dimension for model {m.id}, skipping.")
continue
metadata = self.embedding_model_metadata[m.id]
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.embedding,
metadata=metadata,
)
else:
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
return self._model_cache.values()
async def should_refresh_models(self) -> bool:
return True
async def check_model_availability(self, model):
return model in self._model_cache
return [m.id for m in await self._get_client().models.list()]
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
Together's OpenAI-compatible embeddings endpoint is not compatible with
@ -181,26 +75,28 @@ class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData)
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
"""
# Together support ticket #13332 -> will not fix
if user is not None:
if params.user is not None:
raise ValueError("Together's embeddings endpoint does not support user param.")
# Together support ticket #13333 -> escalated
if dimensions is not None:
if params.dimensions is not None:
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format,
model=await self._get_provider_model_id(params.model),
input=params.input,
encoding_format=params.encoding_format,
)
response.model = model # return the user the same model id they provided, avoid exposing the provider model id
response.model = (
params.model
) # return the user the same model id they provided, avoid exposing the provider model id
# Together support ticket #13330 -> escalated
# - togethercomputer/m2-bert-80M-32k-retrieval *does not* return usage information
if not hasattr(response, "usage") or response.usage is None:
logger.warning(
f"Together's embedding endpoint for {model} did not return usage information, substituting -1s."
f"Together's embedding endpoint for {params.model} did not return usage information, substituting -1s."
)
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return response
return response # type: ignore[no-any-return]

View file

@ -10,6 +10,6 @@ from .config import VertexAIConfig
async def get_adapter_impl(config: VertexAIConfig, _deps):
from .vertexai import VertexAIInferenceAdapter
impl = VertexAIInferenceAdapter(config)
impl = VertexAIInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -25,6 +25,8 @@ class VertexAIProviderDataValidator(BaseModel):
@json_schema_type
class VertexAIConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
project: str = Field(
description="Google Cloud project ID for Vertex AI",
)

View file

@ -4,29 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
import google.auth.transport.requests
from google.auth import default
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import VertexAIConfig
class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: VertexAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="vertex_ai",
api_key_from_config=None, # Vertex AI uses ADC, not API keys
provider_data_api_key_field="vertex_project", # Use project for validation
)
self.config = config
class VertexAIInferenceAdapter(OpenAIMixin):
config: VertexAIConfig
provider_data_api_key_field: str = "vertex_project"
def get_api_key(self) -> str:
"""
@ -41,8 +31,7 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
credentials.refresh(google.auth.transport.requests.Request())
return str(credentials.token)
except Exception:
# If we can't get credentials, return empty string to let LiteLLM handle it
# This allows the LiteLLM mixin to work with ADC directly
# If we can't get credentials, return empty string to let the env work with ADC directly
return ""
def get_base_url(self) -> str:
@ -53,23 +42,3 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
"""
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add Vertex AI specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "vertex_project", None):
params["vertex_project"] = provider_data.vertex_project
if getattr(provider_data, "vertex_location", None):
params["vertex_location"] = provider_data.vertex_location
else:
params["vertex_project"] = self.config.project
params["vertex_location"] = self.config.location
# Remove api_key since Vertex AI uses ADC
params.pop("api_key", None)
return params

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
from .vllm import VLLMInferenceAdapter
assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}"
impl = VLLMInferenceAdapter(config)
impl = VLLMInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,7 +6,7 @@
from pathlib import Path
from pydantic import Field, field_validator
from pydantic import Field, SecretStr, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -22,18 +22,15 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
default=4096,
description="Maximum number of tokens to generate.",
)
api_token: str | None = Field(
default="fake",
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The API token",
)
tls_verify: bool | str = Field(
default=True,
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
)
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)
@field_validator("tls_verify")
@classmethod

View file

@ -3,56 +3,24 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from collections.abc import AsyncIterator
from urllib.parse import urljoin
import httpx
from openai import APIConnectionError
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from pydantic import ConfigDict
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
ModelStore,
OpenAIChatCompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAIChatCompletionRequestWithExtraBody,
ToolChoice,
ToolDefinition,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
UnparseableToolCall,
convert_message_to_openai_dict,
convert_tool_call,
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -61,210 +29,17 @@ from .config import VLLMInferenceAdapterConfig
log = get_logger(name=__name__, category="inference::vllm")
def build_hf_repo_model_entries():
return [
build_hf_repo_model_entry(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
class VLLMInferenceAdapter(OpenAIMixin):
config: VLLMInferenceAdapterConfig
model_config = ConfigDict(arbitrary_types_allowed=True)
def _convert_to_vllm_tool_calls_in_response(
tool_calls,
) -> list[ToolCall]:
if not tool_calls:
return []
provider_data_api_key_field: str = "vllm_api_token"
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call.function.arguments,
)
for call in tool_calls
]
def _convert_to_vllm_tools_in_request(tools: list[ToolDefinition]) -> list[dict]:
compat_tools = []
for tool in tools:
# The tool.tool_name can be a str or a BuiltinTool enum. If
# it's the latter, convert to a string.
tool_name = tool.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
compat_tool = {
"type": "function",
"function": {
"name": tool_name,
"description": tool.description,
"parameters": tool.input_schema
or {
"type": "object",
"properties": {},
"required": [],
},
},
}
compat_tools.append(compat_tool)
return compat_tools
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def _process_vllm_chat_completion_end_of_stream(
finish_reason: str | None,
last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType,
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
) -> list[OpenAIChatCompletionChunk]:
chunks = []
if finish_reason is not None:
stop_reason = _convert_to_vllm_finish_reason(finish_reason)
else:
stop_reason = StopReason.end_of_message
tool_call_bufs = tool_call_bufs or {}
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
args_str = tool_call_buf.arguments or "{}"
try:
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=current_event_type,
delta=ToolCallDelta(
tool_call=ToolCall(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args_str,
),
parse_status=ToolCallParseStatus.succeeded,
),
)
)
)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=str(tool_call_buf),
parse_status=ToolCallParseStatus.failed,
),
)
)
)
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=last_chunk_content or ""),
logprobs=None,
stop_reason=stop_reason,
)
)
)
return chunks
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
event_type = ChatCompletionResponseEventType.progress
tool_call_bufs: dict[str, UnparseableToolCall] = {}
end_of_stream_processed = False
async for chunk in stream:
if not chunk.choices:
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
return
choice = chunk.choices[0]
if choice.delta.tool_calls:
for delta_tool_call in choice.delta.tool_calls:
tool_call = convert_tool_call(delta_tool_call)
if delta_tool_call.index not in tool_call_bufs:
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
tool_call_buf = tool_call_bufs[delta_tool_call.index]
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += (
tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
)
if choice.finish_reason:
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=choice.finish_reason,
last_chunk_content=choice.delta.content,
current_event_type=event_type,
tool_call_bufs=tool_call_bufs,
)
for c in chunks:
yield c
end_of_stream_processed = True
elif not choice.delta.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
)
)
event_type = ChatCompletionResponseEventType.progress
if end_of_stream_processed:
return
# the stream ended without a chunk containing finish_reason - we have to generate the
# respective completion chunks manually
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs
)
for c in chunks:
yield c
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
model_entries=build_hf_repo_model_entries(),
litellm_provider_name="vllm",
api_key_from_config=config.api_token,
provider_data_api_key_field="vllm_api_token",
openai_compat_api_base=config.url,
)
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str | None:
if self.config.auth_credential:
return self.config.auth_credential.get_secret_value()
return "NO KEY REQUIRED"
def get_base_url(self) -> str:
"""Get the base URL from config."""
@ -278,31 +53,6 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
)
async def should_refresh_models(self) -> bool:
# Strictly respecting the refresh_models directive
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=self.__provider_id__,
metadata={},
model_type=model_type,
)
)
return models
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the remote vLLM server.
@ -324,120 +74,38 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def register_model(self, model: Model) -> Model:
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
try:
res = self.client.models.list()
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
) from e
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM. "
f"Available models: {', '.join(available_models)}"
)
return model
async def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
input_dict: dict[str, Any] = {}
# Only include the 'tools' param if there is any. It can break things if an empty list is sent to the vLLM.
if isinstance(request, ChatCompletionRequest) and request.tools:
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
if fmt := request.response_format:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["extra_body"] = {"guided_json": fmt.json_schema}
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if request.logprobs and request.logprobs.top_k:
input_dict["logprobs"] = request.logprobs.top_k
return {
"model": request.model,
**input_dict,
"stream": request.stream,
**options,
}
async def check_model_availability(self, model: str) -> bool:
"""
Skip the check when running without authentication.
"""
if not self.config.auth_credential:
model_ids = []
async for m in self.client.models.list():
if m.id == model: # Found exact match
return True
model_ids.append(m.id)
raise ValueError(f"Model '{model}' not found. Available models: {model_ids}")
log.warning(f"Not checking model availability for {model} as API token may trigger OAuth workflow")
return True
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
max_tokens = max_tokens or self.config.max_tokens
params = params.model_copy()
# Apply vLLM-specific defaults
if params.max_tokens is None and self.config.max_tokens:
params.max_tokens = self.config.max_tokens
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# * https://github.com/vllm-project/vllm/pull/10000
if not tools and tool_choice is not None:
tool_choice = ToolChoice.none.value
if not params.tools and params.tool_choice is not None:
params.tool_choice = ToolChoice.none.value
return await super().openai_chat_completion(
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await super().openai_chat_completion(params)

View file

@ -4,19 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import WatsonXConfig
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
async def get_adapter_impl(config: WatsonXConfig, _deps):
# import dynamically so the import is used only when it is needed
from .watsonx import WatsonXInferenceAdapter
if not isinstance(config, WatsonXConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = WatsonXInferenceAdapter(config)
return adapter
__all__ = ["get_adapter_impl", "WatsonXConfig"]

View file

@ -7,16 +7,18 @@
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class WatsonXProviderDataValidator(BaseModel):
url: str
api_key: str
project_id: str
model_config = ConfigDict(
from_attributes=True,
extra="forbid",
)
watsonx_api_key: str | None
@json_schema_type
@ -25,13 +27,9 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
api_key: SecretStr | None = Field(
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
description="The watsonx API key",
)
project_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key",
default=None,
description="The watsonx.ai project ID",
)
timeout: int = Field(
default=60,

View file

@ -1,47 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/llama-3-3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-2-13b-chat",
CoreModelId.llama2_13b.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]

View file

@ -4,246 +4,120 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from ibm_watsonx_ai.foundation_models import Model
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
import requests
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
GreedySamplingStrategy,
Inference,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
request_has_media,
)
from . import WatsonXConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::watsonx")
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
# Note on structured output
# WatsonX returns responses with a json embedded into a string.
# Examples:
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
_model_cache: dict[str, Model] = {}
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n
# "first_name": "Michael",\n "last_name": "Jordan",\n'...)
# Not even a valid JSON, but we can still extract the JSON from the content
def __init__(self, config: WatsonXConfig):
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="watsonx",
api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None,
provider_data_api_key_field="watsonx_api_key",
)
self.available_models = None
self.config = config
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan",
# "year_born": "1963", "year_retired": "2003"\\}}$')
# Find the start of the boxed content
def get_base_url(self) -> str:
return self.config.url
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: WatsonXConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
self._config = config
self._openai_client: AsyncOpenAI | None = None
self._project_id = self._config.project_id
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def _get_client(self, model_id) -> Model:
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_url = self._config.url
project_id = self._config.project_id
credentials = {"url": config_url, "apikey": config_api_key}
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
self._openai_client = AsyncOpenAI(
base_url=f"{self._config.url}/openai/v1",
api_key=self._config.api_key,
)
return self._openai_client
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
input_dict = {"params": {}}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params:
if request.sampling_params.strategy:
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens:
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = 0.0
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
params = {
**input_dict,
}
# Add watsonx.ai specific parameters
params["project_id"] = self.config.project_id
params["time_limit"] = self.config.timeout
return params
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
# Copied from OpenAIMixin
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the provider's /v1/models.
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params) # type: ignore
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
if not self._model_cache:
await self.list_models()
return model in self._model_cache
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
async def list_models(self) -> list[Model] | None:
self._model_cache = {}
models = []
for model_spec in self._get_model_specs():
functions = [f["id"] for f in model_spec.get("functions", [])]
# Format: {"embedding_dimension": 1536, "context_length": 8192}
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# watsonx.ai sometimes adds usage data to the stream
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
# Example of an embedding model:
# {'model_id': 'ibm/granite-embedding-278m-multilingual',
# 'label': 'granite-embedding-278m-multilingual',
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
# ...
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
if "embedding" in functions:
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
context_length = model_spec["model_limits"]["max_sequence_length"]
embedding_metadata = {
"embedding_dimension": embedding_dimension,
"context_length": context_length,
}
model = Model(
identifier=model_spec["model_id"],
provider_resource_id=provider_resource_id,
provider_id=self.__provider_id__,
metadata=embedding_metadata,
model_type=ModelType.embedding,
)
self._model_cache[provider_resource_id] = model
models.append(model)
if "text_chat" in functions:
model = Model(
identifier=model_spec["model_id"],
provider_resource_id=provider_resource_id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
# In that case, the cache will record the generator Model object, and the list which we return will have
# both the generator Model object and the text chat Model object. That's fine because the cache is
# only used for check_model_availability() anyway.
self._model_cache[provider_resource_id] = model
models.append(model)
return models
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
# So we need to implement our own method to list models by calling the watsonx.ai API.
def _get_model_specs(self) -> list[dict[str, Any]]:
"""
Retrieves foundation model specifications from the watsonx.ai API.
"""
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
headers = {
# Note that there is no authorization header. Listing models does not require authentication.
"Content-Type": "application/json",
}
response = requests.get(url, headers=headers)
# --- Process the Response ---
# Raise an exception for bad status codes (4xx or 5xx)
response.raise_for_status()
# If the request is successful, parse and return the JSON response.
# The response should contain a list of model specifications
response_data = response.json()
if "resources" not in response_data:
raise ValueError("Resources not found in response")
return response_data["resources"]

View file

@ -7,7 +7,7 @@
import json
from typing import Any
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -56,7 +56,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:

View file

@ -8,12 +8,11 @@ from typing import Any
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import NVIDIASafetyConfig
@ -44,7 +43,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
) -> RunShieldResponse:
"""
Run a safety shield check against the provided messages.
@ -118,7 +117,7 @@ class NeMoGuardrails:
response.raise_for_status()
return response.json()
async def run(self, messages: list[Message]) -> RunShieldResponse:
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
@ -132,10 +131,9 @@ class NeMoGuardrails:
Raises:
requests.HTTPError: If the POST request fails.
"""
request_messages = [await convert_message_to_openai_dict_new(message) for message in messages]
request_data = {
"model": self.model,
"messages": request_messages,
"messages": [{"role": message.role, "content": message.content} for message in messages],
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,

View file

@ -4,13 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any
import litellm
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -21,7 +20,6 @@ from llama_stack.apis.shields import Shield
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import SambaNovaSafetyConfig
@ -72,7 +70,7 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
@ -80,12 +78,8 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = litellm.completion(
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
)
response = litellm.completion(model=shield.provider_resource_id, messages=messages, api_key=self._get_api_key())
shield_message = response.choices[0].message.content
if "unsafe" in shield_message.lower():

View file

@ -140,14 +140,13 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
inference_api: Api.inference,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.client = None
self.cache = {}
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.files_api = files_api
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.persistence)
@ -168,7 +167,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
pass
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,

View file

@ -309,14 +309,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
inference_api: Inference,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.cache = {}
self.client = None
self.inference_api = inference_api
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
@ -351,6 +349,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def shutdown(self) -> None:
self.client.close()
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,

View file

@ -345,14 +345,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
inference_api: Api.inference,
files_api: Files | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.conn = None
self.cache = {}
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
@ -392,6 +390,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
if self.conn is not None:
self.conn.close()
log.info("Connection to PGVector database server closed")
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(self, vector_db: VectorDB) -> None:
# Persist vector DB metadata in the KV store

View file

@ -27,7 +27,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
@ -162,14 +162,12 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
inference_api: Api.inference,
files_api: Files | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.client: AsyncQdrantClient = None
self.cache = {}
self.inference_api = inference_api
self.files_api = files_api
self.vector_db_store = None
self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self._qdrant_lock = asyncio.Lock()
async def initialize(self) -> None:
@ -193,6 +191,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def shutdown(self) -> None:
await self.client.close()
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,

View file

@ -284,14 +284,12 @@ class WeaviateVectorIOAdapter(
inference_api: Api.inference,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.client_cache = {}
self.cache = {}
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata"
def _get_client(self) -> weaviate.WeaviateClient:
@ -349,6 +347,8 @@ class WeaviateVectorIOAdapter(
async def shutdown(self) -> None:
for client in self.client_cache.values():
client.close()
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,

View file

@ -12,6 +12,7 @@ from llama_stack.providers.utils.inference.model_registry import RemoteInference
class BedrockBaseConfig(RemoteInferenceProviderConfig):
auth_credential: None = Field(default=None, exclude=True)
aws_access_key_id: str | None = Field(
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",

View file

@ -17,6 +17,7 @@ if TYPE_CHECKING:
from llama_stack.apis.inference import (
ModelStore,
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
@ -32,26 +33,22 @@ class SentenceTransformerEmbeddingMixin:
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
# Convert input to list format if it's a single string
input_list = [input] if isinstance(input, str) else input
input_list = [params.input] if isinstance(params.input, str) else params.input
if not input_list:
raise ValueError("Empty list not supported")
# Get the model and generate embeddings
model_obj = await self.model_store.get_model(model)
model_obj = await self.model_store.get_model(params.model)
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
# Convert embeddings to the requested format
data = []
for i, embedding in enumerate(embeddings):
if encoding_format == "base64":
if params.encoding_format == "base64":
# Convert float array to base64 string
float_bytes = struct.pack(f"{len(embedding)}f", *embedding)
embedding_value = base64.b64encode(float_bytes).decode("ascii")
@ -70,7 +67,7 @@ class SentenceTransformerEmbeddingMixin:
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return OpenAIEmbeddingsResponse(
data=data,
model=model,
model=params.model,
usage=usage,
)

View file

@ -4,8 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import struct
from collections.abc import AsyncIterator
from typing import Any
import litellm
@ -15,18 +16,19 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ToolChoice,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack.providers.utils.inference.openai_compat import (
b64_encode_openai_embeddings_response,
convert_message_to_openai_dict_new,
convert_tooldef_to_openai_tool,
get_sampling_options,
@ -188,16 +190,12 @@ class LiteLLMOpenAIMixin(
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
model_obj = await self.model_store.get_model(model)
model_obj = await self.model_store.get_model(params.model)
# Convert input to list if it's a string
input_list = [input] if isinstance(input, str) else input
input_list = [params.input] if isinstance(params.input, str) else params.input
# Call litellm embedding function
# litellm.drop_params = True
@ -206,11 +204,11 @@ class LiteLLMOpenAIMixin(
input=input_list,
api_key=self.get_api_key(),
api_base=self.api_base,
dimensions=dimensions,
dimensions=params.dimensions,
)
# Convert response to OpenAI format
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
data = b64_encode_openai_embeddings_response(response.data, params.encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response["usage"]["prompt_tokens"],
@ -225,116 +223,78 @@ class LiteLLMOpenAIMixin(
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
frequency_penalty=params.frequency_penalty,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_tokens=params.max_tokens,
n=params.n,
presence_penalty=params.presence_penalty,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
suffix=params.suffix,
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.atext_completion(**params)
return await litellm.atext_completion(**request_params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
# Add usage tracking for streaming when telemetry is active
from llama_stack.providers.utils.telemetry.tracing import get_current_span
if stream and get_current_span() is not None:
stream_options = params.stream_options
if params.stream and get_current_span() is not None:
if stream_options is None:
stream_options = {"include_usage": True}
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
messages=params.messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
functions=params.functions,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_completion_tokens=params.max_completion_tokens,
max_tokens=params.max_tokens,
n=params.n,
parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=params.presence_penalty,
response_format=params.response_format,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,
top_logprobs=params.top_logprobs,
top_p=params.top_p,
user=params.user,
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.acompletion(**params)
return await litellm.acompletion(**request_params)
async def check_model_availability(self, model: str) -> bool:
"""
@ -349,3 +309,28 @@ class LiteLLMOpenAIMixin(
return False
return model in litellm.models_by_provider[self.litellm_provider_name]
def b64_encode_openai_embeddings_response(
response_data: list[dict], encoding_format: str | None = "float"
) -> list[OpenAIEmbeddingData]:
"""
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
"""
data = []
for i, embedding_data in enumerate(response_data):
if encoding_format == "base64":
byte_array = bytearray()
for embedding_value in embedding_data["embedding"]:
byte_array.extend(struct.pack("f", float(embedding_value)))
response_embedding = base64.b64encode(byte_array).decode("utf-8")
else:
response_embedding = embedding_data["embedding"]
data.append(
OpenAIEmbeddingData(
embedding=response_embedding,
index=i,
)
)
return data

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.models import ModelType
@ -24,6 +24,15 @@ class RemoteInferenceProviderConfig(BaseModel):
default=None,
description="List of models that should be registered with the model registry. If None, all models are allowed.",
)
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically from the provider",
)
auth_credential: SecretStr | None = Field(
default=None,
description="Authentication credential for the provider",
alias="api_key",
)
# TODO: this class is more confusing than useful right now. We need to make it

View file

@ -3,9 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import json
import struct
import time
import uuid
import warnings
@ -103,7 +101,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
Message,
OpenAIChatCompletion,
OpenAIEmbeddingData,
OpenAIMessageParam,
OpenAIResponseFormatParam,
SamplingParams,
@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params(
params["user"] = user
return params
def b64_encode_openai_embeddings_response(
response_data: dict, encoding_format: str | None = "float"
) -> list[OpenAIEmbeddingData]:
"""
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
"""
data = []
for i, embedding_data in enumerate(response_data):
if encoding_format == "base64":
byte_array = bytearray()
for embedding_value in embedding_data.embedding:
byte_array.extend(struct.pack("f", float(embedding_value)))
response_embedding = base64.b64encode(byte_array).decode("utf-8")
else:
response_embedding = embedding_data.embedding
data.append(
OpenAIEmbeddingData(
embedding=response_embedding,
index=i,
)
)
return data

View file

@ -7,47 +7,62 @@
import base64
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Iterable
from typing import Any
from openai import NOT_GIVEN, AsyncOpenAI
from pydantic import BaseModel, ConfigDict
from llama_stack.apis.inference import (
Model,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
This is an abstract base class that requires child classes to implement:
- get_api_key(): Method to retrieve the API key
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
The behavior of this class can be customized by child classes in the following ways:
- overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses
- download_images: If True, downloads images and converts to base64 for providers that require it
- embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata
- provider_data_api_key_field: Optional field name in provider data to look for API key
- list_provider_model_ids: Method to list available models from the provider
- get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client
Expected Dependencies:
- self.model_store: Injected by the Llama Stack distribution system at runtime.
This provides model registry functionality for looking up registered models.
The model_store is set in routing_tables/common.py during provider initialization.
"""
# Allow extra fields so the routing infra can inject model_store, __provider_id__, etc.
model_config = ConfigDict(extra="allow")
config: RemoteInferenceProviderConfig
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
# is overwritten with a client-side generated id.
#
@ -73,20 +88,15 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
# Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None
# automatically set by the resolver when instantiating the provider
__provider_id__: str
@abstractmethod
def get_api_key(self) -> str:
def get_api_key(self) -> str | None:
"""
Get the API key.
This method must be implemented by child classes to provide the API key
for authenticating with the OpenAI API or compatible endpoints.
:return: The API key as a string
:return: The API key as a string, or None if not set
"""
pass
if self.config.auth_credential is None:
return None
return self.config.auth_credential.get_secret_value()
@abstractmethod
def get_base_url(self) -> str:
@ -111,6 +121,41 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
"""
return {}
async def list_provider_model_ids(self) -> Iterable[str]:
"""
List available models from the provider.
Child classes can override this method to provide a custom implementation
for listing models. The default implementation uses the AsyncOpenAI client
to list models from the OpenAI-compatible endpoint.
:return: An iterable of model IDs or None if not implemented
"""
client = self.client
async with client:
model_ids = [m.id async for m in client.models.list()]
return model_ids
async def initialize(self) -> None:
"""
Initialize the OpenAI mixin.
This method provides a default implementation that does nothing.
Subclasses can override this method to perform initialization tasks
such as setting up clients, validating configurations, etc.
"""
pass
async def shutdown(self) -> None:
"""
Shutdown the OpenAI mixin.
This method provides a default implementation that does nothing.
Subclasses can override this method to perform cleanup tasks
such as closing connections, releasing resources, etc.
"""
pass
@property
def client(self) -> AsyncOpenAI:
"""
@ -130,13 +175,11 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
api_key = getattr(provider_data, self.provider_data_api_key_field)
if not api_key: # TODO: let get_api_key return None
raise ValueError(
"API key is not set. Please provide a valid API key in the "
"provider data header, e.g. x-llamastack-provider-data: "
f'{{"{self.provider_data_api_key_field}": "<API_KEY>"}}, '
"or in the provider config."
)
if not api_key:
message = "API key not provided."
if self.provider_data_api_key_field:
message += f' Please provide a valid API key in the provider data header, e.g. x-llamastack-provider-data: {{"{self.provider_data_api_key_field}": "<API_KEY>"}}.'
raise ValueError(message)
return AsyncOpenAI(
api_key=api_key,
@ -181,96 +224,47 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
"""
Direct OpenAI completion API call.
"""
# Handle parameters that are not supported by OpenAI API, but may be by the provider
# prompt_logprobs is supported by vLLM
# guided_choice is supported by vLLM
# TODO: test coverage
extra_body: dict[str, Any] = {}
if prompt_logprobs is not None and prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = prompt_logprobs
if guided_choice:
extra_body["guided_choice"] = guided_choice
# TODO: fix openai_completion to return type compatible with OpenAI's API response
resp = await self.client.completions.create(
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
),
extra_body=extra_body,
completion_kwargs = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model),
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
frequency_penalty=params.frequency_penalty,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_tokens=params.max_tokens,
n=params.n,
presence_penalty=params.presence_penalty,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
suffix=params.suffix,
)
if extra_body := params.model_extra:
completion_kwargs["extra_body"] = extra_body
resp = await self.client.completions.create(**completion_kwargs)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""
Direct OpenAI chat completion API call.
"""
messages = params.messages
if self.download_images:
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
@ -289,55 +283,61 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
messages = [await _localize_image_url(m) for m in messages]
params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
request_params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
functions=params.functions,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_completion_tokens=params.max_completion_tokens,
max_tokens=params.max_tokens,
n=params.n,
parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=params.presence_penalty,
response_format=params.response_format,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,
top_logprobs=params.top_logprobs,
top_p=params.top_p,
user=params.user,
)
resp = await self.client.chat.completions.create(**params)
if extra_body := params.model_extra:
request_params["extra_body"] = extra_body
resp = await self.client.chat.completions.create(**request_params)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
Direct OpenAI embeddings API call.
"""
# Prepare request parameters
request_params = {
"model": await self._get_provider_model_id(params.model),
"input": params.input,
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
"user": params.user if params.user is not None else NOT_GIVEN,
}
# Add extra_body if present
extra_body = params.model_extra
if extra_body:
request_params["extra_body"] = extra_body
# Call OpenAI embeddings API with properly typed parameters
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
user=user if user is not None else NOT_GIVEN,
)
response = await self.client.embeddings.create(**request_params)
data = []
for i, embedding_data in enumerate(response.data):
@ -355,7 +355,7 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
return OpenAIEmbeddingsResponse(
data=data,
model=model,
model=params.model,
usage=usage,
)
@ -371,7 +371,7 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
async def register_model(self, model: Model) -> Model:
if not await self.check_model_availability(model.provider_model_id):
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}")
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}") # type: ignore[attr-defined]
return model
async def unregister_model(self, model_id: str) -> None:
@ -387,41 +387,87 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
"""
self._model_cache = {}
async for m in self.client.models.list():
if self.allowed_models and m.id not in self.allowed_models:
logger.info(f"Skipping model {m.id} as it is not in the allowed models list")
try:
iterable = await self.list_provider_model_ids()
except Exception as e:
logger.error(f"{self.__class__.__name__}.list_provider_model_ids() failed with: {e}")
raise
if not hasattr(iterable, "__iter__"):
raise TypeError(
f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of "
f"strings, but returned {type(iterable).__name__}"
)
provider_models_ids = list(iterable)
logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models")
for provider_model_id in provider_models_ids:
if not isinstance(provider_model_id, str):
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
if self.allowed_models and provider_model_id not in self.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue
if metadata := self.embedding_model_metadata.get(m.id):
# This is an embedding model - augment with metadata
if metadata := self.embedding_model_metadata.get(provider_model_id):
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
provider_resource_id=provider_model_id,
identifier=provider_model_id,
model_type=ModelType.embedding,
metadata=metadata,
)
else:
# This is an LLM
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
provider_resource_id=provider_model_id,
identifier=provider_model_id,
model_type=ModelType.llm,
)
self._model_cache[m.id] = model
self._model_cache[provider_model_id] = model
return list(self._model_cache.values())
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the provider's /v1/models.
Check if a specific model is available from the provider's /v1/models or pre-registered.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
:return: True if the model is available dynamically or pre-registered, False otherwise.
"""
# First check if the model is pre-registered in the model store
if hasattr(self, "model_store") and self.model_store:
if await self.model_store.has_model(model):
return True
# Then check the provider's dynamic model cache
if not self._model_cache:
await self.list_models()
return model in self._model_cache
async def should_refresh_models(self) -> bool:
return False
return self.config.refresh_models
#
# The model_dump implementations are to avoid serializing the extra fields,
# e.g. model_store, which are not pydantic.
#
def _filter_fields(self, **kwargs):
"""Helper to exclude extra fields from serialization."""
# Exclude any extra fields stored in __pydantic_extra__
if hasattr(self, "__pydantic_extra__") and self.__pydantic_extra__:
exclude = kwargs.get("exclude", set())
if not isinstance(exclude, set):
exclude = set(exclude) if exclude else set()
exclude.update(self.__pydantic_extra__.keys())
kwargs["exclude"] = exclude
return kwargs
def model_dump(self, **kwargs):
"""Override to exclude extra fields from serialization."""
kwargs = self._filter_fields(**kwargs)
return super().model_dump(**kwargs)
def model_dump_json(self, **kwargs):
"""Override to exclude extra fields from JSON serialization."""
kwargs = self._filter_fields(**kwargs)
return super().model_dump_json(**kwargs)

Some files were not shown because too many files have changed in this diff Show more